Skip to main content

tokio_quiche/quic/io/
worker.rs

1// Copyright (C) 2025, Cloudflare, Inc.
2// All rights reserved.
3//
4// Redistribution and use in source and binary forms, with or without
5// modification, are permitted provided that the following conditions are
6// met:
7//
8//     * Redistributions of source code must retain the above copyright notice,
9//       this list of conditions and the following disclaimer.
10//
11//     * Redistributions in binary form must reproduce the above copyright
12//       notice, this list of conditions and the following disclaimer in the
13//       documentation and/or other materials provided with the distribution.
14//
15// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
16// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
17// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
19// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
27use std::net::SocketAddr;
28use std::ops::ControlFlow;
29use std::sync::Arc;
30use std::task::Poll;
31use std::time::Duration;
32use std::time::Instant;
33#[cfg(feature = "perf-quic-listener-metrics")]
34use std::time::SystemTime;
35
36use super::connection_stage::Close;
37use super::connection_stage::ConnectionStage;
38use super::connection_stage::ConnectionStageContext;
39use super::connection_stage::Handshake;
40use super::connection_stage::RunningApplication;
41use super::gso::*;
42use super::utilization_estimator::BandwidthReporter;
43
44use crate::metrics::labels;
45use crate::metrics::Metrics;
46use crate::quic::connection::ApplicationOverQuic;
47use crate::quic::connection::HandshakeError;
48use crate::quic::connection::Incoming;
49use crate::quic::connection::QuicConnectionStats;
50use crate::quic::connection::SharedConnectionIdGenerator;
51use crate::quic::router::ConnectionMapCommand;
52use crate::quic::QuicheConnection;
53use crate::QuicResult;
54
55use boring::ssl::SslRef;
56use datagram_socket::DatagramSocketSend;
57use datagram_socket::DatagramSocketSendExt;
58use datagram_socket::MaybeConnectedSocket;
59use datagram_socket::QuicAuditStats;
60use foundations::telemetry::log;
61use quiche::ConnectionId;
62use quiche::Error as QuicheError;
63use quiche::SendInfo;
64use tokio::select;
65use tokio::sync::mpsc;
66use tokio::time;
67
68// Number of incoming packets to be buffered in the incoming channel.
69pub(crate) const INCOMING_QUEUE_SIZE: usize = 2048;
70
71// Check if there are any incoming packets while sending data every this number
72// of sent packets
73pub(crate) const CHECK_INCOMING_QUEUE_RATIO: usize = INCOMING_QUEUE_SIZE / 16;
74
75const RELEASE_TIMER_THRESHOLD: Duration = Duration::from_micros(250);
76
77/// Stop queuing GSO packets, if packet size is below this threshold.
78const GSO_THRESHOLD: usize = 1_000;
79
80pub struct WriterConfig {
81    pub pending_cid: Option<ConnectionId<'static>>,
82    pub peer_addr: SocketAddr,
83    pub local_addr: SocketAddr,
84    pub with_gso: bool,
85    pub pacing_offload: bool,
86    pub with_pktinfo: bool,
87}
88
89#[derive(Default)]
90pub(crate) struct WriteState {
91    conn_established: bool,
92    bytes_written: usize,
93    segment_size: usize,
94    num_pkts: usize,
95    tx_time: Option<Instant>,
96    has_pending_data: bool,
97    // If pacer schedules packets too far into the future, we want to pause
98    // sending, until the future arrives
99    next_release_time: Option<Instant>,
100    // The selected source and destination addresses for the current write
101    // cycle.
102    selected_path: Option<(SocketAddr, SocketAddr)>,
103    // Iterator over the network paths that haven't been flushed yet.
104    pending_paths: quiche::SocketAddrIter,
105}
106
107pub(crate) struct IoWorkerParams<Tx, M> {
108    pub(crate) socket: MaybeConnectedSocket<Tx>,
109    pub(crate) shutdown_tx: mpsc::Sender<()>,
110    pub(crate) cfg: WriterConfig,
111    pub(crate) audit_log_stats: Arc<QuicAuditStats>,
112    pub(crate) write_state: WriteState,
113    pub(crate) conn_map_cmd_tx: mpsc::UnboundedSender<ConnectionMapCommand>,
114    pub(crate) cid_generator: Option<SharedConnectionIdGenerator>,
115    #[cfg(feature = "perf-quic-listener-metrics")]
116    pub(crate) init_rx_time: Option<SystemTime>,
117    pub(crate) metrics: M,
118}
119
120pub(crate) struct IoWorker<Tx, M, S> {
121    socket: MaybeConnectedSocket<Tx>,
122    /// A field that signals to the listener task that the connection has gone
123    /// away (nothing is sent here, listener task just detects the sender
124    /// has dropped)
125    shutdown_tx: mpsc::Sender<()>,
126    cfg: WriterConfig,
127    audit_log_stats: Arc<QuicAuditStats>,
128    write_state: WriteState,
129    conn_map_cmd_tx: mpsc::UnboundedSender<ConnectionMapCommand>,
130    cid_generator: Option<SharedConnectionIdGenerator>,
131    #[cfg(feature = "perf-quic-listener-metrics")]
132    init_rx_time: Option<SystemTime>,
133    metrics: M,
134    conn_stage: S,
135    bw_estimator: BandwidthReporter,
136}
137
138impl<Tx, M, S> IoWorker<Tx, M, S>
139where
140    Tx: DatagramSocketSend + Send,
141    M: Metrics,
142    S: ConnectionStage,
143{
144    pub(crate) fn new(params: IoWorkerParams<Tx, M>, conn_stage: S) -> Self {
145        let bw_estimator =
146            BandwidthReporter::new(params.metrics.utilized_bandwidth());
147
148        log::trace!("Creating IoWorker with stage: {conn_stage:?}");
149
150        Self {
151            socket: params.socket,
152            shutdown_tx: params.shutdown_tx,
153            cfg: params.cfg,
154            audit_log_stats: params.audit_log_stats,
155            write_state: params.write_state,
156            conn_map_cmd_tx: params.conn_map_cmd_tx,
157            cid_generator: params.cid_generator,
158            #[cfg(feature = "perf-quic-listener-metrics")]
159            init_rx_time: params.init_rx_time,
160            metrics: params.metrics,
161            conn_stage,
162            bw_estimator,
163        }
164    }
165
166    fn fill_available_scids(&self, qconn: &mut QuicheConnection) {
167        if qconn.scids_left() == 0 {
168            return;
169        }
170        let Some(cid_generator) = self.cid_generator.as_deref() else {
171            return;
172        };
173
174        let current_cid = qconn.source_id().into_owned();
175        for _ in 0..qconn.scids_left() {
176            // We don't emit stateless resets, so any unguessable value is fine
177            let reset_token = random_u128();
178            let new_cid = cid_generator.new_connection_id();
179
180            if self
181                .conn_map_cmd_tx
182                .send(ConnectionMapCommand::MapCid {
183                    existing_cid: current_cid.clone(),
184                    new_cid: new_cid.clone(),
185                })
186                .is_err()
187            {
188                // Can't do anything if the connection map is gone
189                return;
190            }
191
192            if qconn.new_scid(&new_cid, reset_token, false).is_err() {
193                // This only fails if we have reached the CID limit already
194                return;
195            }
196        }
197    }
198
199    fn unmap_cid(&self, cid: ConnectionId<'static>) {
200        // If the connection map is gone, the ID is already "unmapped"
201        let _ = self
202            .conn_map_cmd_tx
203            .send(ConnectionMapCommand::UnmapCid(cid));
204    }
205
206    fn refresh_connection_ids(&self, qconn: &mut QuicheConnection) {
207        // Top up the connection's active CIDs
208        self.fill_available_scids(qconn);
209
210        // Remove retired CIDs from the ingress router
211        while let Some(retired_cid) = qconn.retired_scid_next() {
212            self.unmap_cid(retired_cid);
213        }
214    }
215
216    async fn work_loop<A: ApplicationOverQuic>(
217        &mut self, qconn: &mut QuicheConnection,
218        ctx: &mut ConnectionStageContext<A>,
219    ) -> QuicResult<()> {
220        const DEFAULT_SLEEP: Duration = Duration::from_secs(60);
221        let mut current_deadline: Option<Instant> = None;
222        let sleep = time::sleep(DEFAULT_SLEEP);
223        tokio::pin!(sleep);
224
225        loop {
226            let now = Instant::now();
227
228            self.write_state.has_pending_data = true;
229
230            while self.write_state.has_pending_data {
231                let mut packets_sent = 0;
232
233                // Try to clear all received packets every so often, because
234                // incoming packets contain acks, and because the
235                // receive queue has a very limited size, once it is full incoming
236                // packets get stalled indefinitely
237                let mut did_recv = false;
238                while let Some(pkt) = ctx
239                    .in_pkt
240                    .take()
241                    .or_else(|| ctx.incoming_pkt_receiver.try_recv().ok())
242                {
243                    self.process_incoming(qconn, pkt)?;
244                    did_recv = true;
245                }
246
247                self.conn_stage.on_read(did_recv, qconn, ctx)?;
248                self.refresh_connection_ids(qconn);
249
250                let can_release = match self.write_state.next_release_time {
251                    None => true,
252                    Some(next_release) =>
253                        next_release
254                            .checked_duration_since(now)
255                            .unwrap_or_default() <
256                            RELEASE_TIMER_THRESHOLD,
257                };
258
259                self.write_state.has_pending_data &= can_release;
260
261                while self.write_state.has_pending_data &&
262                    packets_sent < CHECK_INCOMING_QUEUE_RATIO
263                {
264                    self.gather_data_from_quiche_conn(qconn, ctx.buffer())?;
265
266                    // Break if the connection is closed
267                    if qconn.is_closed() {
268                        return Ok(());
269                    }
270
271                    let mut flush_operation_token =
272                        TrackMidHandshakeFlush::new(self.metrics.clone());
273
274                    self.flush_buffer_to_socket(ctx.buffer()).await;
275
276                    flush_operation_token.mark_complete();
277
278                    packets_sent += self.write_state.num_pkts;
279
280                    if let ControlFlow::Break(reason) =
281                        self.conn_stage.on_flush(qconn, ctx)
282                    {
283                        return reason;
284                    }
285                }
286            }
287
288            self.bw_estimator.update(qconn, now);
289
290            self.audit_log_stats
291                .set_max_bandwidth(self.bw_estimator.max_bandwidth);
292            self.audit_log_stats.set_max_loss_pct(
293                (self.bw_estimator.max_loss_pct * 100_f32).round() as u8,
294            );
295
296            let new_deadline = min_of_some(
297                qconn.timeout_instant(),
298                self.write_state.next_release_time,
299            );
300            let new_deadline =
301                min_of_some(new_deadline, self.conn_stage.wait_deadline());
302
303            if new_deadline != current_deadline {
304                current_deadline = new_deadline;
305
306                sleep
307                    .as_mut()
308                    .reset(new_deadline.unwrap_or(now + DEFAULT_SLEEP).into());
309            }
310
311            let incoming_recv = &mut ctx.incoming_pkt_receiver;
312            let application = &mut ctx.application;
313            select! {
314                biased;
315                () = &mut sleep => {
316                    // It's very important that we keep the timeout arm at the top of this loop so
317                    // that we poll it every time we need to. Since this is a biased `select!`, if
318                    // we put this behind another arm, we could theoretically starve the sleep arm
319                    // and hang connections.
320                    //
321                    // See https://docs.rs/tokio/latest/tokio/macro.select.html#fairness for more
322                    qconn.on_timeout();
323
324                    self.write_state.next_release_time = None;
325                    current_deadline = None;
326                    sleep.as_mut().reset((now + DEFAULT_SLEEP).into());
327                }
328                Some(pkt) = incoming_recv.recv() => ctx.in_pkt = Some(pkt),
329                // TODO(erittenhouse): would be nice to decouple wait_for_data from the
330                // application, but wait_for_quiche relies on IOW methods, so we can't write a
331                // default implementation for ConnectionStage
332                status = self.wait_for_data_or_handshake(qconn, application) => status?,
333            };
334
335            if let ControlFlow::Break(reason) = self.conn_stage.post_wait(qconn) {
336                return reason;
337            }
338        }
339    }
340
341    #[cfg(feature = "perf-quic-listener-metrics")]
342    fn measure_complete_handshake_time(&mut self) {
343        if let Some(init_rx_time) = self.init_rx_time.take() {
344            if let Ok(delta) = init_rx_time.elapsed() {
345                self.metrics
346                    .handshake_time_seconds(
347                        labels::QuicHandshakeStage::HandshakeResponse,
348                    )
349                    .observe(delta.as_nanos() as u64);
350            }
351        }
352    }
353
354    fn gather_data_from_quiche_conn(
355        &mut self, qconn: &mut QuicheConnection, send_buf: &mut [u8],
356    ) -> QuicResult<usize> {
357        let mut segment_size = None;
358        let mut send_info = None;
359
360        self.write_state.num_pkts = 0;
361        self.write_state.bytes_written = 0;
362
363        self.write_state.selected_path = None;
364
365        let now = Instant::now();
366
367        let send_buf = {
368            let trunc = UDP_MAX_GSO_PACKET_SIZE.min(send_buf.len());
369            &mut send_buf[..trunc]
370        };
371
372        #[cfg(feature = "gcongestion")]
373        let gcongestion_enabled = true;
374
375        #[cfg(not(feature = "gcongestion"))]
376        let gcongestion_enabled = qconn.gcongestion_enabled().unwrap_or(false);
377
378        let initial_release_decision = if gcongestion_enabled {
379            let initial_release_decision = qconn
380                .get_next_release_time()
381                .filter(|_| self.pacing_enabled(qconn));
382
383            if let Some(future_release_time) =
384                initial_release_decision.as_ref().and_then(|v| v.time(now))
385            {
386                let max_into_fut = qconn.max_release_into_future();
387
388                if future_release_time.duration_since(now) >= max_into_fut {
389                    self.write_state.next_release_time =
390                        Some(now + max_into_fut.mul_f32(0.8));
391                    self.write_state.has_pending_data = false;
392                    return Ok(0);
393                }
394            }
395
396            initial_release_decision
397        } else {
398            None
399        };
400
401        let buffer_write_outcome = loop {
402            let outcome = self.write_packet_to_buffer(
403                qconn,
404                send_buf,
405                &mut send_info,
406                segment_size,
407            );
408
409            let packet_size = match outcome {
410                Ok(0) => break Ok(0),
411
412                Ok(bytes_written) => bytes_written,
413
414                Err(e) => break Err(e),
415            };
416
417            // Flush to network after generating a single packet when GSO
418            // is disabled.
419            if !self.cfg.with_gso {
420                break outcome;
421            }
422
423            #[cfg(not(feature = "gcongestion"))]
424            let max_send_size = if !gcongestion_enabled {
425                // Only call qconn.send_quantum when !gcongestion_enabled.
426                tune_max_send_size(
427                    segment_size,
428                    qconn.send_quantum(),
429                    send_buf.len(),
430                )
431            } else {
432                usize::MAX
433            };
434
435            #[cfg(feature = "gcongestion")]
436            let max_send_size = usize::MAX;
437
438            // If segment_size is known, update the maximum of
439            // GSO sender buffer size to the multiple of
440            // segment_size.
441            let buffer_is_full = self.write_state.num_pkts ==
442                UDP_MAX_SEGMENT_COUNT ||
443                self.write_state.bytes_written >= max_send_size;
444
445            if buffer_is_full {
446                break outcome;
447            }
448
449            // Flush to network when the newly generated packet size is
450            // different from previously written packet, as GSO needs packets
451            // to have the same size, except for the last one in the buffer.
452            // The last packet may be smaller than the previous size.
453            match segment_size {
454                Some(size)
455                    if packet_size != size || packet_size < GSO_THRESHOLD =>
456                    break outcome,
457                None => segment_size = Some(packet_size),
458                _ => (),
459            }
460
461            if gcongestion_enabled {
462                // If the release time of next packet is different, or it can't be
463                // part of a burst, start the next batch
464                if let Some(initial_release_decision) = initial_release_decision {
465                    match qconn.get_next_release_time() {
466                        Some(release)
467                            if release.can_burst() ||
468                                release.time_eq(
469                                    &initial_release_decision,
470                                    now,
471                                ) => {},
472                        _ => break outcome,
473                    }
474                }
475            }
476        };
477
478        let tx_time = if gcongestion_enabled {
479            initial_release_decision
480                .filter(|_| self.pacing_enabled(qconn))
481                // Return the time from the release decision if release_decision.time > now, else None.
482                .and_then(|v| v.time(now))
483        } else {
484            send_info
485                .filter(|_| self.pacing_enabled(qconn))
486                .map(|v| v.at)
487        };
488
489        self.write_state.conn_established = qconn.is_established();
490        self.write_state.tx_time = tx_time;
491        self.write_state.segment_size =
492            segment_size.unwrap_or(self.write_state.bytes_written);
493
494        if !gcongestion_enabled {
495            if let Some(time) = tx_time {
496                const DEFAULT_MAX_INTO_FUTURE: Duration =
497                    Duration::from_millis(1);
498                if time
499                    .checked_duration_since(now)
500                    .map(|d| d > DEFAULT_MAX_INTO_FUTURE)
501                    .unwrap_or(false)
502                {
503                    self.write_state.next_release_time =
504                        Some(now + DEFAULT_MAX_INTO_FUTURE.mul_f32(0.8));
505                    self.write_state.has_pending_data = false;
506                    return Ok(0);
507                }
508            }
509        }
510
511        buffer_write_outcome
512    }
513
514    /// Selects a network path, if none already selected.
515    ///
516    /// This will return the first path available in the write state's
517    /// `pending_paths` iterator. If that is empty a new iterator will be
518    /// created by querying quiche itself.
519    ///
520    /// Note that the connection's statically configured local address will be
521    /// used to query quiche for available paths, so this can't handle multiple
522    /// local addresses currently.
523    fn select_path(
524        &mut self, qconn: &QuicheConnection,
525    ) -> Option<(SocketAddr, SocketAddr)> {
526        if self.write_state.selected_path.is_some() {
527            return self.write_state.selected_path;
528        }
529
530        let from = self.cfg.local_addr;
531
532        // Initialize paths iterator.
533        if self.write_state.pending_paths.len() == 0 {
534            self.write_state.pending_paths = qconn.paths_iter(from);
535        }
536
537        let to = self.write_state.pending_paths.next()?;
538
539        Some((from, to))
540    }
541
542    #[cfg(not(feature = "gcongestion"))]
543    fn pacing_enabled(&self, qconn: &QuicheConnection) -> bool {
544        self.cfg.pacing_offload && qconn.pacing_enabled()
545    }
546
547    #[cfg(feature = "gcongestion")]
548    fn pacing_enabled(&self, _qconn: &QuicheConnection) -> bool {
549        self.cfg.pacing_offload
550    }
551
552    fn write_packet_to_buffer(
553        &mut self, qconn: &mut QuicheConnection, send_buf: &mut [u8],
554        send_info: &mut Option<SendInfo>, segment_size: Option<usize>,
555    ) -> QuicResult<usize> {
556        let mut send_buf = &mut send_buf[self.write_state.bytes_written..];
557        if send_buf.len() > segment_size.unwrap_or(usize::MAX) {
558            // Never let the buffer be longer than segment size, for GSO to
559            // function properly.
560            send_buf = &mut send_buf[..segment_size.unwrap_or(usize::MAX)];
561        }
562
563        // On the first call to `select_path()` a path will be chosen based on
564        // the local address the connection initially landed on. Once a path is
565        // selected following calls to `select_path()` will return it, until it
566        // is reset at the start of the next write cycle.
567        //
568        // The path is then passed to `send_on_path()` which will only generate
569        // packets meant for that path, this way a single GSO buffer will only
570        // contain packets that belong to the same network path, which is
571        // required because the from/to addresses for each `sendmsg()` call
572        // apply to the whole GSO buffer.
573        let (from, to) = self.select_path(qconn).unzip();
574
575        match qconn.send_on_path(send_buf, from, to) {
576            Ok((packet_size, info)) => {
577                let _ = send_info.get_or_insert(info);
578
579                self.write_state.bytes_written += packet_size;
580                self.write_state.num_pkts += 1;
581
582                let from = send_info.as_ref().map(|info| info.from);
583                let to = send_info.as_ref().map(|info| info.to);
584
585                self.write_state.selected_path = from.zip(to);
586
587                self.write_state.has_pending_data = true;
588
589                Ok(packet_size)
590            },
591
592            Err(QuicheError::Done) => {
593                // Flush the current buffer to network. If no other path needs
594                // to be flushed to the network also yield the work loop task.
595                //
596                // Otherwise the write loop will start again and the next path
597                // will be selected.
598                let has_pending_paths = self.write_state.pending_paths.len() > 0;
599
600                // Keep writing if there are paths left to try.
601                self.write_state.has_pending_data = has_pending_paths;
602
603                Ok(0)
604            },
605
606            Err(e) => {
607                let error_code = if let Some(local_error) = qconn.local_error() {
608                    local_error.error_code
609                } else {
610                    let internal_error_code =
611                        quiche::WireErrorCode::InternalError as u64;
612                    let _ = qconn.close(false, internal_error_code, &[]);
613
614                    internal_error_code
615                };
616
617                self.audit_log_stats
618                    .set_sent_conn_close_transport_error_code(error_code as i64);
619
620                Err(Box::new(e))
621            },
622        }
623    }
624
625    async fn flush_buffer_to_socket(&mut self, send_buf: &[u8]) {
626        if self.write_state.bytes_written > 0 {
627            let current_send_buf = &send_buf[..self.write_state.bytes_written];
628
629            let (from, to) = self.write_state.selected_path.unzip();
630
631            let to = to.unwrap_or(self.cfg.peer_addr);
632            let from = from.filter(|_| self.cfg.with_pktinfo);
633
634            let send_res = if let (Some(udp_socket), true) =
635                (self.socket.as_udp_socket(), self.cfg.with_gso)
636            {
637                // Only UDP supports GSO.
638                send_to(
639                    udp_socket,
640                    to,
641                    from,
642                    current_send_buf,
643                    self.write_state.segment_size,
644                    self.write_state.tx_time,
645                    self.metrics
646                        .write_errors(labels::QuicWriteError::WouldBlock),
647                    self.metrics.send_to_wouldblock_duration_s(),
648                )
649                .await
650            } else {
651                self.socket.send_to(current_send_buf, to).await
652            };
653
654            #[cfg(feature = "perf-quic-listener-metrics")]
655            self.measure_complete_handshake_time();
656
657            match send_res {
658                Ok(n) =>
659                    if n < self.write_state.bytes_written {
660                        self.metrics
661                            .write_errors(labels::QuicWriteError::Partial)
662                            .inc();
663                    },
664
665                Err(_) => {
666                    self.metrics.write_errors(labels::QuicWriteError::Err).inc();
667                },
668            }
669        }
670    }
671
672    /// Process the incoming packet
673    fn process_incoming(
674        &mut self, qconn: &mut QuicheConnection, mut pkt: Incoming,
675    ) -> QuicResult<()> {
676        let recv_info = quiche::RecvInfo {
677            from: pkt.peer_addr,
678            to: pkt.local_addr,
679        };
680
681        if let Some(gro) = pkt.gro {
682            for dgram in pkt.buf.chunks_mut(gro as usize) {
683                qconn.recv(dgram, recv_info)?;
684            }
685        } else {
686            qconn.recv(&mut pkt.buf, recv_info)?;
687        }
688
689        Ok(())
690    }
691
692    /// When a connection is established, process application data, if not the
693    /// task is probably polled following a wakeup from boring, so we check
694    /// if quiche has any handshake packets to send.
695    async fn wait_for_data_or_handshake<A: ApplicationOverQuic>(
696        &mut self, qconn: &mut QuicheConnection, quic_application: &mut A,
697    ) -> QuicResult<()> {
698        if quic_application.should_act() {
699            // Poll the application to make progress.
700            //
701            // Once the connection has been established (i.e. the handshake is
702            // complete), we only poll the application.
703            //
704            // The exception is 0-RTT in TLS 1.3, where the full handshake is
705            // still in progress but we have 0-RTT keys to process early data.
706            // This means TLS callbacks might only be polled on the next timeout
707            // or when a packet is received from the peer.
708            quic_application.wait_for_data(qconn).await
709        } else {
710            // Poll quiche to make progress on handshake callbacks.
711            self.wait_for_quiche(qconn, quic_application).await
712        }
713    }
714
715    /// Check if Quiche has any packets to send and flush them to socket.
716    ///
717    /// # Example
718    ///
719    /// This function can be used, for example, to drive an asynchronous TLS
720    /// handshake. Each call to `gather_data_from_quiche_conn` attempts to
721    /// progress the handshake via a call to `quiche::Connection.send()` -
722    /// once one of the `gather_data_from_quiche_conn()` calls writes to the
723    /// send buffer, we flush it to the network socket.
724    async fn wait_for_quiche<App: ApplicationOverQuic>(
725        &mut self, qconn: &mut QuicheConnection, app: &mut App,
726    ) -> QuicResult<()> {
727        let populate_send_buf = std::future::poll_fn(|_| {
728            match self.gather_data_from_quiche_conn(qconn, app.buffer()) {
729                Ok(bytes_written) => {
730                    // We need to avoid consecutive calls to gather(), which write
731                    // data to the buffer, without a flush().
732                    // If we don't avoid those consecutive calls, we end
733                    // up overwriting data in the buffer or unnecessarily waiting
734                    // for more calls to drive_handshake()
735                    // before calling the handshake complete.
736                    if bytes_written == 0 && self.write_state.bytes_written == 0 {
737                        Poll::Pending
738                    } else {
739                        Poll::Ready(Ok(()))
740                    }
741                },
742                _ => Poll::Ready(Err(quiche::Error::TlsFail)),
743            }
744        })
745        .await;
746
747        if populate_send_buf.is_err() {
748            return Err(Box::new(quiche::Error::TlsFail));
749        }
750
751        self.flush_buffer_to_socket(app.buffer()).await;
752
753        Ok(())
754    }
755}
756
757pub struct Running<Tx, M, A> {
758    pub(crate) params: IoWorkerParams<Tx, M>,
759    pub(crate) context: ConnectionStageContext<A>,
760    pub(crate) qconn: QuicheConnection,
761}
762
763impl<Tx, M, A> Running<Tx, M, A> {
764    pub fn ssl(&mut self) -> &mut SslRef {
765        self.qconn.as_mut()
766    }
767}
768
769pub(crate) struct Closing<Tx, M, A> {
770    pub(crate) params: IoWorkerParams<Tx, M>,
771    pub(crate) context: ConnectionStageContext<A>,
772    pub(crate) work_loop_result: QuicResult<()>,
773    pub(crate) qconn: QuicheConnection,
774}
775
776pub enum RunningOrClosing<Tx, M, A> {
777    Running(Running<Tx, M, A>),
778    Closing(Closing<Tx, M, A>),
779}
780
781impl<Tx, M> IoWorker<Tx, M, Handshake>
782where
783    Tx: DatagramSocketSend + Send,
784    M: Metrics,
785{
786    pub(crate) async fn run<A>(
787        mut self, mut qconn: QuicheConnection, mut ctx: ConnectionStageContext<A>,
788    ) -> RunningOrClosing<Tx, M, A>
789    where
790        A: ApplicationOverQuic,
791    {
792        // This makes an assumption that the waker being set in ex_data is stable
793        // across the active task's lifetime. Moving a future that encompasses an
794        // async callback from this task across a channel, for example, will
795        // cause issues as this waker will then be stale and attempt to
796        // wake the wrong task.
797        std::future::poll_fn(|cx| {
798            let ssl = qconn.as_mut();
799            ssl.set_task_waker(Some(cx.waker().clone()));
800
801            Poll::Ready(())
802        })
803        .await;
804
805        #[cfg(target_os = "linux")]
806        if let Some(incoming) = ctx.in_pkt.as_mut() {
807            self.audit_log_stats
808                .set_initial_so_mark_data(incoming.so_mark_data.take());
809        }
810
811        let mut work_loop_result = self.work_loop(&mut qconn, &mut ctx).await;
812        if work_loop_result.is_ok() && qconn.is_closed() {
813            work_loop_result = Err(HandshakeError::ConnectionClosed.into());
814        }
815
816        if let Err(err) = &work_loop_result {
817            self.metrics.failed_handshakes(err.into()).inc();
818
819            return RunningOrClosing::Closing(Closing {
820                params: self.into(),
821                context: ctx,
822                work_loop_result,
823                qconn,
824            });
825        };
826
827        match self.on_conn_established(&mut qconn, &mut ctx.application) {
828            Ok(()) => RunningOrClosing::Running(Running {
829                params: self.into(),
830                context: ctx,
831                qconn,
832            }),
833            Err(e) => {
834                foundations::telemetry::log::warn!(
835                    "Handshake stage on_connection_established failed"; "error"=>%e
836                );
837
838                RunningOrClosing::Closing(Closing {
839                    params: self.into(),
840                    context: ctx,
841                    work_loop_result,
842                    qconn,
843                })
844            },
845        }
846    }
847
848    fn on_conn_established<App: ApplicationOverQuic>(
849        &mut self, qconn: &mut QuicheConnection, driver: &mut App,
850    ) -> QuicResult<()> {
851        // Only calculate the QUIC handshake duration and call the driver's
852        // on_conn_established hook if this is the first time
853        // is_established == true.
854        if self.audit_log_stats.transport_handshake_duration_us() == -1 {
855            self.conn_stage.handshake_info.set_elapsed();
856            let handshake_info = &self.conn_stage.handshake_info;
857
858            self.audit_log_stats
859                .set_transport_handshake_duration(handshake_info.elapsed());
860
861            driver.on_conn_established(qconn, handshake_info)?;
862        }
863
864        if let Some(cid) = self.cfg.pending_cid.take() {
865            self.unmap_cid(cid);
866        }
867
868        Ok(())
869    }
870}
871
872impl<Tx, M, S> From<IoWorker<Tx, M, S>> for IoWorkerParams<Tx, M> {
873    fn from(value: IoWorker<Tx, M, S>) -> Self {
874        Self {
875            socket: value.socket,
876            shutdown_tx: value.shutdown_tx,
877            cfg: value.cfg,
878            audit_log_stats: value.audit_log_stats,
879            write_state: value.write_state,
880            conn_map_cmd_tx: value.conn_map_cmd_tx,
881            cid_generator: value.cid_generator,
882            #[cfg(feature = "perf-quic-listener-metrics")]
883            init_rx_time: value.init_rx_time,
884            metrics: value.metrics,
885        }
886    }
887}
888
889impl<Tx, M> IoWorker<Tx, M, RunningApplication>
890where
891    Tx: DatagramSocketSend + Send,
892    M: Metrics,
893{
894    pub(crate) async fn run<A: ApplicationOverQuic>(
895        mut self, mut qconn: QuicheConnection, mut ctx: ConnectionStageContext<A>,
896    ) -> Closing<Tx, M, A> {
897        // Perform a single call to process_reads()/process_writes(),
898        // unconditionally, to ensure that any application data (e.g.
899        // STREAM frames or datagrams) processed by the Handshake
900        // stage are properly passed to the application.
901        if let Err(e) = self.conn_stage.on_read(true, &mut qconn, &mut ctx) {
902            return Closing {
903                params: self.into(),
904                context: ctx,
905                work_loop_result: Err(e),
906                qconn,
907            };
908        };
909
910        let work_loop_result = self.work_loop(&mut qconn, &mut ctx).await;
911
912        Closing {
913            params: self.into(),
914            context: ctx,
915            work_loop_result,
916            qconn,
917        }
918    }
919}
920
921impl<Tx, M> IoWorker<Tx, M, Close>
922where
923    Tx: DatagramSocketSend + Send,
924    M: Metrics,
925{
926    pub(crate) async fn close<A: ApplicationOverQuic>(
927        mut self, qconn: &mut QuicheConnection,
928        ctx: &mut ConnectionStageContext<A>,
929    ) {
930        if self.conn_stage.work_loop_result.is_ok() &&
931            self.bw_estimator.max_bandwidth > 0
932        {
933            let metrics = &self.metrics;
934
935            metrics
936                .max_bandwidth_mbps()
937                .observe(self.bw_estimator.max_bandwidth as f64 * 1e-6);
938
939            metrics
940                .max_loss_pct()
941                .observe(self.bw_estimator.max_loss_pct as f64 * 100.);
942        }
943
944        if ctx.application.should_act() {
945            ctx.application.on_conn_close(
946                qconn,
947                &self.metrics,
948                &self.conn_stage.work_loop_result,
949            );
950        }
951
952        // TODO: this assumes that the tidy_up operation can be completed in one
953        // send (ignoring flow/congestion control constraints). We should
954        // guarantee that it gets sent by doublechecking the
955        // gathered/flushed byte totals and retry if they don't match.
956        let _ = self.gather_data_from_quiche_conn(qconn, ctx.buffer());
957        self.flush_buffer_to_socket(ctx.buffer()).await;
958
959        *ctx.stats.lock().unwrap() = QuicConnectionStats::from_conn(qconn);
960
961        if let Some(err) = qconn.peer_error() {
962            if err.is_app {
963                self.audit_log_stats
964                    .set_recvd_conn_close_application_error_code(
965                        err.error_code as _,
966                    );
967            } else {
968                self.audit_log_stats
969                    .set_recvd_conn_close_transport_error_code(
970                        err.error_code as _,
971                    );
972            }
973        }
974
975        if let Some(err) = qconn.local_error() {
976            if err.is_app {
977                self.audit_log_stats
978                    .set_sent_conn_close_application_error_code(
979                        err.error_code as _,
980                    );
981            } else {
982                self.audit_log_stats
983                    .set_sent_conn_close_transport_error_code(
984                        err.error_code as _,
985                    );
986            }
987        }
988
989        self.close_connection(qconn);
990
991        if let Err(work_loop_error) = self.conn_stage.work_loop_result {
992            self.audit_log_stats
993                .set_connection_close_reason(work_loop_error);
994        }
995    }
996
997    fn close_connection(&mut self, qconn: &mut QuicheConnection) {
998        if let Some(cid) = self.cfg.pending_cid.take() {
999            self.unmap_cid(cid);
1000        }
1001        while let Some(retired_cid) = qconn.retired_scid_next() {
1002            self.unmap_cid(retired_cid);
1003        }
1004        for cid in qconn.source_ids().cloned() {
1005            self.unmap_cid(cid.into_owned());
1006        }
1007
1008        self.metrics.connections_in_memory().dec();
1009    }
1010}
1011
1012/// Returns the minimum of `v1` and `v2`, ignoring `None`s.
1013fn min_of_some<T: Ord>(v1: Option<T>, v2: Option<T>) -> Option<T> {
1014    match (v1, v2) {
1015        (Some(a), Some(b)) => Some(a.min(b)),
1016        (Some(v), _) | (_, Some(v)) => Some(v),
1017        (None, None) => None,
1018    }
1019}
1020
1021/// A Token which increment the skipped_mid_handshake_flush_count metric on
1022/// `Drop` unless it is marked complete.
1023struct TrackMidHandshakeFlush<M: Metrics> {
1024    complete: bool,
1025    metrics: M,
1026}
1027
1028impl<M: Metrics> TrackMidHandshakeFlush<M> {
1029    fn new(metrics: M) -> Self {
1030        Self {
1031            complete: false,
1032            metrics,
1033        }
1034    }
1035
1036    fn mark_complete(&mut self) {
1037        self.complete = true;
1038    }
1039}
1040
1041impl<M: Metrics> Drop for TrackMidHandshakeFlush<M> {
1042    fn drop(&mut self) {
1043        if !self.complete {
1044            self.metrics.skipped_mid_handshake_flush_count().inc();
1045        }
1046    }
1047}
1048
1049fn random_u128() -> u128 {
1050    let mut buf = [0; 16];
1051    boring::rand::rand_bytes(&mut buf).expect("boring's RAND_bytes never fails");
1052    u128::from_ne_bytes(buf)
1053}