1use std::collections::VecDeque;
32use std::time::Duration;
33use std::time::Instant;
34
35use super::Acked;
36use crate::recovery::gcongestion::Bandwidth;
37use crate::recovery::gcongestion::Lost;
38
39use super::windowed_filter::WindowedFilter;
40
41#[derive(Debug)]
42struct ConnectionStateMap<T> {
43 packet_map: VecDeque<(u64, Option<T>)>,
44}
45
46impl<T> Default for ConnectionStateMap<T> {
47 fn default() -> Self {
48 ConnectionStateMap {
49 packet_map: VecDeque::new(),
50 }
51 }
52}
53
54impl<T> ConnectionStateMap<T> {
55 fn insert(&mut self, pkt_num: u64, val: T) {
56 if let Some((last_pkt, _)) = self.packet_map.back() {
57 assert!(pkt_num > *last_pkt, "{} > {}", pkt_num, *last_pkt);
58 }
59
60 self.packet_map.push_back((pkt_num, Some(val)));
61 }
62
63 fn take(&mut self, pkt_num: u64) -> Option<T> {
64 let first = self.packet_map.front()?;
66 if first.0 == pkt_num {
67 return self.packet_map.pop_front().and_then(|(_, v)| v);
68 }
69 let ret =
71 match self.packet_map.binary_search_by_key(&pkt_num, |&(n, _)| n) {
72 Ok(found) =>
73 self.packet_map.get_mut(found).and_then(|(_, v)| v.take()),
74 Err(_) => None,
75 };
76
77 while let Some((_, None)) = self.packet_map.front() {
78 self.packet_map.pop_front();
79 }
80
81 ret
82 }
83
84 #[cfg(test)]
85 fn peek(&self, pkt_num: u64) -> Option<&T> {
86 match self.packet_map.binary_search_by_key(&pkt_num, |&(n, _)| n) {
88 Ok(found) => self.packet_map.get(found).and_then(|(_, v)| v.as_ref()),
89 Err(_) => None,
90 }
91 }
92
93 fn remove_obsolete(&mut self, least_acked: u64) {
94 while match self.packet_map.front() {
95 Some(&(p, _)) if p < least_acked => {
96 self.packet_map.pop_front();
97 true
98 },
99 _ => false,
100 } {}
101 }
102}
103
104#[derive(Debug)]
105pub struct BandwidthSampler {
106 total_bytes_sent: usize,
109 total_bytes_acked: usize,
110 total_bytes_lost: usize,
111 total_bytes_neutered: usize,
112 last_sent_packet: u64,
113 last_acked_packet: u64,
114 is_app_limited: bool,
115 last_acked_packet_ack_time: Instant,
116 total_bytes_sent_at_last_acked_packet: usize,
117 last_acked_packet_sent_time: Instant,
118 recent_ack_points: RecentAckPoints,
119 a0_candidates: VecDeque<AckPoint>,
120 connection_state_map: ConnectionStateMap<ConnectionStateOnSentPacket>,
121 max_ack_height_tracker: MaxAckHeightTracker,
122 end_of_app_limited_phase: Option<u64>,
125 overestimate_avoidance: bool,
126 choose_a0_point_fix: bool,
130 limit_max_ack_height_tracker_by_send_rate: bool,
131
132 total_bytes_acked_after_last_ack_event: usize,
133}
134
135#[derive(Debug, Default, Clone, Copy)]
138pub struct SendTimeState {
139 pub is_valid: bool,
141 pub is_app_limited: bool,
146 pub total_bytes_sent: usize,
149 pub total_bytes_acked: usize,
151 #[allow(dead_code)]
153 pub total_bytes_lost: usize,
154 pub bytes_in_flight: usize,
159}
160
161#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)]
162struct ExtraAckedEvent {
163 extra_acked: usize,
165 bytes_acked: usize,
167 time_delta: Duration,
168 round: usize,
170}
171
172struct BandwidthSample {
173 bandwidth: Bandwidth,
175 rtt: Duration,
178 send_rate: Option<Bandwidth>,
181 state_at_send: SendTimeState,
183}
184
185#[derive(Debug, Clone, Copy)]
187struct AckPoint {
188 ack_time: Instant,
189 total_bytes_acked: usize,
190}
191
192#[derive(Debug, Default)]
195struct RecentAckPoints {
196 ack_points: [Option<AckPoint>; 2],
197}
198
199#[derive(Debug)]
204struct ConnectionStateOnSentPacket {
205 sent_time: Instant,
207 size: usize,
209 total_bytes_sent_at_last_acked_packet: usize,
212 last_acked_packet_sent_time: Instant,
215 last_acked_packet_ack_time: Instant,
218 send_time_state: SendTimeState,
221}
222
223#[derive(Debug)]
227struct MaxAckHeightTracker {
228 max_ack_height_filter: WindowedFilter<ExtraAckedEvent, usize, usize>,
231 aggregation_epoch_start_time: Option<Instant>,
234 aggregation_epoch_bytes: usize,
235 last_sent_packet_number_before_epoch: u64,
238 num_ack_aggregation_epochs: u64,
241 ack_aggregation_bandwidth_threshold: f64,
242 start_new_aggregation_epoch_after_full_round: bool,
243 reduce_extra_acked_on_bandwidth_increase: bool,
244}
245
246#[derive(Default)]
247pub(crate) struct CongestionEventSample {
248 pub sample_max_bandwidth: Option<Bandwidth>,
250 pub sample_is_app_limited: bool,
252 pub sample_rtt: Option<Duration>,
254 pub sample_max_inflight: usize,
258 pub last_packet_send_state: SendTimeState,
262 pub extra_acked: usize,
266}
267
268impl MaxAckHeightTracker {
269 pub(crate) fn new(window: usize, overestimate_avoidance: bool) -> Self {
270 MaxAckHeightTracker {
271 max_ack_height_filter: WindowedFilter::new(window),
272 aggregation_epoch_start_time: None,
273 aggregation_epoch_bytes: 0,
274 last_sent_packet_number_before_epoch: 0,
275 num_ack_aggregation_epochs: 0,
276 ack_aggregation_bandwidth_threshold: if overestimate_avoidance {
277 2.0
278 } else {
279 1.0
280 },
281 start_new_aggregation_epoch_after_full_round: true,
282 reduce_extra_acked_on_bandwidth_increase: true,
283 }
284 }
285
286 #[allow(dead_code)]
287 fn reset(&mut self, new_height: usize, new_time: usize) {
288 self.max_ack_height_filter.reset(
289 ExtraAckedEvent {
290 extra_acked: new_height,
291 bytes_acked: 0,
292 time_delta: Duration::ZERO,
293 round: new_time,
294 },
295 new_time,
296 );
297 }
298
299 #[allow(clippy::too_many_arguments)]
300 fn update(
301 &mut self, bandwidth_estimate: Bandwidth, is_new_max_bandwidth: bool,
302 round_trip_count: usize, last_sent_packet_number: u64,
303 last_acked_packet_number: u64, ack_time: Instant, bytes_acked: usize,
304 ) -> usize {
305 let mut force_new_epoch = false;
306
307 if self.reduce_extra_acked_on_bandwidth_increase && is_new_max_bandwidth {
308 let mut best =
310 self.max_ack_height_filter.get_best().unwrap_or_default();
311 let mut second_best = self
312 .max_ack_height_filter
313 .get_second_best()
314 .unwrap_or_default();
315 let mut third_best = self
316 .max_ack_height_filter
317 .get_third_best()
318 .unwrap_or_default();
319 self.max_ack_height_filter.clear();
320
321 let expected_bytes_acked =
323 bandwidth_estimate.to_bytes_per_period(best.time_delta) as usize;
324 if expected_bytes_acked < best.bytes_acked {
325 best.extra_acked = best.bytes_acked - expected_bytes_acked;
326 self.max_ack_height_filter.update(best, best.round);
327 }
328
329 let expected_bytes_acked = bandwidth_estimate
330 .to_bytes_per_period(second_best.time_delta)
331 as usize;
332 if expected_bytes_acked < second_best.bytes_acked {
333 second_best.extra_acked =
334 second_best.bytes_acked - expected_bytes_acked;
335 self.max_ack_height_filter
336 .update(second_best, second_best.round);
337 }
338
339 let expected_bytes_acked = bandwidth_estimate
340 .to_bytes_per_period(third_best.time_delta)
341 as usize;
342 if expected_bytes_acked < third_best.bytes_acked {
343 third_best.extra_acked =
344 third_best.bytes_acked - expected_bytes_acked;
345 self.max_ack_height_filter
346 .update(third_best, third_best.round);
347 }
348 }
349
350 if self.start_new_aggregation_epoch_after_full_round &&
353 last_acked_packet_number >
354 self.last_sent_packet_number_before_epoch
355 {
356 force_new_epoch = true;
357 }
358
359 let epoch_start_time = match self.aggregation_epoch_start_time {
360 Some(time) if !force_new_epoch => time,
361 _ => {
362 self.aggregation_epoch_bytes = bytes_acked;
363 self.aggregation_epoch_start_time = Some(ack_time);
364 self.last_sent_packet_number_before_epoch =
365 last_sent_packet_number;
366 self.num_ack_aggregation_epochs += 1;
367 return 0;
368 },
369 };
370
371 let aggregation_delta = ack_time.duration_since(epoch_start_time);
374 let expected_bytes_acked =
375 bandwidth_estimate.to_bytes_per_period(aggregation_delta) as usize;
376 if self.aggregation_epoch_bytes <=
379 (self.ack_aggregation_bandwidth_threshold *
380 expected_bytes_acked as f64) as usize
381 {
382 self.aggregation_epoch_bytes = bytes_acked;
384 self.aggregation_epoch_start_time = Some(ack_time);
385 self.last_sent_packet_number_before_epoch = last_sent_packet_number;
386 self.num_ack_aggregation_epochs += 1;
387 return 0;
388 }
389
390 self.aggregation_epoch_bytes += bytes_acked;
391
392 let extra_bytes_acked =
394 self.aggregation_epoch_bytes - expected_bytes_acked;
395
396 let new_event = ExtraAckedEvent {
397 extra_acked: extra_bytes_acked,
398 bytes_acked: self.aggregation_epoch_bytes,
399 time_delta: aggregation_delta,
400 round: 0,
401 };
402
403 self.max_ack_height_filter
404 .update(new_event, round_trip_count);
405 extra_bytes_acked
406 }
407}
408
409impl From<(Instant, usize, usize, &BandwidthSampler)>
410 for ConnectionStateOnSentPacket
411{
412 fn from(
413 (sent_time, size, bytes_in_flight, sampler): (
414 Instant,
415 usize,
416 usize,
417 &BandwidthSampler,
418 ),
419 ) -> Self {
420 ConnectionStateOnSentPacket {
421 sent_time,
422 size,
423 total_bytes_sent_at_last_acked_packet: sampler
424 .total_bytes_sent_at_last_acked_packet,
425 last_acked_packet_sent_time: sampler.last_acked_packet_sent_time,
426 last_acked_packet_ack_time: sampler.last_acked_packet_ack_time,
427 send_time_state: SendTimeState {
428 is_valid: true,
429 is_app_limited: sampler.is_app_limited,
430 total_bytes_sent: sampler.total_bytes_sent,
431 total_bytes_acked: sampler.total_bytes_acked,
432 total_bytes_lost: sampler.total_bytes_lost,
433 bytes_in_flight,
434 },
435 }
436 }
437}
438
439impl RecentAckPoints {
440 fn update(&mut self, ack_time: Instant, total_bytes_acked: usize) {
441 assert!(
442 total_bytes_acked >=
443 self.ack_points[1].map(|p| p.total_bytes_acked).unwrap_or(0)
444 );
445
446 self.ack_points[0] = self.ack_points[1];
447 self.ack_points[1] = Some(AckPoint {
448 ack_time,
449 total_bytes_acked,
450 });
451 }
452
453 fn clear(&mut self) {
454 self.ack_points = Default::default();
455 }
456
457 fn most_recent(&self) -> Option<AckPoint> {
458 self.ack_points[1]
459 }
460
461 fn less_recent_point(&self, choose_a0_point_fix: bool) -> Option<AckPoint> {
462 if choose_a0_point_fix {
463 self.ack_points[0]
464 .filter(|ack_point| ack_point.total_bytes_acked > 0)
465 .or(self.ack_points[1])
466 } else {
467 self.ack_points[0].or(self.ack_points[1])
468 }
469 }
470}
471
472impl BandwidthSampler {
473 pub(crate) fn new(
474 max_height_tracker_window_length: usize, overestimate_avoidance: bool,
475 choose_a0_point_fix: bool,
476 ) -> Self {
477 BandwidthSampler {
478 total_bytes_sent: 0,
479 total_bytes_acked: 0,
480 total_bytes_lost: 0,
481 total_bytes_neutered: 0,
482 total_bytes_sent_at_last_acked_packet: 0,
483 last_acked_packet_sent_time: Instant::now(),
484 last_acked_packet_ack_time: Instant::now(),
485 is_app_limited: true,
486 connection_state_map: ConnectionStateMap::default(),
487 max_ack_height_tracker: MaxAckHeightTracker::new(
488 max_height_tracker_window_length,
489 overestimate_avoidance,
490 ),
491 total_bytes_acked_after_last_ack_event: 0,
492 overestimate_avoidance,
493 choose_a0_point_fix,
494 limit_max_ack_height_tracker_by_send_rate: false,
495
496 last_sent_packet: 0,
497 last_acked_packet: 0,
498 recent_ack_points: RecentAckPoints::default(),
499 a0_candidates: VecDeque::new(),
500 end_of_app_limited_phase: None,
501 }
502 }
503
504 #[allow(dead_code)]
505 pub(crate) fn is_app_limited(&self) -> bool {
506 self.is_app_limited
507 }
508
509 pub(crate) fn on_packet_sent(
510 &mut self, sent_time: Instant, packet_number: u64, bytes: usize,
511 bytes_in_flight: usize, has_retransmittable_data: bool,
512 ) {
513 self.last_sent_packet = packet_number;
514
515 if !has_retransmittable_data {
516 return;
517 }
518
519 self.total_bytes_sent += bytes;
520
521 if bytes_in_flight == 0 {
529 self.last_acked_packet_ack_time = sent_time;
530 if self.overestimate_avoidance {
531 self.recent_ack_points.clear();
532 self.recent_ack_points
533 .update(sent_time, self.total_bytes_acked);
534 self.a0_candidates.clear();
535 self.a0_candidates
536 .push_back(self.recent_ack_points.most_recent().unwrap());
537 }
538
539 self.total_bytes_sent_at_last_acked_packet = self.total_bytes_sent;
540
541 self.last_acked_packet_sent_time = sent_time;
544 }
545
546 self.connection_state_map.insert(
547 packet_number,
548 (sent_time, bytes, bytes_in_flight + bytes, &*self).into(),
549 );
550 }
551
552 pub(crate) fn on_packet_neutered(&mut self, packet_number: u64) {
553 if let Some(pkt) = self.connection_state_map.take(packet_number) {
554 self.total_bytes_neutered += pkt.size;
555 }
556 }
557
558 pub(crate) fn on_congestion_event(
559 &mut self, ack_time: Instant, acked_packets: &[Acked],
560 lost_packets: &[Lost], mut max_bandwidth: Option<Bandwidth>,
561 est_bandwidth_upper_bound: Bandwidth, round_trip_count: usize,
562 ) -> CongestionEventSample {
563 let mut last_lost_packet_send_state = SendTimeState::default();
564 let mut last_acked_packet_send_state = SendTimeState::default();
565 let mut last_lost_packet_num = 0u64;
566 let mut last_acked_packet_num = 0u64;
567
568 for packet in lost_packets {
569 let send_state =
570 self.on_packet_lost(packet.packet_number, packet.bytes_lost);
571 if send_state.is_valid {
572 last_lost_packet_send_state = send_state;
573 last_lost_packet_num = packet.packet_number;
574 }
575 }
576
577 if acked_packets.is_empty() {
578 return CongestionEventSample {
580 last_packet_send_state: last_lost_packet_send_state,
581 ..Default::default()
582 };
583 }
584
585 let mut event_sample = CongestionEventSample::default();
586
587 let mut max_send_rate = None;
588 for packet in acked_packets {
589 let sample =
590 match self.on_packet_acknowledged(ack_time, packet.pkt_num) {
591 Some(sample) if sample.state_at_send.is_valid => sample,
592 _ => continue,
593 };
594
595 last_acked_packet_send_state = sample.state_at_send;
596 last_acked_packet_num = packet.pkt_num;
597
598 event_sample.sample_rtt = Some(
599 sample
600 .rtt
601 .min(*event_sample.sample_rtt.get_or_insert(sample.rtt)),
602 );
603
604 if Some(sample.bandwidth) > event_sample.sample_max_bandwidth {
605 event_sample.sample_max_bandwidth = Some(sample.bandwidth);
606 event_sample.sample_is_app_limited =
607 sample.state_at_send.is_app_limited;
608 }
609 max_send_rate = max_send_rate.max(sample.send_rate);
610
611 let inflight_sample = self.total_bytes_acked -
612 last_acked_packet_send_state.total_bytes_acked;
613 if inflight_sample > event_sample.sample_max_inflight {
614 event_sample.sample_max_inflight = inflight_sample;
615 }
616 }
617
618 if !last_lost_packet_send_state.is_valid {
619 event_sample.last_packet_send_state = last_acked_packet_send_state;
620 } else if !last_acked_packet_send_state.is_valid {
621 event_sample.last_packet_send_state = last_lost_packet_send_state;
622 } else {
623 event_sample.last_packet_send_state =
628 if last_acked_packet_num > last_lost_packet_num {
629 last_acked_packet_send_state
630 } else {
631 last_lost_packet_send_state
632 };
633 }
634
635 let is_new_max_bandwidth =
636 event_sample.sample_max_bandwidth > max_bandwidth;
637 max_bandwidth = event_sample.sample_max_bandwidth.max(max_bandwidth);
638
639 if self.limit_max_ack_height_tracker_by_send_rate {
640 max_bandwidth = max_bandwidth.max(max_send_rate);
641 }
642
643 let bandwidth_estimate = if let Some(max_bandwidth) = max_bandwidth {
644 max_bandwidth.min(est_bandwidth_upper_bound)
645 } else {
646 est_bandwidth_upper_bound
647 };
648
649 event_sample.extra_acked = self.on_ack_event_end(
650 bandwidth_estimate,
651 is_new_max_bandwidth,
652 round_trip_count,
653 );
654
655 event_sample
656 }
657
658 fn on_packet_lost(
659 &mut self, packet_number: u64, bytes_lost: usize,
660 ) -> SendTimeState {
661 let mut send_time_state = SendTimeState::default();
662
663 self.total_bytes_lost += bytes_lost;
664 if let Some(state) = self.connection_state_map.take(packet_number) {
665 send_time_state = state.send_time_state;
666 send_time_state.is_valid = true;
667 }
668
669 send_time_state
670 }
671
672 fn on_ack_event_end(
673 &mut self, bandwidth_estimate: Bandwidth, is_new_max_bandwidth: bool,
674 round_trip_count: usize,
675 ) -> usize {
676 let newly_acked_bytes =
677 self.total_bytes_acked - self.total_bytes_acked_after_last_ack_event;
678
679 if newly_acked_bytes == 0 {
680 return 0;
681 }
682
683 self.total_bytes_acked_after_last_ack_event = self.total_bytes_acked;
684 let extra_acked = self.max_ack_height_tracker.update(
685 bandwidth_estimate,
686 is_new_max_bandwidth,
687 round_trip_count,
688 self.last_sent_packet,
689 self.last_acked_packet,
690 self.last_acked_packet_ack_time,
691 newly_acked_bytes,
692 );
693 if self.overestimate_avoidance && extra_acked == 0 {
697 self.a0_candidates.push_back(
698 self.recent_ack_points
699 .less_recent_point(self.choose_a0_point_fix)
700 .unwrap(),
701 );
702 }
703
704 extra_acked
705 }
706
707 fn on_packet_acknowledged(
708 &mut self, ack_time: Instant, packet_number: u64,
709 ) -> Option<BandwidthSample> {
710 self.last_acked_packet = packet_number;
711 let sent_packet = self.connection_state_map.take(packet_number)?;
712
713 self.total_bytes_acked += sent_packet.size;
714 self.total_bytes_sent_at_last_acked_packet =
715 sent_packet.send_time_state.total_bytes_sent;
716 self.last_acked_packet_sent_time = sent_packet.sent_time;
717 self.last_acked_packet_ack_time = ack_time;
718 if self.overestimate_avoidance {
719 self.recent_ack_points
720 .update(ack_time, self.total_bytes_acked);
721 }
722
723 if self.is_app_limited {
724 if self.end_of_app_limited_phase.is_none() ||
730 Some(packet_number) > self.end_of_app_limited_phase
731 {
732 self.is_app_limited = false;
733 }
734 }
735
736 let send_rate = if sent_packet.sent_time >
739 sent_packet.last_acked_packet_sent_time
740 {
741 Some(Bandwidth::from_bytes_and_time_delta(
742 sent_packet.send_time_state.total_bytes_sent -
743 sent_packet.total_bytes_sent_at_last_acked_packet,
744 sent_packet.sent_time - sent_packet.last_acked_packet_sent_time,
745 ))
746 } else {
747 None
748 };
749
750 let a0 = if self.overestimate_avoidance {
751 Self::choose_a0_point(
752 &mut self.a0_candidates,
753 sent_packet.send_time_state.total_bytes_acked,
754 self.choose_a0_point_fix,
755 )
756 } else {
757 None
758 };
759
760 let a0 = a0.unwrap_or(AckPoint {
761 ack_time: sent_packet.last_acked_packet_ack_time,
762 total_bytes_acked: sent_packet.send_time_state.total_bytes_acked,
763 });
764
765 if ack_time <= a0.ack_time {
769 return None;
770 }
771
772 let ack_rate = Bandwidth::from_bytes_and_time_delta(
773 self.total_bytes_acked - a0.total_bytes_acked,
774 ack_time.duration_since(a0.ack_time),
775 );
776
777 let bandwidth = if let Some(send_rate) = send_rate {
778 send_rate.min(ack_rate)
779 } else {
780 ack_rate
781 };
782
783 let rtt = ack_time.duration_since(sent_packet.sent_time);
787
788 Some(BandwidthSample {
789 bandwidth,
790 rtt,
791 send_rate,
792 state_at_send: SendTimeState {
793 is_valid: true,
794 ..sent_packet.send_time_state
795 },
796 })
797 }
798
799 fn choose_a0_point(
800 a0_candidates: &mut VecDeque<AckPoint>, total_bytes_acked: usize,
801 choose_a0_point_fix: bool,
802 ) -> Option<AckPoint> {
803 if a0_candidates.is_empty() {
804 return None;
805 }
806
807 while let Some(candidate) = a0_candidates.get(1) {
808 if candidate.total_bytes_acked > total_bytes_acked {
809 if choose_a0_point_fix {
810 break;
811 } else {
812 return Some(*candidate);
813 }
814 }
815 a0_candidates.pop_front();
816 }
817
818 Some(a0_candidates[0])
819 }
820
821 pub(crate) fn total_bytes_acked(&self) -> usize {
822 self.total_bytes_acked
823 }
824
825 pub(crate) fn total_bytes_lost(&self) -> usize {
826 self.total_bytes_lost
827 }
828
829 #[allow(dead_code)]
830 pub(crate) fn reset_max_ack_height_tracker(
831 &mut self, new_height: usize, new_time: usize,
832 ) {
833 self.max_ack_height_tracker.reset(new_height, new_time);
834 }
835
836 pub(crate) fn max_ack_height(&self) -> Option<usize> {
837 self.max_ack_height_tracker
838 .max_ack_height_filter
839 .get_best()
840 .map(|b| b.extra_acked)
841 }
842
843 pub(crate) fn on_app_limited(&mut self) {
844 self.is_app_limited = true;
845 self.end_of_app_limited_phase = Some(self.last_sent_packet);
846 }
847
848 pub(crate) fn remove_obsolete_packets(&mut self, least_acked: u64) {
849 self.connection_state_map.remove_obsolete(least_acked);
857 }
858}
859
860#[cfg(test)]
861mod bandwidth_sampler_tests {
862 use rstest::rstest;
863
864 use super::*;
865
866 const REGULAR_PACKET_SIZE: usize = 1280;
867
868 struct TestSender {
869 sampler: BandwidthSampler,
870 sampler_app_limited_at_start: bool,
871 bytes_in_flight: usize,
872 clock: Instant,
873 max_bandwidth: Bandwidth,
874 est_bandwidth_upper_bound: Bandwidth,
875 round_trip_count: usize,
876 }
877
878 impl TestSender {
879 fn new(overestimate_avoidance: bool, choose_a0_point_fix: bool) -> Self {
880 let sampler = BandwidthSampler::new(
881 0,
882 overestimate_avoidance,
883 choose_a0_point_fix,
884 );
885 TestSender {
886 sampler_app_limited_at_start: sampler.is_app_limited(),
887 sampler,
888 bytes_in_flight: 0,
889 clock: Instant::now(),
890 max_bandwidth: Bandwidth::zero(),
891 est_bandwidth_upper_bound: Bandwidth::infinite(),
892 round_trip_count: 0,
893 }
894 }
895
896 fn get_packet_size(&self, pkt_num: u64) -> usize {
897 self.sampler
898 .connection_state_map
899 .peek(pkt_num)
900 .unwrap()
901 .size
902 }
903
904 fn get_packet_time(&self, pkt_num: u64) -> Instant {
905 self.sampler
906 .connection_state_map
907 .peek(pkt_num)
908 .unwrap()
909 .sent_time
910 }
911
912 fn number_of_tracked_packets(&self) -> usize {
913 self.sampler.connection_state_map.packet_map.len()
914 }
915
916 fn make_acked_packet(&self, pkt_num: u64) -> Acked {
917 let time_sent = self.get_packet_time(pkt_num);
918
919 Acked { pkt_num, time_sent }
920 }
921
922 fn make_lost_packet(&self, pkt_num: u64) -> Lost {
923 let size = self.get_packet_size(pkt_num);
924 Lost {
925 packet_number: pkt_num,
926 bytes_lost: size,
927 }
928 }
929
930 fn ack_packet(&mut self, pkt_num: u64) -> BandwidthSample {
931 let size = self.get_packet_size(pkt_num);
932 self.bytes_in_flight -= size;
933
934 let sample = self.sampler.on_congestion_event(
935 self.clock,
936 &[self.make_acked_packet(pkt_num)],
937 &[],
938 Some(self.max_bandwidth),
939 self.est_bandwidth_upper_bound,
940 self.round_trip_count,
941 );
942
943 let sample_max_bandwidth = sample.sample_max_bandwidth.unwrap();
944 self.max_bandwidth = self.max_bandwidth.max(sample_max_bandwidth);
945
946 let bandwidth_sample = BandwidthSample {
947 bandwidth: sample_max_bandwidth,
948 rtt: sample.sample_rtt.unwrap(),
949 send_rate: None,
950 state_at_send: sample.last_packet_send_state,
951 };
952 assert!(bandwidth_sample.state_at_send.is_valid);
953 bandwidth_sample
954 }
955
956 fn lose_packet(&mut self, pkt_num: u64) -> SendTimeState {
957 let size = self.get_packet_size(pkt_num);
958 self.bytes_in_flight -= size;
959
960 let sample = self.sampler.on_congestion_event(
961 self.clock,
962 &[],
963 &[self.make_lost_packet(pkt_num)],
964 Some(self.max_bandwidth),
965 self.est_bandwidth_upper_bound,
966 self.round_trip_count,
967 );
968
969 assert!(sample.last_packet_send_state.is_valid);
970 assert_eq!(sample.sample_max_bandwidth, None);
971 assert_eq!(sample.sample_rtt, None);
972 sample.last_packet_send_state
973 }
974
975 fn on_congestion_event(
976 &mut self, acked: &[u64], lost: &[u64],
977 ) -> CongestionEventSample {
978 let acked = acked
979 .iter()
980 .map(|pkt| {
981 let acked_size = self.get_packet_size(*pkt);
982 self.bytes_in_flight -= acked_size;
983
984 self.make_acked_packet(*pkt)
985 })
986 .collect::<Vec<_>>();
987
988 let lost = lost
989 .iter()
990 .map(|pkt| {
991 let lost = self.make_lost_packet(*pkt);
992 self.bytes_in_flight -= lost.bytes_lost;
993 lost
994 })
995 .collect::<Vec<_>>();
996
997 let sample = self.sampler.on_congestion_event(
998 self.clock,
999 &acked,
1000 &lost,
1001 Some(self.max_bandwidth),
1002 self.est_bandwidth_upper_bound,
1003 self.round_trip_count,
1004 );
1005
1006 self.max_bandwidth =
1007 self.max_bandwidth.max(sample.sample_max_bandwidth.unwrap());
1008
1009 sample
1010 }
1011
1012 fn send_packet(
1013 &mut self, pkt_num: u64, pkt_sz: usize,
1014 has_retransmittable_data: bool,
1015 ) {
1016 self.sampler.on_packet_sent(
1017 self.clock,
1018 pkt_num,
1019 pkt_sz,
1020 self.bytes_in_flight,
1021 has_retransmittable_data,
1022 );
1023 if has_retransmittable_data {
1024 self.bytes_in_flight += pkt_sz;
1025 }
1026 }
1027
1028 fn advance_time(&mut self, delta: Duration) {
1029 self.clock += delta;
1030 }
1031
1032 fn send_40_and_ack_first_20(&mut self, time_between_packets: Duration) {
1035 for i in 1..=20 {
1037 self.send_packet(i, REGULAR_PACKET_SIZE, true);
1038 self.advance_time(time_between_packets);
1039 }
1040
1041 for i in 1..=20 {
1044 self.ack_packet(i);
1045 self.send_packet(i + 20, REGULAR_PACKET_SIZE, true);
1046 self.advance_time(time_between_packets);
1047 }
1048 }
1049 }
1050
1051 #[rstest]
1052 fn send_and_wait(
1053 #[values(false, true)] overestimate_avoidance: bool,
1054 #[values(false, true)] choose_a0_point_fix: bool,
1055 ) {
1056 let mut test_sender =
1057 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1058 let mut time_between_packets = Duration::from_millis(10);
1059 let mut expected_bandwidth =
1060 Bandwidth::from_bytes_per_second(REGULAR_PACKET_SIZE as u64 * 100);
1061
1062 for i in 1..20 {
1064 test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1065 test_sender.advance_time(time_between_packets);
1066 let current_sample = test_sender.ack_packet(i);
1067 assert_eq!(expected_bandwidth, current_sample.bandwidth);
1068 }
1069
1070 for i in 20..25 {
1072 time_between_packets *= 2;
1073 expected_bandwidth = expected_bandwidth * 0.5;
1074
1075 test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1076 test_sender.advance_time(time_between_packets);
1077 let current_sample = test_sender.ack_packet(i);
1078 assert_eq!(expected_bandwidth, current_sample.bandwidth);
1079 }
1080
1081 test_sender.sampler.remove_obsolete_packets(25);
1082 assert_eq!(0, test_sender.number_of_tracked_packets());
1083 assert_eq!(0, test_sender.bytes_in_flight);
1084 }
1085
1086 #[rstest]
1087 fn send_time_state(
1088 #[values(false, true)] overestimate_avoidance: bool,
1089 #[values(false, true)] choose_a0_point_fix: bool,
1090 ) {
1091 let mut test_sender =
1092 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1093 let time_between_packets = Duration::from_millis(10);
1094
1095 for i in 1..=5 {
1097 test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1098 assert_eq!(
1099 test_sender.sampler.total_bytes_sent,
1100 REGULAR_PACKET_SIZE * i as usize
1101 );
1102 test_sender.advance_time(time_between_packets);
1103 }
1104
1105 let send_time_state = test_sender.ack_packet(1).state_at_send;
1107 assert_eq!(REGULAR_PACKET_SIZE, send_time_state.total_bytes_sent);
1108 assert_eq!(0, send_time_state.total_bytes_acked);
1109 assert_eq!(0, send_time_state.total_bytes_lost);
1110 assert_eq!(REGULAR_PACKET_SIZE, test_sender.sampler.total_bytes_acked);
1111
1112 let send_time_state = test_sender.lose_packet(2);
1114 assert_eq!(REGULAR_PACKET_SIZE * 2, send_time_state.total_bytes_sent);
1115 assert_eq!(0, send_time_state.total_bytes_acked);
1116 assert_eq!(0, send_time_state.total_bytes_lost);
1117 assert_eq!(REGULAR_PACKET_SIZE, test_sender.sampler.total_bytes_lost);
1118
1119 let send_time_state = test_sender.lose_packet(3);
1121 assert_eq!(REGULAR_PACKET_SIZE * 3, send_time_state.total_bytes_sent);
1122 assert_eq!(0, send_time_state.total_bytes_acked);
1123 assert_eq!(0, send_time_state.total_bytes_lost);
1124 assert_eq!(
1125 REGULAR_PACKET_SIZE * 2,
1126 test_sender.sampler.total_bytes_lost
1127 );
1128
1129 for i in 6..=10 {
1131 test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1132 assert_eq!(
1133 test_sender.sampler.total_bytes_sent,
1134 REGULAR_PACKET_SIZE * i as usize
1135 );
1136 test_sender.advance_time(time_between_packets);
1137 }
1138
1139 let mut acked_packet_count = 1;
1141 assert_eq!(
1142 REGULAR_PACKET_SIZE * acked_packet_count,
1143 test_sender.sampler.total_bytes_acked
1144 );
1145 for i in 4..=10 {
1146 let send_time_state = test_sender.ack_packet(i).state_at_send;
1147 acked_packet_count += 1;
1148 assert_eq!(
1149 REGULAR_PACKET_SIZE * acked_packet_count,
1150 test_sender.sampler.total_bytes_acked
1151 );
1152 assert_eq!(
1153 REGULAR_PACKET_SIZE * i as usize,
1154 send_time_state.total_bytes_sent
1155 );
1156
1157 if i <= 5 {
1158 assert_eq!(0, send_time_state.total_bytes_acked);
1159 assert_eq!(0, send_time_state.total_bytes_lost);
1160 } else {
1161 assert_eq!(
1162 REGULAR_PACKET_SIZE,
1163 send_time_state.total_bytes_acked
1164 );
1165 assert_eq!(
1166 REGULAR_PACKET_SIZE * 2,
1167 send_time_state.total_bytes_lost
1168 );
1169 }
1170
1171 assert_eq!(
1173 send_time_state.total_bytes_sent -
1174 send_time_state.total_bytes_acked -
1175 send_time_state.total_bytes_lost,
1176 send_time_state.bytes_in_flight
1177 );
1178
1179 test_sender.advance_time(time_between_packets);
1180 }
1181 }
1182
1183 #[rstest]
1186 fn send_paced(
1187 #[values(false, true)] overestimate_avoidance: bool,
1188 #[values(false, true)] choose_a0_point_fix: bool,
1189 ) {
1190 let mut test_sender =
1191 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1192 let time_between_packets = Duration::from_millis(1);
1193 let expected_bandwidth =
1194 Bandwidth::from_kbits_per_second(REGULAR_PACKET_SIZE as u64 * 8);
1195
1196 test_sender.send_40_and_ack_first_20(time_between_packets);
1197
1198 for i in 21..=40 {
1200 let last_bandwidth = test_sender.ack_packet(i).bandwidth;
1201 assert_eq!(expected_bandwidth, last_bandwidth);
1202 test_sender.advance_time(time_between_packets);
1203 }
1204 test_sender.sampler.remove_obsolete_packets(41);
1205 assert_eq!(0, test_sender.number_of_tracked_packets());
1206 assert_eq!(0, test_sender.bytes_in_flight);
1207 }
1208
1209 #[rstest]
1212 fn send_with_losses(
1213 #[values(false, true)] overestimate_avoidance: bool,
1214 #[values(false, true)] choose_a0_point_fix: bool,
1215 ) {
1216 let mut test_sender =
1217 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1218 let time_between_packets = Duration::from_millis(1);
1219 let expected_bandwidth =
1220 Bandwidth::from_kbits_per_second(REGULAR_PACKET_SIZE as u64 / 2 * 8);
1221
1222 for i in 1..=20 {
1224 test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1225 test_sender.advance_time(time_between_packets);
1226 }
1227
1228 for i in 1..=20 {
1231 if i % 2 == 0 {
1232 test_sender.ack_packet(i);
1233 } else {
1234 test_sender.lose_packet(i);
1235 }
1236 test_sender.send_packet(i + 20, REGULAR_PACKET_SIZE, true);
1237 test_sender.advance_time(time_between_packets);
1238 }
1239
1240 for i in 21..=40 {
1242 if i % 2 == 0 {
1243 let last_bandwidth = test_sender.ack_packet(i).bandwidth;
1244 assert_eq!(expected_bandwidth, last_bandwidth);
1245 } else {
1246 test_sender.lose_packet(i);
1247 }
1248 test_sender.advance_time(time_between_packets);
1249 }
1250 test_sender.sampler.remove_obsolete_packets(41);
1251 assert_eq!(0, test_sender.number_of_tracked_packets());
1252 assert_eq!(0, test_sender.bytes_in_flight);
1253 }
1254
1255 #[rstest]
1260 fn not_congestion_controlled(
1261 #[values(false, true)] overestimate_avoidance: bool,
1262 #[values(false, true)] choose_a0_point_fix: bool,
1263 ) {
1264 let mut test_sender =
1265 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1266 let time_between_packets = Duration::from_millis(1);
1267 let expected_bandwidth =
1268 Bandwidth::from_kbits_per_second(REGULAR_PACKET_SIZE as u64 / 2 * 8);
1269
1270 for i in 1..=20 {
1273 let has_retransmittable_data = i % 2 == 0;
1274 test_sender.send_packet(
1275 i,
1276 REGULAR_PACKET_SIZE,
1277 has_retransmittable_data,
1278 );
1279 test_sender.advance_time(time_between_packets);
1280 }
1281
1282 assert_eq!(10, test_sender.number_of_tracked_packets());
1284
1285 for i in 1..=20 {
1288 if i % 2 == 0 {
1289 test_sender.ack_packet(i);
1290 }
1291 let has_retransmittable_data = i % 2 == 0;
1292 test_sender.send_packet(
1293 i + 20,
1294 REGULAR_PACKET_SIZE,
1295 has_retransmittable_data,
1296 );
1297 test_sender.advance_time(time_between_packets);
1298 }
1299
1300 for i in 21..=40 {
1302 if i % 2 == 0 {
1303 let last_bandwidth = test_sender.ack_packet(i).bandwidth;
1304 assert_eq!(expected_bandwidth, last_bandwidth);
1305 }
1306 test_sender.advance_time(time_between_packets);
1307 }
1308
1309 test_sender.sampler.remove_obsolete_packets(41);
1310 assert_eq!(0, test_sender.number_of_tracked_packets());
1313 assert_eq!(0, test_sender.bytes_in_flight);
1314 }
1315
1316 #[rstest]
1319 fn compressed_ack(
1320 #[values(false, true)] overestimate_avoidance: bool,
1321 #[values(false, true)] choose_a0_point_fix: bool,
1322 ) {
1323 let mut test_sender =
1324 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1325 let time_between_packets = Duration::from_millis(1);
1326 let expected_bandwidth =
1327 Bandwidth::from_kbits_per_second(REGULAR_PACKET_SIZE as u64 * 8);
1328
1329 test_sender.send_40_and_ack_first_20(time_between_packets);
1330
1331 test_sender.advance_time(time_between_packets * 15);
1333
1334 let ridiculously_small_time_delta = Duration::from_micros(20);
1336 let mut last_bandwidth = Bandwidth::zero();
1337 for i in 21..=40 {
1338 last_bandwidth = test_sender.ack_packet(i).bandwidth;
1339 test_sender.advance_time(ridiculously_small_time_delta);
1340 }
1341 assert_eq!(expected_bandwidth, last_bandwidth);
1342
1343 test_sender.sampler.remove_obsolete_packets(41);
1344 assert_eq!(0, test_sender.number_of_tracked_packets());
1347 assert_eq!(0, test_sender.bytes_in_flight);
1348 }
1349
1350 #[rstest]
1352 fn reordered_ack(
1353 #[values(false, true)] overestimate_avoidance: bool,
1354 #[values(false, true)] choose_a0_point_fix: bool,
1355 ) {
1356 let mut test_sender =
1357 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1358 let time_between_packets = Duration::from_millis(1);
1359 let expected_bandwidth =
1360 Bandwidth::from_kbits_per_second(REGULAR_PACKET_SIZE as u64 * 8);
1361
1362 test_sender.send_40_and_ack_first_20(time_between_packets);
1363
1364 for i in 0..20 {
1367 let last_bandwidth = test_sender.ack_packet(40 - i).bandwidth;
1368 assert_eq!(expected_bandwidth, last_bandwidth);
1369 test_sender.send_packet(41 + i, REGULAR_PACKET_SIZE, true);
1370 test_sender.advance_time(time_between_packets);
1371 }
1372
1373 for i in 41..=60 {
1375 let last_bandwidth = test_sender.ack_packet(i).bandwidth;
1376 assert_eq!(expected_bandwidth, last_bandwidth);
1377 test_sender.advance_time(time_between_packets);
1378 }
1379
1380 test_sender.sampler.remove_obsolete_packets(61);
1381 assert_eq!(0, test_sender.number_of_tracked_packets());
1382 assert_eq!(0, test_sender.bytes_in_flight);
1383 }
1384
1385 #[rstest]
1387 fn app_limited(
1388 #[values(false, true)] overestimate_avoidance: bool,
1389 #[values(false, true)] choose_a0_point_fix: bool,
1390 ) {
1391 let mut test_sender =
1392 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1393 let time_between_packets = Duration::from_millis(1);
1394 let expected_bandwidth =
1395 Bandwidth::from_kbits_per_second(REGULAR_PACKET_SIZE as u64 * 8);
1396
1397 for i in 1..=20 {
1398 test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1399 test_sender.advance_time(time_between_packets);
1400 }
1401
1402 for i in 1..=20 {
1403 let sample = test_sender.ack_packet(i);
1404 assert_eq!(
1405 sample.state_at_send.is_app_limited,
1406 test_sender.sampler_app_limited_at_start,
1407 "{i}"
1408 );
1409 test_sender.send_packet(i + 20, REGULAR_PACKET_SIZE, true);
1410 test_sender.advance_time(time_between_packets);
1411 }
1412
1413 test_sender.sampler.on_app_limited();
1416 for i in 21..=40 {
1417 let sample = test_sender.ack_packet(i);
1418 assert!(!sample.state_at_send.is_app_limited, "{i}");
1419 assert_eq!(expected_bandwidth, sample.bandwidth, "{i}");
1420 test_sender.advance_time(time_between_packets);
1421 }
1422
1423 test_sender.advance_time(Duration::from_secs(1));
1425
1426 for i in 41..=60 {
1428 test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1429 test_sender.advance_time(time_between_packets);
1430 }
1431
1432 for i in 41..=60 {
1435 let sample = test_sender.ack_packet(i);
1436 assert!(sample.state_at_send.is_app_limited, "{i}");
1437 if !overestimate_avoidance || choose_a0_point_fix || i < 43 {
1438 assert!(
1439 sample.bandwidth < expected_bandwidth * 0.7,
1440 "{} {:?} vs {:?}",
1441 i,
1442 sample.bandwidth,
1443 expected_bandwidth * 0.7
1444 );
1445 } else {
1446 assert_eq!(sample.bandwidth, expected_bandwidth, "{i}");
1449 }
1450 test_sender.send_packet(i + 20, REGULAR_PACKET_SIZE, true);
1451 test_sender.advance_time(time_between_packets);
1452 }
1453
1454 for i in 61..=80 {
1457 let sample = test_sender.ack_packet(i);
1458 assert!(!sample.state_at_send.is_app_limited, "{i}");
1459 assert_eq!(sample.bandwidth, expected_bandwidth, "{i}");
1460 test_sender.advance_time(time_between_packets);
1461 }
1462
1463 test_sender.sampler.remove_obsolete_packets(81);
1464 assert_eq!(0, test_sender.number_of_tracked_packets());
1465 assert_eq!(0, test_sender.bytes_in_flight);
1466 }
1467
1468 #[rstest]
1470 fn first_round_trip(
1471 #[values(false, true)] overestimate_avoidance: bool,
1472 #[values(false, true)] choose_a0_point_fix: bool,
1473 ) {
1474 let mut test_sender =
1475 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1476 let time_between_packets = Duration::from_millis(1);
1477 let rtt = Duration::from_millis(800);
1478 let num_packets = 10;
1479 let num_bytes = REGULAR_PACKET_SIZE * num_packets;
1480 let real_bandwidth = Bandwidth::from_bytes_and_time_delta(num_bytes, rtt);
1481
1482 for i in 1..=10 {
1483 test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1484 test_sender.advance_time(time_between_packets);
1485 }
1486 test_sender.advance_time(rtt - time_between_packets * num_packets as _);
1487
1488 let mut last_sample = Bandwidth::zero();
1489 for i in 1..=10 {
1490 let sample = test_sender.ack_packet(i).bandwidth;
1491 assert!(sample > last_sample);
1492 last_sample = sample;
1493 test_sender.advance_time(time_between_packets);
1494 }
1495
1496 assert!(last_sample < real_bandwidth);
1503 assert!(last_sample > real_bandwidth * 0.9);
1504 }
1505
1506 #[rstest]
1508 fn remove_obsolete_packets(
1509 #[values(false, true)] overestimate_avoidance: bool,
1510 #[values(false, true)] choose_a0_point_fix: bool,
1511 ) {
1512 let mut test_sender =
1513 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1514
1515 for i in 1..=5 {
1516 test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1517 }
1518 test_sender.advance_time(Duration::from_millis(100));
1519 assert_eq!(5, test_sender.number_of_tracked_packets());
1520 test_sender.sampler.remove_obsolete_packets(4);
1521 assert_eq!(2, test_sender.number_of_tracked_packets());
1522 test_sender.lose_packet(4);
1523 test_sender.sampler.remove_obsolete_packets(5);
1524 assert_eq!(1, test_sender.number_of_tracked_packets());
1525 test_sender.ack_packet(5);
1526 test_sender.sampler.remove_obsolete_packets(6);
1527 assert_eq!(0, test_sender.number_of_tracked_packets());
1528 }
1529
1530 #[rstest]
1531 fn neuter_packet(
1532 #[values(false, true)] overestimate_avoidance: bool,
1533 #[values(false, true)] choose_a0_point_fix: bool,
1534 ) {
1535 let mut test_sender =
1536 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1537 test_sender.send_packet(1, REGULAR_PACKET_SIZE, true);
1538 assert_eq!(test_sender.sampler.total_bytes_neutered, 0);
1539 test_sender.advance_time(Duration::from_millis(10));
1540 test_sender.sampler.on_packet_neutered(1);
1541 assert!(0 < test_sender.sampler.total_bytes_neutered);
1542 assert_eq!(0, test_sender.sampler.total_bytes_acked);
1543
1544 let acked = Acked {
1546 pkt_num: 1,
1547 time_sent: test_sender.clock,
1548 };
1549 test_sender.advance_time(Duration::from_millis(10));
1550 let sample = test_sender.sampler.on_congestion_event(
1551 test_sender.clock,
1552 &[acked],
1553 &[],
1554 Some(test_sender.max_bandwidth),
1555 test_sender.est_bandwidth_upper_bound,
1556 test_sender.round_trip_count,
1557 );
1558
1559 assert_eq!(0, test_sender.sampler.total_bytes_acked);
1560 assert!(sample.sample_max_bandwidth.is_none());
1561 assert!(!sample.sample_is_app_limited);
1562 assert!(sample.sample_rtt.is_none());
1563 assert_eq!(sample.sample_max_inflight, 0);
1564 assert_eq!(sample.extra_acked, 0);
1565 }
1566
1567 #[rstest]
1571 fn congestion_event_sample_default_values() {
1572 let sample = CongestionEventSample::default();
1573 assert!(sample.sample_max_bandwidth.is_none());
1574 assert!(!sample.sample_is_app_limited);
1575 assert!(sample.sample_rtt.is_none());
1576 assert_eq!(sample.sample_max_inflight, 0);
1577 assert_eq!(sample.extra_acked, 0);
1578 }
1579
1580 #[rstest]
1582 fn two_acked_packets_per_event(
1583 #[values(false, true)] overestimate_avoidance: bool,
1584 #[values(false, true)] choose_a0_point_fix: bool,
1585 ) {
1586 let mut test_sender =
1587 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1588 let time_between_packets = Duration::from_millis(10);
1589 let sending_rate = Bandwidth::from_bytes_and_time_delta(
1590 REGULAR_PACKET_SIZE,
1591 time_between_packets,
1592 );
1593
1594 for i in 1..21 {
1595 test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1596 test_sender.advance_time(time_between_packets);
1597 if i % 2 != 0 {
1598 continue;
1599 }
1600
1601 let sample = test_sender.on_congestion_event(&[i - 1, i], &[]);
1602 assert_eq!(sending_rate, sample.sample_max_bandwidth.unwrap());
1603 assert_eq!(time_between_packets, sample.sample_rtt.unwrap());
1604 assert_eq!(2 * REGULAR_PACKET_SIZE, sample.sample_max_inflight);
1605 assert!(sample.last_packet_send_state.is_valid);
1606 assert_eq!(
1607 2 * REGULAR_PACKET_SIZE,
1608 sample.last_packet_send_state.bytes_in_flight
1609 );
1610 assert_eq!(
1611 i as usize * REGULAR_PACKET_SIZE,
1612 sample.last_packet_send_state.total_bytes_sent
1613 );
1614 assert_eq!(
1615 (i - 2) as usize * REGULAR_PACKET_SIZE,
1616 sample.last_packet_send_state.total_bytes_acked
1617 );
1618 assert_eq!(0, sample.last_packet_send_state.total_bytes_lost);
1619 test_sender.sampler.remove_obsolete_packets(i - 2);
1620 }
1621 }
1622
1623 #[rstest]
1624 fn lose_every_other_packet(
1625 #[values(false, true)] overestimate_avoidance: bool,
1626 #[values(false, true)] choose_a0_point_fix: bool,
1627 ) {
1628 let mut test_sender =
1629 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1630 let time_between_packets = Duration::from_millis(10);
1631 let sending_rate = Bandwidth::from_bytes_and_time_delta(
1632 REGULAR_PACKET_SIZE,
1633 time_between_packets,
1634 );
1635
1636 for i in 1..21 {
1637 test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1638 test_sender.advance_time(time_between_packets);
1639 if i % 2 != 0 {
1640 continue;
1641 }
1642 let sample = test_sender.on_congestion_event(&[i], &[i - 1]);
1644 assert_eq!(sending_rate, sample.sample_max_bandwidth.unwrap() * 2.);
1647 assert_eq!(time_between_packets, sample.sample_rtt.unwrap());
1648 assert_eq!(REGULAR_PACKET_SIZE, sample.sample_max_inflight);
1649 assert!(sample.last_packet_send_state.is_valid);
1650 assert_eq!(
1651 2 * REGULAR_PACKET_SIZE,
1652 sample.last_packet_send_state.bytes_in_flight
1653 );
1654 assert_eq!(
1655 i as usize * REGULAR_PACKET_SIZE,
1656 sample.last_packet_send_state.total_bytes_sent
1657 );
1658 assert_eq!(
1659 (i - 2) as usize * REGULAR_PACKET_SIZE / 2,
1660 sample.last_packet_send_state.total_bytes_acked
1661 );
1662 assert_eq!(
1663 (i - 2) as usize * REGULAR_PACKET_SIZE / 2,
1664 sample.last_packet_send_state.total_bytes_lost
1665 );
1666 test_sender.sampler.remove_obsolete_packets(i - 2);
1667 }
1668 }
1669
1670 #[rstest]
1671 fn ack_height_respect_bandwidth_estimate_upper_bound(
1672 #[values(false, true)] overestimate_avoidance: bool,
1673 #[values(false, true)] choose_a0_point_fix: bool,
1674 ) {
1675 let mut test_sender =
1676 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1677 let time_between_packets = Duration::from_millis(10);
1678 let first_packet_sending_rate = Bandwidth::from_bytes_and_time_delta(
1679 REGULAR_PACKET_SIZE,
1680 time_between_packets,
1681 );
1682
1683 test_sender.send_packet(1, REGULAR_PACKET_SIZE, true);
1685 test_sender.advance_time(time_between_packets);
1686 test_sender.send_packet(2, REGULAR_PACKET_SIZE, true);
1687 test_sender.send_packet(3, REGULAR_PACKET_SIZE, true);
1688 test_sender.send_packet(4, REGULAR_PACKET_SIZE, true);
1689 let sample = test_sender.on_congestion_event(&[1], &[]);
1690 assert_eq!(
1691 first_packet_sending_rate,
1692 sample.sample_max_bandwidth.unwrap()
1693 );
1694 assert_eq!(first_packet_sending_rate, test_sender.max_bandwidth);
1695
1696 test_sender.round_trip_count += 1;
1699 test_sender.est_bandwidth_upper_bound = first_packet_sending_rate * 0.3;
1700 test_sender.advance_time(time_between_packets);
1701
1702 let sample = test_sender.on_congestion_event(&[2, 3, 4], &[]);
1703
1704 assert_eq!(
1705 first_packet_sending_rate * 2.,
1706 sample.sample_max_bandwidth.unwrap()
1707 );
1708 assert_eq!(
1709 test_sender.max_bandwidth,
1710 sample.sample_max_bandwidth.unwrap()
1711 );
1712 assert!(2 * REGULAR_PACKET_SIZE < sample.extra_acked);
1713 }
1714}
1715
1716#[cfg(test)]
1717mod max_ack_height_tracker_tests {
1718 use rstest::rstest;
1719
1720 use super::*;
1721
1722 struct TestTracker {
1723 tracker: MaxAckHeightTracker,
1724 bandwidth: Bandwidth,
1725 start: Instant,
1726 now: Instant,
1727 last_sent_packet_number: u64,
1728 last_acked_packet_number: u64,
1729 rtt: Duration,
1730 }
1731
1732 impl TestTracker {
1733 fn new(overestimate_avoidance: bool) -> Self {
1734 let mut tracker =
1735 MaxAckHeightTracker::new(10, overestimate_avoidance);
1736 tracker.ack_aggregation_bandwidth_threshold = 1.8;
1737 tracker.start_new_aggregation_epoch_after_full_round = true;
1738 let start = Instant::now();
1739 TestTracker {
1740 tracker,
1741 start,
1742 now: start + Duration::from_millis(1),
1743 bandwidth: Bandwidth::from_bytes_per_second(10 * 1000),
1744 last_sent_packet_number: 0,
1745 last_acked_packet_number: 0,
1746 rtt: Duration::from_millis(60),
1747 }
1748 }
1749
1750 fn aggregation_episode(
1756 &mut self, aggregation_bandwidth: Bandwidth,
1757 aggregation_duration: Duration, bytes_per_ack: usize,
1758 expect_new_aggregation_epoch: bool,
1759 ) {
1760 assert!(aggregation_bandwidth >= self.bandwidth);
1761 let start_time = self.now;
1762
1763 let aggregation_bytes =
1764 (aggregation_bandwidth * aggregation_duration) as usize;
1765
1766 let num_acks = aggregation_bytes / bytes_per_ack;
1767 assert_eq!(aggregation_bytes, num_acks * bytes_per_ack);
1768
1769 let time_between_acks = Duration::from_micros(
1770 aggregation_duration.as_micros() as u64 / num_acks as u64,
1771 );
1772 assert_eq!(aggregation_duration, time_between_acks * num_acks as u32);
1773
1774 let total_duration = Duration::from_micros(
1776 (aggregation_bytes as u64 * 8 * 1000000) /
1777 self.bandwidth.to_bits_per_second(),
1778 );
1779
1780 assert_eq!(aggregation_bytes as u64, self.bandwidth * total_duration);
1781
1782 let mut last_extra_acked = 0;
1783
1784 for bytes in (0..aggregation_bytes).step_by(bytes_per_ack) {
1785 let extra_acked = self.tracker.update(
1786 self.bandwidth,
1787 true,
1788 self.round_trip_count(),
1789 self.last_sent_packet_number,
1790 self.last_acked_packet_number,
1791 self.now,
1792 bytes_per_ack,
1793 );
1794 if (bytes == 0 && expect_new_aggregation_epoch) ||
1800 (aggregation_bandwidth == self.bandwidth)
1801 {
1802 assert_eq!(0, extra_acked);
1803 } else {
1804 assert!(last_extra_acked < extra_acked);
1805 }
1806 self.now += time_between_acks;
1807 last_extra_acked = extra_acked;
1808 }
1809
1810 self.now = start_time + total_duration;
1812 }
1813
1814 fn round_trip_count(&self) -> usize {
1815 ((self.now - self.start).as_micros() / self.rtt.as_micros()) as usize
1816 }
1817 }
1818
1819 fn test_inner(
1820 overestimate_avoidance: bool, bandwidth_gain: f64,
1821 agg_duration: Duration, byte_per_ack: usize,
1822 ) {
1823 let mut test_tracker = TestTracker::new(overestimate_avoidance);
1824
1825 let rnd = |tracker: &mut TestTracker, expect: bool| {
1826 tracker.aggregation_episode(
1827 tracker.bandwidth * bandwidth_gain,
1828 agg_duration,
1829 byte_per_ack,
1830 expect,
1831 );
1832 };
1833
1834 rnd(&mut test_tracker, true);
1835 rnd(&mut test_tracker, true);
1836
1837 test_tracker.now = test_tracker
1838 .now
1839 .checked_sub(Duration::from_millis(1))
1840 .unwrap();
1841
1842 if test_tracker.tracker.ack_aggregation_bandwidth_threshold > 1.1 {
1843 rnd(&mut test_tracker, true);
1844 assert_eq!(3, test_tracker.tracker.num_ack_aggregation_epochs);
1845 } else {
1846 rnd(&mut test_tracker, false);
1847 assert_eq!(2, test_tracker.tracker.num_ack_aggregation_epochs);
1848 }
1849 }
1850
1851 #[rstest]
1852 fn very_aggregated_large_acks(
1853 #[values(false, true)] overestimate_avoidance: bool,
1854 ) {
1855 test_inner(overestimate_avoidance, 20.0, Duration::from_millis(6), 1200)
1856 }
1857
1858 #[rstest]
1859 fn very_aggregated_small_acks(
1860 #[values(false, true)] overestimate_avoidance: bool,
1861 ) {
1862 test_inner(overestimate_avoidance, 20., Duration::from_millis(6), 300)
1863 }
1864
1865 #[rstest]
1866 fn somewhat_aggregated_large_acks(
1867 #[values(false, true)] overestimate_avoidance: bool,
1868 ) {
1869 test_inner(overestimate_avoidance, 2.0, Duration::from_millis(50), 1000)
1870 }
1871
1872 #[rstest]
1873 fn somewhat_aggregated_small_acks(
1874 #[values(false, true)] overestimate_avoidance: bool,
1875 ) {
1876 test_inner(overestimate_avoidance, 2.0, Duration::from_millis(50), 100)
1877 }
1878
1879 #[rstest]
1880 fn not_aggregated(#[values(false, true)] overestimate_avoidance: bool) {
1881 let mut test_tracker = TestTracker::new(overestimate_avoidance);
1882 test_tracker.aggregation_episode(
1883 test_tracker.bandwidth,
1884 Duration::from_millis(100),
1885 100,
1886 true,
1887 );
1888 assert!(2 < test_tracker.tracker.num_ack_aggregation_epochs);
1889 }
1890
1891 #[rstest]
1892 fn start_new_epoch_after_a_full_round(
1893 #[values(false, true)] overestimate_avoidance: bool,
1894 ) {
1895 let mut test_tracker = TestTracker::new(overestimate_avoidance);
1896
1897 test_tracker.last_sent_packet_number = 10;
1898
1899 test_tracker.aggregation_episode(
1900 test_tracker.bandwidth * 2.0,
1901 Duration::from_millis(50),
1902 100,
1903 true,
1904 );
1905
1906 test_tracker.last_acked_packet_number = 11;
1907
1908 test_tracker.tracker.update(
1912 test_tracker.bandwidth * 0.1,
1913 true,
1914 test_tracker.round_trip_count(),
1915 test_tracker.last_sent_packet_number,
1916 test_tracker.last_acked_packet_number,
1917 test_tracker.now,
1918 100,
1919 );
1920
1921 assert_eq!(2, test_tracker.tracker.num_ack_aggregation_epochs)
1922 }
1923}