tokio_quiche/quic/router/
mod.rs

1// Copyright (C) 2025, Cloudflare, Inc.
2// All rights reserved.
3//
4// Redistribution and use in source and binary forms, with or without
5// modification, are permitted provided that the following conditions are
6// met:
7//
8//     * Redistributions of source code must retain the above copyright notice,
9//       this list of conditions and the following disclaimer.
10//
11//     * Redistributions in binary form must reproduce the above copyright
12//       notice, this list of conditions and the following disclaimer in the
13//       documentation and/or other materials provided with the distribution.
14//
15// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
16// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
17// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
19// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
27pub(crate) mod acceptor;
28pub(crate) mod connector;
29
30use super::connection::ConnectionMap;
31use super::connection::HandshakeInfo;
32use super::connection::Incoming;
33use super::connection::InitialQuicConnection;
34use super::connection::QuicConnectionParams;
35use super::io::worker::WriterConfig;
36use super::QuicheConnection;
37use crate::buf_factory::BufFactory;
38use crate::buf_factory::PooledBuf;
39use crate::metrics::labels;
40use crate::metrics::quic_expensive_metrics_ip_reduce;
41use crate::metrics::Metrics;
42use crate::settings::Config;
43
44use datagram_socket::DatagramSocketRecv;
45use datagram_socket::DatagramSocketSend;
46use foundations::telemetry::log;
47#[cfg(target_os = "linux")]
48use foundations::telemetry::metrics::Counter;
49#[cfg(target_os = "linux")]
50use foundations::telemetry::metrics::TimeHistogram;
51#[cfg(target_os = "linux")]
52use libc::sockaddr_in;
53#[cfg(target_os = "linux")]
54use libc::sockaddr_in6;
55use quiche::ConnectionId;
56use quiche::Header;
57use quiche::MAX_CONN_ID_LEN;
58use std::default::Default;
59use std::future::Future;
60use std::io;
61use std::net::SocketAddr;
62use std::pin::Pin;
63use std::sync::Arc;
64use std::task::ready;
65use std::task::Context;
66use std::task::Poll;
67use std::time::Instant;
68use std::time::SystemTime;
69use task_killswitch::spawn_with_killswitch;
70use tokio::sync::mpsc;
71
72type ConnStream<Tx, M> = mpsc::Receiver<io::Result<InitialQuicConnection<Tx, M>>>;
73
74#[cfg(feature = "perf-quic-listener-metrics")]
75mod listener_stage_timer {
76    use foundations::telemetry::metrics::TimeHistogram;
77    use std::time::Instant;
78
79    pub(super) struct ListenerStageTimer {
80        start: Instant,
81        time_hist: TimeHistogram,
82    }
83
84    impl ListenerStageTimer {
85        pub(super) fn new(
86            start: Instant, time_hist: TimeHistogram,
87        ) -> ListenerStageTimer {
88            ListenerStageTimer { start, time_hist }
89        }
90    }
91
92    impl Drop for ListenerStageTimer {
93        fn drop(&mut self) {
94            self.time_hist
95                .observe((Instant::now() - self.start).as_nanos() as u64);
96        }
97    }
98}
99
100#[derive(Debug)]
101struct PollRecvData {
102    bytes: usize,
103    // The packet's source, e.g., the peer's address
104    src_addr: SocketAddr,
105    // The packet's original destination. If the original destination is
106    // different from the local listening address, this will be `None`.
107    dst_addr_override: Option<SocketAddr>,
108    rx_time: Option<SystemTime>,
109    gro: Option<u16>,
110}
111
112/// A message to the listener notifiying a mapping for a connection should be
113/// removed.
114pub enum ConnectionMapCommand {
115    UnmapCid(ConnectionId<'static>),
116    RemoveScid(ConnectionId<'static>),
117}
118
119/// An `InboundPacketRouter` maintains a map of quic connections and routes
120/// [`Incoming`] packets from the [recv half][rh] of a datagram socket to those
121/// connections or some quic initials handler.
122///
123/// [rh]: datagram_socket::DatagramSocketRecv
124///
125/// When a packet (or batch of packets) is received, the router will either
126/// route those packets to an established
127/// [`QuicConnection`](super::QuicConnection) or have a them handled by a
128/// `InitialPacketHandler` which either acts as a quic listener or
129/// quic connector, a server or client respectively.
130///
131/// If you only have a single connection, or if you need more control over the
132/// socket, use `QuicConnection` directly instead.
133pub struct InboundPacketRouter<Tx, Rx, M, I>
134where
135    Tx: DatagramSocketSend + Send + 'static,
136    M: Metrics,
137{
138    socket_tx: Arc<Tx>,
139    socket_rx: Rx,
140    local_addr: SocketAddr,
141    config: Config,
142    conns: ConnectionMap,
143    incoming_packet_handler: I,
144    shutdown_tx: Option<mpsc::Sender<()>>,
145    shutdown_rx: mpsc::Receiver<()>,
146    conn_map_cmd_tx: mpsc::UnboundedSender<ConnectionMapCommand>,
147    conn_map_cmd_rx: mpsc::UnboundedReceiver<ConnectionMapCommand>,
148    accept_sink: mpsc::Sender<io::Result<InitialQuicConnection<Tx, M>>>,
149    metrics: M,
150    #[cfg(target_os = "linux")]
151    udp_drop_count: u32,
152
153    #[cfg(target_os = "linux")]
154    reusable_cmsg_space: Vec<u8>,
155
156    current_buf: PooledBuf,
157
158    // We keep the metrics in here, to avoid cloning them each packet
159    #[cfg(target_os = "linux")]
160    metrics_handshake_time_seconds: TimeHistogram,
161    #[cfg(target_os = "linux")]
162    metrics_udp_drop_count: Counter,
163}
164
165impl<Tx, Rx, M, I> InboundPacketRouter<Tx, Rx, M, I>
166where
167    Tx: DatagramSocketSend + Send + 'static,
168    Rx: DatagramSocketRecv,
169    M: Metrics,
170    I: InitialPacketHandler,
171{
172    pub(crate) fn new(
173        config: Config, socket_tx: Arc<Tx>, socket_rx: Rx,
174        local_addr: SocketAddr, incoming_packet_handler: I, metrics: M,
175    ) -> (Self, ConnStream<Tx, M>) {
176        let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
177        let (accept_sink, accept_stream) = mpsc::channel(config.listen_backlog);
178        let (conn_map_cmd_tx, conn_map_cmd_rx) = mpsc::unbounded_channel();
179
180        (
181            InboundPacketRouter {
182                local_addr,
183                socket_tx,
184                socket_rx,
185                conns: ConnectionMap::default(),
186                incoming_packet_handler,
187                shutdown_tx: Some(shutdown_tx),
188                shutdown_rx,
189                conn_map_cmd_tx,
190                conn_map_cmd_rx,
191                accept_sink,
192                #[cfg(target_os = "linux")]
193                udp_drop_count: 0,
194                #[cfg(target_os = "linux")]
195                // Specify CMSG space for GRO, timestamp, drop count, IP_RECVORIGDSTADDR, and
196                // IPV6_RECVORIGDSTADDR. Even if they're not all currently used, the cmsg buffer
197                // may have been configured by a previous version of Tokio-Quiche with the socket
198                // re-used on graceful restart. As such, this vector should _only grow_, and care
199                // should be taken when adding new cmsgs.
200                reusable_cmsg_space: nix::cmsg_space!(u32, nix::sys::time::TimeSpec, u16, sockaddr_in, sockaddr_in6),
201                config,
202
203                current_buf: BufFactory::get_max_buf(),
204
205                #[cfg(target_os = "linux")]
206                metrics_handshake_time_seconds: metrics.handshake_time_seconds(labels::QuicHandshakeStage::QueueWaiting),
207                #[cfg(target_os = "linux")]
208                metrics_udp_drop_count: metrics.udp_drop_count(),
209
210                metrics,
211
212            },
213            accept_stream,
214        )
215    }
216
217    fn on_incoming(&mut self, mut incoming: Incoming) -> io::Result<()> {
218        #[cfg(feature = "perf-quic-listener-metrics")]
219        let start = std::time::Instant::now();
220
221        if let Some(dcid) = short_dcid(&incoming.buf) {
222            if let Some(ev_sender) = self.conns.get(&dcid) {
223                let _ = ev_sender.try_send(incoming);
224                return Ok(());
225            }
226        }
227
228        let hdr = Header::from_slice(&mut incoming.buf, MAX_CONN_ID_LEN)
229            .map_err(|e| match e {
230                quiche::Error::BufferTooShort | quiche::Error::InvalidPacket =>
231                    labels::QuicInvalidInitialPacketError::FailedToParse.into(),
232                e => io::Error::other(e),
233            })?;
234
235        if let Some(ev_sender) = self.conns.get(&hdr.dcid) {
236            let _ = ev_sender.try_send(incoming);
237            return Ok(());
238        }
239
240        #[cfg(feature = "perf-quic-listener-metrics")]
241        let _timer = listener_stage_timer::ListenerStageTimer::new(
242            start,
243            self.metrics.handshake_time_seconds(
244                labels::QuicHandshakeStage::HandshakeProtocol,
245            ),
246        );
247
248        if self.shutdown_tx.is_none() {
249            return Ok(());
250        }
251
252        let local_addr = incoming.local_addr;
253        let peer_addr = incoming.peer_addr;
254
255        #[cfg(feature = "perf-quic-listener-metrics")]
256        let init_rx_time = incoming.rx_time;
257
258        let new_connection = self.incoming_packet_handler.handle_initials(
259            incoming,
260            hdr,
261            self.config.as_mut(),
262        )?;
263
264        match new_connection {
265            Some(new_connection) => self.spawn_new_connection(
266                new_connection,
267                local_addr,
268                peer_addr,
269                #[cfg(feature = "perf-quic-listener-metrics")]
270                init_rx_time,
271            ),
272            None => Ok(()),
273        }
274    }
275
276    /// Creates a new [`QuicConnection`](super::QuicConnection) and spawns an
277    /// associated io worker.
278    fn spawn_new_connection(
279        &mut self, new_connection: NewConnection, local_addr: SocketAddr,
280        peer_addr: SocketAddr,
281        #[cfg(feature = "perf-quic-listener-metrics")] init_rx_time: Option<
282            SystemTime,
283        >,
284    ) -> io::Result<()> {
285        let NewConnection {
286            conn,
287            pending_cid,
288            handshake_start_time,
289            initial_pkt,
290        } = new_connection;
291
292        let Some(ref shutdown_tx) = self.shutdown_tx else {
293            // don't create new connections if we're shutting down.
294            return Ok(());
295        };
296        let Ok(send_permit) = self.accept_sink.try_reserve() else {
297            // drop the connection if the backlog is full. the client will retry.
298            return Err(
299                labels::QuicInvalidInitialPacketError::AcceptQueueOverflow.into(),
300            );
301        };
302
303        let scid = conn.source_id().into_owned();
304        let writer_cfg = WriterConfig {
305            peer_addr,
306            pending_cid: pending_cid.clone(),
307            with_gso: self.config.has_gso,
308            pacing_offload: self.config.pacing_offload,
309            with_pktinfo: if self.local_addr.is_ipv4() {
310                self.config.has_ippktinfo
311            } else {
312                self.config.has_ipv6pktinfo
313            },
314        };
315
316        let handshake_info = HandshakeInfo::new(
317            handshake_start_time,
318            self.config.handshake_timeout,
319        );
320
321        let conn = InitialQuicConnection::new(QuicConnectionParams {
322            writer_cfg,
323            initial_pkt,
324            shutdown_tx: shutdown_tx.clone(),
325            conn_map_cmd_tx: self.conn_map_cmd_tx.clone(),
326            scid: scid.clone(),
327            metrics: self.metrics.clone(),
328            #[cfg(feature = "perf-quic-listener-metrics")]
329            init_rx_time,
330            handshake_info,
331            quiche_conn: conn,
332            socket: Arc::clone(&self.socket_tx),
333            local_addr,
334            peer_addr,
335        });
336
337        conn.audit_log_stats
338            .set_transport_handshake_start(instant_to_system(
339                handshake_start_time,
340            ));
341
342        self.conns.insert(scid, &conn);
343
344        // Add the client-generated "pending" connection ID to the map as well.
345        //
346        // This is only required when client address validation is disabled.
347        // When validation is enabled, the client is already using the
348        // server-generated connection ID by the time we get here.
349        if let Some(pending_cid) = pending_cid {
350            self.conns.map_cid(pending_cid, &conn);
351        }
352
353        self.metrics.accepted_initial_packet_count().inc();
354        if self.config.enable_expensive_packet_count_metrics {
355            if let Some(peer_ip) =
356                quic_expensive_metrics_ip_reduce(conn.peer_addr().ip())
357            {
358                self.metrics
359                    .expensive_accepted_initial_packet_count(peer_ip)
360                    .inc();
361            }
362        }
363
364        send_permit.send(Ok(conn));
365        Ok(())
366    }
367}
368
369impl<Tx, Rx, M, I> InboundPacketRouter<Tx, Rx, M, I>
370where
371    Tx: DatagramSocketSend + Send + Sync + 'static,
372    Rx: DatagramSocketRecv,
373    M: Metrics,
374    I: InitialPacketHandler,
375{
376    /// [`InboundPacketRouter::poll_recv_from`] should be used if the underlying
377    /// system or socket does not support rx_time nor GRO.
378    fn poll_recv_from(
379        &mut self, cx: &mut Context<'_>,
380    ) -> Poll<io::Result<PollRecvData>> {
381        let mut buf = tokio::io::ReadBuf::new(&mut self.current_buf);
382        let addr = ready!(self.socket_rx.poll_recv_from(cx, &mut buf))?;
383        Poll::Ready(Ok(PollRecvData {
384            bytes: buf.filled().len(),
385            src_addr: addr,
386            rx_time: None,
387            gro: None,
388            dst_addr_override: None,
389        }))
390    }
391
392    fn poll_recv_and_rx_time(
393        &mut self, cx: &mut Context<'_>,
394    ) -> Poll<io::Result<PollRecvData>> {
395        #[cfg(not(target_os = "linux"))]
396        {
397            self.poll_recv_from(cx)
398        }
399
400        #[cfg(target_os = "linux")]
401        {
402            use nix::errno::Errno;
403            use nix::sys::socket::*;
404            use std::net::SocketAddrV4;
405            use std::net::SocketAddrV6;
406            use std::os::fd::AsRawFd;
407            use tokio::io::Interest;
408
409            let Some(udp_socket) = self.socket_rx.as_udp_socket() else {
410                // the given socket is not a UDP socket, fall back to the
411                // simple poll_recv_from.
412                return self.poll_recv_from(cx);
413            };
414
415            self.reusable_cmsg_space.clear();
416
417            loop {
418                let iov_s = &mut [io::IoSliceMut::new(&mut self.current_buf)];
419                match udp_socket.try_io(Interest::READABLE, || {
420                    recvmsg::<SockaddrStorage>(
421                        udp_socket.as_raw_fd(),
422                        iov_s,
423                        Some(&mut self.reusable_cmsg_space),
424                        MsgFlags::empty(),
425                    )
426                    .map_err(|x| x.into())
427                }) {
428                    Ok(r) => {
429                        let bytes = r.bytes;
430
431                        let address = match r.address {
432                            Some(inner) => inner,
433                            _ => return Poll::Ready(Err(Errno::EINVAL.into())),
434                        };
435
436                        let peer_addr = match address.family() {
437                            Some(AddressFamily::Inet) => SocketAddrV4::from(
438                                *address.as_sockaddr_in().unwrap(),
439                            )
440                            .into(),
441                            Some(AddressFamily::Inet6) => SocketAddrV6::from(
442                                *address.as_sockaddr_in6().unwrap(),
443                            )
444                            .into(),
445                            _ => {
446                                return Poll::Ready(Err(Errno::EINVAL.into()));
447                            },
448                        };
449
450                        let mut rx_time = None;
451                        let mut gro = None;
452                        let mut dst_addr_override = None;
453
454                        for cmsg in r.cmsgs() {
455                            match cmsg {
456                                ControlMessageOwned::RxqOvfl(c) => {
457                                    if c != self.udp_drop_count {
458                                        self.metrics_udp_drop_count.inc_by(
459                                            (c - self.udp_drop_count) as u64,
460                                        );
461                                        self.udp_drop_count = c;
462                                    }
463                                },
464                                ControlMessageOwned::ScmTimestampns(val) => {
465                                    rx_time = SystemTime::UNIX_EPOCH
466                                        .checked_add(val.into());
467                                    if let Some(delta) =
468                                        rx_time.and_then(|rx_time| {
469                                            rx_time.elapsed().ok()
470                                        })
471                                    {
472                                        self.metrics_handshake_time_seconds
473                                            .observe(delta.as_nanos() as u64);
474                                    }
475                                },
476                                ControlMessageOwned::UdpGroSegments(val) =>
477                                    gro = Some(val),
478                                ControlMessageOwned::Ipv4OrigDstAddr(val) => {
479                                    let source_addr = std::net::Ipv4Addr::from(
480                                        u32::to_be(val.sin_addr.s_addr),
481                                    );
482                                    let source_port = u16::to_be(val.sin_port);
483
484                                    let parsed_addr =
485                                        SocketAddr::V4(SocketAddrV4::new(
486                                            source_addr,
487                                            source_port,
488                                        ));
489
490                                    dst_addr_override = resolve_dst_addr(
491                                        &self.local_addr,
492                                        &parsed_addr,
493                                    );
494                                },
495                                ControlMessageOwned::Ipv6OrigDstAddr(val) => {
496                                    // Don't have to flip IPv6 bytes since it's a
497                                    // byte array, not a
498                                    // series of bytes parsed as a u32 as in the
499                                    // IPv4 case
500                                    let source_addr = std::net::Ipv6Addr::from(
501                                        val.sin6_addr.s6_addr,
502                                    );
503                                    let source_port = u16::to_be(val.sin6_port);
504                                    let source_flowinfo =
505                                        u32::to_be(val.sin6_flowinfo);
506                                    let source_scope =
507                                        u32::to_be(val.sin6_scope_id);
508
509                                    let parsed_addr =
510                                        SocketAddr::V6(SocketAddrV6::new(
511                                            source_addr,
512                                            source_port,
513                                            source_flowinfo,
514                                            source_scope,
515                                        ));
516
517                                    dst_addr_override = resolve_dst_addr(
518                                        &self.local_addr,
519                                        &parsed_addr,
520                                    );
521                                },
522                                ControlMessageOwned::Ipv4PacketInfo(_) |
523                                ControlMessageOwned::Ipv6PacketInfo(_) => {
524                                    // We only want the destination address from
525                                    // IP_RECVORIGDSTADDR, but we'll get these
526                                    // messages because
527                                    // we set IP_PKTINFO on the socket.
528                                },
529                                _ => {
530                                    return Poll::Ready(
531                                        Err(Errno::EINVAL.into()),
532                                    );
533                                },
534                            };
535                        }
536
537                        return Poll::Ready(Ok(PollRecvData {
538                            bytes,
539                            src_addr: peer_addr,
540                            dst_addr_override,
541                            rx_time,
542                            gro,
543                        }));
544                    },
545                    Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
546                        // NOTE: we manually poll the socket here to register
547                        // interest in the socket to become
548                        // writable for the given `cx`. Under the hood, tokio's
549                        // implementation just checks for
550                        // EWOULDBLOCK and if socket is busy registers provided
551                        // waker to be invoked when the
552                        // socket is free and consequently drive the event loop.
553                        ready!(udp_socket.poll_recv_ready(cx))?
554                    },
555                    Err(e) => return Poll::Ready(Err(e)),
556                }
557            }
558        }
559    }
560
561    fn handle_conn_map_commands(&mut self) {
562        while let Ok(req) = self.conn_map_cmd_rx.try_recv() {
563            match req {
564                ConnectionMapCommand::UnmapCid(cid) => self.conns.unmap_cid(&cid),
565                ConnectionMapCommand::RemoveScid(scid) =>
566                    self.conns.remove(&scid),
567            }
568        }
569    }
570}
571
572// Quickly extract the connection id of a short quic packet without allocating
573fn short_dcid(buf: &[u8]) -> Option<ConnectionId<'_>> {
574    let is_short_dcid = buf.first()? >> 7 == 0;
575
576    if is_short_dcid {
577        buf.get(1..1 + MAX_CONN_ID_LEN).map(ConnectionId::from_ref)
578    } else {
579        None
580    }
581}
582
583/// Converts an [`Instant`] to a [`SystemTime`], based on the current delta
584/// between both clocks.
585fn instant_to_system(ts: Instant) -> SystemTime {
586    let now = Instant::now();
587    let system_now = SystemTime::now();
588    if let Some(delta) = now.checked_duration_since(ts) {
589        return system_now - delta;
590    }
591
592    let delta = ts.checked_duration_since(now).expect("now < ts");
593    system_now + delta
594}
595
596/// Determine if we should store the destination address for a packet, based on
597/// an address parsed from a
598/// [`ControlMessageOwned`](nix::sys::socket::ControlMessageOwned).
599///
600/// This is to prevent overriding the destination address if the packet was
601/// originally addressed to `local`, as that would cause us to incorrectly
602/// address packets when sending.
603///
604/// Returns the parsed address if it should be stored.
605#[cfg(target_os = "linux")]
606fn resolve_dst_addr(
607    local: &SocketAddr, parsed: &SocketAddr,
608) -> Option<SocketAddr> {
609    if local != parsed {
610        return Some(*parsed);
611    }
612
613    None
614}
615
616impl<Tx, Rx, M, I> Future for InboundPacketRouter<Tx, Rx, M, I>
617where
618    Tx: DatagramSocketSend + Send + Sync + 'static,
619    Rx: DatagramSocketRecv + Unpin,
620    M: Metrics,
621    I: InitialPacketHandler + Unpin,
622{
623    type Output = io::Result<()>;
624
625    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
626        let server_addr = self.local_addr;
627
628        loop {
629            if let Err(error) = self.incoming_packet_handler.update(cx) {
630                // This is so rare that it's easier to spawn a separate task
631                let sender = self.accept_sink.clone();
632                spawn_with_killswitch(async move {
633                    let _ = sender.send(Err(error)).await;
634                });
635            }
636
637            match self.poll_recv_and_rx_time(cx) {
638                Poll::Ready(Ok(PollRecvData {
639                    bytes,
640                    src_addr: peer_addr,
641                    dst_addr_override,
642                    rx_time,
643                    gro,
644                })) => {
645                    let mut buf = std::mem::replace(
646                        &mut self.current_buf,
647                        BufFactory::get_max_buf(),
648                    );
649                    buf.truncate(bytes);
650
651                    let send_from = if let Some(dst_addr) = dst_addr_override {
652                        log::trace!("overriding local address"; "actual_local" => format!("{:?}", dst_addr), "configured_local" => format!("{:?}", server_addr));
653                        dst_addr
654                    } else {
655                        server_addr
656                    };
657
658                    let res = self.on_incoming(Incoming {
659                        peer_addr,
660                        local_addr: send_from,
661                        buf,
662                        rx_time,
663                        gro,
664                    });
665
666                    if let Err(e) = res {
667                        let err_type = initial_packet_error_type(&e);
668                        self.metrics
669                            .rejected_initial_packet_count(err_type.clone())
670                            .inc();
671
672                        if self.config.enable_expensive_packet_count_metrics {
673                            if let Some(peer_ip) =
674                                quic_expensive_metrics_ip_reduce(peer_addr.ip())
675                            {
676                                self.metrics
677                                    .expensive_rejected_initial_packet_count(
678                                        err_type.clone(),
679                                        peer_ip,
680                                    )
681                                    .inc();
682                            }
683                        }
684
685                        if matches!(
686                            err_type,
687                            labels::QuicInvalidInitialPacketError::Unexpected
688                        ) {
689                            // don't block packet routing on errors
690                            let _ = self.accept_sink.try_send(Err(e));
691                        }
692                    }
693                },
694
695                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
696
697                Poll::Pending => {
698                    // Check whether any connections are still active
699                    if self.shutdown_tx.is_some() && self.accept_sink.is_closed()
700                    {
701                        self.shutdown_tx = None;
702                    }
703
704                    if self.shutdown_rx.poll_recv(cx).is_ready() {
705                        return Poll::Ready(Ok(()));
706                    }
707
708                    // Process any incoming connection map signals and handle them
709                    self.handle_conn_map_commands();
710
711                    return Poll::Pending;
712                },
713            }
714        }
715    }
716}
717
718/// Categorizes errors that are returned when handling packets which are not
719/// associated with an established connection. The purpose is to suppress
720/// logging of 'expected' errors (e.g. junk data sent to the UDP socket) to
721/// prevent DoS.
722fn initial_packet_error_type(
723    e: &io::Error,
724) -> labels::QuicInvalidInitialPacketError {
725    Some(e)
726        .filter(|e| e.kind() == io::ErrorKind::Other)
727        .and_then(io::Error::get_ref)
728        .and_then(|e| e.downcast_ref())
729        .map_or(
730            labels::QuicInvalidInitialPacketError::Unexpected,
731            Clone::clone,
732        )
733}
734
735/// An [`InitialPacketHandler`] handles unknown quic initials and processes
736/// them; generally accepting new connections (acting as a server), or
737/// establishing a connection to a server (acting as a client). An
738/// [`InboundPacketRouter`] holds an instance of this trait and routes
739/// [`Incoming`] packets to it when it receives initials.
740///
741/// The handler produces [`quiche::Connection`]s which are then turned into
742/// [`QuicConnection`](super::QuicConnection), IoWorker pair.
743pub trait InitialPacketHandler {
744    fn update(&mut self, _ctx: &mut Context<'_>) -> io::Result<()> {
745        Ok(())
746    }
747
748    fn handle_initials(
749        &mut self, incoming: Incoming, hdr: Header<'static>,
750        quiche_config: &mut quiche::Config,
751    ) -> io::Result<Option<NewConnection>>;
752}
753
754/// A [`NewConnection`] describes a new [`quiche::Connection`] that can be
755/// driven by an io worker.
756pub struct NewConnection {
757    conn: QuicheConnection,
758    pending_cid: Option<ConnectionId<'static>>,
759    initial_pkt: Option<Incoming>,
760    /// When the handshake started. Should be called before [`quiche::accept`]
761    /// or [`quiche::connect`].
762    handshake_start_time: Instant,
763}
764
765// TODO: the router module is private so we can't move these to /tests
766// TODO: Rewrite tests to be Windows compatible
767#[cfg(all(test, unix))]
768mod tests {
769    use super::acceptor::ConnectionAcceptor;
770    use super::acceptor::ConnectionAcceptorConfig;
771    use super::*;
772
773    use crate::http3::settings::Http3Settings;
774    use crate::metrics::DefaultMetrics;
775    use crate::quic::connection::SimpleConnectionIdGenerator;
776    use crate::settings::Config;
777    use crate::settings::Hooks;
778    use crate::settings::QuicSettings;
779    use crate::settings::TlsCertificatePaths;
780    use crate::socket::SocketCapabilities;
781    use crate::ConnectionParams;
782    use crate::ServerH3Driver;
783
784    use datagram_socket::MAX_DATAGRAM_SIZE;
785    use h3i::actions::h3::Action;
786    use std::sync::Arc;
787    use std::time::Duration;
788    use tokio::net::UdpSocket;
789    use tokio::time;
790
791    const TEST_CERT_FILE: &str = concat!(
792        env!("CARGO_MANIFEST_DIR"),
793        "/",
794        "../quiche/examples/cert.crt"
795    );
796    const TEST_KEY_FILE: &str = concat!(
797        env!("CARGO_MANIFEST_DIR"),
798        "/",
799        "../quiche/examples/cert.key"
800    );
801
802    fn test_connect(host_port: String) {
803        let h3i_config = h3i::config::Config::new()
804            .with_host_port("test.com".to_string())
805            .with_idle_timeout(2000)
806            .with_connect_to(host_port)
807            .verify_peer(false)
808            .build()
809            .unwrap();
810
811        let conn_close = h3i::quiche::ConnectionError {
812            is_app: true,
813            error_code: h3i::quiche::WireErrorCode::NoError as _,
814            reason: Vec::new(),
815        };
816        let actions = [Action::ConnectionClose { error: conn_close }];
817
818        let _ = h3i::client::sync_client::connect(h3i_config, &actions, None);
819    }
820
821    #[tokio::test]
822    async fn test_timeout() {
823        // Configure a short idle timeout to speed up connection reclamation as
824        // quiche doesn't support time mocking
825        let quic_settings = QuicSettings {
826            max_idle_timeout: Some(Duration::from_millis(1)),
827            max_recv_udp_payload_size: MAX_DATAGRAM_SIZE,
828            max_send_udp_payload_size: MAX_DATAGRAM_SIZE,
829            ..Default::default()
830        };
831
832        let tls_cert_settings = TlsCertificatePaths {
833            cert: &TEST_CERT_FILE,
834            private_key: &TEST_KEY_FILE,
835            kind: crate::settings::CertificateKind::X509,
836        };
837
838        let params = ConnectionParams::new_server(
839            quic_settings,
840            tls_cert_settings,
841            Hooks::default(),
842        );
843        let config = Config::new(&params, SocketCapabilities::default()).unwrap();
844
845        let socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
846        let local_addr = socket.local_addr().unwrap();
847        let host_port = local_addr.to_string();
848        let socket_tx = Arc::new(socket);
849        let socket_rx = Arc::clone(&socket_tx);
850
851        let acceptor = ConnectionAcceptor::new(
852            ConnectionAcceptorConfig {
853                disable_client_ip_validation: config.disable_client_ip_validation,
854                qlog_dir: config.qlog_dir.clone(),
855                keylog_file: config
856                    .keylog_file
857                    .as_ref()
858                    .and_then(|f| f.try_clone().ok()),
859                #[cfg(target_os = "linux")]
860                with_pktinfo: false,
861            },
862            Arc::clone(&socket_tx),
863            0,
864            Default::default(),
865            Box::new(SimpleConnectionIdGenerator),
866            DefaultMetrics,
867        );
868
869        let (socket_driver, mut incoming) = InboundPacketRouter::new(
870            config,
871            socket_tx,
872            socket_rx,
873            local_addr,
874            acceptor,
875            DefaultMetrics,
876        );
877        tokio::spawn(socket_driver);
878
879        // Start a request and drop it after connection establishment
880        std::thread::spawn(move || test_connect(host_port));
881
882        // Wait for a new connection
883        time::pause();
884
885        let (h3_driver, _) = ServerH3Driver::new(Http3Settings::default());
886        let conn = incoming.recv().await.unwrap().unwrap();
887        let drop_check = conn.incoming_ev_sender.clone();
888        let _conn = conn.start(h3_driver);
889
890        // Poll the incoming until the connection is dropped
891        time::advance(Duration::new(30, 0)).await;
892        time::resume();
893
894        // NOTE: this is a smoke test - in case of issues `notified()` future will
895        // never resolve hanging the test.
896        drop_check.closed().await;
897    }
898}