Skip to main content

tokio_quiche/quic/router/
mod.rs

1// Copyright (C) 2025, Cloudflare, Inc.
2// All rights reserved.
3//
4// Redistribution and use in source and binary forms, with or without
5// modification, are permitted provided that the following conditions are
6// met:
7//
8//     * Redistributions of source code must retain the above copyright notice,
9//       this list of conditions and the following disclaimer.
10//
11//     * Redistributions in binary form must reproduce the above copyright
12//       notice, this list of conditions and the following disclaimer in the
13//       documentation and/or other materials provided with the distribution.
14//
15// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
16// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
17// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
19// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
27pub(crate) mod acceptor;
28pub(crate) mod connector;
29
30use super::connection::ConnectionMap;
31use super::connection::HandshakeInfo;
32use super::connection::Incoming;
33use super::connection::InitialQuicConnection;
34use super::connection::QuicConnectionParams;
35use super::io::worker::WriterConfig;
36use super::QuicheConnection;
37use crate::buf_factory::BufFactory;
38use crate::buf_factory::PooledBuf;
39use crate::metrics::labels;
40use crate::metrics::quic_expensive_metrics_ip_reduce;
41use crate::metrics::Metrics;
42use crate::quic::connection::SharedConnectionIdGenerator;
43use crate::settings::Config;
44use datagram_socket::DatagramSocketRecv;
45use datagram_socket::DatagramSocketSend;
46use foundations::telemetry::log;
47use quiche::ConnectionId;
48use quiche::Header;
49use quiche::MAX_CONN_ID_LEN;
50use std::default::Default;
51use std::future::Future;
52use std::io;
53use std::net::SocketAddr;
54use std::pin::Pin;
55use std::sync::Arc;
56use std::task::ready;
57use std::task::Context;
58use std::task::Poll;
59use std::time::Instant;
60use std::time::SystemTime;
61use task_killswitch::spawn_with_killswitch;
62use tokio::sync::mpsc;
63
64#[cfg(target_os = "linux")]
65use foundations::telemetry::metrics::Counter;
66#[cfg(target_os = "linux")]
67use foundations::telemetry::metrics::TimeHistogram;
68#[cfg(target_os = "linux")]
69use libc::sockaddr_in;
70#[cfg(target_os = "linux")]
71use libc::sockaddr_in6;
72
73type ConnStream<Tx, M> = mpsc::Receiver<io::Result<InitialQuicConnection<Tx, M>>>;
74
75#[cfg(feature = "perf-quic-listener-metrics")]
76mod listener_stage_timer {
77    use foundations::telemetry::metrics::TimeHistogram;
78    use std::time::Instant;
79
80    pub(super) struct ListenerStageTimer {
81        start: Instant,
82        time_hist: TimeHistogram,
83    }
84
85    impl ListenerStageTimer {
86        pub(super) fn new(
87            start: Instant, time_hist: TimeHistogram,
88        ) -> ListenerStageTimer {
89            ListenerStageTimer { start, time_hist }
90        }
91    }
92
93    impl Drop for ListenerStageTimer {
94        fn drop(&mut self) {
95            self.time_hist
96                .observe((Instant::now() - self.start).as_nanos() as u64);
97        }
98    }
99}
100
101#[derive(Debug)]
102struct PollRecvData {
103    bytes: usize,
104    // The packet's source, e.g., the peer's address
105    src_addr: SocketAddr,
106    // The packet's original destination. If the original destination is
107    // different from the local listening address, this will be `None`.
108    dst_addr_override: Option<SocketAddr>,
109    rx_time: Option<SystemTime>,
110    gro: Option<i32>,
111    #[cfg(target_os = "linux")]
112    so_mark_data: Option<[u8; 4]>,
113}
114
115/// A message to the listener notifiying a mapping for a connection should be
116/// removed.
117pub enum ConnectionMapCommand {
118    MapCid {
119        existing_cid: ConnectionId<'static>,
120        new_cid: ConnectionId<'static>,
121    },
122    UnmapCid(ConnectionId<'static>),
123}
124
125/// An `InboundPacketRouter` maintains a map of quic connections and routes
126/// [`Incoming`] packets from the [recv half][rh] of a datagram socket to those
127/// connections or some quic initials handler.
128///
129/// [rh]: datagram_socket::DatagramSocketRecv
130///
131/// When a packet (or batch of packets) is received, the router will either
132/// route those packets to an established
133/// [`QuicConnection`](super::QuicConnection) or have a them handled by a
134/// `InitialPacketHandler` which either acts as a quic listener or
135/// quic connector, a server or client respectively.
136///
137/// If you only have a single connection, or if you need more control over the
138/// socket, use `QuicConnection` directly instead.
139pub struct InboundPacketRouter<Tx, Rx, M, I>
140where
141    Tx: DatagramSocketSend + Send + 'static,
142    M: Metrics,
143{
144    socket_tx: Arc<Tx>,
145    socket_rx: Rx,
146    local_addr: SocketAddr,
147    config: Config,
148    conns: ConnectionMap,
149    incoming_packet_handler: I,
150    shutdown_tx: Option<mpsc::Sender<()>>,
151    shutdown_rx: mpsc::Receiver<()>,
152    conn_map_cmd_tx: mpsc::UnboundedSender<ConnectionMapCommand>,
153    conn_map_cmd_rx: mpsc::UnboundedReceiver<ConnectionMapCommand>,
154    accept_sink: mpsc::Sender<io::Result<InitialQuicConnection<Tx, M>>>,
155    metrics: M,
156    #[cfg(target_os = "linux")]
157    udp_drop_count: u32,
158
159    #[cfg(target_os = "linux")]
160    reusable_cmsg_space: Vec<u8>,
161
162    current_buf: PooledBuf,
163
164    // We keep the metrics in here, to avoid cloning them each packet
165    #[cfg(target_os = "linux")]
166    metrics_handshake_time_seconds: TimeHistogram,
167    #[cfg(target_os = "linux")]
168    metrics_udp_drop_count: Counter,
169}
170
171impl<Tx, Rx, M, I> InboundPacketRouter<Tx, Rx, M, I>
172where
173    Tx: DatagramSocketSend + Send + 'static,
174    Rx: DatagramSocketRecv,
175    M: Metrics,
176    I: InitialPacketHandler,
177{
178    pub(crate) fn new(
179        config: Config, socket_tx: Arc<Tx>, socket_rx: Rx,
180        local_addr: SocketAddr, incoming_packet_handler: I, metrics: M,
181    ) -> (Self, ConnStream<Tx, M>) {
182        let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
183        let (accept_sink, accept_stream) = mpsc::channel(config.listen_backlog);
184        let (conn_map_cmd_tx, conn_map_cmd_rx) = mpsc::unbounded_channel();
185
186        (
187            InboundPacketRouter {
188                local_addr,
189                socket_tx,
190                socket_rx,
191                conns: ConnectionMap::default(),
192                incoming_packet_handler,
193                shutdown_tx: Some(shutdown_tx),
194                shutdown_rx,
195                conn_map_cmd_tx,
196                conn_map_cmd_rx,
197                accept_sink,
198                #[cfg(target_os = "linux")]
199                udp_drop_count: 0,
200                #[cfg(target_os = "linux")]
201                // Specify CMSG space. Even if they're not all currently used, the cmsg buffer may
202                // have been configured by a previous version of Tokio-Quiche with the socket
203                // re-used on graceful restart. As such, this vector should _only grow_, and care
204                // should be taken when adding new cmsgs.
205                reusable_cmsg_space: nix::cmsg_space!(
206                    u32, // GRO
207                    nix::sys::time::TimeSpec, // timestamp
208                    u16, // drop count
209                    sockaddr_in, // IP_RECVORIGDSTADDR
210                    sockaddr_in6, // IPV6_RECVORIGDSTADDR
211                    u32 // SO_MARK
212                ),
213                config,
214
215                current_buf: BufFactory::get_max_buf(),
216
217                #[cfg(target_os = "linux")]
218                metrics_handshake_time_seconds: metrics.handshake_time_seconds(labels::QuicHandshakeStage::QueueWaiting),
219                #[cfg(target_os = "linux")]
220                metrics_udp_drop_count: metrics.udp_drop_count(),
221
222                metrics,
223
224            },
225            accept_stream,
226        )
227    }
228
229    fn on_incoming(&mut self, mut incoming: Incoming) -> io::Result<()> {
230        #[cfg(feature = "perf-quic-listener-metrics")]
231        let start = std::time::Instant::now();
232
233        if let Some(dcid) = short_dcid(&incoming.buf) {
234            if let Some(ev_sender) = self.conns.get(&dcid) {
235                let _ = ev_sender.try_send(incoming);
236                return Ok(());
237            }
238        }
239
240        let hdr = Header::from_slice(&mut incoming.buf, MAX_CONN_ID_LEN)
241            .map_err(|e| match e {
242                quiche::Error::BufferTooShort | quiche::Error::InvalidPacket =>
243                    labels::QuicInvalidInitialPacketError::FailedToParse.into(),
244                e => io::Error::other(e),
245            })?;
246
247        if let Some(ev_sender) = self.conns.get(&hdr.dcid) {
248            let _ = ev_sender.try_send(incoming);
249            return Ok(());
250        }
251
252        #[cfg(feature = "perf-quic-listener-metrics")]
253        let _timer = listener_stage_timer::ListenerStageTimer::new(
254            start,
255            self.metrics.handshake_time_seconds(
256                labels::QuicHandshakeStage::HandshakeProtocol,
257            ),
258        );
259
260        if self.shutdown_tx.is_none() {
261            return Ok(());
262        }
263
264        let local_addr = incoming.local_addr;
265        let peer_addr = incoming.peer_addr;
266
267        #[cfg(feature = "perf-quic-listener-metrics")]
268        let init_rx_time = incoming.rx_time;
269
270        let new_connection = self.incoming_packet_handler.handle_initials(
271            incoming,
272            hdr,
273            self.config.as_mut(),
274        )?;
275
276        match new_connection {
277            Some(new_connection) => self.spawn_new_connection(
278                new_connection,
279                local_addr,
280                peer_addr,
281                #[cfg(feature = "perf-quic-listener-metrics")]
282                init_rx_time,
283            ),
284            None => Ok(()),
285        }
286    }
287
288    /// Creates a new [`QuicConnection`](super::QuicConnection) and spawns an
289    /// associated io worker.
290    fn spawn_new_connection(
291        &mut self, new_connection: NewConnection, local_addr: SocketAddr,
292        peer_addr: SocketAddr,
293        #[cfg(feature = "perf-quic-listener-metrics")] init_rx_time: Option<
294            SystemTime,
295        >,
296    ) -> io::Result<()> {
297        let NewConnection {
298            conn,
299            pending_cid,
300            cid_generator,
301            handshake_start_time,
302            initial_pkt,
303        } = new_connection;
304
305        let Some(ref shutdown_tx) = self.shutdown_tx else {
306            // don't create new connections if we're shutting down.
307            return Ok(());
308        };
309        let Ok(send_permit) = self.accept_sink.try_reserve() else {
310            // drop the connection if the backlog is full. the client will retry.
311            return Err(
312                labels::QuicInvalidInitialPacketError::AcceptQueueOverflow.into(),
313            );
314        };
315
316        let scid = conn.source_id().into_owned();
317        let writer_cfg = WriterConfig {
318            peer_addr,
319            local_addr,
320            pending_cid: pending_cid.clone(),
321            with_gso: self.config.has_gso,
322            pacing_offload: self.config.pacing_offload,
323            with_pktinfo: if self.local_addr.is_ipv4() {
324                self.config.has_ippktinfo
325            } else {
326                self.config.has_ipv6pktinfo
327            },
328        };
329
330        let handshake_info = HandshakeInfo::new(
331            handshake_start_time,
332            self.config.handshake_timeout,
333        );
334
335        let conn = InitialQuicConnection::new(QuicConnectionParams {
336            writer_cfg,
337            initial_pkt,
338            shutdown_tx: shutdown_tx.clone(),
339            conn_map_cmd_tx: self.conn_map_cmd_tx.clone(),
340            scid: scid.clone(),
341            cid_generator,
342            metrics: self.metrics.clone(),
343            #[cfg(feature = "perf-quic-listener-metrics")]
344            init_rx_time,
345            handshake_info,
346            quiche_conn: conn,
347            socket: Arc::clone(&self.socket_tx),
348            local_addr,
349            peer_addr,
350        });
351
352        conn.audit_log_stats
353            .set_transport_handshake_start(instant_to_system(
354                handshake_start_time,
355            ));
356
357        self.conns.insert(&scid, &conn);
358
359        // Add the client-generated "pending" connection ID to the map as well.
360        // This is only required for QUIC servers, because clients can send
361        // Initial packets with arbitrary DCIDs to servers.
362        if let Some(pending_cid) = pending_cid {
363            self.conns.map_cid(&scid, &pending_cid);
364        }
365
366        self.metrics.accepted_initial_packet_count().inc();
367        if self.config.enable_expensive_packet_count_metrics {
368            if let Some(peer_ip) =
369                quic_expensive_metrics_ip_reduce(conn.peer_addr().ip())
370            {
371                self.metrics
372                    .expensive_accepted_initial_packet_count(peer_ip)
373                    .inc();
374            }
375        }
376
377        send_permit.send(Ok(conn));
378        Ok(())
379    }
380}
381
382impl<Tx, Rx, M, I> InboundPacketRouter<Tx, Rx, M, I>
383where
384    Tx: DatagramSocketSend + Send + Sync + 'static,
385    Rx: DatagramSocketRecv,
386    M: Metrics,
387    I: InitialPacketHandler,
388{
389    /// [`InboundPacketRouter::poll_recv_from`] should be used if the underlying
390    /// system or socket does not support rx_time nor GRO.
391    fn poll_recv_from(
392        &mut self, cx: &mut Context<'_>,
393    ) -> Poll<io::Result<PollRecvData>> {
394        let mut buf = tokio::io::ReadBuf::new(&mut self.current_buf);
395        let addr = ready!(self.socket_rx.poll_recv_from(cx, &mut buf))?;
396        Poll::Ready(Ok(PollRecvData {
397            bytes: buf.filled().len(),
398            src_addr: addr,
399            rx_time: None,
400            gro: None,
401            dst_addr_override: None,
402            #[cfg(target_os = "linux")]
403            so_mark_data: None,
404        }))
405    }
406
407    fn poll_recv_and_rx_time(
408        &mut self, cx: &mut Context<'_>,
409    ) -> Poll<io::Result<PollRecvData>> {
410        #[cfg(not(target_os = "linux"))]
411        {
412            self.poll_recv_from(cx)
413        }
414
415        #[cfg(target_os = "linux")]
416        {
417            use libc::SOL_SOCKET;
418            use libc::SO_MARK;
419            use nix::errno::Errno;
420            use nix::sys::socket::*;
421            use std::net::SocketAddrV4;
422            use std::net::SocketAddrV6;
423            use std::os::fd::AsRawFd;
424            use tokio::io::Interest;
425
426            let Some(udp_socket) = self.socket_rx.as_udp_socket() else {
427                // the given socket is not a UDP socket, fall back to the
428                // simple poll_recv_from.
429                return self.poll_recv_from(cx);
430            };
431
432            loop {
433                let iov_s = &mut [io::IoSliceMut::new(&mut self.current_buf)];
434                match udp_socket.try_io(Interest::READABLE, || {
435                    recvmsg::<SockaddrStorage>(
436                        udp_socket.as_raw_fd(),
437                        iov_s,
438                        Some(&mut self.reusable_cmsg_space),
439                        MsgFlags::empty(),
440                    )
441                    .map_err(|x| x.into())
442                }) {
443                    Ok(r) => {
444                        let bytes = r.bytes;
445
446                        let address = match r.address {
447                            Some(inner) => inner,
448                            _ => return Poll::Ready(Err(Errno::EINVAL.into())),
449                        };
450
451                        let peer_addr = match address.family() {
452                            Some(AddressFamily::Inet) => SocketAddrV4::from(
453                                *address.as_sockaddr_in().unwrap(),
454                            )
455                            .into(),
456                            Some(AddressFamily::Inet6) => SocketAddrV6::from(
457                                *address.as_sockaddr_in6().unwrap(),
458                            )
459                            .into(),
460                            _ => {
461                                return Poll::Ready(Err(Errno::EINVAL.into()));
462                            },
463                        };
464
465                        let mut rx_time = None;
466                        let mut gro = None;
467                        let mut dst_addr_override = None;
468                        let mut mark_bytes: Option<[u8; 4]> = None;
469
470                        let Ok(cmsgs) = r.cmsgs() else {
471                            // Best-effort if we can't read cmsgs.
472                            return Poll::Ready(Ok(PollRecvData {
473                                bytes,
474                                src_addr: peer_addr,
475                                dst_addr_override,
476                                rx_time,
477                                gro,
478                                so_mark_data: mark_bytes,
479                            }));
480                        };
481
482                        for cmsg in cmsgs {
483                            match cmsg {
484                                ControlMessageOwned::RxqOvfl(c) => {
485                                    if c != self.udp_drop_count {
486                                        self.metrics_udp_drop_count.inc_by(
487                                            (c - self.udp_drop_count) as u64,
488                                        );
489                                        self.udp_drop_count = c;
490                                    }
491                                },
492                                ControlMessageOwned::ScmTimestampns(val) => {
493                                    rx_time = SystemTime::UNIX_EPOCH
494                                        .checked_add(val.into());
495                                    if let Some(delta) =
496                                        rx_time.and_then(|rx_time| {
497                                            rx_time.elapsed().ok()
498                                        })
499                                    {
500                                        self.metrics_handshake_time_seconds
501                                            .observe(delta.as_nanos() as u64);
502                                    }
503                                },
504                                ControlMessageOwned::UdpGroSegments(val) =>
505                                    gro = Some(val),
506                                ControlMessageOwned::Ipv4OrigDstAddr(val) => {
507                                    let source_addr = std::net::Ipv4Addr::from(
508                                        u32::to_be(val.sin_addr.s_addr),
509                                    );
510                                    let source_port = u16::to_be(val.sin_port);
511
512                                    let parsed_addr =
513                                        SocketAddr::V4(SocketAddrV4::new(
514                                            source_addr,
515                                            source_port,
516                                        ));
517
518                                    dst_addr_override = resolve_dst_addr(
519                                        &self.local_addr,
520                                        &parsed_addr,
521                                    );
522                                },
523                                ControlMessageOwned::Ipv6OrigDstAddr(val) => {
524                                    // Don't have to flip IPv6 bytes since it's a
525                                    // byte array, not a
526                                    // series of bytes parsed as a u32 as in the
527                                    // IPv4 case
528                                    let source_addr = std::net::Ipv6Addr::from(
529                                        val.sin6_addr.s6_addr,
530                                    );
531                                    let source_port = u16::to_be(val.sin6_port);
532                                    let source_flowinfo =
533                                        u32::to_be(val.sin6_flowinfo);
534                                    let source_scope =
535                                        u32::to_be(val.sin6_scope_id);
536
537                                    let parsed_addr =
538                                        SocketAddr::V6(SocketAddrV6::new(
539                                            source_addr,
540                                            source_port,
541                                            source_flowinfo,
542                                            source_scope,
543                                        ));
544
545                                    dst_addr_override = resolve_dst_addr(
546                                        &self.local_addr,
547                                        &parsed_addr,
548                                    );
549                                },
550                                ControlMessageOwned::Ipv4PacketInfo(_) |
551                                ControlMessageOwned::Ipv6PacketInfo(_) => {
552                                    // We only want the destination address from
553                                    // IP_RECVORIGDSTADDR, but we'll get these
554                                    // messages because we set IP_PKTINFO on the
555                                    // socket.
556                                },
557                                ControlMessageOwned::Unknown(raw_cmsg) => {
558                                    let UnknownCmsg {
559                                        cmsg_header,
560                                        data_bytes,
561                                    } = raw_cmsg;
562
563                                    if cmsg_header.cmsg_level == SOL_SOCKET &&
564                                        cmsg_header.cmsg_type == SO_MARK
565                                    {
566                                        let Ok(arr) =
567                                            <[u8; 4]>::try_from(data_bytes)
568                                        else {
569                                            // Should be unreachable as SO_MARK is
570                                            // a u32: https://elixir.bootlin.com/linux/v6.17/source/include/net/sock.h#L487
571                                            continue;
572                                        };
573
574                                        let _ = mark_bytes.insert(arr);
575                                    }
576                                },
577                                _ => {
578                                    // Unrecognized cmsg received, just ignore
579                                    // it.
580                                },
581                            };
582                        }
583
584                        return Poll::Ready(Ok(PollRecvData {
585                            bytes,
586                            src_addr: peer_addr,
587                            dst_addr_override,
588                            rx_time,
589                            gro,
590                            so_mark_data: mark_bytes,
591                        }));
592                    },
593                    Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
594                        // NOTE: we manually poll the socket here to register
595                        // interest in the socket to become
596                        // writable for the given `cx`. Under the hood, tokio's
597                        // implementation just checks for
598                        // EWOULDBLOCK and if socket is busy registers provided
599                        // waker to be invoked when the
600                        // socket is free and consequently drive the event loop.
601                        ready!(udp_socket.poll_recv_ready(cx))?
602                    },
603                    Err(e) => return Poll::Ready(Err(e)),
604                }
605            }
606        }
607    }
608
609    fn handle_conn_map_commands(&mut self) {
610        while let Ok(req) = self.conn_map_cmd_rx.try_recv() {
611            match req {
612                ConnectionMapCommand::MapCid {
613                    existing_cid,
614                    new_cid,
615                } => self.conns.map_cid(&existing_cid, &new_cid),
616                ConnectionMapCommand::UnmapCid(cid) => self.conns.unmap_cid(&cid),
617            }
618        }
619    }
620}
621
622// Quickly extract the connection id of a short quic packet without allocating
623fn short_dcid(buf: &[u8]) -> Option<ConnectionId<'_>> {
624    let is_short_dcid = buf.first()? >> 7 == 0;
625
626    if is_short_dcid {
627        buf.get(1..1 + MAX_CONN_ID_LEN).map(ConnectionId::from_ref)
628    } else {
629        None
630    }
631}
632
633/// Converts an [`Instant`] to a [`SystemTime`], based on the current delta
634/// between both clocks.
635fn instant_to_system(ts: Instant) -> SystemTime {
636    let now = Instant::now();
637    let system_now = SystemTime::now();
638    if let Some(delta) = now.checked_duration_since(ts) {
639        return system_now - delta;
640    }
641
642    let delta = ts.checked_duration_since(now).expect("now < ts");
643    system_now + delta
644}
645
646/// Determine if we should store the destination address for a packet, based on
647/// an address parsed from a
648/// [`ControlMessageOwned`](nix::sys::socket::ControlMessageOwned).
649///
650/// This is to prevent overriding the destination address if the packet was
651/// originally addressed to `local`, as that would cause us to incorrectly
652/// address packets when sending.
653///
654/// Returns the parsed address if it should be stored.
655#[cfg(target_os = "linux")]
656fn resolve_dst_addr(
657    local: &SocketAddr, parsed: &SocketAddr,
658) -> Option<SocketAddr> {
659    if local != parsed {
660        return Some(*parsed);
661    }
662
663    None
664}
665
666impl<Tx, Rx, M, I> Future for InboundPacketRouter<Tx, Rx, M, I>
667where
668    Tx: DatagramSocketSend + Send + Sync + 'static,
669    Rx: DatagramSocketRecv + Unpin,
670    M: Metrics,
671    I: InitialPacketHandler + Unpin,
672{
673    type Output = io::Result<()>;
674
675    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
676        let server_addr = self.local_addr;
677
678        loop {
679            if let Err(error) = self.incoming_packet_handler.update(cx) {
680                // This is so rare that it's easier to spawn a separate task
681                let sender = self.accept_sink.clone();
682                spawn_with_killswitch(async move {
683                    let _ = sender.send(Err(error)).await;
684                });
685            }
686
687            match self.poll_recv_and_rx_time(cx) {
688                Poll::Ready(Ok(PollRecvData {
689                    bytes,
690                    src_addr: peer_addr,
691                    dst_addr_override,
692                    rx_time,
693                    gro,
694                    #[cfg(target_os = "linux")]
695                    so_mark_data,
696                })) => {
697                    let mut buf = std::mem::replace(
698                        &mut self.current_buf,
699                        BufFactory::get_max_buf(),
700                    );
701                    buf.truncate(bytes);
702
703                    let send_from = if let Some(dst_addr) = dst_addr_override {
704                        log::trace!("overriding local address"; "actual_local" => dst_addr, "configured_local" => server_addr);
705                        dst_addr
706                    } else {
707                        server_addr
708                    };
709
710                    let res = self.on_incoming(Incoming {
711                        peer_addr,
712                        local_addr: send_from,
713                        buf,
714                        rx_time,
715                        gro,
716                        #[cfg(target_os = "linux")]
717                        so_mark_data,
718                    });
719
720                    if let Err(e) = res {
721                        let err_type = initial_packet_error_type(&e);
722                        self.metrics
723                            .rejected_initial_packet_count(err_type.clone())
724                            .inc();
725
726                        if self.config.enable_expensive_packet_count_metrics {
727                            if let Some(peer_ip) =
728                                quic_expensive_metrics_ip_reduce(peer_addr.ip())
729                            {
730                                self.metrics
731                                    .expensive_rejected_initial_packet_count(
732                                        err_type.clone(),
733                                        peer_ip,
734                                    )
735                                    .inc();
736                            }
737                        }
738
739                        if matches!(
740                            err_type,
741                            labels::QuicInvalidInitialPacketError::Unexpected
742                        ) {
743                            // don't block packet routing on errors
744                            let _ = self.accept_sink.try_send(Err(e));
745                        }
746                    }
747                },
748
749                Poll::Ready(Err(e)) => {
750                    log::error!("Incoming packet router encountered recvmsg error"; "error" => e);
751                    continue;
752                },
753
754                Poll::Pending => {
755                    // Check whether any connections are still active
756                    if self.shutdown_tx.is_some() && self.accept_sink.is_closed()
757                    {
758                        self.shutdown_tx = None;
759                    }
760
761                    if self.shutdown_rx.poll_recv(cx).is_ready() {
762                        return Poll::Ready(Ok(()));
763                    }
764
765                    // Process any incoming connection map signals and handle them
766                    self.handle_conn_map_commands();
767
768                    return Poll::Pending;
769                },
770            }
771        }
772    }
773}
774
775/// Categorizes errors that are returned when handling packets which are not
776/// associated with an established connection. The purpose is to suppress
777/// logging of 'expected' errors (e.g. junk data sent to the UDP socket) to
778/// prevent DoS.
779fn initial_packet_error_type(
780    e: &io::Error,
781) -> labels::QuicInvalidInitialPacketError {
782    Some(e)
783        .filter(|e| e.kind() == io::ErrorKind::Other)
784        .and_then(io::Error::get_ref)
785        .and_then(|e| e.downcast_ref())
786        .map_or(
787            labels::QuicInvalidInitialPacketError::Unexpected,
788            Clone::clone,
789        )
790}
791
792/// An [`InitialPacketHandler`] handles unknown quic initials and processes
793/// them; generally accepting new connections (acting as a server), or
794/// establishing a connection to a server (acting as a client). An
795/// [`InboundPacketRouter`] holds an instance of this trait and routes
796/// [`Incoming`] packets to it when it receives initials.
797///
798/// The handler produces [`quiche::Connection`]s which are then turned into
799/// [`QuicConnection`](super::QuicConnection), IoWorker pair.
800pub trait InitialPacketHandler {
801    fn update(&mut self, _ctx: &mut Context<'_>) -> io::Result<()> {
802        Ok(())
803    }
804
805    fn handle_initials(
806        &mut self, incoming: Incoming, hdr: Header<'static>,
807        quiche_config: &mut quiche::Config,
808    ) -> io::Result<Option<NewConnection>>;
809}
810
811/// A [`NewConnection`] describes a new [`quiche::Connection`] that can be
812/// driven by an io worker.
813pub struct NewConnection {
814    conn: QuicheConnection,
815    pending_cid: Option<ConnectionId<'static>>,
816    initial_pkt: Option<Incoming>,
817    cid_generator: Option<SharedConnectionIdGenerator>,
818    /// When the handshake started. Should be called before [`quiche::accept`]
819    /// or [`quiche::connect`].
820    handshake_start_time: Instant,
821}
822
823// TODO: the router module is private so we can't move these to /tests
824// TODO: Rewrite tests to be Windows compatible
825#[cfg(all(test, unix))]
826mod tests {
827    use super::acceptor::ConnectionAcceptor;
828    use super::acceptor::ConnectionAcceptorConfig;
829    use super::*;
830
831    use crate::http3::settings::Http3Settings;
832    use crate::metrics::DefaultMetrics;
833    use crate::quic::connection::SimpleConnectionIdGenerator;
834    use crate::settings::Config;
835    use crate::settings::Hooks;
836    use crate::settings::QuicSettings;
837    use crate::settings::TlsCertificatePaths;
838    use crate::socket::SocketCapabilities;
839    use crate::ConnectionParams;
840    use crate::ServerH3Driver;
841
842    use datagram_socket::MAX_DATAGRAM_SIZE;
843    use h3i::actions::h3::Action;
844    use std::sync::Arc;
845    use std::time::Duration;
846    use tokio::net::UdpSocket;
847    use tokio::time;
848
849    const TEST_CERT_FILE: &str = concat!(
850        env!("CARGO_MANIFEST_DIR"),
851        "/",
852        "../quiche/examples/cert.crt"
853    );
854    const TEST_KEY_FILE: &str = concat!(
855        env!("CARGO_MANIFEST_DIR"),
856        "/",
857        "../quiche/examples/cert.key"
858    );
859
860    fn test_connect(host_port: String) {
861        let h3i_config = h3i::config::Config::new()
862            .with_host_port("test.com".to_string())
863            .with_idle_timeout(2000)
864            .with_connect_to(host_port)
865            .verify_peer(false)
866            .build()
867            .unwrap();
868
869        let conn_close = h3i::quiche::ConnectionError {
870            is_app: true,
871            error_code: h3i::quiche::WireErrorCode::NoError as _,
872            reason: Vec::new(),
873        };
874        let actions = vec![Action::ConnectionClose { error: conn_close }];
875
876        let _ = h3i::client::sync_client::connect(h3i_config, actions, None);
877    }
878
879    #[tokio::test]
880    async fn test_timeout() {
881        // Configure a short idle timeout to speed up connection reclamation as
882        // quiche doesn't support time mocking
883        let quic_settings = QuicSettings {
884            max_idle_timeout: Some(Duration::from_millis(1)),
885            max_recv_udp_payload_size: MAX_DATAGRAM_SIZE,
886            max_send_udp_payload_size: MAX_DATAGRAM_SIZE,
887            ..Default::default()
888        };
889
890        let tls_cert_settings = TlsCertificatePaths {
891            cert: TEST_CERT_FILE,
892            private_key: TEST_KEY_FILE,
893            kind: crate::settings::CertificateKind::X509,
894        };
895
896        let params = ConnectionParams::new_server(
897            quic_settings,
898            tls_cert_settings,
899            Hooks::default(),
900        );
901        let config = Config::new(&params, SocketCapabilities::default()).unwrap();
902
903        let socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
904        let local_addr = socket.local_addr().unwrap();
905        let host_port = local_addr.to_string();
906        let socket_tx = Arc::new(socket);
907        let socket_rx = Arc::clone(&socket_tx);
908
909        let acceptor = ConnectionAcceptor::new(
910            ConnectionAcceptorConfig {
911                disable_client_ip_validation: config.disable_client_ip_validation,
912                qlog_dir: config.qlog_dir.clone(),
913                keylog_file: config
914                    .keylog_file
915                    .as_ref()
916                    .and_then(|f| f.try_clone().ok()),
917                #[cfg(target_os = "linux")]
918                with_pktinfo: false,
919            },
920            Arc::clone(&socket_tx),
921            Default::default(),
922            Arc::new(SimpleConnectionIdGenerator),
923            DefaultMetrics,
924        );
925
926        let (socket_driver, mut incoming) = InboundPacketRouter::new(
927            config,
928            socket_tx,
929            socket_rx,
930            local_addr,
931            acceptor,
932            DefaultMetrics,
933        );
934        tokio::spawn(socket_driver);
935
936        // Start a request and drop it after connection establishment
937        std::thread::spawn(move || test_connect(host_port));
938
939        // Wait for a new connection
940        time::pause();
941
942        let (h3_driver, _) = ServerH3Driver::new(Http3Settings::default());
943        let conn = incoming.recv().await.unwrap().unwrap();
944        let drop_check = conn.incoming_ev_sender.clone();
945        let _conn = conn.start(h3_driver);
946
947        // Poll the incoming until the connection is dropped
948        time::advance(Duration::new(30, 0)).await;
949        time::resume();
950
951        // NOTE: this is a smoke test - in case of issues `notified()` future will
952        // never resolve hanging the test.
953        drop_check.closed().await;
954    }
955}