quiche/recovery/gcongestion/bbr/
bandwidth_sampler.rs

1// Copyright (c) 2016 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5// Copyright (C) 2023, Cloudflare, Inc.
6// All rights reserved.
7//
8// Redistribution and use in source and binary forms, with or without
9// modification, are permitted provided that the following conditions are
10// met:
11//
12//     * Redistributions of source code must retain the above copyright notice,
13//       this list of conditions and the following disclaimer.
14//
15//     * Redistributions in binary form must reproduce the above copyright
16//       notice, this list of conditions and the following disclaimer in the
17//       documentation and/or other materials provided with the distribution.
18//
19// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
20// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
21// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
22// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
23// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31use std::collections::VecDeque;
32use std::time::Duration;
33use std::time::Instant;
34
35use super::Acked;
36use crate::recovery::gcongestion::bandwidth::Bandwidth;
37use crate::recovery::gcongestion::Lost;
38
39use super::windowed_filter::WindowedFilter;
40
41#[derive(Debug)]
42struct ConnectionStateMap<T> {
43    packet_map: VecDeque<(u64, Option<T>)>,
44}
45
46impl<T> Default for ConnectionStateMap<T> {
47    fn default() -> Self {
48        ConnectionStateMap {
49            packet_map: VecDeque::new(),
50        }
51    }
52}
53
54impl<T> ConnectionStateMap<T> {
55    fn insert(&mut self, pkt_num: u64, val: T) {
56        if let Some((last_pkt, _)) = self.packet_map.back() {
57            assert!(pkt_num > *last_pkt, "{} > {}", pkt_num, *last_pkt);
58        }
59
60        self.packet_map.push_back((pkt_num, Some(val)));
61    }
62
63    fn take(&mut self, pkt_num: u64) -> Option<T> {
64        // First we check if the next packet is the one we are looking for
65        let first = self.packet_map.front()?;
66        if first.0 == pkt_num {
67            return self.packet_map.pop_front().and_then(|(_, v)| v);
68        }
69        // Use binary search
70        let ret =
71            match self.packet_map.binary_search_by_key(&pkt_num, |&(n, _)| n) {
72                Ok(found) =>
73                    self.packet_map.get_mut(found).and_then(|(_, v)| v.take()),
74                Err(_) => None,
75            };
76
77        while let Some((_, None)) = self.packet_map.front() {
78            self.packet_map.pop_front();
79        }
80
81        ret
82    }
83
84    #[cfg(test)]
85    fn peek(&self, pkt_num: u64) -> Option<&T> {
86        // Use binary search
87        match self.packet_map.binary_search_by_key(&pkt_num, |&(n, _)| n) {
88            Ok(found) => self.packet_map.get(found).and_then(|(_, v)| v.as_ref()),
89            Err(_) => None,
90        }
91    }
92
93    fn remove_obsolete(&mut self, least_acked: u64) {
94        while match self.packet_map.front() {
95            Some(&(p, _)) if p < least_acked => {
96                self.packet_map.pop_front();
97                true
98            },
99            _ => false,
100        } {}
101    }
102}
103
104#[derive(Debug)]
105pub struct BandwidthSampler {
106    /// The total number of congestion controlled bytes sent during the
107    /// connection.
108    total_bytes_sent: usize,
109    total_bytes_acked: usize,
110    total_bytes_lost: usize,
111    total_bytes_neutered: usize,
112    last_sent_packet: u64,
113    last_acked_packet: u64,
114    is_app_limited: bool,
115    last_acked_packet_ack_time: Instant,
116    total_bytes_sent_at_last_acked_packet: usize,
117    last_acked_packet_sent_time: Instant,
118    recent_ack_points: RecentAckPoints,
119    a0_candidates: VecDeque<AckPoint>,
120    connection_state_map: ConnectionStateMap<ConnectionStateOnSentPacket>,
121    max_ack_height_tracker: MaxAckHeightTracker,
122    /// The packet that will be acknowledged after this one will cause the
123    /// sampler to exit the app-limited phase.
124    end_of_app_limited_phase: Option<u64>,
125    overestimate_avoidance: bool,
126    limit_max_ack_height_tracker_by_send_rate: bool,
127
128    total_bytes_acked_after_last_ack_event: usize,
129}
130
131/// A subset of [`ConnectionStateOnSentPacket`] which is returned
132/// to the caller when the packet is acked or lost.
133#[derive(Debug, Default, Clone, Copy)]
134pub struct SendTimeState {
135    /// Whether other states in this object is valid.
136    pub is_valid: bool,
137    /// Whether the sender is app limited at the time the packet was sent.
138    /// App limited bandwidth sample might be artificially low because the
139    /// sender did not have enough data to send in order to saturate the
140    /// link.
141    pub is_app_limited: bool,
142    /// Total number of sent bytes at the time the packet was sent.
143    /// Includes the packet itself.
144    pub total_bytes_sent: usize,
145    /// Total number of acked bytes at the time the packet was sent.
146    pub total_bytes_acked: usize,
147    /// Total number of lost bytes at the time the packet was sent.
148    #[allow(dead_code)]
149    pub total_bytes_lost: usize,
150    /// Total number of inflight bytes at the time the packet was sent.
151    /// Includes the packet itself.
152    /// It should be equal to `total_bytes_sent` minus the sum of
153    /// `total_bytes_acked`, `total_bytes_lost` and total neutered bytes.
154    pub bytes_in_flight: usize,
155}
156
157#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)]
158struct ExtraAckedEvent {
159    /// The excess bytes acknowlwedged in the time delta for this event.
160    extra_acked: usize,
161    /// The bytes acknowledged and time delta from the event.
162    bytes_acked: usize,
163    time_delta: Duration,
164    /// The round trip of the event.
165    round: usize,
166}
167
168struct BandwidthSample {
169    /// The bandwidth at that particular sample.
170    bandwidth: Bandwidth,
171    /// The RTT measurement at this particular sample.  Does not correct for
172    /// delayed ack time.
173    rtt: Duration,
174    /// [`send_rate`] is computed from the current packet being acked('P') and
175    /// an earlier packet that is acked before P was sent.
176    send_rate: Option<Bandwidth>,
177    /// States captured when the packet was sent.
178    state_at_send: SendTimeState,
179}
180
181/// [`AckPoint`] represents a point on the ack line.
182#[derive(Debug, Clone, Copy)]
183struct AckPoint {
184    ack_time: Instant,
185    total_bytes_acked: usize,
186}
187
188/// [`RecentAckPoints`] maintains the most recent 2 ack points at distinct
189/// times.
190#[derive(Debug, Default)]
191struct RecentAckPoints {
192    ack_points: [Option<AckPoint>; 2],
193}
194
195// [`ConnectionStateOnSentPacket`] represents the information about a sent
196// packet and the state of the connection at the moment the packet was sent,
197// specifically the information about the most recently acknowledged packet at
198// that moment.
199#[derive(Debug)]
200struct ConnectionStateOnSentPacket {
201    /// Time at which the packet is sent.
202    sent_time: Instant,
203    /// Size of the packet.
204    size: usize,
205    /// The value of [`total_bytes_sent_at_last_acked_packet`] at the time the
206    /// packet was sent.
207    total_bytes_sent_at_last_acked_packet: usize,
208    /// The value of [`last_acked_packet_sent_time`] at the time the packet was
209    /// sent.
210    last_acked_packet_sent_time: Instant,
211    /// The value of [`last_acked_packet_ack_time`] at the time the packet was
212    /// sent.
213    last_acked_packet_ack_time: Instant,
214    /// Send time states that are returned to the congestion controller when the
215    /// packet is acked or lost.
216    send_time_state: SendTimeState,
217}
218
219/// [`MaxAckHeightTracker`] is part of the [`BandwidthSampler`]. It is called
220/// after every ack event to keep track the degree of ack
221/// aggregation(a.k.a "ack height").
222#[derive(Debug)]
223struct MaxAckHeightTracker {
224    /// Tracks the maximum number of bytes acked faster than the estimated
225    /// bandwidth.
226    max_ack_height_filter: WindowedFilter<ExtraAckedEvent, usize, usize>,
227    /// The time this aggregation started and the number of bytes acked during
228    /// it.
229    aggregation_epoch_start_time: Option<Instant>,
230    aggregation_epoch_bytes: usize,
231    /// The last sent packet number before the current aggregation epoch
232    /// started.
233    last_sent_packet_number_before_epoch: u64,
234    /// The number of ack aggregation epochs ever started, including the ongoing
235    /// one. Stats only.
236    num_ack_aggregation_epochs: u64,
237    ack_aggregation_bandwidth_threshold: f64,
238    start_new_aggregation_epoch_after_full_round: bool,
239    reduce_extra_acked_on_bandwidth_increase: bool,
240}
241
242#[derive(Default)]
243pub(crate) struct CongestionEventSample {
244    /// The maximum bandwidth sample from all acked packets.
245    pub sample_max_bandwidth: Option<Bandwidth>,
246    /// Whether [`sample_max_bandwidth`] is from a app-limited sample.
247    pub sample_is_app_limited: bool,
248    /// The minimum rtt sample from all acked packets.
249    pub sample_rtt: Option<Duration>,
250    /// For each packet p in acked packets, this is the max value of
251    /// INFLIGHT(p), where INFLIGHT(p) is the number of bytes acked while p
252    /// is inflight.
253    pub sample_max_inflight: usize,
254    /// The send state of the largest packet in acked_packets, unless it is
255    /// empty. If acked_packets is empty, it's the send state of the largest
256    /// packet in lost_packets.
257    pub last_packet_send_state: SendTimeState,
258    /// The number of extra bytes acked from this ack event, compared to what is
259    /// expected from the flow's bandwidth. Larger value means more ack
260    /// aggregation.
261    pub extra_acked: usize,
262}
263
264impl MaxAckHeightTracker {
265    pub(crate) fn new(window: usize, overestimate_avoidance: bool) -> Self {
266        MaxAckHeightTracker {
267            max_ack_height_filter: WindowedFilter::new(window),
268            aggregation_epoch_start_time: None,
269            aggregation_epoch_bytes: 0,
270            last_sent_packet_number_before_epoch: 0,
271            num_ack_aggregation_epochs: 0,
272            ack_aggregation_bandwidth_threshold: if overestimate_avoidance {
273                2.0
274            } else {
275                1.0
276            },
277            start_new_aggregation_epoch_after_full_round: true,
278            reduce_extra_acked_on_bandwidth_increase: true,
279        }
280    }
281
282    #[allow(dead_code)]
283    fn reset(&mut self, new_height: usize, new_time: usize) {
284        self.max_ack_height_filter.reset(
285            ExtraAckedEvent {
286                extra_acked: new_height,
287                bytes_acked: 0,
288                time_delta: Duration::ZERO,
289                round: new_time,
290            },
291            new_time,
292        );
293    }
294
295    #[allow(clippy::too_many_arguments)]
296    fn update(
297        &mut self, bandwidth_estimate: Bandwidth, is_new_max_bandwidth: bool,
298        round_trip_count: usize, last_sent_packet_number: u64,
299        last_acked_packet_number: u64, ack_time: Instant, bytes_acked: usize,
300    ) -> usize {
301        let mut force_new_epoch = false;
302
303        if self.reduce_extra_acked_on_bandwidth_increase && is_new_max_bandwidth {
304            // Save and clear existing entries.
305            let mut best =
306                self.max_ack_height_filter.get_best().unwrap_or_default();
307            let mut second_best = self
308                .max_ack_height_filter
309                .get_second_best()
310                .unwrap_or_default();
311            let mut third_best = self
312                .max_ack_height_filter
313                .get_third_best()
314                .unwrap_or_default();
315            self.max_ack_height_filter.clear();
316
317            // Reinsert the heights into the filter after recalculating.
318            let expected_bytes_acked =
319                bandwidth_estimate.to_bytes_per_period(best.time_delta) as usize;
320            if expected_bytes_acked < best.bytes_acked {
321                best.extra_acked = best.bytes_acked - expected_bytes_acked;
322                self.max_ack_height_filter.update(best, best.round);
323            }
324
325            let expected_bytes_acked = bandwidth_estimate
326                .to_bytes_per_period(second_best.time_delta)
327                as usize;
328            if expected_bytes_acked < second_best.bytes_acked {
329                second_best.extra_acked =
330                    second_best.bytes_acked - expected_bytes_acked;
331                self.max_ack_height_filter
332                    .update(second_best, second_best.round);
333            }
334
335            let expected_bytes_acked = bandwidth_estimate
336                .to_bytes_per_period(third_best.time_delta)
337                as usize;
338            if expected_bytes_acked < third_best.bytes_acked {
339                third_best.extra_acked =
340                    third_best.bytes_acked - expected_bytes_acked;
341                self.max_ack_height_filter
342                    .update(third_best, third_best.round);
343            }
344        }
345
346        // If any packet sent after the start of the epoch has been acked, start a
347        // new epoch.
348        if self.start_new_aggregation_epoch_after_full_round &&
349            last_acked_packet_number >
350                self.last_sent_packet_number_before_epoch
351        {
352            force_new_epoch = true;
353        }
354
355        let epoch_start_time = match self.aggregation_epoch_start_time {
356            Some(time) if !force_new_epoch => time,
357            _ => {
358                self.aggregation_epoch_bytes = bytes_acked;
359                self.aggregation_epoch_start_time = Some(ack_time);
360                self.last_sent_packet_number_before_epoch =
361                    last_sent_packet_number;
362                self.num_ack_aggregation_epochs += 1;
363                return 0;
364            },
365        };
366
367        // Compute how many bytes are expected to be delivered, assuming max
368        // bandwidth is correct.
369        let aggregation_delta = ack_time.duration_since(epoch_start_time);
370        let expected_bytes_acked =
371            bandwidth_estimate.to_bytes_per_period(aggregation_delta) as usize;
372        // Reset the current aggregation epoch as soon as the ack arrival rate is
373        // less than or equal to the max bandwidth.
374        if self.aggregation_epoch_bytes <=
375            (self.ack_aggregation_bandwidth_threshold *
376                expected_bytes_acked as f64) as usize
377        {
378            // Reset to start measuring a new aggregation epoch.
379            self.aggregation_epoch_bytes = bytes_acked;
380            self.aggregation_epoch_start_time = Some(ack_time);
381            self.last_sent_packet_number_before_epoch = last_sent_packet_number;
382            self.num_ack_aggregation_epochs += 1;
383            return 0;
384        }
385
386        self.aggregation_epoch_bytes += bytes_acked;
387
388        // Compute how many extra bytes were delivered vs max bandwidth.
389        let extra_bytes_acked =
390            self.aggregation_epoch_bytes - expected_bytes_acked;
391
392        let new_event = ExtraAckedEvent {
393            extra_acked: extra_bytes_acked,
394            bytes_acked: self.aggregation_epoch_bytes,
395            time_delta: aggregation_delta,
396            round: 0,
397        };
398
399        self.max_ack_height_filter
400            .update(new_event, round_trip_count);
401        extra_bytes_acked
402    }
403}
404
405impl From<(Instant, usize, usize, &BandwidthSampler)>
406    for ConnectionStateOnSentPacket
407{
408    fn from(
409        (sent_time, size, bytes_in_flight, sampler): (
410            Instant,
411            usize,
412            usize,
413            &BandwidthSampler,
414        ),
415    ) -> Self {
416        ConnectionStateOnSentPacket {
417            sent_time,
418            size,
419            total_bytes_sent_at_last_acked_packet: sampler
420                .total_bytes_sent_at_last_acked_packet,
421            last_acked_packet_sent_time: sampler.last_acked_packet_sent_time,
422            last_acked_packet_ack_time: sampler.last_acked_packet_ack_time,
423            send_time_state: SendTimeState {
424                is_valid: true,
425                is_app_limited: sampler.is_app_limited,
426                total_bytes_sent: sampler.total_bytes_sent,
427                total_bytes_acked: sampler.total_bytes_acked,
428                total_bytes_lost: sampler.total_bytes_lost,
429                bytes_in_flight,
430            },
431        }
432    }
433}
434
435impl RecentAckPoints {
436    fn update(&mut self, ack_time: Instant, total_bytes_acked: usize) {
437        assert!(
438            total_bytes_acked >=
439                self.ack_points[1].map(|p| p.total_bytes_acked).unwrap_or(0)
440        );
441
442        self.ack_points[0] = self.ack_points[1];
443        self.ack_points[1] = Some(AckPoint {
444            ack_time,
445            total_bytes_acked,
446        });
447    }
448
449    fn clear(&mut self) {
450        self.ack_points = Default::default();
451    }
452
453    fn most_recent(&self) -> Option<AckPoint> {
454        self.ack_points[1]
455    }
456
457    fn less_recent_point(&self) -> Option<AckPoint> {
458        self.ack_points[0].or(self.ack_points[1])
459    }
460}
461
462impl BandwidthSampler {
463    pub(crate) fn new(
464        max_height_tracker_window_length: usize, overestimate_avoidance: bool,
465    ) -> Self {
466        BandwidthSampler {
467            total_bytes_sent: 0,
468            total_bytes_acked: 0,
469            total_bytes_lost: 0,
470            total_bytes_neutered: 0,
471            total_bytes_sent_at_last_acked_packet: 0,
472            last_acked_packet_sent_time: Instant::now(),
473            last_acked_packet_ack_time: Instant::now(),
474            is_app_limited: true,
475            connection_state_map: ConnectionStateMap::default(),
476            max_ack_height_tracker: MaxAckHeightTracker::new(
477                max_height_tracker_window_length,
478                overestimate_avoidance,
479            ),
480            total_bytes_acked_after_last_ack_event: 0,
481            overestimate_avoidance,
482            limit_max_ack_height_tracker_by_send_rate: false,
483
484            last_sent_packet: 0,
485            last_acked_packet: 0,
486            recent_ack_points: RecentAckPoints::default(),
487            a0_candidates: VecDeque::new(),
488            end_of_app_limited_phase: None,
489        }
490    }
491
492    #[allow(dead_code)]
493    pub(crate) fn is_app_limited(&self) -> bool {
494        self.is_app_limited
495    }
496
497    pub(crate) fn on_packet_sent(
498        &mut self, sent_time: Instant, packet_number: u64, bytes: usize,
499        bytes_in_flight: usize, has_retransmittable_data: bool,
500    ) {
501        self.last_sent_packet = packet_number;
502
503        if !has_retransmittable_data {
504            return;
505        }
506
507        self.total_bytes_sent += bytes;
508
509        // If there are no packets in flight, the time at which the new
510        // transmission opens can be treated as the A_0 point for the
511        // purpose of bandwidth sampling. This underestimates bandwidth to
512        // some extent, and produces some artificially low samples for
513        // most packets in flight, but it provides with samples at
514        // important points where we would not have them otherwise, most
515        // importantly at the beginning of the connection.
516        if bytes_in_flight == 0 {
517            self.last_acked_packet_ack_time = sent_time;
518            if self.overestimate_avoidance {
519                self.recent_ack_points.clear();
520                self.recent_ack_points
521                    .update(sent_time, self.total_bytes_acked);
522                self.a0_candidates.clear();
523                self.a0_candidates
524                    .push_back(self.recent_ack_points.most_recent().unwrap());
525            }
526
527            self.total_bytes_sent_at_last_acked_packet = self.total_bytes_sent;
528
529            // In this situation ack compression is not a concern, set send rate
530            // to effectively infinite.
531            self.last_acked_packet_sent_time = sent_time;
532        }
533
534        self.connection_state_map.insert(
535            packet_number,
536            (sent_time, bytes, bytes_in_flight + bytes, &*self).into(),
537        );
538    }
539
540    pub(crate) fn on_packet_neutered(&mut self, packet_number: u64) {
541        if let Some(pkt) = self.connection_state_map.take(packet_number) {
542            self.total_bytes_neutered += pkt.size;
543        }
544    }
545
546    pub(crate) fn on_congestion_event(
547        &mut self, ack_time: Instant, acked_packets: &[Acked],
548        lost_packets: &[Lost], mut max_bandwidth: Option<Bandwidth>,
549        est_bandwidth_upper_bound: Bandwidth, round_trip_count: usize,
550    ) -> CongestionEventSample {
551        let mut last_lost_packet_send_state = SendTimeState::default();
552        let mut last_acked_packet_send_state = SendTimeState::default();
553        let mut last_lost_packet_num = 0u64;
554        let mut last_acked_packet_num = 0u64;
555
556        for packet in lost_packets {
557            let send_state =
558                self.on_packet_lost(packet.packet_number, packet.bytes_lost);
559            if send_state.is_valid {
560                last_lost_packet_send_state = send_state;
561                last_lost_packet_num = packet.packet_number;
562            }
563        }
564
565        if acked_packets.is_empty() {
566            // Only populate send state for a loss-only event.
567            return CongestionEventSample {
568                last_packet_send_state: last_lost_packet_send_state,
569                ..Default::default()
570            };
571        }
572
573        let mut event_sample = CongestionEventSample::default();
574
575        let mut max_send_rate = None;
576        for packet in acked_packets {
577            let sample =
578                match self.on_packet_acknowledged(ack_time, packet.pkt_num) {
579                    Some(sample) if sample.state_at_send.is_valid => sample,
580                    _ => continue,
581                };
582
583            last_acked_packet_send_state = sample.state_at_send;
584            last_acked_packet_num = packet.pkt_num;
585
586            event_sample.sample_rtt = Some(
587                sample
588                    .rtt
589                    .min(*event_sample.sample_rtt.get_or_insert(sample.rtt)),
590            );
591
592            if Some(sample.bandwidth) > event_sample.sample_max_bandwidth {
593                event_sample.sample_max_bandwidth = Some(sample.bandwidth);
594                event_sample.sample_is_app_limited =
595                    sample.state_at_send.is_app_limited;
596            }
597            max_send_rate = max_send_rate.max(sample.send_rate);
598
599            let inflight_sample = self.total_bytes_acked -
600                last_acked_packet_send_state.total_bytes_acked;
601            if inflight_sample > event_sample.sample_max_inflight {
602                event_sample.sample_max_inflight = inflight_sample;
603            }
604        }
605
606        if !last_lost_packet_send_state.is_valid {
607            event_sample.last_packet_send_state = last_acked_packet_send_state;
608        } else if !last_acked_packet_send_state.is_valid {
609            event_sample.last_packet_send_state = last_lost_packet_send_state;
610        } else {
611            // If two packets are inflight and an alarm is armed to lose a packet
612            // and it wakes up late, then the first of two in flight packets could
613            // have been acknowledged before the wakeup, which re-evaluates loss
614            // detection, and could declare the later of the two lost.
615            event_sample.last_packet_send_state =
616                if last_acked_packet_num > last_lost_packet_num {
617                    last_acked_packet_send_state
618                } else {
619                    last_lost_packet_send_state
620                };
621        }
622
623        let is_new_max_bandwidth =
624            event_sample.sample_max_bandwidth > max_bandwidth;
625        max_bandwidth = event_sample.sample_max_bandwidth.max(max_bandwidth);
626
627        if self.limit_max_ack_height_tracker_by_send_rate {
628            max_bandwidth = max_bandwidth.max(max_send_rate);
629        }
630
631        let bandwidth_estimate = if let Some(max_bandwidth) = max_bandwidth {
632            max_bandwidth.min(est_bandwidth_upper_bound)
633        } else {
634            est_bandwidth_upper_bound
635        };
636
637        event_sample.extra_acked = self.on_ack_event_end(
638            bandwidth_estimate,
639            is_new_max_bandwidth,
640            round_trip_count,
641        );
642
643        event_sample
644    }
645
646    fn on_packet_lost(
647        &mut self, packet_number: u64, bytes_lost: usize,
648    ) -> SendTimeState {
649        let mut send_time_state = SendTimeState::default();
650
651        self.total_bytes_lost += bytes_lost;
652        if let Some(state) = self.connection_state_map.take(packet_number) {
653            send_time_state = state.send_time_state;
654            send_time_state.is_valid = true;
655        }
656
657        send_time_state
658    }
659
660    fn on_ack_event_end(
661        &mut self, bandwidth_estimate: Bandwidth, is_new_max_bandwidth: bool,
662        round_trip_count: usize,
663    ) -> usize {
664        let newly_acked_bytes =
665            self.total_bytes_acked - self.total_bytes_acked_after_last_ack_event;
666
667        if newly_acked_bytes == 0 {
668            return 0;
669        }
670
671        self.total_bytes_acked_after_last_ack_event = self.total_bytes_acked;
672        let extra_acked = self.max_ack_height_tracker.update(
673            bandwidth_estimate,
674            is_new_max_bandwidth,
675            round_trip_count,
676            self.last_sent_packet,
677            self.last_acked_packet,
678            self.last_acked_packet_ack_time,
679            newly_acked_bytes,
680        );
681        // If `extra_acked` is zero, i.e. this ack event marks the start of a new
682        // ack aggregation epoch, save `less_recent_point`, which is the
683        // last ack point of the previous epoch, as a A0 candidate.
684        if self.overestimate_avoidance && extra_acked == 0 {
685            self.a0_candidates
686                .push_back(self.recent_ack_points.less_recent_point().unwrap());
687        }
688
689        extra_acked
690    }
691
692    fn on_packet_acknowledged(
693        &mut self, ack_time: Instant, packet_number: u64,
694    ) -> Option<BandwidthSample> {
695        self.last_acked_packet = packet_number;
696        let sent_packet = self.connection_state_map.take(packet_number)?;
697
698        self.total_bytes_acked += sent_packet.size;
699        self.total_bytes_sent_at_last_acked_packet =
700            sent_packet.send_time_state.total_bytes_sent;
701        self.last_acked_packet_sent_time = sent_packet.sent_time;
702        self.last_acked_packet_ack_time = ack_time;
703        if self.overestimate_avoidance {
704            self.recent_ack_points
705                .update(ack_time, self.total_bytes_acked);
706        }
707
708        if self.is_app_limited {
709            // Exit app-limited phase in two cases:
710            // (1) end_of_app_limited_phase is not initialized, i.e., so far all
711            // packets are sent while there are buffered packets or pending data.
712            // (2) The current acked packet is after the sent packet marked as the
713            // end of the app limit phase.
714            if self.end_of_app_limited_phase.is_none() ||
715                Some(packet_number) > self.end_of_app_limited_phase
716            {
717                self.is_app_limited = false;
718            }
719        }
720
721        // No send rate indicates that the sampler is supposed to discard the
722        // current send rate sample and use only the ack rate.
723        let send_rate = if sent_packet.sent_time >
724            sent_packet.last_acked_packet_sent_time
725        {
726            Some(Bandwidth::from_bytes_and_time_delta(
727                sent_packet.send_time_state.total_bytes_sent -
728                    sent_packet.total_bytes_sent_at_last_acked_packet,
729                sent_packet.sent_time - sent_packet.last_acked_packet_sent_time,
730            ))
731        } else {
732            None
733        };
734
735        let a0 = if self.overestimate_avoidance {
736            Self::choose_a0_point(
737                &mut self.a0_candidates,
738                sent_packet.send_time_state.total_bytes_acked,
739            )
740        } else {
741            None
742        };
743
744        let a0 = a0.unwrap_or(AckPoint {
745            ack_time: sent_packet.last_acked_packet_ack_time,
746            total_bytes_acked: sent_packet.send_time_state.total_bytes_acked,
747        });
748
749        // During the slope calculation, ensure that ack time of the current
750        // packet is always larger than the time of the previous packet,
751        // otherwise division by zero or integer underflow can occur.
752        if ack_time <= a0.ack_time {
753            return None;
754        }
755
756        let ack_rate = Bandwidth::from_bytes_and_time_delta(
757            self.total_bytes_acked - a0.total_bytes_acked,
758            ack_time.duration_since(a0.ack_time),
759        );
760
761        let bandwidth = if let Some(send_rate) = send_rate {
762            send_rate.min(ack_rate)
763        } else {
764            ack_rate
765        };
766
767        // Note: this sample does not account for delayed acknowledgement time.
768        // This means that the RTT measurements here can be artificially
769        // high, especially on low bandwidth connections.
770        let rtt = ack_time.duration_since(sent_packet.sent_time);
771
772        Some(BandwidthSample {
773            bandwidth,
774            rtt,
775            send_rate,
776            state_at_send: SendTimeState {
777                is_valid: true,
778                ..sent_packet.send_time_state
779            },
780        })
781    }
782
783    fn choose_a0_point(
784        a0_candidates: &mut VecDeque<AckPoint>, total_bytes_acked: usize,
785    ) -> Option<AckPoint> {
786        if a0_candidates.is_empty() {
787            return None;
788        }
789
790        while let Some(candidate) = a0_candidates.get(1) {
791            if candidate.total_bytes_acked > total_bytes_acked {
792                return Some(*candidate);
793            }
794            a0_candidates.pop_front();
795        }
796
797        Some(a0_candidates[0])
798    }
799
800    pub(crate) fn total_bytes_acked(&self) -> usize {
801        self.total_bytes_acked
802    }
803
804    pub(crate) fn total_bytes_lost(&self) -> usize {
805        self.total_bytes_lost
806    }
807
808    #[allow(dead_code)]
809    pub(crate) fn reset_max_ack_height_tracker(
810        &mut self, new_height: usize, new_time: usize,
811    ) {
812        self.max_ack_height_tracker.reset(new_height, new_time);
813    }
814
815    pub(crate) fn max_ack_height(&self) -> Option<usize> {
816        self.max_ack_height_tracker
817            .max_ack_height_filter
818            .get_best()
819            .map(|b| b.extra_acked)
820    }
821
822    pub(crate) fn on_app_limited(&mut self) {
823        self.is_app_limited = true;
824        self.end_of_app_limited_phase = Some(self.last_sent_packet);
825    }
826
827    pub(crate) fn remove_obsolete_packets(&mut self, least_acked: u64) {
828        // A packet can become obsolete when it is removed from
829        // QuicUnackedPacketMap's view of inflight before it is acked or
830        // marked as lost. For example, when
831        // QuicSentPacketManager::RetransmitCryptoPackets retransmits a crypto
832        // packet, the packet is removed from QuicUnackedPacketMap's
833        // inflight, but is not marked as acked or lost in the
834        // BandwidthSampler.
835        self.connection_state_map.remove_obsolete(least_acked);
836    }
837}
838
839#[cfg(test)]
840mod bandwidth_sampler_tests {
841    use super::*;
842
843    const REGULAR_PACKET_SIZE: usize = 1280;
844
845    struct TestSender {
846        sampler: BandwidthSampler,
847        sampler_app_limited_at_start: bool,
848        bytes_in_flight: usize,
849        clock: Instant,
850        max_bandwidth: Bandwidth,
851        est_bandwidth_upper_bound: Bandwidth,
852        round_trip_count: usize,
853    }
854
855    impl TestSender {
856        fn new() -> Self {
857            let sampler = BandwidthSampler::new(0, false);
858            TestSender {
859                sampler_app_limited_at_start: sampler.is_app_limited(),
860                sampler,
861                bytes_in_flight: 0,
862                clock: Instant::now(),
863                max_bandwidth: Bandwidth::zero(),
864                est_bandwidth_upper_bound: Bandwidth::infinite(),
865                round_trip_count: 0,
866            }
867        }
868
869        fn get_packet_size(&self, pkt_num: u64) -> usize {
870            self.sampler
871                .connection_state_map
872                .peek(pkt_num)
873                .unwrap()
874                .size
875        }
876
877        fn get_packet_time(&self, pkt_num: u64) -> Instant {
878            self.sampler
879                .connection_state_map
880                .peek(pkt_num)
881                .unwrap()
882                .sent_time
883        }
884
885        fn number_of_tracked_packets(&self) -> usize {
886            self.sampler.connection_state_map.packet_map.len()
887        }
888
889        fn make_acked_packet(&self, pkt_num: u64) -> Acked {
890            let time_sent = self.get_packet_time(pkt_num);
891
892            Acked { pkt_num, time_sent }
893        }
894
895        fn make_lost_packet(&self, pkt_num: u64) -> Lost {
896            let size = self.get_packet_size(pkt_num);
897            Lost {
898                packet_number: pkt_num,
899                bytes_lost: size,
900            }
901        }
902
903        fn ack_packet(&mut self, pkt_num: u64) -> BandwidthSample {
904            let size = self.get_packet_size(pkt_num);
905            self.bytes_in_flight -= size;
906
907            let sample = self.sampler.on_congestion_event(
908                self.clock,
909                &[self.make_acked_packet(pkt_num)],
910                &[],
911                Some(self.max_bandwidth),
912                self.est_bandwidth_upper_bound,
913                self.round_trip_count,
914            );
915
916            let max_bandwidth =
917                self.max_bandwidth.max(sample.sample_max_bandwidth.unwrap());
918
919            let bandwidth_sample = BandwidthSample {
920                bandwidth: max_bandwidth,
921                rtt: sample.sample_rtt.unwrap(),
922                send_rate: None,
923                state_at_send: sample.last_packet_send_state,
924            };
925            assert!(bandwidth_sample.state_at_send.is_valid);
926            bandwidth_sample
927        }
928
929        fn lose_packet(&mut self, pkt_num: u64) -> SendTimeState {
930            let size = self.get_packet_size(pkt_num);
931            self.bytes_in_flight -= size;
932
933            let sample = self.sampler.on_congestion_event(
934                self.clock,
935                &[],
936                &[self.make_lost_packet(pkt_num)],
937                Some(self.max_bandwidth),
938                self.est_bandwidth_upper_bound,
939                self.round_trip_count,
940            );
941
942            assert!(sample.last_packet_send_state.is_valid);
943            assert_eq!(sample.sample_max_bandwidth, None);
944            assert_eq!(sample.sample_rtt, None);
945            sample.last_packet_send_state
946        }
947
948        fn on_congestion_event(
949            &mut self, acked: &[u64], lost: &[u64],
950        ) -> CongestionEventSample {
951            let acked = acked
952                .into_iter()
953                .map(|pkt| {
954                    let acked_size = self.get_packet_size(*pkt);
955                    self.bytes_in_flight -= acked_size;
956
957                    let acked = self.make_acked_packet(*pkt);
958                    acked
959                })
960                .collect::<Vec<_>>();
961
962            let lost = lost
963                .into_iter()
964                .map(|pkt| {
965                    let lost = self.make_lost_packet(*pkt);
966                    self.bytes_in_flight -= lost.bytes_lost;
967                    lost
968                })
969                .collect::<Vec<_>>();
970
971            let sample = self.sampler.on_congestion_event(
972                self.clock,
973                &acked,
974                &lost,
975                Some(self.max_bandwidth),
976                self.est_bandwidth_upper_bound,
977                self.round_trip_count,
978            );
979
980            self.max_bandwidth =
981                self.max_bandwidth.max(sample.sample_max_bandwidth.unwrap());
982
983            sample
984        }
985
986        fn send_packet(
987            &mut self, pkt_num: u64, pkt_sz: usize,
988            has_retransmittable_data: bool,
989        ) {
990            self.sampler.on_packet_sent(
991                self.clock,
992                pkt_num,
993                pkt_sz,
994                self.bytes_in_flight,
995                has_retransmittable_data,
996            );
997            if has_retransmittable_data {
998                self.bytes_in_flight += pkt_sz;
999            }
1000        }
1001
1002        fn advance_time(&mut self, delta: Duration) {
1003            self.clock += delta;
1004        }
1005
1006        // Sends one packet and acks it.  Then, send 20 packets.  Finally, send
1007        // another 20 packets while acknowledging previous 20.
1008        fn send_40_and_ack_first_20(&mut self, time_between_packets: Duration) {
1009            // Send 20 packets at a constant inter-packet time.
1010            for i in 1..=20 {
1011                self.send_packet(i, REGULAR_PACKET_SIZE, true);
1012                self.advance_time(time_between_packets);
1013            }
1014
1015            // Ack packets 1 to 20, while sending new packets at the same rate as
1016            // before.
1017            for i in 1..=20 {
1018                self.ack_packet(i);
1019                self.send_packet(i + 20, REGULAR_PACKET_SIZE, true);
1020                self.advance_time(time_between_packets);
1021            }
1022        }
1023    }
1024
1025    #[test]
1026    fn send_and_wait() {
1027        let mut test_sender = TestSender::new();
1028        let mut time_between_packets = Duration::from_millis(10);
1029        let mut expected_bandwidth =
1030            Bandwidth::from_bytes_per_second(REGULAR_PACKET_SIZE as u64 * 100);
1031
1032        // Send packets at the constant bandwidth.
1033        for i in 1..20 {
1034            test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1035            test_sender.advance_time(time_between_packets);
1036            let current_sample = test_sender.ack_packet(i);
1037            assert_eq!(expected_bandwidth, current_sample.bandwidth);
1038        }
1039
1040        // Send packets at the exponentially decreasing bandwidth.
1041        for i in 20..25 {
1042            time_between_packets = time_between_packets * 2;
1043            expected_bandwidth = expected_bandwidth * 0.5;
1044
1045            test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1046            test_sender.advance_time(time_between_packets);
1047            let current_sample = test_sender.ack_packet(i);
1048            assert_eq!(expected_bandwidth, current_sample.bandwidth);
1049        }
1050
1051        test_sender.sampler.remove_obsolete_packets(25);
1052        assert_eq!(0, test_sender.number_of_tracked_packets());
1053        assert_eq!(0, test_sender.bytes_in_flight);
1054    }
1055
1056    #[test]
1057    fn send_time_state() {
1058        let mut test_sender = TestSender::new();
1059        let time_between_packets = Duration::from_millis(10);
1060
1061        // Send packets 1-5.
1062        for i in 1..=5 {
1063            test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1064            assert_eq!(
1065                test_sender.sampler.total_bytes_sent,
1066                REGULAR_PACKET_SIZE * i as usize
1067            );
1068            test_sender.advance_time(time_between_packets);
1069        }
1070
1071        // Ack packet 1.
1072        let send_time_state = test_sender.ack_packet(1).state_at_send;
1073        assert_eq!(REGULAR_PACKET_SIZE * 1, send_time_state.total_bytes_sent);
1074        assert_eq!(0, send_time_state.total_bytes_acked);
1075        assert_eq!(0, send_time_state.total_bytes_lost);
1076        assert_eq!(
1077            REGULAR_PACKET_SIZE * 1,
1078            test_sender.sampler.total_bytes_acked
1079        );
1080
1081        // Lose packet 2.
1082        let send_time_state = test_sender.lose_packet(2);
1083        assert_eq!(REGULAR_PACKET_SIZE * 2, send_time_state.total_bytes_sent);
1084        assert_eq!(0, send_time_state.total_bytes_acked);
1085        assert_eq!(0, send_time_state.total_bytes_lost);
1086        assert_eq!(
1087            REGULAR_PACKET_SIZE * 1,
1088            test_sender.sampler.total_bytes_lost
1089        );
1090
1091        // Lose packet 3.
1092        let send_time_state = test_sender.lose_packet(3);
1093        assert_eq!(REGULAR_PACKET_SIZE * 3, send_time_state.total_bytes_sent);
1094        assert_eq!(0, send_time_state.total_bytes_acked);
1095        assert_eq!(0, send_time_state.total_bytes_lost);
1096        assert_eq!(
1097            REGULAR_PACKET_SIZE * 2,
1098            test_sender.sampler.total_bytes_lost
1099        );
1100
1101        // Send packets 6-10.
1102        for i in 6..=10 {
1103            test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1104            assert_eq!(
1105                test_sender.sampler.total_bytes_sent,
1106                REGULAR_PACKET_SIZE * i as usize
1107            );
1108            test_sender.advance_time(time_between_packets);
1109        }
1110
1111        // Ack all inflight packets.
1112        let mut acked_packet_count = 1;
1113        assert_eq!(
1114            REGULAR_PACKET_SIZE * acked_packet_count,
1115            test_sender.sampler.total_bytes_acked
1116        );
1117        for i in 4..=10 {
1118            let send_time_state = test_sender.ack_packet(i).state_at_send;
1119            acked_packet_count += 1;
1120            assert_eq!(
1121                REGULAR_PACKET_SIZE * acked_packet_count,
1122                test_sender.sampler.total_bytes_acked
1123            );
1124            assert_eq!(
1125                REGULAR_PACKET_SIZE * i as usize,
1126                send_time_state.total_bytes_sent
1127            );
1128
1129            if i <= 5 {
1130                assert_eq!(0, send_time_state.total_bytes_acked);
1131                assert_eq!(0, send_time_state.total_bytes_lost);
1132            } else {
1133                assert_eq!(
1134                    REGULAR_PACKET_SIZE * 1,
1135                    send_time_state.total_bytes_acked
1136                );
1137                assert_eq!(
1138                    REGULAR_PACKET_SIZE * 2,
1139                    send_time_state.total_bytes_lost
1140                );
1141            }
1142
1143            // This equation works because there is no neutered bytes.
1144            assert_eq!(
1145                send_time_state.total_bytes_sent -
1146                    send_time_state.total_bytes_acked -
1147                    send_time_state.total_bytes_lost,
1148                send_time_state.bytes_in_flight
1149            );
1150
1151            test_sender.advance_time(time_between_packets);
1152        }
1153    }
1154
1155    /// Test the sampler during regular windowed sender scenario with fixed CWND
1156    /// of 20.
1157    #[test]
1158    fn send_paced() {
1159        let mut test_sender = TestSender::new();
1160        let time_between_packets = Duration::from_millis(1);
1161        let expected_bandwidth =
1162            Bandwidth::from_kbits_per_second(REGULAR_PACKET_SIZE as u64 * 8);
1163
1164        test_sender.send_40_and_ack_first_20(time_between_packets);
1165
1166        // Ack the packets 21 to 40, arriving at the correct bandwidth.
1167        for i in 21..=40 {
1168            let last_bandwidth = test_sender.ack_packet(i).bandwidth;
1169            assert_eq!(expected_bandwidth, last_bandwidth);
1170            test_sender.advance_time(time_between_packets);
1171        }
1172        test_sender.sampler.remove_obsolete_packets(41);
1173        assert_eq!(0, test_sender.number_of_tracked_packets());
1174        assert_eq!(0, test_sender.bytes_in_flight);
1175    }
1176
1177    /// Test the sampler in a scenario where 50% of packets is consistently
1178    /// lost.
1179    #[test]
1180    fn send_with_losses() {
1181        let mut test_sender = TestSender::new();
1182        let time_between_packets = Duration::from_millis(1);
1183        let expected_bandwidth =
1184            Bandwidth::from_kbits_per_second(REGULAR_PACKET_SIZE as u64 / 2 * 8);
1185
1186        // Send 20 packets, each 1 ms apart.
1187        for i in 1..=20 {
1188            test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1189            test_sender.advance_time(time_between_packets);
1190        }
1191
1192        // Ack packets 1 to 20, losing every even-numbered packet, while sending
1193        // new packets at the same rate as before.
1194        for i in 1..=20 {
1195            if i % 2 == 0 {
1196                test_sender.ack_packet(i);
1197            } else {
1198                test_sender.lose_packet(i);
1199            }
1200            test_sender.send_packet(i + 20, REGULAR_PACKET_SIZE, true);
1201            test_sender.advance_time(time_between_packets);
1202        }
1203
1204        // Ack the packets 21 to 40 with the same loss pattern.
1205        for i in 21..=40 {
1206            if i % 2 == 0 {
1207                let last_bandwidth = test_sender.ack_packet(i).bandwidth;
1208                assert_eq!(expected_bandwidth, last_bandwidth);
1209            } else {
1210                test_sender.lose_packet(i);
1211            }
1212            test_sender.advance_time(time_between_packets);
1213        }
1214        test_sender.sampler.remove_obsolete_packets(41);
1215        assert_eq!(0, test_sender.number_of_tracked_packets());
1216        assert_eq!(0, test_sender.bytes_in_flight);
1217    }
1218
1219    /// Test the sampler in a scenario where the 50% of packets are not
1220    /// congestion controlled (specifically, non-retransmittable data is not
1221    /// congestion controlled).  Should be functionally consistent in behavior
1222    /// with the [`send_with_losses`] test.
1223    #[test]
1224    fn not_congestion_controlled() {
1225        let mut test_sender = TestSender::new();
1226        let time_between_packets = Duration::from_millis(1);
1227        let expected_bandwidth =
1228            Bandwidth::from_kbits_per_second(REGULAR_PACKET_SIZE as u64 / 2 * 8);
1229
1230        // Send 20 packets, each 1 ms apart. Every even packet is not congestion
1231        // controlled.
1232        for i in 1..=20 {
1233            let has_retransmittable_data = i % 2 == 0;
1234            test_sender.send_packet(
1235                i,
1236                REGULAR_PACKET_SIZE,
1237                has_retransmittable_data,
1238            );
1239            test_sender.advance_time(time_between_packets);
1240        }
1241
1242        // Ensure only congestion controlled packets are tracked.
1243        assert_eq!(10, test_sender.number_of_tracked_packets());
1244
1245        // Ack packets 2 to 21, ignoring every even-numbered packet, while sending
1246        // new packets at the same rate as before.
1247        for i in 1..=20 {
1248            if i % 2 == 0 {
1249                test_sender.ack_packet(i);
1250            }
1251            let has_retransmittable_data = i % 2 == 0;
1252            test_sender.send_packet(
1253                i + 20,
1254                REGULAR_PACKET_SIZE,
1255                has_retransmittable_data,
1256            );
1257            test_sender.advance_time(time_between_packets);
1258        }
1259
1260        // Ack the packets 22 to 41 with the same congestion controlled pattern.
1261        for i in 21..=40 {
1262            if i % 2 == 0 {
1263                let last_bandwidth = test_sender.ack_packet(i).bandwidth;
1264                assert_eq!(expected_bandwidth, last_bandwidth);
1265            }
1266            test_sender.advance_time(time_between_packets);
1267        }
1268
1269        test_sender.sampler.remove_obsolete_packets(41);
1270        // Since only congestion controlled packets are entered into the map, it
1271        // has to be empty at this point.
1272        assert_eq!(0, test_sender.number_of_tracked_packets());
1273        assert_eq!(0, test_sender.bytes_in_flight);
1274    }
1275
1276    /// Simulate a situation where ACKs arrive in burst and earlier than usual,
1277    /// thus producing an ACK rate which is higher than the original send rate.
1278    #[test]
1279    fn compressed_ack() {
1280        let mut test_sender = TestSender::new();
1281        let time_between_packets = Duration::from_millis(1);
1282        let expected_bandwidth =
1283            Bandwidth::from_kbits_per_second(REGULAR_PACKET_SIZE as u64 * 8);
1284
1285        test_sender.send_40_and_ack_first_20(time_between_packets);
1286
1287        // Simulate an RTT somewhat lower than the one for 1-to-21 transmission.
1288        test_sender.advance_time(time_between_packets * 15);
1289
1290        // Ack the packets 21 to 40 almost immediately at once.
1291        let ridiculously_small_time_delta = Duration::from_micros(20);
1292        let mut last_bandwidth = Bandwidth::zero();
1293        for i in 21..=40 {
1294            last_bandwidth = test_sender.ack_packet(i).bandwidth;
1295            test_sender.advance_time(ridiculously_small_time_delta);
1296        }
1297        assert_eq!(expected_bandwidth, last_bandwidth);
1298
1299        test_sender.sampler.remove_obsolete_packets(41);
1300        // Since only congestion controlled packets are entered into the map, it
1301        // has to be empty at this point.
1302        assert_eq!(0, test_sender.number_of_tracked_packets());
1303        assert_eq!(0, test_sender.bytes_in_flight);
1304    }
1305
1306    /// Tests receiving ACK packets in the reverse order.
1307    #[test]
1308    fn reordered_ack() {
1309        let mut test_sender = TestSender::new();
1310        let time_between_packets = Duration::from_millis(1);
1311        let expected_bandwidth =
1312            Bandwidth::from_kbits_per_second(REGULAR_PACKET_SIZE as u64 * 8);
1313
1314        test_sender.send_40_and_ack_first_20(time_between_packets);
1315
1316        // Ack the packets 21 to 40 in the reverse order, while sending packets 41
1317        // to 60.
1318        for i in 0..20 {
1319            let last_bandwidth = test_sender.ack_packet(40 - i).bandwidth;
1320            assert_eq!(expected_bandwidth, last_bandwidth);
1321            test_sender.send_packet(41 + i, REGULAR_PACKET_SIZE, true);
1322            test_sender.advance_time(time_between_packets);
1323        }
1324
1325        // Ack the packets 41 to 60, now in the regular order.
1326        for i in 41..=60 {
1327            let last_bandwidth = test_sender.ack_packet(i).bandwidth;
1328            assert_eq!(expected_bandwidth, last_bandwidth);
1329            test_sender.advance_time(time_between_packets);
1330        }
1331
1332        test_sender.sampler.remove_obsolete_packets(61);
1333        assert_eq!(0, test_sender.number_of_tracked_packets());
1334        assert_eq!(0, test_sender.bytes_in_flight);
1335    }
1336
1337    /// Test the app-limited logic.
1338    #[test]
1339    fn app_limited() {
1340        let mut test_sender = TestSender::new();
1341        let time_between_packets = Duration::from_millis(1);
1342        let expected_bandwidth =
1343            Bandwidth::from_kbits_per_second(REGULAR_PACKET_SIZE as u64 * 8);
1344
1345        for i in 1..=20 {
1346            test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1347            test_sender.advance_time(time_between_packets);
1348        }
1349
1350        for i in 1..=20 {
1351            let sample = test_sender.ack_packet(i);
1352            assert_eq!(
1353                sample.state_at_send.is_app_limited,
1354                test_sender.sampler_app_limited_at_start
1355            );
1356            test_sender.send_packet(i + 20, REGULAR_PACKET_SIZE, true);
1357            test_sender.advance_time(time_between_packets);
1358        }
1359
1360        // We are now app-limited. Ack 21 to 40 as usual, but do not send anything
1361        // for now.
1362        test_sender.sampler.on_app_limited();
1363        for i in 21..=40 {
1364            let sample = test_sender.ack_packet(i);
1365            assert!(!sample.state_at_send.is_app_limited);
1366            assert_eq!(expected_bandwidth, sample.bandwidth);
1367            test_sender.advance_time(time_between_packets);
1368        }
1369
1370        // Enter quiescence.
1371        test_sender.advance_time(Duration::from_secs(1));
1372
1373        // Send packets 41 to 60, all of which would be marked as app-limited.
1374        for i in 41..=60 {
1375            test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1376            test_sender.advance_time(time_between_packets);
1377        }
1378
1379        // Ack packets 41 to 60, while sending packets 61 to 80.  41 to 60 should
1380        // be app-limited and underestimate the bandwidth due to that.
1381        for i in 41..=60 {
1382            let sample = test_sender.ack_packet(i);
1383            assert!(sample.state_at_send.is_app_limited);
1384            assert!(sample.bandwidth < expected_bandwidth * 0.7);
1385            test_sender.send_packet(i + 20, REGULAR_PACKET_SIZE, true);
1386            test_sender.advance_time(time_between_packets);
1387        }
1388
1389        // Run out of packets, and then ack packet 61 to 80, all of which should
1390        // have correct non-app-limited samples.
1391        for i in 61..=80 {
1392            let sample = test_sender.ack_packet(i);
1393            assert!(!sample.state_at_send.is_app_limited);
1394            assert_eq!(sample.bandwidth, expected_bandwidth);
1395            test_sender.advance_time(time_between_packets);
1396        }
1397
1398        test_sender.sampler.remove_obsolete_packets(81);
1399        assert_eq!(0, test_sender.number_of_tracked_packets());
1400        assert_eq!(0, test_sender.bytes_in_flight);
1401    }
1402
1403    /// Test the samples taken at the first flight of packets sent.
1404    #[test]
1405    fn first_round_trip() {
1406        let mut test_sender = TestSender::new();
1407        let time_between_packets = Duration::from_millis(1);
1408        let rtt = Duration::from_millis(800);
1409        let num_packets = 10;
1410        let num_bytes = REGULAR_PACKET_SIZE * num_packets;
1411        let real_bandwidth = Bandwidth::from_bytes_and_time_delta(num_bytes, rtt);
1412
1413        for i in 1..=10 {
1414            test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1415            test_sender.advance_time(time_between_packets);
1416        }
1417        test_sender.advance_time(rtt - time_between_packets * num_packets as _);
1418
1419        let mut last_sample = Bandwidth::zero();
1420        for i in 1..=10 {
1421            let sample = test_sender.ack_packet(i).bandwidth;
1422            assert!(sample > last_sample);
1423            last_sample = sample;
1424            test_sender.advance_time(time_between_packets);
1425        }
1426
1427        // The final measured sample for the first flight of sample is expected to
1428        // be smaller than the real bandwidth, yet it should not lose more
1429        // than 10%. The specific value of the error depends on the
1430        // difference between the RTT and the time it takes to exhaust the
1431        // congestion window (i.e. in the limit when all packets are sent
1432        // simultaneously, last sample would indicate the real bandwidth).
1433        assert!(last_sample < real_bandwidth);
1434        assert!(last_sample > real_bandwidth * 0.9);
1435    }
1436
1437    /// Test sampler's ability to remove obsolete packets.
1438    #[test]
1439    fn remove_obsolete_packets() {
1440        let mut test_sender = TestSender::new();
1441
1442        for i in 1..=5 {
1443            test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1444        }
1445        test_sender.advance_time(Duration::from_millis(100));
1446        assert_eq!(5, test_sender.number_of_tracked_packets());
1447        test_sender.sampler.remove_obsolete_packets(4);
1448        assert_eq!(2, test_sender.number_of_tracked_packets());
1449        test_sender.lose_packet(4);
1450        test_sender.sampler.remove_obsolete_packets(5);
1451        assert_eq!(1, test_sender.number_of_tracked_packets());
1452        test_sender.ack_packet(5);
1453        test_sender.sampler.remove_obsolete_packets(6);
1454        assert_eq!(0, test_sender.number_of_tracked_packets());
1455    }
1456
1457    #[test]
1458    fn neuter_packet() {
1459        let mut test_sender = TestSender::new();
1460        test_sender.send_packet(1, REGULAR_PACKET_SIZE, true);
1461        assert_eq!(test_sender.sampler.total_bytes_neutered, 0);
1462        test_sender.advance_time(Duration::from_millis(10));
1463        test_sender.sampler.on_packet_neutered(1);
1464        assert!(0 < test_sender.sampler.total_bytes_neutered);
1465        assert_eq!(0, test_sender.sampler.total_bytes_acked);
1466
1467        // If packet 1 is acked it should not produce a bandwidth sample.
1468        let acked = Acked {
1469            pkt_num: 1,
1470            time_sent: test_sender.clock,
1471        };
1472        test_sender.advance_time(Duration::from_millis(10));
1473        let sample = test_sender.sampler.on_congestion_event(
1474            test_sender.clock,
1475            &[acked],
1476            &[],
1477            Some(test_sender.max_bandwidth),
1478            test_sender.est_bandwidth_upper_bound,
1479            test_sender.round_trip_count,
1480        );
1481
1482        assert_eq!(0, test_sender.sampler.total_bytes_acked);
1483        assert!(sample.sample_max_bandwidth.is_none());
1484        assert!(!sample.sample_is_app_limited);
1485        assert!(sample.sample_rtt.is_none());
1486        assert_eq!(sample.sample_max_inflight, 0);
1487        assert_eq!(sample.extra_acked, 0);
1488    }
1489
1490    /// Make sure a default constructed [`CongestionEventSample`] has the
1491    /// correct initial values for
1492    /// [`BandwidthSampler::on_congestion_event()`] to work.
1493    #[test]
1494    fn congestion_event_sample_default_values() {
1495        let sample = CongestionEventSample::default();
1496        assert!(sample.sample_max_bandwidth.is_none());
1497        assert!(!sample.sample_is_app_limited);
1498        assert!(sample.sample_rtt.is_none());
1499        assert_eq!(sample.sample_max_inflight, 0);
1500        assert_eq!(sample.extra_acked, 0);
1501    }
1502
1503    /// 1) Send 2 packets, 2) Ack both in 1 event, 3) Repeat.
1504    #[test]
1505    fn two_acked_packets_per_event() {
1506        let mut test_sender = TestSender::new();
1507        let time_between_packets = Duration::from_millis(10);
1508        let sending_rate = Bandwidth::from_bytes_and_time_delta(
1509            REGULAR_PACKET_SIZE,
1510            time_between_packets,
1511        );
1512
1513        for i in 1..21 {
1514            test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1515            test_sender.advance_time(time_between_packets);
1516            if i % 2 != 0 {
1517                continue;
1518            }
1519
1520            let sample = test_sender.on_congestion_event(&[i - 1, i], &[]);
1521            assert_eq!(sending_rate, sample.sample_max_bandwidth.unwrap());
1522            assert_eq!(time_between_packets, sample.sample_rtt.unwrap());
1523            assert_eq!(2 * REGULAR_PACKET_SIZE, sample.sample_max_inflight);
1524            assert!(sample.last_packet_send_state.is_valid);
1525            assert_eq!(
1526                2 * REGULAR_PACKET_SIZE,
1527                sample.last_packet_send_state.bytes_in_flight
1528            );
1529            assert_eq!(
1530                i as usize * REGULAR_PACKET_SIZE,
1531                sample.last_packet_send_state.total_bytes_sent
1532            );
1533            assert_eq!(
1534                (i - 2) as usize * REGULAR_PACKET_SIZE,
1535                sample.last_packet_send_state.total_bytes_acked
1536            );
1537            assert_eq!(0, sample.last_packet_send_state.total_bytes_lost);
1538            test_sender.sampler.remove_obsolete_packets(i - 2);
1539        }
1540    }
1541
1542    #[test]
1543    fn lose_every_other_packet() {
1544        let mut test_sender = TestSender::new();
1545        let time_between_packets = Duration::from_millis(10);
1546        let sending_rate = Bandwidth::from_bytes_and_time_delta(
1547            REGULAR_PACKET_SIZE,
1548            time_between_packets,
1549        );
1550
1551        for i in 1..21 {
1552            test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1553            test_sender.advance_time(time_between_packets);
1554            if i % 2 != 0 {
1555                continue;
1556            }
1557            // Ack packet i and lose i-1.
1558            let sample = test_sender.on_congestion_event(&[i], &[i - 1]);
1559            // Losing 50% packets means sending rate is twice the bandwidth.
1560
1561            assert_eq!(sending_rate, sample.sample_max_bandwidth.unwrap() * 2.);
1562            assert_eq!(time_between_packets, sample.sample_rtt.unwrap());
1563            assert_eq!(REGULAR_PACKET_SIZE, sample.sample_max_inflight);
1564            assert!(sample.last_packet_send_state.is_valid);
1565            assert_eq!(
1566                2 * REGULAR_PACKET_SIZE,
1567                sample.last_packet_send_state.bytes_in_flight
1568            );
1569            assert_eq!(
1570                i as usize * REGULAR_PACKET_SIZE,
1571                sample.last_packet_send_state.total_bytes_sent
1572            );
1573            assert_eq!(
1574                (i - 2) as usize * REGULAR_PACKET_SIZE / 2,
1575                sample.last_packet_send_state.total_bytes_acked
1576            );
1577            assert_eq!(
1578                (i - 2) as usize * REGULAR_PACKET_SIZE / 2,
1579                sample.last_packet_send_state.total_bytes_lost
1580            );
1581            test_sender.sampler.remove_obsolete_packets(i - 2);
1582        }
1583    }
1584
1585    #[test]
1586    fn ack_height_respect_bandwidth_estimate_upper_bound() {
1587        let mut test_sender = TestSender::new();
1588        let time_between_packets = Duration::from_millis(10);
1589        let first_packet_sending_rate = Bandwidth::from_bytes_and_time_delta(
1590            REGULAR_PACKET_SIZE,
1591            time_between_packets,
1592        );
1593
1594        // Send packets 1 to 4 and ack packet 1.
1595        test_sender.send_packet(1, REGULAR_PACKET_SIZE, true);
1596        test_sender.advance_time(time_between_packets);
1597        test_sender.send_packet(2, REGULAR_PACKET_SIZE, true);
1598        test_sender.send_packet(3, REGULAR_PACKET_SIZE, true);
1599        test_sender.send_packet(4, REGULAR_PACKET_SIZE, true);
1600        let sample = test_sender.on_congestion_event(&[1], &[]);
1601        assert_eq!(
1602            first_packet_sending_rate,
1603            sample.sample_max_bandwidth.unwrap()
1604        );
1605        assert_eq!(first_packet_sending_rate, test_sender.max_bandwidth);
1606
1607        // Ack packet 2, 3 and 4, all of which uses S(1) to calculate ack rate
1608        // since there were no acks at the time they were sent.
1609        test_sender.round_trip_count += 1;
1610        test_sender.est_bandwidth_upper_bound = first_packet_sending_rate * 0.3;
1611        test_sender.advance_time(time_between_packets);
1612
1613        let sample = test_sender.on_congestion_event(&[2, 3, 4], &[]);
1614
1615        assert_eq!(
1616            first_packet_sending_rate * 2.,
1617            sample.sample_max_bandwidth.unwrap()
1618        );
1619        assert_eq!(
1620            test_sender.max_bandwidth,
1621            sample.sample_max_bandwidth.unwrap()
1622        );
1623        assert!(2 * REGULAR_PACKET_SIZE < sample.extra_acked);
1624    }
1625}
1626
1627#[cfg(test)]
1628mod max_ack_height_tracker_tests {
1629    use super::*;
1630
1631    struct TestTracker {
1632        tracker: MaxAckHeightTracker,
1633        bandwidth: Bandwidth,
1634        start: Instant,
1635        now: Instant,
1636        last_sent_packet_number: u64,
1637        last_acked_packet_number: u64,
1638        rtt: Duration,
1639    }
1640
1641    impl TestTracker {
1642        fn new() -> Self {
1643            let mut tracker = MaxAckHeightTracker::new(10, false);
1644            tracker.ack_aggregation_bandwidth_threshold = 1.8;
1645            tracker.start_new_aggregation_epoch_after_full_round = true;
1646            let start = Instant::now();
1647            TestTracker {
1648                tracker,
1649                start,
1650                now: start + Duration::from_millis(1),
1651                bandwidth: Bandwidth::from_bytes_per_second(10 * 1000),
1652                last_sent_packet_number: 0,
1653                last_acked_packet_number: 0,
1654                rtt: Duration::from_millis(60),
1655            }
1656        }
1657
1658        // Run a full aggregation episode, which is one or more aggregated acks,
1659        // followed by a quiet period in which no ack happens.
1660        // After this function returns, the time is set to the earliest point at
1661        // which any ack event will cause tracker_.Update() to start a new
1662        // aggregation.
1663        fn aggregation_episode(
1664            &mut self, aggregation_bandwidth: Bandwidth,
1665            aggregation_duration: Duration, bytes_per_ack: usize,
1666            expect_new_aggregation_epoch: bool,
1667        ) {
1668            assert!(aggregation_bandwidth >= self.bandwidth);
1669            let start_time = self.now;
1670
1671            let aggregation_bytes =
1672                (aggregation_bandwidth * aggregation_duration) as usize;
1673
1674            let num_acks = aggregation_bytes / bytes_per_ack;
1675            assert_eq!(aggregation_bytes, num_acks * bytes_per_ack);
1676
1677            let time_between_acks = Duration::from_micros(
1678                aggregation_duration.as_micros() as u64 / num_acks as u64,
1679            );
1680            assert_eq!(aggregation_duration, time_between_acks * num_acks as u32);
1681
1682            // The total duration of aggregation time and quiet period.
1683            let total_duration = Duration::from_micros(
1684                (aggregation_bytes as u64 * 8 * 1000000) /
1685                    self.bandwidth.to_bits_per_second() as u64,
1686            );
1687
1688            assert_eq!(aggregation_bytes as u64, self.bandwidth * total_duration);
1689
1690            let mut last_extra_acked = 0;
1691
1692            for bytes in (0..aggregation_bytes).step_by(bytes_per_ack) {
1693                let extra_acked = self.tracker.update(
1694                    self.bandwidth,
1695                    true,
1696                    self.round_trip_count(),
1697                    self.last_sent_packet_number,
1698                    self.last_acked_packet_number,
1699                    self.now,
1700                    bytes_per_ack,
1701                );
1702                // `extra_acked` should be 0 if either
1703                // [1] We are at the beginning of a aggregation epoch(bytes==0)
1704                // and the     the current tracker implementation
1705                // can identify it, or [2] We are not really
1706                // aggregating acks.
1707                if (bytes == 0 && expect_new_aggregation_epoch) ||
1708                    (aggregation_bandwidth == self.bandwidth)
1709                {
1710                    assert_eq!(0, extra_acked);
1711                } else {
1712                    assert!(last_extra_acked < extra_acked);
1713                }
1714                self.now = self.now + time_between_acks;
1715                last_extra_acked = extra_acked;
1716            }
1717
1718            // Advance past the quiet period.
1719            self.now = start_time + total_duration;
1720        }
1721
1722        fn round_trip_count(&self) -> usize {
1723            ((self.now - self.start).as_micros() / self.rtt.as_micros()) as usize
1724        }
1725    }
1726
1727    fn test_inner(
1728        bandwidth_gain: f64, agg_duration: Duration, byte_per_ack: usize,
1729    ) {
1730        let mut test_tracker = TestTracker::new();
1731
1732        let rnd = |tracker: &mut TestTracker, expect: bool| {
1733            tracker.aggregation_episode(
1734                tracker.bandwidth * bandwidth_gain,
1735                agg_duration,
1736                byte_per_ack,
1737                expect,
1738            );
1739        };
1740
1741        rnd(&mut test_tracker, true);
1742        rnd(&mut test_tracker, true);
1743
1744        test_tracker.now = test_tracker
1745            .now
1746            .checked_sub(Duration::from_millis(1))
1747            .unwrap();
1748
1749        if test_tracker.tracker.ack_aggregation_bandwidth_threshold > 1.1 {
1750            rnd(&mut test_tracker, true);
1751            assert_eq!(3, test_tracker.tracker.num_ack_aggregation_epochs);
1752        } else {
1753            rnd(&mut test_tracker, false);
1754            assert_eq!(2, test_tracker.tracker.num_ack_aggregation_epochs);
1755        }
1756    }
1757
1758    #[test]
1759    fn very_aggregated_large_acks() {
1760        test_inner(20.0, Duration::from_millis(6), 1200)
1761    }
1762
1763    #[test]
1764    fn very_aggregated_small_acks() {
1765        test_inner(20., Duration::from_millis(6), 300)
1766    }
1767
1768    #[test]
1769    fn somewhat_aggregated_large_acks() {
1770        test_inner(2.0, Duration::from_millis(50), 1000)
1771    }
1772
1773    #[test]
1774    fn somewhat_aggregated_small_acks() {
1775        test_inner(2.0, Duration::from_millis(50), 100)
1776    }
1777
1778    #[test]
1779    fn not_aggregated() {
1780        let mut test_tracker = TestTracker::new();
1781        test_tracker.aggregation_episode(
1782            test_tracker.bandwidth,
1783            Duration::from_millis(100),
1784            100,
1785            true,
1786        );
1787        assert!(2 < test_tracker.tracker.num_ack_aggregation_epochs);
1788    }
1789
1790    #[test]
1791    fn start_new_epoch_after_a_full_round() {
1792        let mut test_tracker = TestTracker::new();
1793
1794        test_tracker.last_sent_packet_number = 10;
1795
1796        test_tracker.aggregation_episode(
1797            test_tracker.bandwidth * 2.0,
1798            Duration::from_millis(50),
1799            100,
1800            true,
1801        );
1802
1803        test_tracker.last_acked_packet_number = 11;
1804
1805        // Update with a tiny bandwidth causes a very low expected bytes acked,
1806        // which in turn causes the current epoch to continue if the
1807        // `tracker` doesn't check the packet numbers.
1808        test_tracker.tracker.update(
1809            test_tracker.bandwidth * 0.1,
1810            true,
1811            test_tracker.round_trip_count(),
1812            test_tracker.last_sent_packet_number,
1813            test_tracker.last_acked_packet_number,
1814            test_tracker.now,
1815            100,
1816        );
1817
1818        assert_eq!(2, test_tracker.tracker.num_ack_aggregation_epochs)
1819    }
1820}