tokio_quiche/quic/router/
mod.rs

1// Copyright (C) 2025, Cloudflare, Inc.
2// All rights reserved.
3//
4// Redistribution and use in source and binary forms, with or without
5// modification, are permitted provided that the following conditions are
6// met:
7//
8//     * Redistributions of source code must retain the above copyright notice,
9//       this list of conditions and the following disclaimer.
10//
11//     * Redistributions in binary form must reproduce the above copyright
12//       notice, this list of conditions and the following disclaimer in the
13//       documentation and/or other materials provided with the distribution.
14//
15// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
16// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
17// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
19// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
27pub(crate) mod acceptor;
28pub(crate) mod connector;
29
30use super::connection::ConnectionMap;
31use super::connection::HandshakeInfo;
32use super::connection::Incoming;
33use super::connection::InitialQuicConnection;
34use super::connection::QuicConnectionParams;
35use super::io::worker::WriterConfig;
36use super::QuicheConnection;
37use crate::buf_factory::BufFactory;
38use crate::buf_factory::PooledBuf;
39use crate::metrics::labels;
40use crate::metrics::quic_expensive_metrics_ip_reduce;
41use crate::metrics::Metrics;
42use crate::settings::Config;
43
44use datagram_socket::DatagramSocketRecv;
45use datagram_socket::DatagramSocketSend;
46use foundations::telemetry::log;
47#[cfg(target_os = "linux")]
48use foundations::telemetry::metrics::Counter;
49#[cfg(target_os = "linux")]
50use foundations::telemetry::metrics::TimeHistogram;
51#[cfg(target_os = "linux")]
52use libc::sockaddr_in;
53#[cfg(target_os = "linux")]
54use libc::sockaddr_in6;
55use quiche::ConnectionId;
56use quiche::Header;
57use quiche::MAX_CONN_ID_LEN;
58use std::default::Default;
59use std::future::Future;
60use std::io;
61use std::net::SocketAddr;
62use std::pin::Pin;
63use std::sync::Arc;
64use std::task::ready;
65use std::task::Context;
66use std::task::Poll;
67use std::time::Instant;
68use std::time::SystemTime;
69use task_killswitch::spawn_with_killswitch;
70use tokio::sync::mpsc;
71
72type ConnStream<Tx, M> = mpsc::Receiver<io::Result<InitialQuicConnection<Tx, M>>>;
73
74#[cfg(feature = "perf-quic-listener-metrics")]
75mod listener_stage_timer {
76    use foundations::telemetry::metrics::TimeHistogram;
77    use std::time::Instant;
78
79    pub(super) struct ListenerStageTimer {
80        start: Instant,
81        time_hist: TimeHistogram,
82    }
83
84    impl ListenerStageTimer {
85        pub(super) fn new(
86            start: Instant, time_hist: TimeHistogram,
87        ) -> ListenerStageTimer {
88            ListenerStageTimer { start, time_hist }
89        }
90    }
91
92    impl Drop for ListenerStageTimer {
93        fn drop(&mut self) {
94            self.time_hist
95                .observe((Instant::now() - self.start).as_nanos() as u64);
96        }
97    }
98}
99
100#[derive(Debug)]
101struct PollRecvData {
102    bytes: usize,
103    // The packet's source, e.g., the peer's address
104    src_addr: SocketAddr,
105    // The packet's original destination. If the original destination is
106    // different from the local listening address, this will be `None`.
107    dst_addr_override: Option<SocketAddr>,
108    rx_time: Option<SystemTime>,
109    gro: Option<u16>,
110}
111
112/// A message to the listener notifiying a mapping for a connection should be
113/// removed.
114pub enum ConnectionMapCommand {
115    UnmapCid(ConnectionId<'static>),
116    RemoveScid(ConnectionId<'static>),
117}
118
119/// An `InboundPacketRouter` maintains a map of quic connections and routes
120/// [`Incoming`] packets from the [recv half][rh] of a datagram socket to those
121/// connections or some quic initials handler.
122///
123/// [rh]: datagram_socket::DatagramSocketRecv
124///
125/// When a packet (or batch of packets) is received, the router will either
126/// route those packets to an established
127/// [`QuicConnection`](super::QuicConnection) or have a them handled by a
128/// `InitialPacketHandler` which either acts as a quic listener or
129/// quic connector, a server or client respectively.
130///
131/// If you only have a single connection, or if you need more control over the
132/// socket, use `QuicConnection` directly instead.
133pub struct InboundPacketRouter<Tx, Rx, M, I>
134where
135    Tx: DatagramSocketSend + Send + 'static,
136    M: Metrics,
137{
138    socket_tx: Arc<Tx>,
139    socket_rx: Rx,
140    local_addr: SocketAddr,
141    config: Config,
142    conns: ConnectionMap,
143    incoming_packet_handler: I,
144    shutdown_tx: Option<mpsc::Sender<()>>,
145    shutdown_rx: mpsc::Receiver<()>,
146    conn_map_cmd_tx: mpsc::UnboundedSender<ConnectionMapCommand>,
147    conn_map_cmd_rx: mpsc::UnboundedReceiver<ConnectionMapCommand>,
148    accept_sink: mpsc::Sender<io::Result<InitialQuicConnection<Tx, M>>>,
149    metrics: M,
150    #[cfg(target_os = "linux")]
151    udp_drop_count: u32,
152
153    #[cfg(target_os = "linux")]
154    reusable_cmsg_space: Vec<u8>,
155
156    current_buf: PooledBuf,
157
158    // We keep the metrics in here, to avoid cloning them each packet
159    #[cfg(target_os = "linux")]
160    metrics_handshake_time_seconds: TimeHistogram,
161    #[cfg(target_os = "linux")]
162    metrics_udp_drop_count: Counter,
163}
164
165impl<Tx, Rx, M, I> InboundPacketRouter<Tx, Rx, M, I>
166where
167    Tx: DatagramSocketSend + Send + 'static,
168    Rx: DatagramSocketRecv,
169    M: Metrics,
170    I: InitialPacketHandler,
171{
172    pub(crate) fn new(
173        config: Config, socket_tx: Arc<Tx>, socket_rx: Rx,
174        local_addr: SocketAddr, incoming_packet_handler: I, metrics: M,
175    ) -> (Self, ConnStream<Tx, M>) {
176        let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
177        let (accept_sink, accept_stream) = mpsc::channel(config.listen_backlog);
178        let (conn_map_cmd_tx, conn_map_cmd_rx) = mpsc::unbounded_channel();
179
180        (
181            InboundPacketRouter {
182                local_addr,
183                socket_tx,
184                socket_rx,
185                conns: ConnectionMap::default(),
186                incoming_packet_handler,
187                shutdown_tx: Some(shutdown_tx),
188                shutdown_rx,
189                conn_map_cmd_tx,
190                conn_map_cmd_rx,
191                accept_sink,
192                #[cfg(target_os = "linux")]
193                udp_drop_count: 0,
194                #[cfg(target_os = "linux")]
195                // Specify CMSG space for GRO, timestamp, drop count, IP_RECVORIGDSTADDR, and
196                // IPV6_RECVORIGDSTADDR. Even if they're not all currently used, the cmsg buffer
197                // may have been configured by a previous version of Tokio-Quiche with the socket
198                // re-used on graceful restart. As such, this vector should _only grow_, and care
199                // should be taken when adding new cmsgs.
200                reusable_cmsg_space: nix::cmsg_space!(u32, nix::sys::time::TimeSpec, u16, sockaddr_in, sockaddr_in6),
201                config,
202
203                current_buf: BufFactory::get_max_buf(),
204
205                #[cfg(target_os = "linux")]
206                metrics_handshake_time_seconds: metrics.handshake_time_seconds(labels::QuicHandshakeStage::QueueWaiting),
207                #[cfg(target_os = "linux")]
208                metrics_udp_drop_count: metrics.udp_drop_count(),
209
210                metrics,
211
212            },
213            accept_stream,
214        )
215    }
216
217    fn on_incoming(&mut self, mut incoming: Incoming) -> io::Result<()> {
218        #[cfg(feature = "perf-quic-listener-metrics")]
219        let start = std::time::Instant::now();
220
221        if let Some(dcid) = short_dcid(&incoming.buf) {
222            if let Some(ev_sender) = self.conns.get(&dcid) {
223                let _ = ev_sender.try_send(incoming);
224                return Ok(());
225            }
226        }
227
228        let hdr = Header::from_slice(&mut incoming.buf, MAX_CONN_ID_LEN)
229            .map_err(|e| match e {
230                quiche::Error::BufferTooShort | quiche::Error::InvalidPacket =>
231                    labels::QuicInvalidInitialPacketError::FailedToParse.into(),
232                e => io::Error::other(e),
233            })?;
234
235        if let Some(ev_sender) = self.conns.get(&hdr.dcid) {
236            let _ = ev_sender.try_send(incoming);
237            return Ok(());
238        }
239
240        #[cfg(feature = "perf-quic-listener-metrics")]
241        let _timer = listener_stage_timer::ListenerStageTimer::new(
242            start,
243            self.metrics.handshake_time_seconds(
244                labels::QuicHandshakeStage::HandshakeProtocol,
245            ),
246        );
247
248        if self.shutdown_tx.is_none() {
249            return Ok(());
250        }
251
252        let local_addr = incoming.local_addr;
253        let peer_addr = incoming.peer_addr;
254
255        #[cfg(feature = "perf-quic-listener-metrics")]
256        let init_rx_time = incoming.rx_time;
257
258        let new_connection = self.incoming_packet_handler.handle_initials(
259            incoming,
260            hdr,
261            self.config.as_mut(),
262        )?;
263
264        match new_connection {
265            Some(new_connection) => self.spawn_new_connection(
266                new_connection,
267                local_addr,
268                peer_addr,
269                #[cfg(feature = "perf-quic-listener-metrics")]
270                init_rx_time,
271            ),
272            None => Ok(()),
273        }
274    }
275
276    /// Creates a new [`QuicConnection`](super::QuicConnection) and spawns an
277    /// associated io worker.
278    fn spawn_new_connection(
279        &mut self, new_connection: NewConnection, local_addr: SocketAddr,
280        peer_addr: SocketAddr,
281        #[cfg(feature = "perf-quic-listener-metrics")] init_rx_time: Option<
282            SystemTime,
283        >,
284    ) -> io::Result<()> {
285        let NewConnection {
286            conn,
287            pending_cid,
288            handshake_start_time,
289            initial_pkt,
290        } = new_connection;
291
292        let Some(ref shutdown_tx) = self.shutdown_tx else {
293            // don't create new connections if we're shutting down.
294            return Ok(());
295        };
296        let Ok(send_permit) = self.accept_sink.try_reserve() else {
297            // drop the connection if the backlog is full. the client will retry.
298            return Err(
299                labels::QuicInvalidInitialPacketError::AcceptQueueOverflow.into(),
300            );
301        };
302
303        let scid = conn.source_id().into_owned();
304        let writer_cfg = WriterConfig {
305            peer_addr,
306            local_addr,
307            pending_cid: pending_cid.clone(),
308            with_gso: self.config.has_gso,
309            pacing_offload: self.config.pacing_offload,
310            with_pktinfo: if self.local_addr.is_ipv4() {
311                self.config.has_ippktinfo
312            } else {
313                self.config.has_ipv6pktinfo
314            },
315        };
316
317        let handshake_info = HandshakeInfo::new(
318            handshake_start_time,
319            self.config.handshake_timeout,
320        );
321
322        let conn = InitialQuicConnection::new(QuicConnectionParams {
323            writer_cfg,
324            initial_pkt,
325            shutdown_tx: shutdown_tx.clone(),
326            conn_map_cmd_tx: self.conn_map_cmd_tx.clone(),
327            scid: scid.clone(),
328            metrics: self.metrics.clone(),
329            #[cfg(feature = "perf-quic-listener-metrics")]
330            init_rx_time,
331            handshake_info,
332            quiche_conn: conn,
333            socket: Arc::clone(&self.socket_tx),
334            local_addr,
335            peer_addr,
336        });
337
338        conn.audit_log_stats
339            .set_transport_handshake_start(instant_to_system(
340                handshake_start_time,
341            ));
342
343        self.conns.insert(scid, &conn);
344
345        // Add the client-generated "pending" connection ID to the map as well.
346        //
347        // This is only required when client address validation is disabled.
348        // When validation is enabled, the client is already using the
349        // server-generated connection ID by the time we get here.
350        if let Some(pending_cid) = pending_cid {
351            self.conns.map_cid(pending_cid, &conn);
352        }
353
354        self.metrics.accepted_initial_packet_count().inc();
355        if self.config.enable_expensive_packet_count_metrics {
356            if let Some(peer_ip) =
357                quic_expensive_metrics_ip_reduce(conn.peer_addr().ip())
358            {
359                self.metrics
360                    .expensive_accepted_initial_packet_count(peer_ip)
361                    .inc();
362            }
363        }
364
365        send_permit.send(Ok(conn));
366        Ok(())
367    }
368}
369
370impl<Tx, Rx, M, I> InboundPacketRouter<Tx, Rx, M, I>
371where
372    Tx: DatagramSocketSend + Send + Sync + 'static,
373    Rx: DatagramSocketRecv,
374    M: Metrics,
375    I: InitialPacketHandler,
376{
377    /// [`InboundPacketRouter::poll_recv_from`] should be used if the underlying
378    /// system or socket does not support rx_time nor GRO.
379    fn poll_recv_from(
380        &mut self, cx: &mut Context<'_>,
381    ) -> Poll<io::Result<PollRecvData>> {
382        let mut buf = tokio::io::ReadBuf::new(&mut self.current_buf);
383        let addr = ready!(self.socket_rx.poll_recv_from(cx, &mut buf))?;
384        Poll::Ready(Ok(PollRecvData {
385            bytes: buf.filled().len(),
386            src_addr: addr,
387            rx_time: None,
388            gro: None,
389            dst_addr_override: None,
390        }))
391    }
392
393    fn poll_recv_and_rx_time(
394        &mut self, cx: &mut Context<'_>,
395    ) -> Poll<io::Result<PollRecvData>> {
396        #[cfg(not(target_os = "linux"))]
397        {
398            self.poll_recv_from(cx)
399        }
400
401        #[cfg(target_os = "linux")]
402        {
403            use nix::errno::Errno;
404            use nix::sys::socket::*;
405            use std::net::SocketAddrV4;
406            use std::net::SocketAddrV6;
407            use std::os::fd::AsRawFd;
408            use tokio::io::Interest;
409
410            let Some(udp_socket) = self.socket_rx.as_udp_socket() else {
411                // the given socket is not a UDP socket, fall back to the
412                // simple poll_recv_from.
413                return self.poll_recv_from(cx);
414            };
415
416            self.reusable_cmsg_space.clear();
417
418            loop {
419                let iov_s = &mut [io::IoSliceMut::new(&mut self.current_buf)];
420                match udp_socket.try_io(Interest::READABLE, || {
421                    recvmsg::<SockaddrStorage>(
422                        udp_socket.as_raw_fd(),
423                        iov_s,
424                        Some(&mut self.reusable_cmsg_space),
425                        MsgFlags::empty(),
426                    )
427                    .map_err(|x| x.into())
428                }) {
429                    Ok(r) => {
430                        let bytes = r.bytes;
431
432                        let address = match r.address {
433                            Some(inner) => inner,
434                            _ => return Poll::Ready(Err(Errno::EINVAL.into())),
435                        };
436
437                        let peer_addr = match address.family() {
438                            Some(AddressFamily::Inet) => SocketAddrV4::from(
439                                *address.as_sockaddr_in().unwrap(),
440                            )
441                            .into(),
442                            Some(AddressFamily::Inet6) => SocketAddrV6::from(
443                                *address.as_sockaddr_in6().unwrap(),
444                            )
445                            .into(),
446                            _ => {
447                                return Poll::Ready(Err(Errno::EINVAL.into()));
448                            },
449                        };
450
451                        let mut rx_time = None;
452                        let mut gro = None;
453                        let mut dst_addr_override = None;
454
455                        for cmsg in r.cmsgs() {
456                            match cmsg {
457                                ControlMessageOwned::RxqOvfl(c) => {
458                                    if c != self.udp_drop_count {
459                                        self.metrics_udp_drop_count.inc_by(
460                                            (c - self.udp_drop_count) as u64,
461                                        );
462                                        self.udp_drop_count = c;
463                                    }
464                                },
465                                ControlMessageOwned::ScmTimestampns(val) => {
466                                    rx_time = SystemTime::UNIX_EPOCH
467                                        .checked_add(val.into());
468                                    if let Some(delta) =
469                                        rx_time.and_then(|rx_time| {
470                                            rx_time.elapsed().ok()
471                                        })
472                                    {
473                                        self.metrics_handshake_time_seconds
474                                            .observe(delta.as_nanos() as u64);
475                                    }
476                                },
477                                ControlMessageOwned::UdpGroSegments(val) =>
478                                    gro = Some(val),
479                                ControlMessageOwned::Ipv4OrigDstAddr(val) => {
480                                    let source_addr = std::net::Ipv4Addr::from(
481                                        u32::to_be(val.sin_addr.s_addr),
482                                    );
483                                    let source_port = u16::to_be(val.sin_port);
484
485                                    let parsed_addr =
486                                        SocketAddr::V4(SocketAddrV4::new(
487                                            source_addr,
488                                            source_port,
489                                        ));
490
491                                    dst_addr_override = resolve_dst_addr(
492                                        &self.local_addr,
493                                        &parsed_addr,
494                                    );
495                                },
496                                ControlMessageOwned::Ipv6OrigDstAddr(val) => {
497                                    // Don't have to flip IPv6 bytes since it's a
498                                    // byte array, not a
499                                    // series of bytes parsed as a u32 as in the
500                                    // IPv4 case
501                                    let source_addr = std::net::Ipv6Addr::from(
502                                        val.sin6_addr.s6_addr,
503                                    );
504                                    let source_port = u16::to_be(val.sin6_port);
505                                    let source_flowinfo =
506                                        u32::to_be(val.sin6_flowinfo);
507                                    let source_scope =
508                                        u32::to_be(val.sin6_scope_id);
509
510                                    let parsed_addr =
511                                        SocketAddr::V6(SocketAddrV6::new(
512                                            source_addr,
513                                            source_port,
514                                            source_flowinfo,
515                                            source_scope,
516                                        ));
517
518                                    dst_addr_override = resolve_dst_addr(
519                                        &self.local_addr,
520                                        &parsed_addr,
521                                    );
522                                },
523                                ControlMessageOwned::Ipv4PacketInfo(_) |
524                                ControlMessageOwned::Ipv6PacketInfo(_) => {
525                                    // We only want the destination address from
526                                    // IP_RECVORIGDSTADDR, but we'll get these
527                                    // messages because
528                                    // we set IP_PKTINFO on the socket.
529                                },
530                                _ => {
531                                    return Poll::Ready(
532                                        Err(Errno::EINVAL.into()),
533                                    );
534                                },
535                            };
536                        }
537
538                        return Poll::Ready(Ok(PollRecvData {
539                            bytes,
540                            src_addr: peer_addr,
541                            dst_addr_override,
542                            rx_time,
543                            gro,
544                        }));
545                    },
546                    Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
547                        // NOTE: we manually poll the socket here to register
548                        // interest in the socket to become
549                        // writable for the given `cx`. Under the hood, tokio's
550                        // implementation just checks for
551                        // EWOULDBLOCK and if socket is busy registers provided
552                        // waker to be invoked when the
553                        // socket is free and consequently drive the event loop.
554                        ready!(udp_socket.poll_recv_ready(cx))?
555                    },
556                    Err(e) => return Poll::Ready(Err(e)),
557                }
558            }
559        }
560    }
561
562    fn handle_conn_map_commands(&mut self) {
563        while let Ok(req) = self.conn_map_cmd_rx.try_recv() {
564            match req {
565                ConnectionMapCommand::UnmapCid(cid) => self.conns.unmap_cid(&cid),
566                ConnectionMapCommand::RemoveScid(scid) =>
567                    self.conns.remove(&scid),
568            }
569        }
570    }
571}
572
573// Quickly extract the connection id of a short quic packet without allocating
574fn short_dcid(buf: &[u8]) -> Option<ConnectionId<'_>> {
575    let is_short_dcid = buf.first()? >> 7 == 0;
576
577    if is_short_dcid {
578        buf.get(1..1 + MAX_CONN_ID_LEN).map(ConnectionId::from_ref)
579    } else {
580        None
581    }
582}
583
584/// Converts an [`Instant`] to a [`SystemTime`], based on the current delta
585/// between both clocks.
586fn instant_to_system(ts: Instant) -> SystemTime {
587    let now = Instant::now();
588    let system_now = SystemTime::now();
589    if let Some(delta) = now.checked_duration_since(ts) {
590        return system_now - delta;
591    }
592
593    let delta = ts.checked_duration_since(now).expect("now < ts");
594    system_now + delta
595}
596
597/// Determine if we should store the destination address for a packet, based on
598/// an address parsed from a
599/// [`ControlMessageOwned`](nix::sys::socket::ControlMessageOwned).
600///
601/// This is to prevent overriding the destination address if the packet was
602/// originally addressed to `local`, as that would cause us to incorrectly
603/// address packets when sending.
604///
605/// Returns the parsed address if it should be stored.
606#[cfg(target_os = "linux")]
607fn resolve_dst_addr(
608    local: &SocketAddr, parsed: &SocketAddr,
609) -> Option<SocketAddr> {
610    if local != parsed {
611        return Some(*parsed);
612    }
613
614    None
615}
616
617impl<Tx, Rx, M, I> Future for InboundPacketRouter<Tx, Rx, M, I>
618where
619    Tx: DatagramSocketSend + Send + Sync + 'static,
620    Rx: DatagramSocketRecv + Unpin,
621    M: Metrics,
622    I: InitialPacketHandler + Unpin,
623{
624    type Output = io::Result<()>;
625
626    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
627        let server_addr = self.local_addr;
628
629        loop {
630            if let Err(error) = self.incoming_packet_handler.update(cx) {
631                // This is so rare that it's easier to spawn a separate task
632                let sender = self.accept_sink.clone();
633                spawn_with_killswitch(async move {
634                    let _ = sender.send(Err(error)).await;
635                });
636            }
637
638            match self.poll_recv_and_rx_time(cx) {
639                Poll::Ready(Ok(PollRecvData {
640                    bytes,
641                    src_addr: peer_addr,
642                    dst_addr_override,
643                    rx_time,
644                    gro,
645                })) => {
646                    let mut buf = std::mem::replace(
647                        &mut self.current_buf,
648                        BufFactory::get_max_buf(),
649                    );
650                    buf.truncate(bytes);
651
652                    let send_from = if let Some(dst_addr) = dst_addr_override {
653                        log::trace!("overriding local address"; "actual_local" => format!("{:?}", dst_addr), "configured_local" => format!("{:?}", server_addr));
654                        dst_addr
655                    } else {
656                        server_addr
657                    };
658
659                    let res = self.on_incoming(Incoming {
660                        peer_addr,
661                        local_addr: send_from,
662                        buf,
663                        rx_time,
664                        gro,
665                    });
666
667                    if let Err(e) = res {
668                        let err_type = initial_packet_error_type(&e);
669                        self.metrics
670                            .rejected_initial_packet_count(err_type.clone())
671                            .inc();
672
673                        if self.config.enable_expensive_packet_count_metrics {
674                            if let Some(peer_ip) =
675                                quic_expensive_metrics_ip_reduce(peer_addr.ip())
676                            {
677                                self.metrics
678                                    .expensive_rejected_initial_packet_count(
679                                        err_type.clone(),
680                                        peer_ip,
681                                    )
682                                    .inc();
683                            }
684                        }
685
686                        if matches!(
687                            err_type,
688                            labels::QuicInvalidInitialPacketError::Unexpected
689                        ) {
690                            // don't block packet routing on errors
691                            let _ = self.accept_sink.try_send(Err(e));
692                        }
693                    }
694                },
695
696                Poll::Ready(Err(e)) => {
697                    log::error!("Incoming packet router encountered recvmsg error"; "error" => e);
698                    continue;
699                },
700
701                Poll::Pending => {
702                    // Check whether any connections are still active
703                    if self.shutdown_tx.is_some() && self.accept_sink.is_closed()
704                    {
705                        self.shutdown_tx = None;
706                    }
707
708                    if self.shutdown_rx.poll_recv(cx).is_ready() {
709                        return Poll::Ready(Ok(()));
710                    }
711
712                    // Process any incoming connection map signals and handle them
713                    self.handle_conn_map_commands();
714
715                    return Poll::Pending;
716                },
717            }
718        }
719    }
720}
721
722/// Categorizes errors that are returned when handling packets which are not
723/// associated with an established connection. The purpose is to suppress
724/// logging of 'expected' errors (e.g. junk data sent to the UDP socket) to
725/// prevent DoS.
726fn initial_packet_error_type(
727    e: &io::Error,
728) -> labels::QuicInvalidInitialPacketError {
729    Some(e)
730        .filter(|e| e.kind() == io::ErrorKind::Other)
731        .and_then(io::Error::get_ref)
732        .and_then(|e| e.downcast_ref())
733        .map_or(
734            labels::QuicInvalidInitialPacketError::Unexpected,
735            Clone::clone,
736        )
737}
738
739/// An [`InitialPacketHandler`] handles unknown quic initials and processes
740/// them; generally accepting new connections (acting as a server), or
741/// establishing a connection to a server (acting as a client). An
742/// [`InboundPacketRouter`] holds an instance of this trait and routes
743/// [`Incoming`] packets to it when it receives initials.
744///
745/// The handler produces [`quiche::Connection`]s which are then turned into
746/// [`QuicConnection`](super::QuicConnection), IoWorker pair.
747pub trait InitialPacketHandler {
748    fn update(&mut self, _ctx: &mut Context<'_>) -> io::Result<()> {
749        Ok(())
750    }
751
752    fn handle_initials(
753        &mut self, incoming: Incoming, hdr: Header<'static>,
754        quiche_config: &mut quiche::Config,
755    ) -> io::Result<Option<NewConnection>>;
756}
757
758/// A [`NewConnection`] describes a new [`quiche::Connection`] that can be
759/// driven by an io worker.
760pub struct NewConnection {
761    conn: QuicheConnection,
762    pending_cid: Option<ConnectionId<'static>>,
763    initial_pkt: Option<Incoming>,
764    /// When the handshake started. Should be called before [`quiche::accept`]
765    /// or [`quiche::connect`].
766    handshake_start_time: Instant,
767}
768
769// TODO: the router module is private so we can't move these to /tests
770// TODO: Rewrite tests to be Windows compatible
771#[cfg(all(test, unix))]
772mod tests {
773    use super::acceptor::ConnectionAcceptor;
774    use super::acceptor::ConnectionAcceptorConfig;
775    use super::*;
776
777    use crate::http3::settings::Http3Settings;
778    use crate::metrics::DefaultMetrics;
779    use crate::quic::connection::SimpleConnectionIdGenerator;
780    use crate::settings::Config;
781    use crate::settings::Hooks;
782    use crate::settings::QuicSettings;
783    use crate::settings::TlsCertificatePaths;
784    use crate::socket::SocketCapabilities;
785    use crate::ConnectionParams;
786    use crate::ServerH3Driver;
787
788    use datagram_socket::MAX_DATAGRAM_SIZE;
789    use h3i::actions::h3::Action;
790    use std::sync::Arc;
791    use std::time::Duration;
792    use tokio::net::UdpSocket;
793    use tokio::time;
794
795    const TEST_CERT_FILE: &str = concat!(
796        env!("CARGO_MANIFEST_DIR"),
797        "/",
798        "../quiche/examples/cert.crt"
799    );
800    const TEST_KEY_FILE: &str = concat!(
801        env!("CARGO_MANIFEST_DIR"),
802        "/",
803        "../quiche/examples/cert.key"
804    );
805
806    fn test_connect(host_port: String) {
807        let h3i_config = h3i::config::Config::new()
808            .with_host_port("test.com".to_string())
809            .with_idle_timeout(2000)
810            .with_connect_to(host_port)
811            .verify_peer(false)
812            .build()
813            .unwrap();
814
815        let conn_close = h3i::quiche::ConnectionError {
816            is_app: true,
817            error_code: h3i::quiche::WireErrorCode::NoError as _,
818            reason: Vec::new(),
819        };
820        let actions = vec![Action::ConnectionClose { error: conn_close }];
821
822        let _ = h3i::client::sync_client::connect(h3i_config, actions, None);
823    }
824
825    #[tokio::test]
826    async fn test_timeout() {
827        // Configure a short idle timeout to speed up connection reclamation as
828        // quiche doesn't support time mocking
829        let quic_settings = QuicSettings {
830            max_idle_timeout: Some(Duration::from_millis(1)),
831            max_recv_udp_payload_size: MAX_DATAGRAM_SIZE,
832            max_send_udp_payload_size: MAX_DATAGRAM_SIZE,
833            ..Default::default()
834        };
835
836        let tls_cert_settings = TlsCertificatePaths {
837            cert: TEST_CERT_FILE,
838            private_key: TEST_KEY_FILE,
839            kind: crate::settings::CertificateKind::X509,
840        };
841
842        let params = ConnectionParams::new_server(
843            quic_settings,
844            tls_cert_settings,
845            Hooks::default(),
846        );
847        let config = Config::new(&params, SocketCapabilities::default()).unwrap();
848
849        let socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
850        let local_addr = socket.local_addr().unwrap();
851        let host_port = local_addr.to_string();
852        let socket_tx = Arc::new(socket);
853        let socket_rx = Arc::clone(&socket_tx);
854
855        let acceptor = ConnectionAcceptor::new(
856            ConnectionAcceptorConfig {
857                disable_client_ip_validation: config.disable_client_ip_validation,
858                qlog_dir: config.qlog_dir.clone(),
859                keylog_file: config
860                    .keylog_file
861                    .as_ref()
862                    .and_then(|f| f.try_clone().ok()),
863                #[cfg(target_os = "linux")]
864                with_pktinfo: false,
865            },
866            Arc::clone(&socket_tx),
867            0,
868            Default::default(),
869            Box::new(SimpleConnectionIdGenerator),
870            DefaultMetrics,
871        );
872
873        let (socket_driver, mut incoming) = InboundPacketRouter::new(
874            config,
875            socket_tx,
876            socket_rx,
877            local_addr,
878            acceptor,
879            DefaultMetrics,
880        );
881        tokio::spawn(socket_driver);
882
883        // Start a request and drop it after connection establishment
884        std::thread::spawn(move || test_connect(host_port));
885
886        // Wait for a new connection
887        time::pause();
888
889        let (h3_driver, _) = ServerH3Driver::new(Http3Settings::default());
890        let conn = incoming.recv().await.unwrap().unwrap();
891        let drop_check = conn.incoming_ev_sender.clone();
892        let _conn = conn.start(h3_driver);
893
894        // Poll the incoming until the connection is dropped
895        time::advance(Duration::new(30, 0)).await;
896        time::resume();
897
898        // NOTE: this is a smoke test - in case of issues `notified()` future will
899        // never resolve hanging the test.
900        drop_check.closed().await;
901    }
902}