1use std::collections::VecDeque;
32use std::time::Duration;
33use std::time::Instant;
34
35use super::Acked;
36use crate::recovery::gcongestion::Bandwidth;
37use crate::recovery::gcongestion::Lost;
38
39use super::windowed_filter::WindowedFilter;
40
41#[derive(Debug)]
42struct ConnectionStateMap<T> {
43 packet_map: VecDeque<(u64, Option<T>)>,
44}
45
46impl<T> Default for ConnectionStateMap<T> {
47 fn default() -> Self {
48 ConnectionStateMap {
49 packet_map: VecDeque::new(),
50 }
51 }
52}
53
54impl<T> ConnectionStateMap<T> {
55 fn insert(&mut self, pkt_num: u64, val: T) {
56 if let Some((last_pkt, _)) = self.packet_map.back() {
57 assert!(pkt_num > *last_pkt, "{} > {}", pkt_num, *last_pkt);
58 }
59
60 self.packet_map.push_back((pkt_num, Some(val)));
61 }
62
63 fn take(&mut self, pkt_num: u64) -> Option<T> {
64 let first = self.packet_map.front()?;
66 if first.0 == pkt_num {
67 return self.packet_map.pop_front().and_then(|(_, v)| v);
68 }
69 let ret =
71 match self.packet_map.binary_search_by_key(&pkt_num, |&(n, _)| n) {
72 Ok(found) =>
73 self.packet_map.get_mut(found).and_then(|(_, v)| v.take()),
74 Err(_) => None,
75 };
76
77 while let Some((_, None)) = self.packet_map.front() {
78 self.packet_map.pop_front();
79 }
80
81 ret
82 }
83
84 #[cfg(test)]
85 fn peek(&self, pkt_num: u64) -> Option<&T> {
86 match self.packet_map.binary_search_by_key(&pkt_num, |&(n, _)| n) {
88 Ok(found) => self.packet_map.get(found).and_then(|(_, v)| v.as_ref()),
89 Err(_) => None,
90 }
91 }
92
93 fn remove_obsolete(&mut self, least_acked: u64) {
94 while match self.packet_map.front() {
95 Some(&(p, _)) if p < least_acked => {
96 self.packet_map.pop_front();
97 true
98 },
99 _ => false,
100 } {}
101 }
102}
103
104#[derive(Debug)]
105pub struct BandwidthSampler {
106 total_bytes_sent: usize,
109 total_bytes_acked: usize,
110 total_bytes_lost: usize,
111 total_bytes_neutered: usize,
112 last_sent_packet: u64,
113 last_acked_packet: u64,
114 is_app_limited: bool,
115 last_acked_packet_ack_time: Instant,
116 total_bytes_sent_at_last_acked_packet: usize,
117 last_acked_packet_sent_time: Instant,
118 recent_ack_points: RecentAckPoints,
119 a0_candidates: VecDeque<AckPoint>,
120 connection_state_map: ConnectionStateMap<ConnectionStateOnSentPacket>,
121 max_ack_height_tracker: MaxAckHeightTracker,
122 end_of_app_limited_phase: Option<u64>,
125 overestimate_avoidance: bool,
126 choose_a0_point_fix: bool,
130 limit_max_ack_height_tracker_by_send_rate: bool,
131
132 total_bytes_acked_after_last_ack_event: usize,
133}
134
135#[derive(Debug, Default, Clone, Copy)]
138pub struct SendTimeState {
139 pub is_valid: bool,
141 pub is_app_limited: bool,
146 pub total_bytes_sent: usize,
149 pub total_bytes_acked: usize,
151 #[allow(dead_code)]
153 pub total_bytes_lost: usize,
154 pub bytes_in_flight: usize,
159}
160
161#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)]
162struct ExtraAckedEvent {
163 extra_acked: usize,
165 bytes_acked: usize,
167 time_delta: Duration,
168 round: usize,
170}
171
172struct BandwidthSample {
175 bandwidth: Bandwidth,
177 rtt: Duration,
180 send_rate: Option<Bandwidth>,
184 ack_rate: Bandwidth,
188 state_at_send: SendTimeState,
190}
191
192#[derive(Debug, Clone, Copy)]
194struct AckPoint {
195 ack_time: Instant,
196 total_bytes_acked: usize,
197}
198
199#[derive(Debug, Default)]
202struct RecentAckPoints {
203 ack_points: [Option<AckPoint>; 2],
204}
205
206#[derive(Debug)]
211struct ConnectionStateOnSentPacket {
212 sent_time: Instant,
214 size: usize,
216 total_bytes_sent_at_last_acked_packet: usize,
219 last_acked_packet_sent_time: Instant,
222 last_acked_packet_ack_time: Instant,
225 send_time_state: SendTimeState,
228}
229
230#[derive(Debug)]
234struct MaxAckHeightTracker {
235 max_ack_height_filter: WindowedFilter<ExtraAckedEvent, usize, usize>,
238 aggregation_epoch_start_time: Option<Instant>,
241 aggregation_epoch_bytes: usize,
242 last_sent_packet_number_before_epoch: u64,
245 num_ack_aggregation_epochs: u64,
248 ack_aggregation_bandwidth_threshold: f64,
249 start_new_aggregation_epoch_after_full_round: bool,
250 reduce_extra_acked_on_bandwidth_increase: bool,
251}
252
253#[derive(Default)]
256pub(crate) struct CongestionEventSample {
257 pub sample_max_bandwidth: Option<Bandwidth>,
259 pub sample_is_app_limited: bool,
261 pub sample_rtt: Option<Duration>,
263 pub sample_max_inflight: usize,
267 pub last_packet_send_state: SendTimeState,
271 pub extra_acked: usize,
275
276 pub sample_max_send_rate: Option<Bandwidth>,
279 pub sample_max_ack_rate: Option<Bandwidth>,
282}
283
284impl MaxAckHeightTracker {
285 pub(crate) fn new(window: usize, overestimate_avoidance: bool) -> Self {
286 MaxAckHeightTracker {
287 max_ack_height_filter: WindowedFilter::new(window),
288 aggregation_epoch_start_time: None,
289 aggregation_epoch_bytes: 0,
290 last_sent_packet_number_before_epoch: 0,
291 num_ack_aggregation_epochs: 0,
292 ack_aggregation_bandwidth_threshold: if overestimate_avoidance {
293 2.0
294 } else {
295 1.0
296 },
297 start_new_aggregation_epoch_after_full_round: true,
298 reduce_extra_acked_on_bandwidth_increase: true,
299 }
300 }
301
302 #[allow(dead_code)]
303 fn reset(&mut self, new_height: usize, new_time: usize) {
304 self.max_ack_height_filter.reset(
305 ExtraAckedEvent {
306 extra_acked: new_height,
307 bytes_acked: 0,
308 time_delta: Duration::ZERO,
309 round: new_time,
310 },
311 new_time,
312 );
313 }
314
315 #[allow(clippy::too_many_arguments)]
316 fn update(
317 &mut self, bandwidth_estimate: Bandwidth, is_new_max_bandwidth: bool,
318 round_trip_count: usize, last_sent_packet_number: u64,
319 last_acked_packet_number: u64, ack_time: Instant, bytes_acked: usize,
320 ) -> usize {
321 let mut force_new_epoch = false;
322
323 if self.reduce_extra_acked_on_bandwidth_increase && is_new_max_bandwidth {
324 let mut best =
326 self.max_ack_height_filter.get_best().unwrap_or_default();
327 let mut second_best = self
328 .max_ack_height_filter
329 .get_second_best()
330 .unwrap_or_default();
331 let mut third_best = self
332 .max_ack_height_filter
333 .get_third_best()
334 .unwrap_or_default();
335 self.max_ack_height_filter.clear();
336
337 let expected_bytes_acked =
339 bandwidth_estimate.to_bytes_per_period(best.time_delta) as usize;
340 if expected_bytes_acked < best.bytes_acked {
341 best.extra_acked = best.bytes_acked - expected_bytes_acked;
342 self.max_ack_height_filter.update(best, best.round);
343 }
344
345 let expected_bytes_acked = bandwidth_estimate
346 .to_bytes_per_period(second_best.time_delta)
347 as usize;
348 if expected_bytes_acked < second_best.bytes_acked {
349 second_best.extra_acked =
350 second_best.bytes_acked - expected_bytes_acked;
351 self.max_ack_height_filter
352 .update(second_best, second_best.round);
353 }
354
355 let expected_bytes_acked = bandwidth_estimate
356 .to_bytes_per_period(third_best.time_delta)
357 as usize;
358 if expected_bytes_acked < third_best.bytes_acked {
359 third_best.extra_acked =
360 third_best.bytes_acked - expected_bytes_acked;
361 self.max_ack_height_filter
362 .update(third_best, third_best.round);
363 }
364 }
365
366 if self.start_new_aggregation_epoch_after_full_round &&
369 last_acked_packet_number >
370 self.last_sent_packet_number_before_epoch
371 {
372 force_new_epoch = true;
373 }
374
375 let epoch_start_time = match self.aggregation_epoch_start_time {
376 Some(time) if !force_new_epoch => time,
377 _ => {
378 self.aggregation_epoch_bytes = bytes_acked;
379 self.aggregation_epoch_start_time = Some(ack_time);
380 self.last_sent_packet_number_before_epoch =
381 last_sent_packet_number;
382 self.num_ack_aggregation_epochs += 1;
383 return 0;
384 },
385 };
386
387 let aggregation_delta = ack_time.duration_since(epoch_start_time);
390 let expected_bytes_acked =
391 bandwidth_estimate.to_bytes_per_period(aggregation_delta) as usize;
392 if self.aggregation_epoch_bytes <=
395 (self.ack_aggregation_bandwidth_threshold *
396 expected_bytes_acked as f64) as usize
397 {
398 self.aggregation_epoch_bytes = bytes_acked;
400 self.aggregation_epoch_start_time = Some(ack_time);
401 self.last_sent_packet_number_before_epoch = last_sent_packet_number;
402 self.num_ack_aggregation_epochs += 1;
403 return 0;
404 }
405
406 self.aggregation_epoch_bytes += bytes_acked;
407
408 let extra_bytes_acked =
410 self.aggregation_epoch_bytes - expected_bytes_acked;
411
412 let new_event = ExtraAckedEvent {
413 extra_acked: extra_bytes_acked,
414 bytes_acked: self.aggregation_epoch_bytes,
415 time_delta: aggregation_delta,
416 round: 0,
417 };
418
419 self.max_ack_height_filter
420 .update(new_event, round_trip_count);
421 extra_bytes_acked
422 }
423}
424
425impl From<(Instant, usize, usize, &BandwidthSampler)>
426 for ConnectionStateOnSentPacket
427{
428 fn from(
429 (sent_time, size, bytes_in_flight, sampler): (
430 Instant,
431 usize,
432 usize,
433 &BandwidthSampler,
434 ),
435 ) -> Self {
436 ConnectionStateOnSentPacket {
437 sent_time,
438 size,
439 total_bytes_sent_at_last_acked_packet: sampler
440 .total_bytes_sent_at_last_acked_packet,
441 last_acked_packet_sent_time: sampler.last_acked_packet_sent_time,
442 last_acked_packet_ack_time: sampler.last_acked_packet_ack_time,
443 send_time_state: SendTimeState {
444 is_valid: true,
445 is_app_limited: sampler.is_app_limited,
446 total_bytes_sent: sampler.total_bytes_sent,
447 total_bytes_acked: sampler.total_bytes_acked,
448 total_bytes_lost: sampler.total_bytes_lost,
449 bytes_in_flight,
450 },
451 }
452 }
453}
454
455impl RecentAckPoints {
456 fn update(&mut self, ack_time: Instant, total_bytes_acked: usize) {
457 assert!(
458 total_bytes_acked >=
459 self.ack_points[1].map(|p| p.total_bytes_acked).unwrap_or(0)
460 );
461
462 self.ack_points[0] = self.ack_points[1];
463 self.ack_points[1] = Some(AckPoint {
464 ack_time,
465 total_bytes_acked,
466 });
467 }
468
469 fn clear(&mut self) {
470 self.ack_points = Default::default();
471 }
472
473 fn most_recent(&self) -> Option<AckPoint> {
474 self.ack_points[1]
475 }
476
477 fn less_recent_point(&self, choose_a0_point_fix: bool) -> Option<AckPoint> {
478 if choose_a0_point_fix {
479 self.ack_points[0]
480 .filter(|ack_point| ack_point.total_bytes_acked > 0)
481 .or(self.ack_points[1])
482 } else {
483 self.ack_points[0].or(self.ack_points[1])
484 }
485 }
486}
487
488impl BandwidthSampler {
489 pub(crate) fn new(
490 max_height_tracker_window_length: usize, overestimate_avoidance: bool,
491 choose_a0_point_fix: bool,
492 ) -> Self {
493 BandwidthSampler {
494 total_bytes_sent: 0,
495 total_bytes_acked: 0,
496 total_bytes_lost: 0,
497 total_bytes_neutered: 0,
498 total_bytes_sent_at_last_acked_packet: 0,
499 last_acked_packet_sent_time: Instant::now(),
500 last_acked_packet_ack_time: Instant::now(),
501 is_app_limited: true,
502 connection_state_map: ConnectionStateMap::default(),
503 max_ack_height_tracker: MaxAckHeightTracker::new(
504 max_height_tracker_window_length,
505 overestimate_avoidance,
506 ),
507 total_bytes_acked_after_last_ack_event: 0,
508 overestimate_avoidance,
509 choose_a0_point_fix,
510 limit_max_ack_height_tracker_by_send_rate: false,
511
512 last_sent_packet: 0,
513 last_acked_packet: 0,
514 recent_ack_points: RecentAckPoints::default(),
515 a0_candidates: VecDeque::new(),
516 end_of_app_limited_phase: None,
517 }
518 }
519
520 #[allow(dead_code)]
521 pub(crate) fn is_app_limited(&self) -> bool {
522 self.is_app_limited
523 }
524
525 pub(crate) fn on_packet_sent(
526 &mut self, sent_time: Instant, packet_number: u64, bytes: usize,
527 bytes_in_flight: usize, has_retransmittable_data: bool,
528 ) {
529 self.last_sent_packet = packet_number;
530
531 if !has_retransmittable_data {
532 return;
533 }
534
535 self.total_bytes_sent += bytes;
536
537 if bytes_in_flight == 0 {
545 self.last_acked_packet_ack_time = sent_time;
546 if self.overestimate_avoidance {
547 self.recent_ack_points.clear();
548 self.recent_ack_points
549 .update(sent_time, self.total_bytes_acked);
550 self.a0_candidates.clear();
551 self.a0_candidates
552 .push_back(self.recent_ack_points.most_recent().unwrap());
553 }
554
555 self.total_bytes_sent_at_last_acked_packet = self.total_bytes_sent;
556
557 self.last_acked_packet_sent_time = sent_time;
560 }
561
562 self.connection_state_map.insert(
563 packet_number,
564 (sent_time, bytes, bytes_in_flight + bytes, &*self).into(),
565 );
566 }
567
568 pub(crate) fn on_packet_neutered(&mut self, packet_number: u64) {
569 if let Some(pkt) = self.connection_state_map.take(packet_number) {
570 self.total_bytes_neutered += pkt.size;
571 }
572 }
573
574 pub(crate) fn on_congestion_event(
575 &mut self, ack_time: Instant, acked_packets: &[Acked],
576 lost_packets: &[Lost], mut max_bandwidth: Option<Bandwidth>,
577 est_bandwidth_upper_bound: Bandwidth, round_trip_count: usize,
578 ) -> CongestionEventSample {
579 let mut last_lost_packet_send_state = SendTimeState::default();
580 let mut last_acked_packet_send_state = SendTimeState::default();
581 let mut last_lost_packet_num = 0u64;
582 let mut last_acked_packet_num = 0u64;
583
584 for packet in lost_packets {
585 let send_state =
586 self.on_packet_lost(packet.packet_number, packet.bytes_lost);
587 if send_state.is_valid {
588 last_lost_packet_send_state = send_state;
589 last_lost_packet_num = packet.packet_number;
590 }
591 }
592
593 if acked_packets.is_empty() {
594 return CongestionEventSample {
596 last_packet_send_state: last_lost_packet_send_state,
597 ..Default::default()
598 };
599 }
600
601 let mut event_sample = CongestionEventSample::default();
602
603 let mut max_send_rate = None;
604 let mut max_ack_rate = None;
605 for packet in acked_packets {
606 let sample =
607 match self.on_packet_acknowledged(ack_time, packet.pkt_num) {
608 Some(sample) if sample.state_at_send.is_valid => sample,
609 _ => continue,
610 };
611
612 last_acked_packet_send_state = sample.state_at_send;
613 last_acked_packet_num = packet.pkt_num;
614
615 event_sample.sample_rtt = Some(
616 sample
617 .rtt
618 .min(*event_sample.sample_rtt.get_or_insert(sample.rtt)),
619 );
620
621 if Some(sample.bandwidth) > event_sample.sample_max_bandwidth {
622 event_sample.sample_max_bandwidth = Some(sample.bandwidth);
623 event_sample.sample_is_app_limited =
624 sample.state_at_send.is_app_limited;
625 }
626 max_send_rate = max_send_rate.max(sample.send_rate);
627 max_ack_rate = max_ack_rate.max(Some(sample.ack_rate));
628
629 let inflight_sample = self.total_bytes_acked -
630 last_acked_packet_send_state.total_bytes_acked;
631 if inflight_sample > event_sample.sample_max_inflight {
632 event_sample.sample_max_inflight = inflight_sample;
633 }
634 }
635
636 if !last_lost_packet_send_state.is_valid {
637 event_sample.last_packet_send_state = last_acked_packet_send_state;
638 } else if !last_acked_packet_send_state.is_valid {
639 event_sample.last_packet_send_state = last_lost_packet_send_state;
640 } else {
641 event_sample.last_packet_send_state =
646 if last_acked_packet_num > last_lost_packet_num {
647 last_acked_packet_send_state
648 } else {
649 last_lost_packet_send_state
650 };
651 }
652
653 let is_new_max_bandwidth =
654 event_sample.sample_max_bandwidth > max_bandwidth;
655 max_bandwidth = event_sample.sample_max_bandwidth.max(max_bandwidth);
656
657 if self.limit_max_ack_height_tracker_by_send_rate {
658 max_bandwidth = max_bandwidth.max(max_send_rate);
659 }
660
661 let bandwidth_estimate = if let Some(max_bandwidth) = max_bandwidth {
662 max_bandwidth.min(est_bandwidth_upper_bound)
663 } else {
664 est_bandwidth_upper_bound
665 };
666
667 event_sample.extra_acked = self.on_ack_event_end(
668 bandwidth_estimate,
669 is_new_max_bandwidth,
670 round_trip_count,
671 );
672
673 event_sample.sample_max_send_rate = max_send_rate;
674 event_sample.sample_max_ack_rate = max_ack_rate;
675
676 event_sample
677 }
678
679 fn on_packet_lost(
680 &mut self, packet_number: u64, bytes_lost: usize,
681 ) -> SendTimeState {
682 let mut send_time_state = SendTimeState::default();
683
684 self.total_bytes_lost += bytes_lost;
685 if let Some(state) = self.connection_state_map.take(packet_number) {
686 send_time_state = state.send_time_state;
687 send_time_state.is_valid = true;
688 }
689
690 send_time_state
691 }
692
693 fn on_ack_event_end(
694 &mut self, bandwidth_estimate: Bandwidth, is_new_max_bandwidth: bool,
695 round_trip_count: usize,
696 ) -> usize {
697 let newly_acked_bytes =
698 self.total_bytes_acked - self.total_bytes_acked_after_last_ack_event;
699
700 if newly_acked_bytes == 0 {
701 return 0;
702 }
703
704 self.total_bytes_acked_after_last_ack_event = self.total_bytes_acked;
705 let extra_acked = self.max_ack_height_tracker.update(
706 bandwidth_estimate,
707 is_new_max_bandwidth,
708 round_trip_count,
709 self.last_sent_packet,
710 self.last_acked_packet,
711 self.last_acked_packet_ack_time,
712 newly_acked_bytes,
713 );
714 if self.overestimate_avoidance && extra_acked == 0 {
718 self.a0_candidates.push_back(
719 self.recent_ack_points
720 .less_recent_point(self.choose_a0_point_fix)
721 .unwrap(),
722 );
723 }
724
725 extra_acked
726 }
727
728 fn on_packet_acknowledged(
729 &mut self, ack_time: Instant, packet_number: u64,
730 ) -> Option<BandwidthSample> {
731 self.last_acked_packet = packet_number;
732 let sent_packet = self.connection_state_map.take(packet_number)?;
733
734 self.total_bytes_acked += sent_packet.size;
735 self.total_bytes_sent_at_last_acked_packet =
736 sent_packet.send_time_state.total_bytes_sent;
737 self.last_acked_packet_sent_time = sent_packet.sent_time;
738 self.last_acked_packet_ack_time = ack_time;
739 if self.overestimate_avoidance {
740 self.recent_ack_points
741 .update(ack_time, self.total_bytes_acked);
742 }
743
744 if self.is_app_limited {
745 if self.end_of_app_limited_phase.is_none() ||
751 Some(packet_number) > self.end_of_app_limited_phase
752 {
753 self.is_app_limited = false;
754 }
755 }
756
757 let send_rate = if sent_packet.sent_time >
760 sent_packet.last_acked_packet_sent_time
761 {
762 Some(Bandwidth::from_bytes_and_time_delta(
763 sent_packet.send_time_state.total_bytes_sent -
764 sent_packet.total_bytes_sent_at_last_acked_packet,
765 sent_packet.sent_time - sent_packet.last_acked_packet_sent_time,
766 ))
767 } else {
768 None
769 };
770
771 let a0 = if self.overestimate_avoidance {
772 Self::choose_a0_point(
773 &mut self.a0_candidates,
774 sent_packet.send_time_state.total_bytes_acked,
775 self.choose_a0_point_fix,
776 )
777 } else {
778 None
779 };
780
781 let a0 = a0.unwrap_or(AckPoint {
782 ack_time: sent_packet.last_acked_packet_ack_time,
783 total_bytes_acked: sent_packet.send_time_state.total_bytes_acked,
784 });
785
786 if ack_time <= a0.ack_time {
790 return None;
791 }
792
793 let ack_rate = Bandwidth::from_bytes_and_time_delta(
794 self.total_bytes_acked - a0.total_bytes_acked,
795 ack_time.duration_since(a0.ack_time),
796 );
797
798 let bandwidth = if let Some(send_rate) = send_rate {
799 send_rate.min(ack_rate)
800 } else {
801 ack_rate
802 };
803
804 let rtt = ack_time.duration_since(sent_packet.sent_time);
808
809 Some(BandwidthSample {
810 bandwidth,
811 rtt,
812 send_rate,
813 ack_rate,
814 state_at_send: SendTimeState {
815 is_valid: true,
816 ..sent_packet.send_time_state
817 },
818 })
819 }
820
821 fn choose_a0_point(
822 a0_candidates: &mut VecDeque<AckPoint>, total_bytes_acked: usize,
823 choose_a0_point_fix: bool,
824 ) -> Option<AckPoint> {
825 if a0_candidates.is_empty() {
826 return None;
827 }
828
829 while let Some(candidate) = a0_candidates.get(1) {
830 if candidate.total_bytes_acked > total_bytes_acked {
831 if choose_a0_point_fix {
832 break;
833 } else {
834 return Some(*candidate);
835 }
836 }
837 a0_candidates.pop_front();
838 }
839
840 Some(a0_candidates[0])
841 }
842
843 pub(crate) fn total_bytes_acked(&self) -> usize {
844 self.total_bytes_acked
845 }
846
847 pub(crate) fn total_bytes_lost(&self) -> usize {
848 self.total_bytes_lost
849 }
850
851 #[allow(dead_code)]
852 pub(crate) fn reset_max_ack_height_tracker(
853 &mut self, new_height: usize, new_time: usize,
854 ) {
855 self.max_ack_height_tracker.reset(new_height, new_time);
856 }
857
858 pub(crate) fn max_ack_height(&self) -> Option<usize> {
859 self.max_ack_height_tracker
860 .max_ack_height_filter
861 .get_best()
862 .map(|b| b.extra_acked)
863 }
864
865 pub(crate) fn on_app_limited(&mut self) {
866 self.is_app_limited = true;
867 self.end_of_app_limited_phase = Some(self.last_sent_packet);
868 }
869
870 pub(crate) fn remove_obsolete_packets(&mut self, least_acked: u64) {
871 self.connection_state_map.remove_obsolete(least_acked);
879 }
880}
881
882#[cfg(test)]
883mod bandwidth_sampler_tests {
884 use rstest::rstest;
885
886 use super::*;
887
888 const REGULAR_PACKET_SIZE: usize = 1280;
889
890 struct TestSender {
891 sampler: BandwidthSampler,
892 sampler_app_limited_at_start: bool,
893 bytes_in_flight: usize,
894 clock: Instant,
895 max_bandwidth: Bandwidth,
896 est_bandwidth_upper_bound: Bandwidth,
897 round_trip_count: usize,
898 }
899
900 impl TestSender {
901 fn new(overestimate_avoidance: bool, choose_a0_point_fix: bool) -> Self {
902 let sampler = BandwidthSampler::new(
903 0,
904 overestimate_avoidance,
905 choose_a0_point_fix,
906 );
907 TestSender {
908 sampler_app_limited_at_start: sampler.is_app_limited(),
909 sampler,
910 bytes_in_flight: 0,
911 clock: Instant::now(),
912 max_bandwidth: Bandwidth::zero(),
913 est_bandwidth_upper_bound: Bandwidth::infinite(),
914 round_trip_count: 0,
915 }
916 }
917
918 fn get_packet_size(&self, pkt_num: u64) -> usize {
919 self.sampler
920 .connection_state_map
921 .peek(pkt_num)
922 .unwrap()
923 .size
924 }
925
926 fn get_packet_time(&self, pkt_num: u64) -> Instant {
927 self.sampler
928 .connection_state_map
929 .peek(pkt_num)
930 .unwrap()
931 .sent_time
932 }
933
934 fn number_of_tracked_packets(&self) -> usize {
935 self.sampler.connection_state_map.packet_map.len()
936 }
937
938 fn make_acked_packet(&self, pkt_num: u64) -> Acked {
939 let time_sent = self.get_packet_time(pkt_num);
940
941 Acked { pkt_num, time_sent }
942 }
943
944 fn make_lost_packet(&self, pkt_num: u64) -> Lost {
945 let size = self.get_packet_size(pkt_num);
946 Lost {
947 packet_number: pkt_num,
948 bytes_lost: size,
949 }
950 }
951
952 fn ack_packet(&mut self, pkt_num: u64) -> BandwidthSample {
953 let size = self.get_packet_size(pkt_num);
954 self.bytes_in_flight -= size;
955
956 let sample = self.sampler.on_congestion_event(
957 self.clock,
958 &[self.make_acked_packet(pkt_num)],
959 &[],
960 Some(self.max_bandwidth),
961 self.est_bandwidth_upper_bound,
962 self.round_trip_count,
963 );
964
965 let sample_max_bandwidth = sample.sample_max_bandwidth.unwrap();
966 self.max_bandwidth = self.max_bandwidth.max(sample_max_bandwidth);
967
968 let bandwidth_sample = BandwidthSample {
969 bandwidth: sample_max_bandwidth,
970 rtt: sample.sample_rtt.unwrap(),
971 send_rate: None,
972 ack_rate: Bandwidth::zero(),
974 state_at_send: sample.last_packet_send_state,
975 };
976 assert!(bandwidth_sample.state_at_send.is_valid);
977 bandwidth_sample
978 }
979
980 fn lose_packet(&mut self, pkt_num: u64) -> SendTimeState {
981 let size = self.get_packet_size(pkt_num);
982 self.bytes_in_flight -= size;
983
984 let sample = self.sampler.on_congestion_event(
985 self.clock,
986 &[],
987 &[self.make_lost_packet(pkt_num)],
988 Some(self.max_bandwidth),
989 self.est_bandwidth_upper_bound,
990 self.round_trip_count,
991 );
992
993 assert!(sample.last_packet_send_state.is_valid);
994 assert_eq!(sample.sample_max_bandwidth, None);
995 assert_eq!(sample.sample_rtt, None);
996 sample.last_packet_send_state
997 }
998
999 fn on_congestion_event(
1000 &mut self, acked: &[u64], lost: &[u64],
1001 ) -> CongestionEventSample {
1002 let acked = acked
1003 .iter()
1004 .map(|pkt| {
1005 let acked_size = self.get_packet_size(*pkt);
1006 self.bytes_in_flight -= acked_size;
1007
1008 self.make_acked_packet(*pkt)
1009 })
1010 .collect::<Vec<_>>();
1011
1012 let lost = lost
1013 .iter()
1014 .map(|pkt| {
1015 let lost = self.make_lost_packet(*pkt);
1016 self.bytes_in_flight -= lost.bytes_lost;
1017 lost
1018 })
1019 .collect::<Vec<_>>();
1020
1021 let sample = self.sampler.on_congestion_event(
1022 self.clock,
1023 &acked,
1024 &lost,
1025 Some(self.max_bandwidth),
1026 self.est_bandwidth_upper_bound,
1027 self.round_trip_count,
1028 );
1029
1030 self.max_bandwidth =
1031 self.max_bandwidth.max(sample.sample_max_bandwidth.unwrap());
1032
1033 sample
1034 }
1035
1036 fn send_packet(
1037 &mut self, pkt_num: u64, pkt_sz: usize,
1038 has_retransmittable_data: bool,
1039 ) {
1040 self.sampler.on_packet_sent(
1041 self.clock,
1042 pkt_num,
1043 pkt_sz,
1044 self.bytes_in_flight,
1045 has_retransmittable_data,
1046 );
1047 if has_retransmittable_data {
1048 self.bytes_in_flight += pkt_sz;
1049 }
1050 }
1051
1052 fn advance_time(&mut self, delta: Duration) {
1053 self.clock += delta;
1054 }
1055
1056 fn send_40_and_ack_first_20(&mut self, time_between_packets: Duration) {
1059 for i in 1..=20 {
1061 self.send_packet(i, REGULAR_PACKET_SIZE, true);
1062 self.advance_time(time_between_packets);
1063 }
1064
1065 for i in 1..=20 {
1068 self.ack_packet(i);
1069 self.send_packet(i + 20, REGULAR_PACKET_SIZE, true);
1070 self.advance_time(time_between_packets);
1071 }
1072 }
1073 }
1074
1075 #[rstest]
1076 fn send_and_wait(
1077 #[values(false, true)] overestimate_avoidance: bool,
1078 #[values(false, true)] choose_a0_point_fix: bool,
1079 ) {
1080 let mut test_sender =
1081 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1082 let mut time_between_packets = Duration::from_millis(10);
1083 let mut expected_bandwidth =
1084 Bandwidth::from_bytes_per_second(REGULAR_PACKET_SIZE as u64 * 100);
1085
1086 for i in 1..20 {
1088 test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1089 test_sender.advance_time(time_between_packets);
1090 let current_sample = test_sender.ack_packet(i);
1091 assert_eq!(expected_bandwidth, current_sample.bandwidth);
1092 }
1093
1094 for i in 20..25 {
1096 time_between_packets *= 2;
1097 expected_bandwidth = expected_bandwidth * 0.5;
1098
1099 test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1100 test_sender.advance_time(time_between_packets);
1101 let current_sample = test_sender.ack_packet(i);
1102 assert_eq!(expected_bandwidth, current_sample.bandwidth);
1103 }
1104
1105 test_sender.sampler.remove_obsolete_packets(25);
1106 assert_eq!(0, test_sender.number_of_tracked_packets());
1107 assert_eq!(0, test_sender.bytes_in_flight);
1108 }
1109
1110 #[rstest]
1111 fn send_time_state(
1112 #[values(false, true)] overestimate_avoidance: bool,
1113 #[values(false, true)] choose_a0_point_fix: bool,
1114 ) {
1115 let mut test_sender =
1116 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1117 let time_between_packets = Duration::from_millis(10);
1118
1119 for i in 1..=5 {
1121 test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1122 assert_eq!(
1123 test_sender.sampler.total_bytes_sent,
1124 REGULAR_PACKET_SIZE * i as usize
1125 );
1126 test_sender.advance_time(time_between_packets);
1127 }
1128
1129 let send_time_state = test_sender.ack_packet(1).state_at_send;
1131 assert_eq!(REGULAR_PACKET_SIZE, send_time_state.total_bytes_sent);
1132 assert_eq!(0, send_time_state.total_bytes_acked);
1133 assert_eq!(0, send_time_state.total_bytes_lost);
1134 assert_eq!(REGULAR_PACKET_SIZE, test_sender.sampler.total_bytes_acked);
1135
1136 let send_time_state = test_sender.lose_packet(2);
1138 assert_eq!(REGULAR_PACKET_SIZE * 2, send_time_state.total_bytes_sent);
1139 assert_eq!(0, send_time_state.total_bytes_acked);
1140 assert_eq!(0, send_time_state.total_bytes_lost);
1141 assert_eq!(REGULAR_PACKET_SIZE, test_sender.sampler.total_bytes_lost);
1142
1143 let send_time_state = test_sender.lose_packet(3);
1145 assert_eq!(REGULAR_PACKET_SIZE * 3, send_time_state.total_bytes_sent);
1146 assert_eq!(0, send_time_state.total_bytes_acked);
1147 assert_eq!(0, send_time_state.total_bytes_lost);
1148 assert_eq!(
1149 REGULAR_PACKET_SIZE * 2,
1150 test_sender.sampler.total_bytes_lost
1151 );
1152
1153 for i in 6..=10 {
1155 test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1156 assert_eq!(
1157 test_sender.sampler.total_bytes_sent,
1158 REGULAR_PACKET_SIZE * i as usize
1159 );
1160 test_sender.advance_time(time_between_packets);
1161 }
1162
1163 let mut acked_packet_count = 1;
1165 assert_eq!(
1166 REGULAR_PACKET_SIZE * acked_packet_count,
1167 test_sender.sampler.total_bytes_acked
1168 );
1169 for i in 4..=10 {
1170 let send_time_state = test_sender.ack_packet(i).state_at_send;
1171 acked_packet_count += 1;
1172 assert_eq!(
1173 REGULAR_PACKET_SIZE * acked_packet_count,
1174 test_sender.sampler.total_bytes_acked
1175 );
1176 assert_eq!(
1177 REGULAR_PACKET_SIZE * i as usize,
1178 send_time_state.total_bytes_sent
1179 );
1180
1181 if i <= 5 {
1182 assert_eq!(0, send_time_state.total_bytes_acked);
1183 assert_eq!(0, send_time_state.total_bytes_lost);
1184 } else {
1185 assert_eq!(
1186 REGULAR_PACKET_SIZE,
1187 send_time_state.total_bytes_acked
1188 );
1189 assert_eq!(
1190 REGULAR_PACKET_SIZE * 2,
1191 send_time_state.total_bytes_lost
1192 );
1193 }
1194
1195 assert_eq!(
1197 send_time_state.total_bytes_sent -
1198 send_time_state.total_bytes_acked -
1199 send_time_state.total_bytes_lost,
1200 send_time_state.bytes_in_flight
1201 );
1202
1203 test_sender.advance_time(time_between_packets);
1204 }
1205 }
1206
1207 #[rstest]
1210 fn send_paced(
1211 #[values(false, true)] overestimate_avoidance: bool,
1212 #[values(false, true)] choose_a0_point_fix: bool,
1213 ) {
1214 let mut test_sender =
1215 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1216 let time_between_packets = Duration::from_millis(1);
1217 let expected_bandwidth =
1218 Bandwidth::from_kbits_per_second(REGULAR_PACKET_SIZE as u64 * 8);
1219
1220 test_sender.send_40_and_ack_first_20(time_between_packets);
1221
1222 for i in 21..=40 {
1224 let last_bandwidth = test_sender.ack_packet(i).bandwidth;
1225 assert_eq!(expected_bandwidth, last_bandwidth);
1226 test_sender.advance_time(time_between_packets);
1227 }
1228 test_sender.sampler.remove_obsolete_packets(41);
1229 assert_eq!(0, test_sender.number_of_tracked_packets());
1230 assert_eq!(0, test_sender.bytes_in_flight);
1231 }
1232
1233 #[rstest]
1236 fn send_with_losses(
1237 #[values(false, true)] overestimate_avoidance: bool,
1238 #[values(false, true)] choose_a0_point_fix: bool,
1239 ) {
1240 let mut test_sender =
1241 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1242 let time_between_packets = Duration::from_millis(1);
1243 let expected_bandwidth =
1244 Bandwidth::from_kbits_per_second(REGULAR_PACKET_SIZE as u64 / 2 * 8);
1245
1246 for i in 1..=20 {
1248 test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1249 test_sender.advance_time(time_between_packets);
1250 }
1251
1252 for i in 1..=20 {
1255 if i % 2 == 0 {
1256 test_sender.ack_packet(i);
1257 } else {
1258 test_sender.lose_packet(i);
1259 }
1260 test_sender.send_packet(i + 20, REGULAR_PACKET_SIZE, true);
1261 test_sender.advance_time(time_between_packets);
1262 }
1263
1264 for i in 21..=40 {
1266 if i % 2 == 0 {
1267 let last_bandwidth = test_sender.ack_packet(i).bandwidth;
1268 assert_eq!(expected_bandwidth, last_bandwidth);
1269 } else {
1270 test_sender.lose_packet(i);
1271 }
1272 test_sender.advance_time(time_between_packets);
1273 }
1274 test_sender.sampler.remove_obsolete_packets(41);
1275 assert_eq!(0, test_sender.number_of_tracked_packets());
1276 assert_eq!(0, test_sender.bytes_in_flight);
1277 }
1278
1279 #[rstest]
1284 fn not_congestion_controlled(
1285 #[values(false, true)] overestimate_avoidance: bool,
1286 #[values(false, true)] choose_a0_point_fix: bool,
1287 ) {
1288 let mut test_sender =
1289 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1290 let time_between_packets = Duration::from_millis(1);
1291 let expected_bandwidth =
1292 Bandwidth::from_kbits_per_second(REGULAR_PACKET_SIZE as u64 / 2 * 8);
1293
1294 for i in 1..=20 {
1297 let has_retransmittable_data = i % 2 == 0;
1298 test_sender.send_packet(
1299 i,
1300 REGULAR_PACKET_SIZE,
1301 has_retransmittable_data,
1302 );
1303 test_sender.advance_time(time_between_packets);
1304 }
1305
1306 assert_eq!(10, test_sender.number_of_tracked_packets());
1308
1309 for i in 1..=20 {
1312 if i % 2 == 0 {
1313 test_sender.ack_packet(i);
1314 }
1315 let has_retransmittable_data = i % 2 == 0;
1316 test_sender.send_packet(
1317 i + 20,
1318 REGULAR_PACKET_SIZE,
1319 has_retransmittable_data,
1320 );
1321 test_sender.advance_time(time_between_packets);
1322 }
1323
1324 for i in 21..=40 {
1326 if i % 2 == 0 {
1327 let last_bandwidth = test_sender.ack_packet(i).bandwidth;
1328 assert_eq!(expected_bandwidth, last_bandwidth);
1329 }
1330 test_sender.advance_time(time_between_packets);
1331 }
1332
1333 test_sender.sampler.remove_obsolete_packets(41);
1334 assert_eq!(0, test_sender.number_of_tracked_packets());
1337 assert_eq!(0, test_sender.bytes_in_flight);
1338 }
1339
1340 #[rstest]
1343 fn compressed_ack(
1344 #[values(false, true)] overestimate_avoidance: bool,
1345 #[values(false, true)] choose_a0_point_fix: bool,
1346 ) {
1347 let mut test_sender =
1348 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1349 let time_between_packets = Duration::from_millis(1);
1350 let expected_bandwidth =
1351 Bandwidth::from_kbits_per_second(REGULAR_PACKET_SIZE as u64 * 8);
1352
1353 test_sender.send_40_and_ack_first_20(time_between_packets);
1354
1355 test_sender.advance_time(time_between_packets * 15);
1357
1358 let ridiculously_small_time_delta = Duration::from_micros(20);
1360 let mut last_bandwidth = Bandwidth::zero();
1361 for i in 21..=40 {
1362 last_bandwidth = test_sender.ack_packet(i).bandwidth;
1363 test_sender.advance_time(ridiculously_small_time_delta);
1364 }
1365 assert_eq!(expected_bandwidth, last_bandwidth);
1366
1367 test_sender.sampler.remove_obsolete_packets(41);
1368 assert_eq!(0, test_sender.number_of_tracked_packets());
1371 assert_eq!(0, test_sender.bytes_in_flight);
1372 }
1373
1374 #[rstest]
1376 fn reordered_ack(
1377 #[values(false, true)] overestimate_avoidance: bool,
1378 #[values(false, true)] choose_a0_point_fix: bool,
1379 ) {
1380 let mut test_sender =
1381 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1382 let time_between_packets = Duration::from_millis(1);
1383 let expected_bandwidth =
1384 Bandwidth::from_kbits_per_second(REGULAR_PACKET_SIZE as u64 * 8);
1385
1386 test_sender.send_40_and_ack_first_20(time_between_packets);
1387
1388 for i in 0..20 {
1391 let last_bandwidth = test_sender.ack_packet(40 - i).bandwidth;
1392 assert_eq!(expected_bandwidth, last_bandwidth);
1393 test_sender.send_packet(41 + i, REGULAR_PACKET_SIZE, true);
1394 test_sender.advance_time(time_between_packets);
1395 }
1396
1397 for i in 41..=60 {
1399 let last_bandwidth = test_sender.ack_packet(i).bandwidth;
1400 assert_eq!(expected_bandwidth, last_bandwidth);
1401 test_sender.advance_time(time_between_packets);
1402 }
1403
1404 test_sender.sampler.remove_obsolete_packets(61);
1405 assert_eq!(0, test_sender.number_of_tracked_packets());
1406 assert_eq!(0, test_sender.bytes_in_flight);
1407 }
1408
1409 #[rstest]
1411 fn app_limited(
1412 #[values(false, true)] overestimate_avoidance: bool,
1413 #[values(false, true)] choose_a0_point_fix: bool,
1414 ) {
1415 let mut test_sender =
1416 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1417 let time_between_packets = Duration::from_millis(1);
1418 let expected_bandwidth =
1419 Bandwidth::from_kbits_per_second(REGULAR_PACKET_SIZE as u64 * 8);
1420
1421 for i in 1..=20 {
1422 test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1423 test_sender.advance_time(time_between_packets);
1424 }
1425
1426 for i in 1..=20 {
1427 let sample = test_sender.ack_packet(i);
1428 assert_eq!(
1429 sample.state_at_send.is_app_limited,
1430 test_sender.sampler_app_limited_at_start,
1431 "{i}"
1432 );
1433 test_sender.send_packet(i + 20, REGULAR_PACKET_SIZE, true);
1434 test_sender.advance_time(time_between_packets);
1435 }
1436
1437 test_sender.sampler.on_app_limited();
1440 for i in 21..=40 {
1441 let sample = test_sender.ack_packet(i);
1442 assert!(!sample.state_at_send.is_app_limited, "{i}");
1443 assert_eq!(expected_bandwidth, sample.bandwidth, "{i}");
1444 test_sender.advance_time(time_between_packets);
1445 }
1446
1447 test_sender.advance_time(Duration::from_secs(1));
1449
1450 for i in 41..=60 {
1452 test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1453 test_sender.advance_time(time_between_packets);
1454 }
1455
1456 for i in 41..=60 {
1459 let sample = test_sender.ack_packet(i);
1460 assert!(sample.state_at_send.is_app_limited, "{i}");
1461 if !overestimate_avoidance || choose_a0_point_fix || i < 43 {
1462 assert!(
1463 sample.bandwidth < expected_bandwidth * 0.7,
1464 "{} {:?} vs {:?}",
1465 i,
1466 sample.bandwidth,
1467 expected_bandwidth * 0.7
1468 );
1469 } else {
1470 assert_eq!(sample.bandwidth, expected_bandwidth, "{i}");
1473 }
1474 test_sender.send_packet(i + 20, REGULAR_PACKET_SIZE, true);
1475 test_sender.advance_time(time_between_packets);
1476 }
1477
1478 for i in 61..=80 {
1481 let sample = test_sender.ack_packet(i);
1482 assert!(!sample.state_at_send.is_app_limited, "{i}");
1483 assert_eq!(sample.bandwidth, expected_bandwidth, "{i}");
1484 test_sender.advance_time(time_between_packets);
1485 }
1486
1487 test_sender.sampler.remove_obsolete_packets(81);
1488 assert_eq!(0, test_sender.number_of_tracked_packets());
1489 assert_eq!(0, test_sender.bytes_in_flight);
1490 }
1491
1492 #[rstest]
1494 fn first_round_trip(
1495 #[values(false, true)] overestimate_avoidance: bool,
1496 #[values(false, true)] choose_a0_point_fix: bool,
1497 ) {
1498 let mut test_sender =
1499 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1500 let time_between_packets = Duration::from_millis(1);
1501 let rtt = Duration::from_millis(800);
1502 let num_packets = 10;
1503 let num_bytes = REGULAR_PACKET_SIZE * num_packets;
1504 let real_bandwidth = Bandwidth::from_bytes_and_time_delta(num_bytes, rtt);
1505
1506 for i in 1..=10 {
1507 test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1508 test_sender.advance_time(time_between_packets);
1509 }
1510 test_sender.advance_time(rtt - time_between_packets * num_packets as _);
1511
1512 let mut last_sample = Bandwidth::zero();
1513 for i in 1..=10 {
1514 let sample = test_sender.ack_packet(i).bandwidth;
1515 assert!(sample > last_sample);
1516 last_sample = sample;
1517 test_sender.advance_time(time_between_packets);
1518 }
1519
1520 assert!(last_sample < real_bandwidth);
1527 assert!(last_sample > real_bandwidth * 0.9);
1528 }
1529
1530 #[rstest]
1532 fn remove_obsolete_packets(
1533 #[values(false, true)] overestimate_avoidance: bool,
1534 #[values(false, true)] choose_a0_point_fix: bool,
1535 ) {
1536 let mut test_sender =
1537 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1538
1539 for i in 1..=5 {
1540 test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1541 }
1542 test_sender.advance_time(Duration::from_millis(100));
1543 assert_eq!(5, test_sender.number_of_tracked_packets());
1544 test_sender.sampler.remove_obsolete_packets(4);
1545 assert_eq!(2, test_sender.number_of_tracked_packets());
1546 test_sender.lose_packet(4);
1547 test_sender.sampler.remove_obsolete_packets(5);
1548 assert_eq!(1, test_sender.number_of_tracked_packets());
1549 test_sender.ack_packet(5);
1550 test_sender.sampler.remove_obsolete_packets(6);
1551 assert_eq!(0, test_sender.number_of_tracked_packets());
1552 }
1553
1554 #[rstest]
1555 fn neuter_packet(
1556 #[values(false, true)] overestimate_avoidance: bool,
1557 #[values(false, true)] choose_a0_point_fix: bool,
1558 ) {
1559 let mut test_sender =
1560 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1561 test_sender.send_packet(1, REGULAR_PACKET_SIZE, true);
1562 assert_eq!(test_sender.sampler.total_bytes_neutered, 0);
1563 test_sender.advance_time(Duration::from_millis(10));
1564 test_sender.sampler.on_packet_neutered(1);
1565 assert!(0 < test_sender.sampler.total_bytes_neutered);
1566 assert_eq!(0, test_sender.sampler.total_bytes_acked);
1567
1568 let acked = Acked {
1570 pkt_num: 1,
1571 time_sent: test_sender.clock,
1572 };
1573 test_sender.advance_time(Duration::from_millis(10));
1574 let sample = test_sender.sampler.on_congestion_event(
1575 test_sender.clock,
1576 &[acked],
1577 &[],
1578 Some(test_sender.max_bandwidth),
1579 test_sender.est_bandwidth_upper_bound,
1580 test_sender.round_trip_count,
1581 );
1582
1583 assert_eq!(0, test_sender.sampler.total_bytes_acked);
1584 assert!(sample.sample_max_bandwidth.is_none());
1585 assert!(!sample.sample_is_app_limited);
1586 assert!(sample.sample_rtt.is_none());
1587 assert_eq!(sample.sample_max_inflight, 0);
1588 assert_eq!(sample.extra_acked, 0);
1589 }
1590
1591 #[rstest]
1595 fn congestion_event_sample_default_values() {
1596 let sample = CongestionEventSample::default();
1597 assert!(sample.sample_max_bandwidth.is_none());
1598 assert!(!sample.sample_is_app_limited);
1599 assert!(sample.sample_rtt.is_none());
1600 assert_eq!(sample.sample_max_inflight, 0);
1601 assert_eq!(sample.extra_acked, 0);
1602 }
1603
1604 #[rstest]
1606 fn two_acked_packets_per_event(
1607 #[values(false, true)] overestimate_avoidance: bool,
1608 #[values(false, true)] choose_a0_point_fix: bool,
1609 ) {
1610 let mut test_sender =
1611 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1612 let time_between_packets = Duration::from_millis(10);
1613 let sending_rate = Bandwidth::from_bytes_and_time_delta(
1614 REGULAR_PACKET_SIZE,
1615 time_between_packets,
1616 );
1617
1618 for i in 1..21 {
1619 test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1620 test_sender.advance_time(time_between_packets);
1621 if i % 2 != 0 {
1622 continue;
1623 }
1624
1625 let sample = test_sender.on_congestion_event(&[i - 1, i], &[]);
1626 assert_eq!(sending_rate, sample.sample_max_bandwidth.unwrap());
1627 assert_eq!(time_between_packets, sample.sample_rtt.unwrap());
1628 assert_eq!(2 * REGULAR_PACKET_SIZE, sample.sample_max_inflight);
1629 assert!(sample.last_packet_send_state.is_valid);
1630 assert_eq!(
1631 2 * REGULAR_PACKET_SIZE,
1632 sample.last_packet_send_state.bytes_in_flight
1633 );
1634 assert_eq!(
1635 i as usize * REGULAR_PACKET_SIZE,
1636 sample.last_packet_send_state.total_bytes_sent
1637 );
1638 assert_eq!(
1639 (i - 2) as usize * REGULAR_PACKET_SIZE,
1640 sample.last_packet_send_state.total_bytes_acked
1641 );
1642 assert_eq!(0, sample.last_packet_send_state.total_bytes_lost);
1643 test_sender.sampler.remove_obsolete_packets(i - 2);
1644 }
1645 }
1646
1647 #[rstest]
1648 fn lose_every_other_packet(
1649 #[values(false, true)] overestimate_avoidance: bool,
1650 #[values(false, true)] choose_a0_point_fix: bool,
1651 ) {
1652 let mut test_sender =
1653 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1654 let time_between_packets = Duration::from_millis(10);
1655 let sending_rate = Bandwidth::from_bytes_and_time_delta(
1656 REGULAR_PACKET_SIZE,
1657 time_between_packets,
1658 );
1659
1660 for i in 1..21 {
1661 test_sender.send_packet(i, REGULAR_PACKET_SIZE, true);
1662 test_sender.advance_time(time_between_packets);
1663 if i % 2 != 0 {
1664 continue;
1665 }
1666 let sample = test_sender.on_congestion_event(&[i], &[i - 1]);
1668 assert_eq!(sending_rate, sample.sample_max_bandwidth.unwrap() * 2.);
1671 assert_eq!(time_between_packets, sample.sample_rtt.unwrap());
1672 assert_eq!(REGULAR_PACKET_SIZE, sample.sample_max_inflight);
1673 assert!(sample.last_packet_send_state.is_valid);
1674 assert_eq!(
1675 2 * REGULAR_PACKET_SIZE,
1676 sample.last_packet_send_state.bytes_in_flight
1677 );
1678 assert_eq!(
1679 i as usize * REGULAR_PACKET_SIZE,
1680 sample.last_packet_send_state.total_bytes_sent
1681 );
1682 assert_eq!(
1683 (i - 2) as usize * REGULAR_PACKET_SIZE / 2,
1684 sample.last_packet_send_state.total_bytes_acked
1685 );
1686 assert_eq!(
1687 (i - 2) as usize * REGULAR_PACKET_SIZE / 2,
1688 sample.last_packet_send_state.total_bytes_lost
1689 );
1690 test_sender.sampler.remove_obsolete_packets(i - 2);
1691 }
1692 }
1693
1694 #[rstest]
1695 fn ack_height_respect_bandwidth_estimate_upper_bound(
1696 #[values(false, true)] overestimate_avoidance: bool,
1697 #[values(false, true)] choose_a0_point_fix: bool,
1698 ) {
1699 let mut test_sender =
1700 TestSender::new(overestimate_avoidance, choose_a0_point_fix);
1701 let time_between_packets = Duration::from_millis(10);
1702 let first_packet_sending_rate = Bandwidth::from_bytes_and_time_delta(
1703 REGULAR_PACKET_SIZE,
1704 time_between_packets,
1705 );
1706
1707 test_sender.send_packet(1, REGULAR_PACKET_SIZE, true);
1709 test_sender.advance_time(time_between_packets);
1710 test_sender.send_packet(2, REGULAR_PACKET_SIZE, true);
1711 test_sender.send_packet(3, REGULAR_PACKET_SIZE, true);
1712 test_sender.send_packet(4, REGULAR_PACKET_SIZE, true);
1713 let sample = test_sender.on_congestion_event(&[1], &[]);
1714 assert_eq!(
1715 first_packet_sending_rate,
1716 sample.sample_max_bandwidth.unwrap()
1717 );
1718 assert_eq!(first_packet_sending_rate, test_sender.max_bandwidth);
1719
1720 test_sender.round_trip_count += 1;
1723 test_sender.est_bandwidth_upper_bound = first_packet_sending_rate * 0.3;
1724 test_sender.advance_time(time_between_packets);
1725
1726 let sample = test_sender.on_congestion_event(&[2, 3, 4], &[]);
1727
1728 assert_eq!(
1729 first_packet_sending_rate * 2.,
1730 sample.sample_max_bandwidth.unwrap()
1731 );
1732 assert_eq!(
1733 test_sender.max_bandwidth,
1734 sample.sample_max_bandwidth.unwrap()
1735 );
1736 assert!(2 * REGULAR_PACKET_SIZE < sample.extra_acked);
1737 }
1738}
1739
1740#[cfg(test)]
1741mod max_ack_height_tracker_tests {
1742 use rstest::rstest;
1743
1744 use super::*;
1745
1746 struct TestTracker {
1747 tracker: MaxAckHeightTracker,
1748 bandwidth: Bandwidth,
1749 start: Instant,
1750 now: Instant,
1751 last_sent_packet_number: u64,
1752 last_acked_packet_number: u64,
1753 rtt: Duration,
1754 }
1755
1756 impl TestTracker {
1757 fn new(overestimate_avoidance: bool) -> Self {
1758 let mut tracker =
1759 MaxAckHeightTracker::new(10, overestimate_avoidance);
1760 tracker.ack_aggregation_bandwidth_threshold = 1.8;
1761 tracker.start_new_aggregation_epoch_after_full_round = true;
1762 let start = Instant::now();
1763 TestTracker {
1764 tracker,
1765 start,
1766 now: start + Duration::from_millis(1),
1767 bandwidth: Bandwidth::from_bytes_per_second(10 * 1000),
1768 last_sent_packet_number: 0,
1769 last_acked_packet_number: 0,
1770 rtt: Duration::from_millis(60),
1771 }
1772 }
1773
1774 fn aggregation_episode(
1780 &mut self, aggregation_bandwidth: Bandwidth,
1781 aggregation_duration: Duration, bytes_per_ack: usize,
1782 expect_new_aggregation_epoch: bool,
1783 ) {
1784 assert!(aggregation_bandwidth >= self.bandwidth);
1785 let start_time = self.now;
1786
1787 let aggregation_bytes =
1788 (aggregation_bandwidth * aggregation_duration) as usize;
1789
1790 let num_acks = aggregation_bytes / bytes_per_ack;
1791 assert_eq!(aggregation_bytes, num_acks * bytes_per_ack);
1792
1793 let time_between_acks = Duration::from_micros(
1794 aggregation_duration.as_micros() as u64 / num_acks as u64,
1795 );
1796 assert_eq!(aggregation_duration, time_between_acks * num_acks as u32);
1797
1798 let total_duration = Duration::from_micros(
1800 (aggregation_bytes as u64 * 8 * 1000000) /
1801 self.bandwidth.to_bits_per_second(),
1802 );
1803
1804 assert_eq!(aggregation_bytes as u64, self.bandwidth * total_duration);
1805
1806 let mut last_extra_acked = 0;
1807
1808 for bytes in (0..aggregation_bytes).step_by(bytes_per_ack) {
1809 let extra_acked = self.tracker.update(
1810 self.bandwidth,
1811 true,
1812 self.round_trip_count(),
1813 self.last_sent_packet_number,
1814 self.last_acked_packet_number,
1815 self.now,
1816 bytes_per_ack,
1817 );
1818 if (bytes == 0 && expect_new_aggregation_epoch) ||
1824 (aggregation_bandwidth == self.bandwidth)
1825 {
1826 assert_eq!(0, extra_acked);
1827 } else {
1828 assert!(last_extra_acked < extra_acked);
1829 }
1830 self.now += time_between_acks;
1831 last_extra_acked = extra_acked;
1832 }
1833
1834 self.now = start_time + total_duration;
1836 }
1837
1838 fn round_trip_count(&self) -> usize {
1839 ((self.now - self.start).as_micros() / self.rtt.as_micros()) as usize
1840 }
1841 }
1842
1843 fn test_inner(
1844 overestimate_avoidance: bool, bandwidth_gain: f64,
1845 agg_duration: Duration, byte_per_ack: usize,
1846 ) {
1847 let mut test_tracker = TestTracker::new(overestimate_avoidance);
1848
1849 let rnd = |tracker: &mut TestTracker, expect: bool| {
1850 tracker.aggregation_episode(
1851 tracker.bandwidth * bandwidth_gain,
1852 agg_duration,
1853 byte_per_ack,
1854 expect,
1855 );
1856 };
1857
1858 rnd(&mut test_tracker, true);
1859 rnd(&mut test_tracker, true);
1860
1861 test_tracker.now = test_tracker
1862 .now
1863 .checked_sub(Duration::from_millis(1))
1864 .unwrap();
1865
1866 if test_tracker.tracker.ack_aggregation_bandwidth_threshold > 1.1 {
1867 rnd(&mut test_tracker, true);
1868 assert_eq!(3, test_tracker.tracker.num_ack_aggregation_epochs);
1869 } else {
1870 rnd(&mut test_tracker, false);
1871 assert_eq!(2, test_tracker.tracker.num_ack_aggregation_epochs);
1872 }
1873 }
1874
1875 #[rstest]
1876 fn very_aggregated_large_acks(
1877 #[values(false, true)] overestimate_avoidance: bool,
1878 ) {
1879 test_inner(overestimate_avoidance, 20.0, Duration::from_millis(6), 1200)
1880 }
1881
1882 #[rstest]
1883 fn very_aggregated_small_acks(
1884 #[values(false, true)] overestimate_avoidance: bool,
1885 ) {
1886 test_inner(overestimate_avoidance, 20., Duration::from_millis(6), 300)
1887 }
1888
1889 #[rstest]
1890 fn somewhat_aggregated_large_acks(
1891 #[values(false, true)] overestimate_avoidance: bool,
1892 ) {
1893 test_inner(overestimate_avoidance, 2.0, Duration::from_millis(50), 1000)
1894 }
1895
1896 #[rstest]
1897 fn somewhat_aggregated_small_acks(
1898 #[values(false, true)] overestimate_avoidance: bool,
1899 ) {
1900 test_inner(overestimate_avoidance, 2.0, Duration::from_millis(50), 100)
1901 }
1902
1903 #[rstest]
1904 fn not_aggregated(#[values(false, true)] overestimate_avoidance: bool) {
1905 let mut test_tracker = TestTracker::new(overestimate_avoidance);
1906 test_tracker.aggregation_episode(
1907 test_tracker.bandwidth,
1908 Duration::from_millis(100),
1909 100,
1910 true,
1911 );
1912 assert!(2 < test_tracker.tracker.num_ack_aggregation_epochs);
1913 }
1914
1915 #[rstest]
1916 fn start_new_epoch_after_a_full_round(
1917 #[values(false, true)] overestimate_avoidance: bool,
1918 ) {
1919 let mut test_tracker = TestTracker::new(overestimate_avoidance);
1920
1921 test_tracker.last_sent_packet_number = 10;
1922
1923 test_tracker.aggregation_episode(
1924 test_tracker.bandwidth * 2.0,
1925 Duration::from_millis(50),
1926 100,
1927 true,
1928 );
1929
1930 test_tracker.last_acked_packet_number = 11;
1931
1932 test_tracker.tracker.update(
1936 test_tracker.bandwidth * 0.1,
1937 true,
1938 test_tracker.round_trip_count(),
1939 test_tracker.last_sent_packet_number,
1940 test_tracker.last_acked_packet_number,
1941 test_tracker.now,
1942 100,
1943 );
1944
1945 assert_eq!(2, test_tracker.tracker.num_ack_aggregation_epochs)
1946 }
1947}