tokio_quiche/quic/io/
worker.rs

1// Copyright (C) 2025, Cloudflare, Inc.
2// All rights reserved.
3//
4// Redistribution and use in source and binary forms, with or without
5// modification, are permitted provided that the following conditions are
6// met:
7//
8//     * Redistributions of source code must retain the above copyright notice,
9//       this list of conditions and the following disclaimer.
10//
11//     * Redistributions in binary form must reproduce the above copyright
12//       notice, this list of conditions and the following disclaimer in the
13//       documentation and/or other materials provided with the distribution.
14//
15// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
16// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
17// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
19// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
27use std::net::SocketAddr;
28use std::ops::ControlFlow;
29use std::sync::Arc;
30use std::task::Poll;
31use std::time::Duration;
32use std::time::Instant;
33#[cfg(feature = "perf-quic-listener-metrics")]
34use std::time::SystemTime;
35
36use super::connection_stage::Close;
37use super::connection_stage::ConnectionStage;
38use super::connection_stage::ConnectionStageContext;
39use super::connection_stage::Handshake;
40use super::connection_stage::RunningApplication;
41use super::gso::*;
42use super::utilization_estimator::BandwidthReporter;
43
44use crate::metrics::labels;
45use crate::metrics::Metrics;
46use crate::quic::connection::ApplicationOverQuic;
47use crate::quic::connection::HandshakeError;
48use crate::quic::connection::Incoming;
49use crate::quic::connection::QuicConnectionStats;
50use crate::quic::router::ConnectionMapCommand;
51use crate::quic::QuicheConnection;
52use crate::QuicResult;
53
54use boring::ssl::SslRef;
55use datagram_socket::DatagramSocketSend;
56use datagram_socket::DatagramSocketSendExt;
57use datagram_socket::MaybeConnectedSocket;
58use datagram_socket::QuicAuditStats;
59use foundations::telemetry::log;
60use quiche::ConnectionId;
61use quiche::Error as QuicheError;
62use quiche::SendInfo;
63use tokio::select;
64use tokio::sync::mpsc;
65use tokio::time;
66
67// Number of incoming packets to be buffered in the incoming channel.
68pub(crate) const INCOMING_QUEUE_SIZE: usize = 2048;
69
70// Check if there are any incoming packets while sending data every this number
71// of sent packets
72pub(crate) const CHECK_INCOMING_QUEUE_RATIO: usize = INCOMING_QUEUE_SIZE / 16;
73
74const RELEASE_TIMER_THRESHOLD: Duration = Duration::from_micros(250);
75
76/// Stop queuing GSO packets, if packet size is below this threshold.
77const GSO_THRESHOLD: usize = 1_000;
78
79pub struct WriterConfig {
80    pub pending_cid: Option<ConnectionId<'static>>,
81    pub peer_addr: SocketAddr,
82    pub local_addr: SocketAddr,
83    pub with_gso: bool,
84    pub pacing_offload: bool,
85    pub with_pktinfo: bool,
86}
87
88#[derive(Default)]
89pub(crate) struct WriteState {
90    conn_established: bool,
91    bytes_written: usize,
92    segment_size: usize,
93    num_pkts: usize,
94    tx_time: Option<Instant>,
95    has_pending_data: bool,
96    // If pacer schedules packets too far into the future, we want to pause
97    // sending, until the future arrives
98    next_release_time: Option<Instant>,
99    // The selected source and destination addresses for the current write
100    // cycle.
101    selected_path: Option<(SocketAddr, SocketAddr)>,
102    // Iterator over the network paths that haven't been flushed yet.
103    pending_paths: quiche::SocketAddrIter,
104}
105
106pub(crate) struct IoWorkerParams<Tx, M> {
107    pub(crate) socket: MaybeConnectedSocket<Tx>,
108    pub(crate) shutdown_tx: mpsc::Sender<()>,
109    pub(crate) cfg: WriterConfig,
110    pub(crate) audit_log_stats: Arc<QuicAuditStats>,
111    pub(crate) write_state: WriteState,
112    pub(crate) conn_map_cmd_tx: mpsc::UnboundedSender<ConnectionMapCommand>,
113    #[cfg(feature = "perf-quic-listener-metrics")]
114    pub(crate) init_rx_time: Option<SystemTime>,
115    pub(crate) metrics: M,
116}
117
118pub(crate) struct IoWorker<Tx, M, S> {
119    socket: MaybeConnectedSocket<Tx>,
120    /// A field that signals to the listener task that the connection has gone
121    /// away (nothing is sent here, listener task just detects the sender
122    /// has dropped)
123    shutdown_tx: mpsc::Sender<()>,
124    cfg: WriterConfig,
125    audit_log_stats: Arc<QuicAuditStats>,
126    write_state: WriteState,
127    conn_map_cmd_tx: mpsc::UnboundedSender<ConnectionMapCommand>,
128    #[cfg(feature = "perf-quic-listener-metrics")]
129    init_rx_time: Option<SystemTime>,
130    metrics: M,
131    conn_stage: S,
132    bw_estimator: BandwidthReporter,
133}
134
135impl<Tx, M, S> IoWorker<Tx, M, S>
136where
137    Tx: DatagramSocketSend + Send,
138    M: Metrics,
139    S: ConnectionStage,
140{
141    pub(crate) fn new(params: IoWorkerParams<Tx, M>, conn_stage: S) -> Self {
142        let bw_estimator =
143            BandwidthReporter::new(params.metrics.utilized_bandwidth());
144
145        log::trace!("Creating IoWorker with stage: {conn_stage:?}");
146
147        Self {
148            socket: params.socket,
149            shutdown_tx: params.shutdown_tx,
150            cfg: params.cfg,
151            audit_log_stats: params.audit_log_stats,
152            write_state: params.write_state,
153            conn_map_cmd_tx: params.conn_map_cmd_tx,
154            #[cfg(feature = "perf-quic-listener-metrics")]
155            init_rx_time: params.init_rx_time,
156            metrics: params.metrics,
157            conn_stage,
158            bw_estimator,
159        }
160    }
161
162    async fn work_loop<A: ApplicationOverQuic>(
163        &mut self, qconn: &mut QuicheConnection,
164        ctx: &mut ConnectionStageContext<A>,
165    ) -> QuicResult<()> {
166        const DEFAULT_SLEEP: Duration = Duration::from_secs(60);
167        let mut current_deadline: Option<Instant> = None;
168        let sleep = time::sleep(DEFAULT_SLEEP);
169        tokio::pin!(sleep);
170
171        loop {
172            let now = Instant::now();
173
174            self.write_state.has_pending_data = true;
175
176            while self.write_state.has_pending_data {
177                let mut packets_sent = 0;
178
179                // Try to clear all received packets every so often, because
180                // incoming packets contain acks, and because the
181                // receive queue has a very limited size, once it is full incoming
182                // packets get stalled indefinitely
183                let mut did_recv = false;
184                while let Some(pkt) = ctx
185                    .in_pkt
186                    .take()
187                    .or_else(|| ctx.incoming_pkt_receiver.try_recv().ok())
188                {
189                    self.process_incoming(qconn, pkt)?;
190                    did_recv = true;
191                }
192
193                self.conn_stage.on_read(did_recv, qconn, ctx)?;
194
195                let can_release = match self.write_state.next_release_time {
196                    None => true,
197                    Some(next_release) =>
198                        next_release
199                            .checked_duration_since(now)
200                            .unwrap_or_default() <
201                            RELEASE_TIMER_THRESHOLD,
202                };
203
204                self.write_state.has_pending_data &= can_release;
205
206                while self.write_state.has_pending_data &&
207                    packets_sent < CHECK_INCOMING_QUEUE_RATIO
208                {
209                    self.gather_data_from_quiche_conn(qconn, ctx.buffer())?;
210
211                    // Break if the connection is closed
212                    if qconn.is_closed() {
213                        return Ok(());
214                    }
215
216                    self.flush_buffer_to_socket(ctx.buffer()).await;
217                    packets_sent += self.write_state.num_pkts;
218
219                    if let ControlFlow::Break(reason) =
220                        self.conn_stage.on_flush(qconn, ctx)
221                    {
222                        return reason;
223                    }
224                }
225            }
226
227            self.bw_estimator.update(qconn, now);
228
229            self.audit_log_stats
230                .set_max_bandwidth(self.bw_estimator.max_bandwidth);
231            self.audit_log_stats.set_max_loss_pct(
232                (self.bw_estimator.max_loss_pct * 100_f32).round() as u8,
233            );
234
235            let new_deadline = min_of_some(
236                qconn.timeout_instant(),
237                self.write_state.next_release_time,
238            );
239            let new_deadline =
240                min_of_some(new_deadline, self.conn_stage.wait_deadline());
241
242            if new_deadline != current_deadline {
243                current_deadline = new_deadline;
244
245                sleep
246                    .as_mut()
247                    .reset(new_deadline.unwrap_or(now + DEFAULT_SLEEP).into());
248            }
249
250            let incoming_recv = &mut ctx.incoming_pkt_receiver;
251            let application = &mut ctx.application;
252            select! {
253                biased;
254                () = &mut sleep => {
255                    // It's very important that we keep the timeout arm at the top of this loop so
256                    // that we poll it every time we need to. Since this is a biased `select!`, if
257                    // we put this behind another arm, we could theoretically starve the sleep arm
258                    // and hang connections.
259                    //
260                    // See https://docs.rs/tokio/latest/tokio/macro.select.html#fairness for more
261                    qconn.on_timeout();
262
263                    self.write_state.next_release_time = None;
264                    current_deadline = None;
265                    sleep.as_mut().reset((now + DEFAULT_SLEEP).into());
266                }
267                Some(pkt) = incoming_recv.recv() => ctx.in_pkt = Some(pkt),
268                // TODO(erittenhouse): would be nice to decouple wait_for_data from the
269                // application, but wait_for_quiche relies on IOW methods, so we can't write a
270                // default implementation for ConnectionStage
271                status = self.wait_for_data_or_handshake(qconn, application) => status?,
272            };
273
274            if let ControlFlow::Break(reason) = self.conn_stage.post_wait(qconn) {
275                return reason;
276            }
277        }
278    }
279
280    #[cfg(feature = "perf-quic-listener-metrics")]
281    fn measure_complete_handshake_time(&mut self) {
282        if let Some(init_rx_time) = self.init_rx_time.take() {
283            if let Ok(delta) = init_rx_time.elapsed() {
284                self.metrics
285                    .handshake_time_seconds(
286                        labels::QuicHandshakeStage::HandshakeResponse,
287                    )
288                    .observe(delta.as_nanos() as u64);
289            }
290        }
291    }
292
293    fn gather_data_from_quiche_conn(
294        &mut self, qconn: &mut QuicheConnection, send_buf: &mut [u8],
295    ) -> QuicResult<usize> {
296        let mut segment_size = None;
297        let mut send_info = None;
298
299        self.write_state.num_pkts = 0;
300        self.write_state.bytes_written = 0;
301
302        self.write_state.selected_path = None;
303
304        let now = Instant::now();
305
306        let send_buf = {
307            let trunc = UDP_MAX_GSO_PACKET_SIZE.min(send_buf.len());
308            &mut send_buf[..trunc]
309        };
310
311        #[cfg(feature = "gcongestion")]
312        let gcongestion_enabled = true;
313
314        #[cfg(not(feature = "gcongestion"))]
315        let gcongestion_enabled = qconn.gcongestion_enabled().unwrap_or(false);
316
317        let initial_release_decision = if gcongestion_enabled {
318            let initial_release_decision = qconn
319                .get_next_release_time()
320                .filter(|_| self.pacing_enabled(qconn));
321
322            if let Some(future_release_time) =
323                initial_release_decision.as_ref().and_then(|v| v.time(now))
324            {
325                let max_into_fut = qconn.max_release_into_future();
326
327                if future_release_time.duration_since(now) >= max_into_fut {
328                    self.write_state.next_release_time =
329                        Some(now + max_into_fut.mul_f32(0.8));
330                    self.write_state.has_pending_data = false;
331                    return Ok(0);
332                }
333            }
334
335            initial_release_decision
336        } else {
337            None
338        };
339
340        let buffer_write_outcome = loop {
341            let outcome = self.write_packet_to_buffer(
342                qconn,
343                send_buf,
344                &mut send_info,
345                segment_size,
346            );
347
348            let packet_size = match outcome {
349                Ok(0) => break Ok(0),
350
351                Ok(bytes_written) => bytes_written,
352
353                Err(e) => break Err(e),
354            };
355
356            // Flush to network after generating a single packet when GSO
357            // is disabled.
358            if !self.cfg.with_gso {
359                break outcome;
360            }
361
362            #[cfg(not(feature = "gcongestion"))]
363            let max_send_size = if !gcongestion_enabled {
364                // Only call qconn.send_quantum when !gcongestion_enabled.
365                tune_max_send_size(
366                    segment_size,
367                    qconn.send_quantum(),
368                    send_buf.len(),
369                )
370            } else {
371                usize::MAX
372            };
373
374            #[cfg(feature = "gcongestion")]
375            let max_send_size = usize::MAX;
376
377            // If segment_size is known, update the maximum of
378            // GSO sender buffer size to the multiple of
379            // segment_size.
380            let buffer_is_full = self.write_state.num_pkts ==
381                UDP_MAX_SEGMENT_COUNT ||
382                self.write_state.bytes_written >= max_send_size;
383
384            if buffer_is_full {
385                break outcome;
386            }
387
388            // Flush to network when the newly generated packet size is
389            // different from previously written packet, as GSO needs packets
390            // to have the same size, except for the last one in the buffer.
391            // The last packet may be smaller than the previous size.
392            match segment_size {
393                Some(size)
394                    if packet_size != size || packet_size < GSO_THRESHOLD =>
395                    break outcome,
396                None => segment_size = Some(packet_size),
397                _ => (),
398            }
399
400            if gcongestion_enabled {
401                // If the release time of next packet is different, or it can't be
402                // part of a burst, start the next batch
403                if let Some(initial_release_decision) = initial_release_decision {
404                    match qconn.get_next_release_time() {
405                        Some(release)
406                            if release.can_burst() ||
407                                release.time_eq(
408                                    &initial_release_decision,
409                                    now,
410                                ) => {},
411                        _ => break outcome,
412                    }
413                }
414            }
415        };
416
417        let tx_time = if gcongestion_enabled {
418            initial_release_decision
419                .filter(|_| self.pacing_enabled(qconn))
420                // Return the time from the release decision if release_decision.time > now, else None.
421                .and_then(|v| v.time(now))
422        } else {
423            send_info
424                .filter(|_| self.pacing_enabled(qconn))
425                .map(|v| v.at)
426        };
427
428        self.write_state.conn_established = qconn.is_established();
429        self.write_state.tx_time = tx_time;
430        self.write_state.segment_size =
431            segment_size.unwrap_or(self.write_state.bytes_written);
432
433        if !gcongestion_enabled {
434            if let Some(time) = tx_time {
435                const DEFAULT_MAX_INTO_FUTURE: Duration =
436                    Duration::from_millis(1);
437                if time
438                    .checked_duration_since(now)
439                    .map(|d| d > DEFAULT_MAX_INTO_FUTURE)
440                    .unwrap_or(false)
441                {
442                    self.write_state.next_release_time =
443                        Some(now + DEFAULT_MAX_INTO_FUTURE.mul_f32(0.8));
444                    self.write_state.has_pending_data = false;
445                    return Ok(0);
446                }
447            }
448        }
449
450        buffer_write_outcome
451    }
452
453    /// Selects a network path, if none already selected.
454    ///
455    /// This will return the first path available in the write state's
456    /// `pending_paths` iterator. If that is empty a new iterator will be
457    /// created by querying quiche itself.
458    ///
459    /// Note that the connection's statically configured local address will be
460    /// used to query quiche for available paths, so this can't handle multiple
461    /// local addresses currently.
462    fn select_path(
463        &mut self, qconn: &QuicheConnection,
464    ) -> Option<(SocketAddr, SocketAddr)> {
465        if self.write_state.selected_path.is_some() {
466            return self.write_state.selected_path;
467        }
468
469        let from = self.cfg.local_addr;
470
471        // Initialize paths iterator.
472        if self.write_state.pending_paths.len() == 0 {
473            self.write_state.pending_paths = qconn.paths_iter(from);
474        }
475
476        let to = self.write_state.pending_paths.next()?;
477
478        Some((from, to))
479    }
480
481    #[cfg(not(feature = "gcongestion"))]
482    fn pacing_enabled(&self, qconn: &QuicheConnection) -> bool {
483        self.cfg.pacing_offload && qconn.pacing_enabled()
484    }
485
486    #[cfg(feature = "gcongestion")]
487    fn pacing_enabled(&self, _qconn: &QuicheConnection) -> bool {
488        self.cfg.pacing_offload
489    }
490
491    fn write_packet_to_buffer(
492        &mut self, qconn: &mut QuicheConnection, send_buf: &mut [u8],
493        send_info: &mut Option<SendInfo>, segment_size: Option<usize>,
494    ) -> QuicResult<usize> {
495        let mut send_buf = &mut send_buf[self.write_state.bytes_written..];
496        if send_buf.len() > segment_size.unwrap_or(usize::MAX) {
497            // Never let the buffer be longer than segment size, for GSO to
498            // function properly.
499            send_buf = &mut send_buf[..segment_size.unwrap_or(usize::MAX)];
500        }
501
502        // On the first call to `select_path()` a path will be chosen based on
503        // the local address the connection initially landed on. Once a path is
504        // selected following calls to `select_path()` will return it, until it
505        // is reset at the start of the next write cycle.
506        //
507        // The path is then passed to `send_on_path()` which will only generate
508        // packets meant for that path, this way a single GSO buffer will only
509        // contain packets that belong to the same network path, which is
510        // required because the from/to addresses for each `sendmsg()` call
511        // apply to the whole GSO buffer.
512        let (from, to) = self.select_path(qconn).unzip();
513
514        match qconn.send_on_path(send_buf, from, to) {
515            Ok((packet_size, info)) => {
516                let _ = send_info.get_or_insert(info);
517
518                self.write_state.bytes_written += packet_size;
519                self.write_state.num_pkts += 1;
520
521                let from = send_info.as_ref().map(|info| info.from);
522                let to = send_info.as_ref().map(|info| info.to);
523
524                self.write_state.selected_path = from.zip(to);
525
526                self.write_state.has_pending_data = true;
527
528                Ok(packet_size)
529            },
530
531            Err(QuicheError::Done) => {
532                // Flush the current buffer to network. If no other path needs
533                // to be flushed to the network also yield the work loop task.
534                //
535                // Otherwise the write loop will start again and the next path
536                // will be selected.
537                let has_pending_paths = self.write_state.pending_paths.len() > 0;
538
539                // Keep writing if there are paths left to try.
540                self.write_state.has_pending_data = has_pending_paths;
541
542                Ok(0)
543            },
544
545            Err(e) => {
546                let error_code = if let Some(local_error) = qconn.local_error() {
547                    local_error.error_code
548                } else {
549                    let internal_error_code =
550                        quiche::WireErrorCode::InternalError as u64;
551                    let _ = qconn.close(false, internal_error_code, &[]);
552
553                    internal_error_code
554                };
555
556                self.audit_log_stats
557                    .set_sent_conn_close_transport_error_code(error_code as i64);
558
559                Err(Box::new(e))
560            },
561        }
562    }
563
564    async fn flush_buffer_to_socket(&mut self, send_buf: &[u8]) {
565        if self.write_state.bytes_written > 0 {
566            let current_send_buf = &send_buf[..self.write_state.bytes_written];
567
568            let (from, to) = self.write_state.selected_path.unzip();
569
570            let to = to.unwrap_or(self.cfg.peer_addr);
571            let from = from.filter(|_| self.cfg.with_pktinfo);
572
573            let send_res = if let (Some(udp_socket), true) =
574                (self.socket.as_udp_socket(), self.cfg.with_gso)
575            {
576                // Only UDP supports GSO.
577                send_to(
578                    udp_socket,
579                    to,
580                    from,
581                    current_send_buf,
582                    self.write_state.segment_size,
583                    self.write_state.tx_time,
584                    self.metrics
585                        .write_errors(labels::QuicWriteError::WouldBlock),
586                )
587                .await
588            } else {
589                self.socket.send_to(current_send_buf, to).await
590            };
591
592            #[cfg(feature = "perf-quic-listener-metrics")]
593            self.measure_complete_handshake_time();
594
595            match send_res {
596                Ok(n) =>
597                    if n < self.write_state.bytes_written {
598                        self.metrics
599                            .write_errors(labels::QuicWriteError::Partial)
600                            .inc();
601                    },
602
603                Err(_) => {
604                    self.metrics.write_errors(labels::QuicWriteError::Err).inc();
605                },
606            }
607        }
608    }
609
610    /// Process the incoming packet
611    fn process_incoming(
612        &mut self, qconn: &mut QuicheConnection, mut pkt: Incoming,
613    ) -> QuicResult<()> {
614        let recv_info = quiche::RecvInfo {
615            from: pkt.peer_addr,
616            to: pkt.local_addr,
617        };
618
619        if let Some(gro) = pkt.gro {
620            for dgram in pkt.buf.chunks_mut(gro as usize) {
621                qconn.recv(dgram, recv_info)?;
622            }
623        } else {
624            qconn.recv(&mut pkt.buf, recv_info)?;
625        }
626
627        Ok(())
628    }
629
630    /// When a connection is established, process application data, if not the
631    /// task is probably polled following a wakeup from boring, so we check
632    /// if quiche has any handshake packets to send.
633    async fn wait_for_data_or_handshake<A: ApplicationOverQuic>(
634        &mut self, qconn: &mut QuicheConnection, quic_application: &mut A,
635    ) -> QuicResult<()> {
636        if quic_application.should_act() {
637            quic_application.wait_for_data(qconn).await
638        } else {
639            self.wait_for_quiche(qconn, quic_application).await
640        }
641    }
642
643    /// Check if Quiche has any packets to send and flush them to socket.
644    ///
645    /// # Example
646    ///
647    /// This function can be used, for example, to drive an asynchronous TLS
648    /// handshake. Each call to `gather_data_from_quiche_conn` attempts to
649    /// progress the handshake via a call to `quiche::Connection.send()` -
650    /// once one of the `gather_data_from_quiche_conn()` calls writes to the
651    /// send buffer, we flush it to the network socket.
652    async fn wait_for_quiche<App: ApplicationOverQuic>(
653        &mut self, qconn: &mut QuicheConnection, app: &mut App,
654    ) -> QuicResult<()> {
655        let populate_send_buf = std::future::poll_fn(|_| {
656            match self.gather_data_from_quiche_conn(qconn, app.buffer()) {
657                Ok(bytes_written) => {
658                    // We need to avoid consecutive calls to gather(), which write
659                    // data to the buffer, without a flush().
660                    // If we don't avoid those consecutive calls, we end
661                    // up overwriting data in the buffer or unnecessarily waiting
662                    // for more calls to drive_handshake()
663                    // before calling the handshake complete.
664                    if bytes_written == 0 && self.write_state.bytes_written == 0 {
665                        Poll::Pending
666                    } else {
667                        Poll::Ready(Ok(()))
668                    }
669                },
670                _ => Poll::Ready(Err(quiche::Error::TlsFail)),
671            }
672        })
673        .await;
674
675        if populate_send_buf.is_err() {
676            return Err(Box::new(quiche::Error::TlsFail));
677        }
678
679        self.flush_buffer_to_socket(app.buffer()).await;
680
681        Ok(())
682    }
683}
684
685pub struct Running<Tx, M, A> {
686    pub(crate) params: IoWorkerParams<Tx, M>,
687    pub(crate) context: ConnectionStageContext<A>,
688    pub(crate) qconn: QuicheConnection,
689}
690
691impl<Tx, M, A> Running<Tx, M, A> {
692    pub fn ssl(&mut self) -> &mut SslRef {
693        self.qconn.as_mut()
694    }
695}
696
697pub(crate) struct Closing<Tx, M, A> {
698    pub(crate) params: IoWorkerParams<Tx, M>,
699    pub(crate) context: ConnectionStageContext<A>,
700    pub(crate) work_loop_result: QuicResult<()>,
701    pub(crate) qconn: QuicheConnection,
702}
703
704pub enum RunningOrClosing<Tx, M, A> {
705    Running(Running<Tx, M, A>),
706    Closing(Closing<Tx, M, A>),
707}
708
709impl<Tx, M> IoWorker<Tx, M, Handshake>
710where
711    Tx: DatagramSocketSend + Send,
712    M: Metrics,
713{
714    pub(crate) async fn run<A>(
715        mut self, mut qconn: QuicheConnection, mut ctx: ConnectionStageContext<A>,
716    ) -> RunningOrClosing<Tx, M, A>
717    where
718        A: ApplicationOverQuic,
719    {
720        // This makes an assumption that the waker being set in ex_data is stable
721        // accross the active task's lifetime. Moving a future that encompasses an
722        // async callback from this task accross a channel, for example, will
723        // cause issues as this waker will then be stale and attempt to
724        // wake the wrong task.
725        std::future::poll_fn(|cx| {
726            let ssl = qconn.as_mut();
727            ssl.set_task_waker(Some(cx.waker().clone()));
728
729            Poll::Ready(())
730        })
731        .await;
732
733        let mut work_loop_result = self.work_loop(&mut qconn, &mut ctx).await;
734        if work_loop_result.is_ok() && qconn.is_closed() {
735            work_loop_result = Err(HandshakeError::ConnectionClosed.into());
736        }
737
738        if let Err(err) = &work_loop_result {
739            self.metrics.failed_handshakes(err.into()).inc();
740
741            return RunningOrClosing::Closing(Closing {
742                params: self.into(),
743                context: ctx,
744                work_loop_result,
745                qconn,
746            });
747        };
748
749        match self.on_conn_established(&mut qconn, &mut ctx.application) {
750            Ok(()) => RunningOrClosing::Running(Running {
751                params: self.into(),
752                context: ctx,
753                qconn,
754            }),
755            Err(e) => {
756                foundations::telemetry::log::warn!(
757                    "Handshake stage on_connection_established failed"; "error"=>%e
758                );
759
760                RunningOrClosing::Closing(Closing {
761                    params: self.into(),
762                    context: ctx,
763                    work_loop_result,
764                    qconn,
765                })
766            },
767        }
768    }
769
770    fn on_conn_established<App: ApplicationOverQuic>(
771        &mut self, qconn: &mut QuicheConnection, driver: &mut App,
772    ) -> QuicResult<()> {
773        // Only calculate the QUIC handshake duration and call the driver's
774        // on_conn_established hook if this is the first time
775        // is_established == true.
776        if self.audit_log_stats.transport_handshake_duration_us() == -1 {
777            self.conn_stage.handshake_info.set_elapsed();
778            let handshake_info = &self.conn_stage.handshake_info;
779
780            self.audit_log_stats
781                .set_transport_handshake_duration(handshake_info.elapsed());
782
783            driver.on_conn_established(qconn, handshake_info)?;
784        }
785
786        if let Some(cid) = self.cfg.pending_cid.take() {
787            let _ = self
788                .conn_map_cmd_tx
789                .send(ConnectionMapCommand::UnmapCid(cid));
790        }
791
792        Ok(())
793    }
794}
795
796impl<Tx, M, S> From<IoWorker<Tx, M, S>> for IoWorkerParams<Tx, M> {
797    fn from(value: IoWorker<Tx, M, S>) -> Self {
798        Self {
799            socket: value.socket,
800            shutdown_tx: value.shutdown_tx,
801            cfg: value.cfg,
802            audit_log_stats: value.audit_log_stats,
803            write_state: value.write_state,
804            conn_map_cmd_tx: value.conn_map_cmd_tx,
805            #[cfg(feature = "perf-quic-listener-metrics")]
806            init_rx_time: value.init_rx_time,
807            metrics: value.metrics,
808        }
809    }
810}
811
812impl<Tx, M> IoWorker<Tx, M, RunningApplication>
813where
814    Tx: DatagramSocketSend + Send,
815    M: Metrics,
816{
817    pub(crate) async fn run<A: ApplicationOverQuic>(
818        mut self, mut qconn: QuicheConnection, mut ctx: ConnectionStageContext<A>,
819    ) -> Closing<Tx, M, A> {
820        // Perform a single call to process_reads()/process_writes(),
821        // unconditionally, to ensure that any application data (e.g.
822        // STREAM frames or datagrams) processed by the Handshake
823        // stage are properly passed to the application.
824        if let Err(e) = self.conn_stage.on_read(true, &mut qconn, &mut ctx) {
825            return Closing {
826                params: self.into(),
827                context: ctx,
828                work_loop_result: Err(e),
829                qconn,
830            };
831        };
832
833        let work_loop_result = self.work_loop(&mut qconn, &mut ctx).await;
834
835        Closing {
836            params: self.into(),
837            context: ctx,
838            work_loop_result,
839            qconn,
840        }
841    }
842}
843
844impl<Tx, M> IoWorker<Tx, M, Close>
845where
846    Tx: DatagramSocketSend + Send,
847    M: Metrics,
848{
849    pub(crate) async fn close<A: ApplicationOverQuic>(
850        mut self, qconn: &mut QuicheConnection,
851        ctx: &mut ConnectionStageContext<A>,
852    ) {
853        if self.conn_stage.work_loop_result.is_ok() &&
854            self.bw_estimator.max_bandwidth > 0
855        {
856            let metrics = &self.metrics;
857
858            metrics
859                .max_bandwidth_mbps()
860                .observe(self.bw_estimator.max_bandwidth as f64 * 1e-6);
861
862            metrics
863                .max_loss_pct()
864                .observe(self.bw_estimator.max_loss_pct as f64 * 100.);
865        }
866
867        if ctx.application.should_act() {
868            ctx.application.on_conn_close(
869                qconn,
870                &self.metrics,
871                &self.conn_stage.work_loop_result,
872            );
873        }
874
875        // TODO: this assumes that the tidy_up operation can be completed in one
876        // send (ignoring flow/congestion control constraints). We should
877        // guarantee that it gets sent by doublechecking the
878        // gathered/flushed byte totals and retry if they don't match.
879        let _ = self.gather_data_from_quiche_conn(qconn, ctx.buffer());
880        self.flush_buffer_to_socket(ctx.buffer()).await;
881
882        *ctx.stats.lock().unwrap() = QuicConnectionStats::from_conn(qconn);
883
884        if let Some(err) = qconn.peer_error() {
885            if err.is_app {
886                self.audit_log_stats
887                    .set_recvd_conn_close_application_error_code(
888                        err.error_code as _,
889                    );
890            } else {
891                self.audit_log_stats
892                    .set_recvd_conn_close_transport_error_code(
893                        err.error_code as _,
894                    );
895            }
896        }
897
898        self.close_connection(qconn);
899
900        if let Err(work_loop_error) = self.conn_stage.work_loop_result {
901            self.audit_log_stats
902                .set_connection_close_reason(work_loop_error);
903        }
904    }
905
906    fn close_connection(&mut self, qconn: &QuicheConnection) {
907        let scid = qconn.source_id().into_owned();
908
909        if let Some(cid) = self.cfg.pending_cid.take() {
910            let _ = self
911                .conn_map_cmd_tx
912                .send(ConnectionMapCommand::UnmapCid(cid));
913        }
914
915        let _ = self
916            .conn_map_cmd_tx
917            .send(ConnectionMapCommand::RemoveScid(scid));
918
919        self.metrics.connections_in_memory().dec();
920    }
921}
922
923/// Returns the minimum of `v1` and `v2`, ignoring `None`s.
924fn min_of_some<T: Ord>(v1: Option<T>, v2: Option<T>) -> Option<T> {
925    match (v1, v2) {
926        (Some(a), Some(b)) => Some(a.min(b)),
927        (Some(v), _) | (_, Some(v)) => Some(v),
928        (None, None) => None,
929    }
930}