Skip to main content

tokio_quiche/quic/io/
worker.rs

1// Copyright (C) 2025, Cloudflare, Inc.
2// All rights reserved.
3//
4// Redistribution and use in source and binary forms, with or without
5// modification, are permitted provided that the following conditions are
6// met:
7//
8//     * Redistributions of source code must retain the above copyright notice,
9//       this list of conditions and the following disclaimer.
10//
11//     * Redistributions in binary form must reproduce the above copyright
12//       notice, this list of conditions and the following disclaimer in the
13//       documentation and/or other materials provided with the distribution.
14//
15// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
16// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
17// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
19// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
27use std::net::SocketAddr;
28use std::ops::ControlFlow;
29use std::sync::Arc;
30use std::task::Poll;
31use std::time::Duration;
32use std::time::Instant;
33#[cfg(feature = "perf-quic-listener-metrics")]
34use std::time::SystemTime;
35
36use super::connection_stage::Close;
37use super::connection_stage::ConnectionStage;
38use super::connection_stage::ConnectionStageContext;
39use super::connection_stage::Handshake;
40use super::connection_stage::RunningApplication;
41use super::gso::*;
42use super::utilization_estimator::BandwidthReporter;
43
44use crate::metrics::labels;
45use crate::metrics::Metrics;
46use crate::quic::connection::ApplicationOverQuic;
47use crate::quic::connection::HandshakeError;
48use crate::quic::connection::Incoming;
49use crate::quic::connection::QuicConnectionStats;
50use crate::quic::connection::SharedConnectionIdGenerator;
51use crate::quic::router::ConnectionMapCommand;
52use crate::quic::QuicheConnection;
53use crate::QuicResult;
54
55use boring::ssl::SslRef;
56use datagram_socket::DatagramSocketSend;
57use datagram_socket::DatagramSocketSendExt;
58use datagram_socket::MaybeConnectedSocket;
59use datagram_socket::QuicAuditStats;
60use foundations::telemetry::log;
61use quiche::ConnectionId;
62use quiche::Error as QuicheError;
63use quiche::SendInfo;
64use tokio::select;
65use tokio::sync::mpsc;
66use tokio::time;
67
68// Number of incoming packets to be buffered in the incoming channel.
69pub(crate) const INCOMING_QUEUE_SIZE: usize = 2048;
70
71// Check if there are any incoming packets while sending data every this number
72// of sent packets
73pub(crate) const CHECK_INCOMING_QUEUE_RATIO: usize = INCOMING_QUEUE_SIZE / 16;
74
75const RELEASE_TIMER_THRESHOLD: Duration = Duration::from_micros(250);
76
77/// Stop queuing GSO packets, if packet size is below this threshold.
78const GSO_THRESHOLD: usize = 1_000;
79
80pub struct WriterConfig {
81    pub pending_cid: Option<ConnectionId<'static>>,
82    pub peer_addr: SocketAddr,
83    pub local_addr: SocketAddr,
84    pub with_gso: bool,
85    pub pacing_offload: bool,
86    pub with_pktinfo: bool,
87}
88
89#[derive(Default)]
90pub(crate) struct WriteState {
91    conn_established: bool,
92    bytes_written: usize,
93    segment_size: usize,
94    num_pkts: usize,
95    tx_time: Option<Instant>,
96    has_pending_data: bool,
97    // If pacer schedules packets too far into the future, we want to pause
98    // sending, until the future arrives
99    next_release_time: Option<Instant>,
100    // The selected source and destination addresses for the current write
101    // cycle.
102    selected_path: Option<(SocketAddr, SocketAddr)>,
103    // Iterator over the network paths that haven't been flushed yet.
104    pending_paths: quiche::SocketAddrIter,
105}
106
107pub(crate) struct IoWorkerParams<Tx, M> {
108    pub(crate) socket: MaybeConnectedSocket<Tx>,
109    pub(crate) shutdown_tx: mpsc::Sender<()>,
110    pub(crate) cfg: WriterConfig,
111    pub(crate) audit_log_stats: Arc<QuicAuditStats>,
112    pub(crate) write_state: WriteState,
113    pub(crate) conn_map_cmd_tx: mpsc::UnboundedSender<ConnectionMapCommand>,
114    pub(crate) cid_generator: Option<SharedConnectionIdGenerator>,
115    #[cfg(feature = "perf-quic-listener-metrics")]
116    pub(crate) init_rx_time: Option<SystemTime>,
117    pub(crate) metrics: M,
118}
119
120pub(crate) struct IoWorker<Tx, M, S> {
121    socket: MaybeConnectedSocket<Tx>,
122    /// A field that signals to the listener task that the connection has gone
123    /// away (nothing is sent here, listener task just detects the sender
124    /// has dropped)
125    shutdown_tx: mpsc::Sender<()>,
126    cfg: WriterConfig,
127    audit_log_stats: Arc<QuicAuditStats>,
128    write_state: WriteState,
129    conn_map_cmd_tx: mpsc::UnboundedSender<ConnectionMapCommand>,
130    cid_generator: Option<SharedConnectionIdGenerator>,
131    #[cfg(feature = "perf-quic-listener-metrics")]
132    init_rx_time: Option<SystemTime>,
133    metrics: M,
134    conn_stage: S,
135    bw_estimator: BandwidthReporter,
136}
137
138impl<Tx, M, S> IoWorker<Tx, M, S>
139where
140    Tx: DatagramSocketSend + Send,
141    M: Metrics,
142    S: ConnectionStage,
143{
144    pub(crate) fn new(params: IoWorkerParams<Tx, M>, conn_stage: S) -> Self {
145        let bw_estimator =
146            BandwidthReporter::new(params.metrics.utilized_bandwidth());
147
148        log::trace!("Creating IoWorker with stage: {conn_stage:?}");
149
150        Self {
151            socket: params.socket,
152            shutdown_tx: params.shutdown_tx,
153            cfg: params.cfg,
154            audit_log_stats: params.audit_log_stats,
155            write_state: params.write_state,
156            conn_map_cmd_tx: params.conn_map_cmd_tx,
157            cid_generator: params.cid_generator,
158            #[cfg(feature = "perf-quic-listener-metrics")]
159            init_rx_time: params.init_rx_time,
160            metrics: params.metrics,
161            conn_stage,
162            bw_estimator,
163        }
164    }
165
166    fn fill_available_scids(&self, qconn: &mut QuicheConnection) {
167        if qconn.scids_left() == 0 {
168            return;
169        }
170        let Some(cid_generator) = self.cid_generator.as_deref() else {
171            return;
172        };
173
174        let current_cid = qconn.source_id().into_owned();
175        for _ in 0..qconn.scids_left() {
176            // We don't emit stateless resets, so any unguessable value is fine
177            let reset_token = random_u128();
178            let new_cid = cid_generator.new_connection_id();
179
180            if self
181                .conn_map_cmd_tx
182                .send(ConnectionMapCommand::MapCid {
183                    existing_cid: current_cid.clone(),
184                    new_cid: new_cid.clone(),
185                })
186                .is_err()
187            {
188                // Can't do anything if the connection map is gone
189                return;
190            }
191
192            if qconn.new_scid(&new_cid, reset_token, false).is_err() {
193                // This only fails if we have reached the CID limit already
194                return;
195            }
196        }
197    }
198
199    fn unmap_cid(&self, cid: ConnectionId<'static>) {
200        // If the connection map is gone, the ID is already "unmapped"
201        let _ = self
202            .conn_map_cmd_tx
203            .send(ConnectionMapCommand::UnmapCid(cid));
204    }
205
206    fn refresh_connection_ids(&self, qconn: &mut QuicheConnection) {
207        // Top up the connection's active CIDs
208        self.fill_available_scids(qconn);
209
210        // Remove retired CIDs from the ingress router
211        while let Some(retired_cid) = qconn.retired_scid_next() {
212            self.unmap_cid(retired_cid);
213        }
214    }
215
216    async fn work_loop<A: ApplicationOverQuic>(
217        &mut self, qconn: &mut QuicheConnection,
218        ctx: &mut ConnectionStageContext<A>,
219    ) -> QuicResult<()> {
220        const DEFAULT_SLEEP: Duration = Duration::from_secs(60);
221        let mut current_deadline: Option<Instant> = None;
222        let sleep = time::sleep(DEFAULT_SLEEP);
223        tokio::pin!(sleep);
224
225        loop {
226            let now = Instant::now();
227
228            self.write_state.has_pending_data = true;
229
230            while self.write_state.has_pending_data {
231                let mut packets_sent = 0;
232
233                // Try to clear all received packets every so often, because
234                // incoming packets contain acks, and because the
235                // receive queue has a very limited size, once it is full incoming
236                // packets get stalled indefinitely
237                let mut did_recv = false;
238                while let Some(pkt) = ctx
239                    .in_pkt
240                    .take()
241                    .or_else(|| ctx.incoming_pkt_receiver.try_recv().ok())
242                {
243                    self.process_incoming(qconn, pkt)?;
244                    did_recv = true;
245                }
246
247                self.conn_stage.on_read(did_recv, qconn, ctx)?;
248                self.refresh_connection_ids(qconn);
249
250                let can_release = match self.write_state.next_release_time {
251                    None => true,
252                    Some(next_release) =>
253                        next_release
254                            .checked_duration_since(now)
255                            .unwrap_or_default() <
256                            RELEASE_TIMER_THRESHOLD,
257                };
258
259                self.write_state.has_pending_data &= can_release;
260
261                while self.write_state.has_pending_data &&
262                    packets_sent < CHECK_INCOMING_QUEUE_RATIO
263                {
264                    self.gather_data_from_quiche_conn(qconn, ctx.buffer())?;
265
266                    // Break if the connection is closed
267                    if qconn.is_closed() {
268                        return Ok(());
269                    }
270
271                    let mut flush_operation_token =
272                        TrackMidHandshakeFlush::new(self.metrics.clone());
273
274                    self.flush_buffer_to_socket(ctx.buffer()).await;
275
276                    flush_operation_token.mark_complete();
277
278                    packets_sent += self.write_state.num_pkts;
279
280                    if let ControlFlow::Break(reason) =
281                        self.conn_stage.on_flush(qconn, ctx)
282                    {
283                        return reason;
284                    }
285                }
286            }
287
288            self.bw_estimator.update(qconn, now);
289
290            self.audit_log_stats
291                .set_max_bandwidth(self.bw_estimator.max_bandwidth);
292            self.audit_log_stats.set_max_loss_pct(
293                (self.bw_estimator.max_loss_pct * 100_f32).round() as u8,
294            );
295
296            let new_deadline = min_of_some(
297                qconn.timeout_instant(),
298                self.write_state.next_release_time,
299            );
300            let new_deadline =
301                min_of_some(new_deadline, self.conn_stage.wait_deadline());
302
303            if new_deadline != current_deadline {
304                current_deadline = new_deadline;
305
306                sleep
307                    .as_mut()
308                    .reset(new_deadline.unwrap_or(now + DEFAULT_SLEEP).into());
309            }
310
311            let incoming_recv = &mut ctx.incoming_pkt_receiver;
312            let application = &mut ctx.application;
313
314            select! {
315                biased;
316                () = &mut sleep => {
317                    // It's very important that we keep the timeout arm at the top of this loop so
318                    // that we poll it every time we need to. Since this is a biased `select!`, if
319                    // we put this behind another arm, we could theoretically starve the sleep arm
320                    // and hang connections.
321                    //
322                    // See https://docs.rs/tokio/latest/tokio/macro.select.html#fairness for more
323                    qconn.on_timeout();
324
325                    self.write_state.next_release_time = None;
326                    current_deadline = None;
327                    sleep.as_mut().reset((now + DEFAULT_SLEEP).into());
328                }
329                Some(pkt) = incoming_recv.recv() => ctx.in_pkt = Some(pkt),
330                directive = self.wait_for_data_or_handshake(qconn, application) => {
331                    match directive? {
332                        WaitForDataOrHandshakeDirective::Flush => {
333                            self.flush_buffer_to_socket(application.buffer()).await;
334                        }
335                        WaitForDataOrHandshakeDirective::Noop => {}
336                    }
337                },
338            };
339
340            if let ControlFlow::Break(reason) = self.conn_stage.post_wait(qconn) {
341                return reason;
342            }
343        }
344    }
345
346    #[cfg(feature = "perf-quic-listener-metrics")]
347    fn measure_complete_handshake_time(&mut self) {
348        if let Some(init_rx_time) = self.init_rx_time.take() {
349            if let Ok(delta) = init_rx_time.elapsed() {
350                self.metrics
351                    .handshake_time_seconds(
352                        labels::QuicHandshakeStage::HandshakeResponse,
353                    )
354                    .observe(delta.as_nanos() as u64);
355            }
356        }
357    }
358
359    fn gather_data_from_quiche_conn(
360        &mut self, qconn: &mut QuicheConnection, send_buf: &mut [u8],
361    ) -> QuicResult<usize> {
362        let mut segment_size = None;
363        let mut send_info = None;
364
365        self.write_state.num_pkts = 0;
366        self.write_state.bytes_written = 0;
367
368        self.write_state.selected_path = None;
369
370        let now = Instant::now();
371
372        let send_buf = {
373            let trunc = UDP_MAX_GSO_PACKET_SIZE.min(send_buf.len());
374            &mut send_buf[..trunc]
375        };
376
377        #[cfg(feature = "gcongestion")]
378        let gcongestion_enabled = true;
379
380        #[cfg(not(feature = "gcongestion"))]
381        let gcongestion_enabled = qconn.gcongestion_enabled().unwrap_or(false);
382
383        let initial_release_decision = if gcongestion_enabled {
384            let initial_release_decision = qconn
385                .get_next_release_time()
386                .filter(|_| self.pacing_enabled(qconn));
387
388            if let Some(future_release_time) =
389                initial_release_decision.as_ref().and_then(|v| v.time(now))
390            {
391                let max_into_fut = qconn.max_release_into_future();
392
393                if future_release_time.duration_since(now) >= max_into_fut {
394                    self.write_state.next_release_time =
395                        Some(now + max_into_fut.mul_f32(0.8));
396                    self.write_state.has_pending_data = false;
397                    return Ok(0);
398                }
399            }
400
401            initial_release_decision
402        } else {
403            None
404        };
405
406        let buffer_write_outcome = loop {
407            let outcome = self.write_packet_to_buffer(
408                qconn,
409                send_buf,
410                &mut send_info,
411                segment_size,
412            );
413
414            let packet_size = match outcome {
415                Ok(0) => break Ok(0),
416
417                Ok(bytes_written) => bytes_written,
418
419                Err(e) => break Err(e),
420            };
421
422            // Flush to network after generating a single packet when GSO
423            // is disabled.
424            if !self.cfg.with_gso {
425                break outcome;
426            }
427
428            #[cfg(not(feature = "gcongestion"))]
429            let max_send_size = if !gcongestion_enabled {
430                // Only call qconn.send_quantum when !gcongestion_enabled.
431                tune_max_send_size(
432                    segment_size,
433                    qconn.send_quantum(),
434                    send_buf.len(),
435                )
436            } else {
437                usize::MAX
438            };
439
440            #[cfg(feature = "gcongestion")]
441            let max_send_size = usize::MAX;
442
443            // If segment_size is known, update the maximum of
444            // GSO sender buffer size to the multiple of
445            // segment_size.
446            let buffer_is_full = self.write_state.num_pkts ==
447                UDP_MAX_SEGMENT_COUNT ||
448                self.write_state.bytes_written >= max_send_size;
449
450            if buffer_is_full {
451                break outcome;
452            }
453
454            // Flush to network when the newly generated packet size is
455            // different from previously written packet, as GSO needs packets
456            // to have the same size, except for the last one in the buffer.
457            // The last packet may be smaller than the previous size.
458            match segment_size {
459                Some(size)
460                    if packet_size != size || packet_size < GSO_THRESHOLD =>
461                    break outcome,
462                None => segment_size = Some(packet_size),
463                _ => (),
464            }
465
466            if gcongestion_enabled {
467                // If the release time of next packet is different, or it can't be
468                // part of a burst, start the next batch
469                if let Some(initial_release_decision) = initial_release_decision {
470                    match qconn.get_next_release_time() {
471                        Some(release)
472                            if release.can_burst() ||
473                                release.time_eq(
474                                    &initial_release_decision,
475                                    now,
476                                ) => {},
477                        _ => break outcome,
478                    }
479                }
480            }
481        };
482
483        let tx_time = if gcongestion_enabled {
484            initial_release_decision
485                .filter(|_| self.pacing_enabled(qconn))
486                // Return the time from the release decision if release_decision.time > now, else None.
487                .and_then(|v| v.time(now))
488        } else {
489            send_info
490                .filter(|_| self.pacing_enabled(qconn))
491                .map(|v| v.at)
492        };
493
494        self.write_state.conn_established = qconn.is_established();
495        self.write_state.tx_time = tx_time;
496        self.write_state.segment_size =
497            segment_size.unwrap_or(self.write_state.bytes_written);
498
499        if !gcongestion_enabled {
500            if let Some(time) = tx_time {
501                const DEFAULT_MAX_INTO_FUTURE: Duration =
502                    Duration::from_millis(1);
503                if time
504                    .checked_duration_since(now)
505                    .map(|d| d > DEFAULT_MAX_INTO_FUTURE)
506                    .unwrap_or(false)
507                {
508                    self.write_state.next_release_time =
509                        Some(now + DEFAULT_MAX_INTO_FUTURE.mul_f32(0.8));
510                    self.write_state.has_pending_data = false;
511                    return Ok(0);
512                }
513            }
514        }
515
516        buffer_write_outcome
517    }
518
519    /// Selects a network path, if none already selected.
520    ///
521    /// This will return the first path available in the write state's
522    /// `pending_paths` iterator. If that is empty a new iterator will be
523    /// created by querying quiche itself.
524    ///
525    /// Note that the connection's statically configured local address will be
526    /// used to query quiche for available paths, so this can't handle multiple
527    /// local addresses currently.
528    fn select_path(
529        &mut self, qconn: &QuicheConnection,
530    ) -> Option<(SocketAddr, SocketAddr)> {
531        if self.write_state.selected_path.is_some() {
532            return self.write_state.selected_path;
533        }
534
535        let from = self.cfg.local_addr;
536
537        // Initialize paths iterator.
538        if self.write_state.pending_paths.len() == 0 {
539            self.write_state.pending_paths = qconn.paths_iter(from);
540        }
541
542        let to = self.write_state.pending_paths.next()?;
543
544        Some((from, to))
545    }
546
547    #[cfg(not(feature = "gcongestion"))]
548    fn pacing_enabled(&self, qconn: &QuicheConnection) -> bool {
549        self.cfg.pacing_offload && qconn.pacing_enabled()
550    }
551
552    #[cfg(feature = "gcongestion")]
553    fn pacing_enabled(&self, _qconn: &QuicheConnection) -> bool {
554        self.cfg.pacing_offload
555    }
556
557    fn write_packet_to_buffer(
558        &mut self, qconn: &mut QuicheConnection, send_buf: &mut [u8],
559        send_info: &mut Option<SendInfo>, segment_size: Option<usize>,
560    ) -> QuicResult<usize> {
561        let mut send_buf = &mut send_buf[self.write_state.bytes_written..];
562        if send_buf.len() > segment_size.unwrap_or(usize::MAX) {
563            // Never let the buffer be longer than segment size, for GSO to
564            // function properly.
565            send_buf = &mut send_buf[..segment_size.unwrap_or(usize::MAX)];
566        }
567
568        // On the first call to `select_path()` a path will be chosen based on
569        // the local address the connection initially landed on. Once a path is
570        // selected following calls to `select_path()` will return it, until it
571        // is reset at the start of the next write cycle.
572        //
573        // The path is then passed to `send_on_path()` which will only generate
574        // packets meant for that path, this way a single GSO buffer will only
575        // contain packets that belong to the same network path, which is
576        // required because the from/to addresses for each `sendmsg()` call
577        // apply to the whole GSO buffer.
578        let (from, to) = self.select_path(qconn).unzip();
579
580        match qconn.send_on_path(send_buf, from, to) {
581            Ok((packet_size, info)) => {
582                let _ = send_info.get_or_insert(info);
583
584                self.write_state.bytes_written += packet_size;
585                self.write_state.num_pkts += 1;
586
587                let from = send_info.as_ref().map(|info| info.from);
588                let to = send_info.as_ref().map(|info| info.to);
589
590                self.write_state.selected_path = from.zip(to);
591
592                self.write_state.has_pending_data = true;
593
594                Ok(packet_size)
595            },
596
597            Err(QuicheError::Done) => {
598                // Flush the current buffer to network. If no other path needs
599                // to be flushed to the network also yield the work loop task.
600                //
601                // Otherwise the write loop will start again and the next path
602                // will be selected.
603                let has_pending_paths = self.write_state.pending_paths.len() > 0;
604
605                // Keep writing if there are paths left to try.
606                self.write_state.has_pending_data = has_pending_paths;
607
608                Ok(0)
609            },
610
611            Err(e) => {
612                let error_code = if let Some(local_error) = qconn.local_error() {
613                    local_error.error_code
614                } else {
615                    let internal_error_code =
616                        quiche::WireErrorCode::InternalError as u64;
617                    let _ = qconn.close(false, internal_error_code, &[]);
618
619                    internal_error_code
620                };
621
622                self.audit_log_stats
623                    .set_sent_conn_close_transport_error_code(error_code as i64);
624
625                Err(Box::new(e))
626            },
627        }
628    }
629
630    async fn flush_buffer_to_socket(&mut self, send_buf: &[u8]) {
631        if self.write_state.bytes_written > 0 {
632            let current_send_buf = &send_buf[..self.write_state.bytes_written];
633
634            let (from, to) = self.write_state.selected_path.unzip();
635
636            let to = to.unwrap_or(self.cfg.peer_addr);
637            let from = from.filter(|_| self.cfg.with_pktinfo);
638
639            let send_res = if let (Some(udp_socket), true) =
640                (self.socket.as_udp_socket(), self.cfg.with_gso)
641            {
642                // Only UDP supports GSO.
643                send_to(
644                    udp_socket,
645                    to,
646                    from,
647                    current_send_buf,
648                    self.write_state.segment_size,
649                    self.write_state.tx_time,
650                    self.metrics
651                        .write_errors(labels::QuicWriteError::WouldBlock),
652                    self.metrics.send_to_wouldblock_duration_s(),
653                )
654                .await
655            } else {
656                self.socket.send_to(current_send_buf, to).await
657            };
658
659            #[cfg(feature = "perf-quic-listener-metrics")]
660            self.measure_complete_handshake_time();
661
662            match send_res {
663                Ok(n) =>
664                    if n < self.write_state.bytes_written {
665                        self.metrics
666                            .write_errors(labels::QuicWriteError::Partial)
667                            .inc();
668                    },
669
670                Err(_) => {
671                    self.metrics.write_errors(labels::QuicWriteError::Err).inc();
672                },
673            }
674        }
675    }
676
677    /// Process the incoming packet
678    fn process_incoming(
679        &mut self, qconn: &mut QuicheConnection, mut pkt: Incoming,
680    ) -> QuicResult<()> {
681        let recv_info = quiche::RecvInfo {
682            from: pkt.peer_addr,
683            to: pkt.local_addr,
684        };
685
686        if let Some(gro) = pkt.gro {
687            for dgram in pkt.buf.chunks_mut(gro as usize) {
688                qconn.recv(dgram, recv_info)?;
689            }
690        } else {
691            qconn.recv(&mut pkt.buf, recv_info)?;
692        }
693
694        Ok(())
695    }
696
697    // When a connection is established, process application data, if not the task
698    // is probably polled following a wakeup from boring, so we check if quiche
699    // has any handshake packets to send.
700    //
701    // TODO(erittenhouse): would be nice to decouple wait_for_data from the
702    // application, but wait_for_quiche relies on IOW methods, so we can't write a
703    // default implementation for ConnectionStage
704    async fn wait_for_data_or_handshake<A: ApplicationOverQuic>(
705        &mut self, qconn: &mut QuicheConnection, quic_application: &mut A,
706    ) -> QuicResult<WaitForDataOrHandshakeDirective> {
707        if quic_application.should_act() {
708            // Poll the application to make progress.
709            //
710            // Once the connection has been established (i.e. the handshake is
711            // complete), we only poll the application.
712            //
713            // The exception is 0-RTT in TLS 1.3, where the full handshake is
714            // still in progress but we have 0-RTT keys to process early data.
715            // This means TLS callbacks might only be polled on the next timeout
716            // or when a packet is received from the peer.
717            quic_application.wait_for_data(qconn).await?;
718            Ok(WaitForDataOrHandshakeDirective::Noop)
719        } else {
720            // Poll quiche to make progress on handshake callbacks.
721            self.wait_for_quiche(qconn, quic_application.buffer())
722                .await?;
723            Ok(WaitForDataOrHandshakeDirective::Flush)
724        }
725    }
726
727    /// Check if Quiche has any packets to send
728    ///
729    /// If yes: fills buffer and updates self.write_state.bytes_written
730    /// If no: Poll::Pending
731    ///
732    /// # Example
733    ///
734    /// This function can be used, for example, to drive an asynchronous TLS
735    /// handshake. Each call to `gather_data_from_quiche_conn` attempts to
736    /// progress the handshake via a call to `quiche::Connection.send()` -
737    /// once one of the `gather_data_from_quiche_conn()` calls writes to the
738    /// send buffer, we signal to the caller which has to take care of flushing
739    async fn wait_for_quiche(
740        &mut self, qconn: &mut QuicheConnection, buffer: &mut [u8],
741    ) -> QuicResult<()> {
742        std::future::poll_fn(|_| {
743            match self.gather_data_from_quiche_conn(qconn, buffer) {
744                Ok(bytes_written) => {
745                    // We need to avoid consecutive calls to gather(), which write
746                    // data to the buffer, without a flush().
747                    // If we don't avoid those consecutive calls, we end
748                    // up overwriting data in the buffer or unnecessarily waiting
749                    // for more calls to drive_handshake()
750                    // before calling the handshake complete.
751                    if bytes_written == 0 && self.write_state.bytes_written == 0 {
752                        Poll::Pending
753                    } else {
754                        Poll::Ready(Ok(()))
755                    }
756                },
757                _ => Poll::Ready(Err(quiche::Error::TlsFail)),
758            }
759        })
760        .await?;
761        Ok(())
762    }
763}
764
765/// Whether caller of [`wait_for_data_or_handshake`] is required to
766/// call [`flush_buffer_to_socket`]
767#[must_use]
768enum WaitForDataOrHandshakeDirective {
769    Noop,
770    Flush,
771}
772
773pub struct Running<Tx, M, A> {
774    pub(crate) params: IoWorkerParams<Tx, M>,
775    pub(crate) context: ConnectionStageContext<A>,
776    /// See [`QuicConnectionParams::quiche_conn`].
777    pub(crate) qconn: Box<QuicheConnection>,
778}
779
780impl<Tx, M, A> Running<Tx, M, A> {
781    pub fn ssl(&mut self) -> &mut SslRef {
782        // Deref to pick `Connection::as_mut` over `Box::as_mut`.
783        (*self.qconn).as_mut()
784    }
785}
786
787pub(crate) struct Closing<Tx, M, A> {
788    pub(crate) params: IoWorkerParams<Tx, M>,
789    pub(crate) context: ConnectionStageContext<A>,
790    pub(crate) work_loop_result: QuicResult<()>,
791    /// See [`QuicConnectionParams::quiche_conn`].
792    pub(crate) qconn: Box<QuicheConnection>,
793}
794
795pub enum RunningOrClosing<Tx, M, A> {
796    Running(Running<Tx, M, A>),
797    Closing(Closing<Tx, M, A>),
798}
799
800impl<Tx, M> IoWorker<Tx, M, Handshake>
801where
802    Tx: DatagramSocketSend + Send,
803    M: Metrics,
804{
805    pub(crate) async fn run<A>(
806        mut self, mut qconn: Box<QuicheConnection>,
807        mut ctx: ConnectionStageContext<A>,
808    ) -> RunningOrClosing<Tx, M, A>
809    where
810        A: ApplicationOverQuic,
811    {
812        // This makes an assumption that the waker being set in ex_data is stable
813        // across the active task's lifetime. Moving a future that encompasses an
814        // async callback from this task across a channel, for example, will
815        // cause issues as this waker will then be stale and attempt to
816        // wake the wrong task.
817        std::future::poll_fn(|cx| {
818            // Deref to pick `Connection::as_mut` over `Box::as_mut`.
819            let ssl = (*qconn).as_mut();
820            ssl.set_task_waker(Some(cx.waker().clone()));
821
822            Poll::Ready(())
823        })
824        .await;
825
826        #[cfg(target_os = "linux")]
827        if let Some(incoming) = ctx.in_pkt.as_mut() {
828            self.audit_log_stats
829                .set_initial_so_mark_data(incoming.so_mark_data.take());
830        }
831
832        let mut work_loop_result = self.work_loop(&mut qconn, &mut ctx).await;
833        if work_loop_result.is_ok() && qconn.is_closed() {
834            work_loop_result = Err(HandshakeError::ConnectionClosed.into());
835        }
836
837        if let Err(err) = &work_loop_result {
838            self.metrics.failed_handshakes(err.into()).inc();
839
840            return RunningOrClosing::Closing(Closing {
841                params: self.into(),
842                context: ctx,
843                work_loop_result,
844                qconn,
845            });
846        };
847
848        match self.on_conn_established(&mut qconn, &mut ctx.application) {
849            Ok(()) => RunningOrClosing::Running(Running {
850                params: self.into(),
851                context: ctx,
852                qconn,
853            }),
854            Err(e) => {
855                foundations::telemetry::log::warn!(
856                    "Handshake stage on_connection_established failed"; "error"=>%e
857                );
858
859                RunningOrClosing::Closing(Closing {
860                    params: self.into(),
861                    context: ctx,
862                    work_loop_result,
863                    qconn,
864                })
865            },
866        }
867    }
868
869    fn on_conn_established<App: ApplicationOverQuic>(
870        &mut self, qconn: &mut QuicheConnection, driver: &mut App,
871    ) -> QuicResult<()> {
872        // Only calculate the QUIC handshake duration and call the driver's
873        // on_conn_established hook if this is the first time
874        // is_established == true.
875        if self.audit_log_stats.transport_handshake_duration_us() == -1 {
876            self.conn_stage.handshake_info.set_elapsed();
877            let handshake_info = &self.conn_stage.handshake_info;
878
879            self.audit_log_stats
880                .set_transport_handshake_duration(handshake_info.elapsed());
881
882            driver.on_conn_established(qconn, handshake_info)?;
883        }
884
885        if let Some(cid) = self.cfg.pending_cid.take() {
886            self.unmap_cid(cid);
887        }
888
889        Ok(())
890    }
891}
892
893impl<Tx, M, S> From<IoWorker<Tx, M, S>> for IoWorkerParams<Tx, M> {
894    fn from(value: IoWorker<Tx, M, S>) -> Self {
895        Self {
896            socket: value.socket,
897            shutdown_tx: value.shutdown_tx,
898            cfg: value.cfg,
899            audit_log_stats: value.audit_log_stats,
900            write_state: value.write_state,
901            conn_map_cmd_tx: value.conn_map_cmd_tx,
902            cid_generator: value.cid_generator,
903            #[cfg(feature = "perf-quic-listener-metrics")]
904            init_rx_time: value.init_rx_time,
905            metrics: value.metrics,
906        }
907    }
908}
909
910impl<Tx, M> IoWorker<Tx, M, RunningApplication>
911where
912    Tx: DatagramSocketSend + Send,
913    M: Metrics,
914{
915    pub(crate) async fn run<A: ApplicationOverQuic>(
916        mut self, mut qconn: Box<QuicheConnection>,
917        mut ctx: ConnectionStageContext<A>,
918    ) -> Closing<Tx, M, A> {
919        // Perform a single call to process_reads()/process_writes(),
920        // unconditionally, to ensure that any application data (e.g.
921        // STREAM frames or datagrams) processed by the Handshake
922        // stage are properly passed to the application.
923        if let Err(e) = self.conn_stage.on_read(true, &mut qconn, &mut ctx) {
924            return Closing {
925                params: self.into(),
926                context: ctx,
927                work_loop_result: Err(e),
928                qconn,
929            };
930        };
931
932        let work_loop_result = self.work_loop(&mut qconn, &mut ctx).await;
933
934        Closing {
935            params: self.into(),
936            context: ctx,
937            work_loop_result,
938            qconn,
939        }
940    }
941}
942
943impl<Tx, M> IoWorker<Tx, M, Close>
944where
945    Tx: DatagramSocketSend + Send,
946    M: Metrics,
947{
948    pub(crate) async fn close<A: ApplicationOverQuic>(
949        mut self, qconn: &mut QuicheConnection,
950        ctx: &mut ConnectionStageContext<A>,
951    ) {
952        if self.conn_stage.work_loop_result.is_ok() &&
953            self.bw_estimator.max_bandwidth > 0
954        {
955            let metrics = &self.metrics;
956
957            metrics
958                .max_bandwidth_mbps()
959                .observe(self.bw_estimator.max_bandwidth as f64 * 1e-6);
960
961            metrics
962                .max_loss_pct()
963                .observe(self.bw_estimator.max_loss_pct as f64 * 100.);
964        }
965
966        if ctx.application.should_act() {
967            ctx.application.on_conn_close(
968                qconn,
969                &self.metrics,
970                &self.conn_stage.work_loop_result,
971            );
972        }
973
974        // TODO: this assumes that the tidy_up operation can be completed in one
975        // send (ignoring flow/congestion control constraints). We should
976        // guarantee that it gets sent by doublechecking the
977        // gathered/flushed byte totals and retry if they don't match.
978        let _ = self.gather_data_from_quiche_conn(qconn, ctx.buffer());
979        self.flush_buffer_to_socket(ctx.buffer()).await;
980
981        *ctx.stats.lock().unwrap() = QuicConnectionStats::from_conn(qconn);
982
983        if let Some(err) = qconn.peer_error() {
984            if err.is_app {
985                self.audit_log_stats
986                    .set_recvd_conn_close_application_error_code(
987                        err.error_code as _,
988                    );
989            } else {
990                self.audit_log_stats
991                    .set_recvd_conn_close_transport_error_code(
992                        err.error_code as _,
993                    );
994            }
995        }
996
997        if let Some(err) = qconn.local_error() {
998            if err.is_app {
999                self.audit_log_stats
1000                    .set_sent_conn_close_application_error_code(
1001                        err.error_code as _,
1002                    );
1003            } else {
1004                self.audit_log_stats
1005                    .set_sent_conn_close_transport_error_code(
1006                        err.error_code as _,
1007                    );
1008            }
1009        }
1010
1011        self.close_connection(qconn);
1012
1013        if let Err(work_loop_error) = self.conn_stage.work_loop_result {
1014            self.audit_log_stats
1015                .set_connection_close_reason(work_loop_error);
1016        }
1017    }
1018
1019    fn close_connection(&mut self, qconn: &mut QuicheConnection) {
1020        if let Some(cid) = self.cfg.pending_cid.take() {
1021            self.unmap_cid(cid);
1022        }
1023        while let Some(retired_cid) = qconn.retired_scid_next() {
1024            self.unmap_cid(retired_cid);
1025        }
1026        for cid in qconn.source_ids().cloned() {
1027            self.unmap_cid(cid.into_owned());
1028        }
1029
1030        self.metrics.connections_in_memory().dec();
1031    }
1032}
1033
1034/// Returns the minimum of `v1` and `v2`, ignoring `None`s.
1035fn min_of_some<T: Ord>(v1: Option<T>, v2: Option<T>) -> Option<T> {
1036    match (v1, v2) {
1037        (Some(a), Some(b)) => Some(a.min(b)),
1038        (Some(v), _) | (_, Some(v)) => Some(v),
1039        (None, None) => None,
1040    }
1041}
1042
1043/// A Token which increment the skipped_mid_handshake_flush_count metric on
1044/// `Drop` unless it is marked complete.
1045struct TrackMidHandshakeFlush<M: Metrics> {
1046    complete: bool,
1047    metrics: M,
1048}
1049
1050impl<M: Metrics> TrackMidHandshakeFlush<M> {
1051    fn new(metrics: M) -> Self {
1052        Self {
1053            complete: false,
1054            metrics,
1055        }
1056    }
1057
1058    fn mark_complete(&mut self) {
1059        self.complete = true;
1060    }
1061}
1062
1063impl<M: Metrics> Drop for TrackMidHandshakeFlush<M> {
1064    fn drop(&mut self) {
1065        if !self.complete {
1066            self.metrics.skipped_mid_handshake_flush_count().inc();
1067        }
1068    }
1069}
1070
1071fn random_u128() -> u128 {
1072    let mut buf = [0; 16];
1073    boring::rand::rand_bytes(&mut buf).expect("boring's RAND_bytes never fails");
1074    u128::from_ne_bytes(buf)
1075}