Skip to main content

tokio_quiche/quic/router/
mod.rs

1// Copyright (C) 2025, Cloudflare, Inc.
2// All rights reserved.
3//
4// Redistribution and use in source and binary forms, with or without
5// modification, are permitted provided that the following conditions are
6// met:
7//
8//     * Redistributions of source code must retain the above copyright notice,
9//       this list of conditions and the following disclaimer.
10//
11//     * Redistributions in binary form must reproduce the above copyright
12//       notice, this list of conditions and the following disclaimer in the
13//       documentation and/or other materials provided with the distribution.
14//
15// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
16// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
17// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
19// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
27pub(crate) mod acceptor;
28pub(crate) mod connector;
29
30use super::connection::ConnectionMap;
31use super::connection::HandshakeInfo;
32use super::connection::Incoming;
33use super::connection::InitialQuicConnection;
34use super::connection::QuicConnectionParams;
35use super::io::worker::WriterConfig;
36use super::QuicheConnection;
37use crate::metrics::labels;
38use crate::metrics::quic_expensive_metrics_ip_reduce;
39use crate::metrics::Metrics;
40use crate::quic::connection::SharedConnectionIdGenerator;
41use crate::settings::Config;
42use datagram_socket::DatagramSocketRecv;
43use datagram_socket::DatagramSocketSend;
44use foundations::telemetry::log;
45use quiche::ConnectionId;
46use quiche::Header;
47use quiche::MAX_CONN_ID_LEN;
48use std::default::Default;
49use std::future::Future;
50use std::io;
51use std::net::SocketAddr;
52use std::pin::Pin;
53use std::sync::Arc;
54use std::task::ready;
55use std::task::Context;
56use std::task::Poll;
57use std::time::Instant;
58use std::time::SystemTime;
59use task_killswitch::spawn_with_killswitch;
60use tokio::sync::mpsc;
61
62#[cfg(target_os = "linux")]
63use foundations::telemetry::metrics::Counter;
64#[cfg(target_os = "linux")]
65use foundations::telemetry::metrics::TimeHistogram;
66#[cfg(target_os = "linux")]
67use libc::sockaddr_in;
68#[cfg(target_os = "linux")]
69use libc::sockaddr_in6;
70
71type ConnStream<Tx, M> = mpsc::Receiver<io::Result<InitialQuicConnection<Tx, M>>>;
72
73#[cfg(feature = "perf-quic-listener-metrics")]
74mod listener_stage_timer {
75    use foundations::telemetry::metrics::TimeHistogram;
76    use std::time::Instant;
77
78    pub(super) struct ListenerStageTimer {
79        start: Instant,
80        time_hist: TimeHistogram,
81    }
82
83    impl ListenerStageTimer {
84        pub(super) fn new(
85            start: Instant, time_hist: TimeHistogram,
86        ) -> ListenerStageTimer {
87            ListenerStageTimer { start, time_hist }
88        }
89    }
90
91    impl Drop for ListenerStageTimer {
92        fn drop(&mut self) {
93            self.time_hist
94                .observe((Instant::now() - self.start).as_nanos() as u64);
95        }
96    }
97}
98
99#[derive(Debug)]
100struct PollRecvData {
101    buf: Vec<u8>,
102    // The packet's source, e.g., the peer's address
103    src_addr: SocketAddr,
104    // The packet's original destination. If the original destination is
105    // different from the local listening address, this will be `None`.
106    dst_addr_override: Option<SocketAddr>,
107    rx_time: Option<SystemTime>,
108    gro: Option<i32>,
109    #[cfg(target_os = "linux")]
110    so_mark_data: Option<[u8; 4]>,
111}
112
113/// A message to the listener notifiying a mapping for a connection should be
114/// removed.
115pub enum ConnectionMapCommand {
116    MapCid {
117        existing_cid: ConnectionId<'static>,
118        new_cid: ConnectionId<'static>,
119    },
120    UnmapCid(ConnectionId<'static>),
121}
122
123/// An `InboundPacketRouter` maintains a map of quic connections and routes
124/// [`Incoming`] packets from the [recv half][rh] of a datagram socket to those
125/// connections or some quic initials handler.
126///
127/// [rh]: datagram_socket::DatagramSocketRecv
128///
129/// When a packet (or batch of packets) is received, the router will either
130/// route those packets to an established
131/// [`QuicConnection`](super::QuicConnection) or have a them handled by a
132/// `InitialPacketHandler` which either acts as a quic listener or
133/// quic connector, a server or client respectively.
134///
135/// If you only have a single connection, or if you need more control over the
136/// socket, use `QuicConnection` directly instead.
137pub struct InboundPacketRouter<Tx, Rx, M, I>
138where
139    Tx: DatagramSocketSend + Send + 'static,
140    M: Metrics,
141{
142    socket_tx: Arc<Tx>,
143    socket_rx: Rx,
144    local_addr: SocketAddr,
145    config: Config,
146    conns: ConnectionMap,
147    incoming_packet_handler: I,
148    shutdown_tx: Option<mpsc::Sender<()>>,
149    shutdown_rx: mpsc::Receiver<()>,
150    conn_map_cmd_tx: mpsc::UnboundedSender<ConnectionMapCommand>,
151    conn_map_cmd_rx: mpsc::UnboundedReceiver<ConnectionMapCommand>,
152    accept_sink: mpsc::Sender<io::Result<InitialQuicConnection<Tx, M>>>,
153    metrics: M,
154    #[cfg(target_os = "linux")]
155    udp_drop_count: u32,
156
157    #[cfg(target_os = "linux")]
158    reusable_cmsg_space: Vec<u8>,
159
160    #[cfg(target_os = "linux")]
161    buf: Vec<u8>,
162
163    // We keep the metrics in here, to avoid cloning them each packet
164    #[cfg(target_os = "linux")]
165    metrics_handshake_time_seconds: TimeHistogram,
166    #[cfg(target_os = "linux")]
167    metrics_udp_drop_count: Counter,
168}
169
170impl<Tx, Rx, M, I> InboundPacketRouter<Tx, Rx, M, I>
171where
172    Tx: DatagramSocketSend + Send + 'static,
173    Rx: DatagramSocketRecv,
174    M: Metrics,
175    I: InitialPacketHandler,
176{
177    pub(crate) fn new(
178        config: Config, socket_tx: Arc<Tx>, socket_rx: Rx,
179        local_addr: SocketAddr, incoming_packet_handler: I, metrics: M,
180    ) -> (Self, ConnStream<Tx, M>) {
181        let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
182        let (accept_sink, accept_stream) = mpsc::channel(config.listen_backlog);
183        let (conn_map_cmd_tx, conn_map_cmd_rx) = mpsc::unbounded_channel();
184
185        (
186            InboundPacketRouter {
187                local_addr,
188                socket_tx,
189                socket_rx,
190                conns: ConnectionMap::default(),
191                incoming_packet_handler,
192                shutdown_tx: Some(shutdown_tx),
193                shutdown_rx,
194                conn_map_cmd_tx,
195                conn_map_cmd_rx,
196                accept_sink,
197                #[cfg(target_os = "linux")]
198                udp_drop_count: 0,
199                #[cfg(target_os = "linux")]
200                // Specify CMSG space. Even if they're not all currently used, the cmsg buffer may
201                // have been configured by a previous version of Tokio-Quiche with the socket
202                // re-used on graceful restart. As such, this vector should _only grow_, and care
203                // should be taken when adding new cmsgs.
204                reusable_cmsg_space: nix::cmsg_space!(
205                    u32, // GRO
206                    nix::sys::time::TimeSpec, // timestamp
207                    u16, // drop count
208                    sockaddr_in, // IP_RECVORIGDSTADDR
209                    sockaddr_in6, // IPV6_RECVORIGDSTADDR
210                    u32 // SO_MARK
211                ),
212
213                config,
214
215                #[cfg(target_os = "linux")]
216                buf: Vec::new(),
217                #[cfg(target_os = "linux")]
218                metrics_handshake_time_seconds: metrics.handshake_time_seconds(labels::QuicHandshakeStage::QueueWaiting),
219                #[cfg(target_os = "linux")]
220                metrics_udp_drop_count: metrics.udp_drop_count(),
221
222                metrics,
223
224            },
225            accept_stream,
226        )
227    }
228
229    fn on_incoming(&mut self, mut incoming: Incoming) -> io::Result<()> {
230        #[cfg(feature = "perf-quic-listener-metrics")]
231        let start = std::time::Instant::now();
232
233        if let Some(dcid) = short_dcid(&incoming.buf) {
234            if let Some(ev_sender) = self.conns.get(&dcid) {
235                let _ = ev_sender.try_send(incoming);
236                return Ok(());
237            }
238        }
239
240        let hdr = Header::from_slice(&mut incoming.buf, MAX_CONN_ID_LEN)
241            .map_err(|e| match e {
242                quiche::Error::BufferTooShort | quiche::Error::InvalidPacket =>
243                    labels::QuicInvalidInitialPacketError::FailedToParse.into(),
244                e => io::Error::other(e),
245            })?;
246
247        if let Some(ev_sender) = self.conns.get(&hdr.dcid) {
248            let _ = ev_sender.try_send(incoming);
249            return Ok(());
250        }
251
252        #[cfg(feature = "perf-quic-listener-metrics")]
253        let _timer = listener_stage_timer::ListenerStageTimer::new(
254            start,
255            self.metrics.handshake_time_seconds(
256                labels::QuicHandshakeStage::HandshakeProtocol,
257            ),
258        );
259
260        if self.shutdown_tx.is_none() {
261            return Ok(());
262        }
263
264        let local_addr = incoming.local_addr;
265        let peer_addr = incoming.peer_addr;
266
267        #[cfg(feature = "perf-quic-listener-metrics")]
268        let init_rx_time = incoming.rx_time;
269
270        let new_connection = self.incoming_packet_handler.handle_initials(
271            incoming,
272            hdr,
273            self.config.as_mut(),
274        )?;
275
276        match new_connection {
277            Some(new_connection) => self.spawn_new_connection(
278                new_connection,
279                local_addr,
280                peer_addr,
281                #[cfg(feature = "perf-quic-listener-metrics")]
282                init_rx_time,
283            ),
284            None => Ok(()),
285        }
286    }
287
288    /// Creates a new [`QuicConnection`](super::QuicConnection) and spawns an
289    /// associated io worker.
290    fn spawn_new_connection(
291        &mut self, new_connection: NewConnection, local_addr: SocketAddr,
292        peer_addr: SocketAddr,
293        #[cfg(feature = "perf-quic-listener-metrics")] init_rx_time: Option<
294            SystemTime,
295        >,
296    ) -> io::Result<()> {
297        let NewConnection {
298            conn,
299            pending_cid,
300            cid_generator,
301            handshake_start_time,
302            initial_pkt,
303        } = new_connection;
304
305        let Some(ref shutdown_tx) = self.shutdown_tx else {
306            // don't create new connections if we're shutting down.
307            return Ok(());
308        };
309        let Ok(send_permit) = self.accept_sink.try_reserve() else {
310            // drop the connection if the backlog is full. the client will retry.
311            return Err(
312                labels::QuicInvalidInitialPacketError::AcceptQueueOverflow.into(),
313            );
314        };
315
316        let scid = conn.source_id().into_owned();
317        let writer_cfg = WriterConfig {
318            peer_addr,
319            local_addr,
320            pending_cid: pending_cid.clone(),
321            with_gso: self.config.has_gso,
322            pacing_offload: self.config.pacing_offload,
323            with_pktinfo: if self.local_addr.is_ipv4() {
324                self.config.has_ippktinfo
325            } else {
326                self.config.has_ipv6pktinfo
327            },
328        };
329
330        let handshake_info = HandshakeInfo::new(
331            handshake_start_time,
332            self.config.handshake_timeout,
333        );
334
335        let conn = InitialQuicConnection::new(QuicConnectionParams {
336            writer_cfg,
337            initial_pkt,
338            shutdown_tx: shutdown_tx.clone(),
339            conn_map_cmd_tx: self.conn_map_cmd_tx.clone(),
340            scid: scid.clone(),
341            cid_generator,
342            metrics: self.metrics.clone(),
343            #[cfg(feature = "perf-quic-listener-metrics")]
344            init_rx_time,
345            handshake_info,
346            quiche_conn: conn,
347            socket: Arc::clone(&self.socket_tx),
348            local_addr,
349            peer_addr,
350        });
351
352        conn.audit_log_stats
353            .set_transport_handshake_start(instant_to_system(
354                handshake_start_time,
355            ));
356
357        self.conns.insert(&scid, &conn);
358
359        // Add the client-generated "pending" connection ID to the map as well.
360        // This is only required for QUIC servers, because clients can send
361        // Initial packets with arbitrary DCIDs to servers.
362        if let Some(pending_cid) = pending_cid {
363            self.conns.map_cid(&scid, &pending_cid);
364        }
365
366        self.metrics.accepted_initial_packet_count().inc();
367        if self.config.enable_expensive_packet_count_metrics {
368            if let Some(peer_ip) =
369                quic_expensive_metrics_ip_reduce(conn.peer_addr().ip())
370            {
371                self.metrics
372                    .expensive_accepted_initial_packet_count(peer_ip)
373                    .inc();
374            }
375        }
376
377        send_permit.send(Ok(conn));
378        Ok(())
379    }
380}
381
382impl<Tx, Rx, M, I> InboundPacketRouter<Tx, Rx, M, I>
383where
384    Tx: DatagramSocketSend + Send + Sync + 'static,
385    Rx: DatagramSocketRecv,
386    M: Metrics,
387    I: InitialPacketHandler,
388{
389    /// [`InboundPacketRouter::poll_recv_from`] should be used if the underlying
390    /// system or socket does not support rx_time nor GRO.
391    fn poll_recv_from(
392        &mut self, cx: &mut Context<'_>,
393    ) -> Poll<io::Result<PollRecvData>> {
394        let mut buf = Vec::with_capacity(datagram_socket::MAX_DATAGRAM_SIZE);
395        // We use ReadBuf's ability to write to uninitialized memory to avoid
396        // the cost of having to initialize the Vec.
397        let mut read_buf = tokio::io::ReadBuf::uninit(buf.spare_capacity_mut());
398        let addr = ready!(self.socket_rx.poll_recv_from(cx, &mut read_buf))?;
399        let n = read_buf.filled().len();
400        unsafe {
401            // Safety: ReadBuf has guaranteed that `n` initialized bytes have
402            // been written to the buffer, so we can set the vec's length
403            // accordingly
404            buf.set_len(n);
405        }
406        Poll::Ready(Ok(PollRecvData {
407            buf,
408            src_addr: addr,
409            rx_time: None,
410            gro: None,
411            dst_addr_override: None,
412            #[cfg(target_os = "linux")]
413            so_mark_data: None,
414        }))
415    }
416
417    fn poll_recv_and_rx_time(
418        &mut self, cx: &mut Context<'_>,
419    ) -> Poll<io::Result<PollRecvData>> {
420        #[cfg(not(target_os = "linux"))]
421        {
422            self.poll_recv_from(cx)
423        }
424
425        #[cfg(target_os = "linux")]
426        {
427            use libc::SOL_SOCKET;
428            use libc::SO_MARK;
429            use nix::errno::Errno;
430            use nix::sys::socket::*;
431            use std::net::SocketAddrV4;
432            use std::net::SocketAddrV6;
433            use std::os::fd::AsRawFd;
434            use tokio::io::Interest;
435
436            use crate::buf_factory::BufFactory;
437
438            let Some(udp_socket) = self.socket_rx.as_udp_socket() else {
439                // the given socket is not a UDP socket, fall back to the
440                // simple poll_recv_from.
441                return self.poll_recv_from(cx);
442            };
443
444            // Note, the resize will be a no-op after the first call since
445            // we never truncate the `self.buf`
446            self.buf.resize(BufFactory::MAX_BUF_SIZE, 0u8);
447            loop {
448                let iov_s = &mut [io::IoSliceMut::new(&mut self.buf)];
449                match udp_socket.try_io(Interest::READABLE, || {
450                    recvmsg::<SockaddrStorage>(
451                        udp_socket.as_raw_fd(),
452                        iov_s,
453                        Some(&mut self.reusable_cmsg_space),
454                        MsgFlags::empty(),
455                    )
456                    .map_err(|x| x.into())
457                }) {
458                    Ok(r) => {
459                        let filled_buf =
460                            r.iovs().next().map(Vec::from).unwrap_or_default();
461                        // The slices returend by `nix::socket::recvmsg`'s result
462                        // add up to `r.bytes`. This assert is just to make sure
463                        // the code handles the result correctly.
464                        debug_assert_eq!(r.bytes, filled_buf.len());
465
466                        let address = match r.address {
467                            Some(inner) => inner,
468                            _ => return Poll::Ready(Err(Errno::EINVAL.into())),
469                        };
470
471                        let peer_addr = match address.family() {
472                            Some(AddressFamily::Inet) => SocketAddrV4::from(
473                                *address.as_sockaddr_in().unwrap(),
474                            )
475                            .into(),
476                            Some(AddressFamily::Inet6) => SocketAddrV6::from(
477                                *address.as_sockaddr_in6().unwrap(),
478                            )
479                            .into(),
480                            _ => {
481                                return Poll::Ready(Err(Errno::EINVAL.into()));
482                            },
483                        };
484
485                        let mut rx_time = None;
486                        let mut gro = None;
487                        let mut dst_addr_override = None;
488                        let mut mark_bytes: Option<[u8; 4]> = None;
489
490                        let Ok(cmsgs) = r.cmsgs() else {
491                            // Best-effort if we can't read cmsgs.
492                            return Poll::Ready(Ok(PollRecvData {
493                                buf: filled_buf,
494                                src_addr: peer_addr,
495                                dst_addr_override,
496                                rx_time,
497                                gro,
498                                so_mark_data: mark_bytes,
499                            }));
500                        };
501
502                        for cmsg in cmsgs {
503                            match cmsg {
504                                ControlMessageOwned::RxqOvfl(c) => {
505                                    if c != self.udp_drop_count {
506                                        self.metrics_udp_drop_count.inc_by(
507                                            (c - self.udp_drop_count) as u64,
508                                        );
509                                        self.udp_drop_count = c;
510                                    }
511                                },
512                                ControlMessageOwned::ScmTimestampns(val) => {
513                                    rx_time = SystemTime::UNIX_EPOCH
514                                        .checked_add(val.into());
515                                    if let Some(delta) =
516                                        rx_time.and_then(|rx_time| {
517                                            rx_time.elapsed().ok()
518                                        })
519                                    {
520                                        self.metrics_handshake_time_seconds
521                                            .observe(delta.as_nanos() as u64);
522                                    }
523                                },
524                                ControlMessageOwned::UdpGroSegments(val) =>
525                                    gro = Some(val),
526                                ControlMessageOwned::Ipv4OrigDstAddr(val) => {
527                                    let source_addr = std::net::Ipv4Addr::from(
528                                        u32::to_be(val.sin_addr.s_addr),
529                                    );
530                                    let source_port = u16::to_be(val.sin_port);
531
532                                    let parsed_addr =
533                                        SocketAddr::V4(SocketAddrV4::new(
534                                            source_addr,
535                                            source_port,
536                                        ));
537
538                                    dst_addr_override = resolve_dst_addr(
539                                        &self.local_addr,
540                                        &parsed_addr,
541                                    );
542                                },
543                                ControlMessageOwned::Ipv6OrigDstAddr(val) => {
544                                    // Don't have to flip IPv6 bytes since it's a
545                                    // byte array, not a
546                                    // series of bytes parsed as a u32 as in the
547                                    // IPv4 case
548                                    let source_addr = std::net::Ipv6Addr::from(
549                                        val.sin6_addr.s6_addr,
550                                    );
551                                    let source_port = u16::to_be(val.sin6_port);
552                                    let source_flowinfo =
553                                        u32::to_be(val.sin6_flowinfo);
554                                    let source_scope =
555                                        u32::to_be(val.sin6_scope_id);
556
557                                    let parsed_addr =
558                                        SocketAddr::V6(SocketAddrV6::new(
559                                            source_addr,
560                                            source_port,
561                                            source_flowinfo,
562                                            source_scope,
563                                        ));
564
565                                    dst_addr_override = resolve_dst_addr(
566                                        &self.local_addr,
567                                        &parsed_addr,
568                                    );
569                                },
570                                ControlMessageOwned::Ipv4PacketInfo(_) |
571                                ControlMessageOwned::Ipv6PacketInfo(_) => {
572                                    // We only want the destination address from
573                                    // IP_RECVORIGDSTADDR, but we'll get these
574                                    // messages because we set IP_PKTINFO on the
575                                    // socket.
576                                },
577                                ControlMessageOwned::Unknown(raw_cmsg) => {
578                                    let UnknownCmsg {
579                                        cmsg_header,
580                                        data_bytes,
581                                    } = raw_cmsg;
582
583                                    if cmsg_header.cmsg_level == SOL_SOCKET &&
584                                        cmsg_header.cmsg_type == SO_MARK
585                                    {
586                                        let Ok(arr) =
587                                            <[u8; 4]>::try_from(data_bytes)
588                                        else {
589                                            // Should be unreachable as SO_MARK is
590                                            // a u32: https://elixir.bootlin.com/linux/v6.17/source/include/net/sock.h#L487
591                                            continue;
592                                        };
593
594                                        let _ = mark_bytes.insert(arr);
595                                    }
596                                },
597                                _ => {
598                                    // Unrecognized cmsg received, just ignore
599                                    // it.
600                                },
601                            };
602                        }
603
604                        return Poll::Ready(Ok(PollRecvData {
605                            buf: filled_buf,
606                            src_addr: peer_addr,
607                            dst_addr_override,
608                            rx_time,
609                            gro,
610                            so_mark_data: mark_bytes,
611                        }));
612                    },
613                    Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
614                        // NOTE: we manually poll the socket here to register
615                        // interest in the socket to become
616                        // writable for the given `cx`. Under the hood, tokio's
617                        // implementation just checks for
618                        // EWOULDBLOCK and if socket is busy registers provided
619                        // waker to be invoked when the
620                        // socket is free and consequently drive the event loop.
621                        ready!(udp_socket.poll_recv_ready(cx))?
622                    },
623                    Err(e) => return Poll::Ready(Err(e)),
624                }
625            }
626        }
627    }
628
629    fn handle_conn_map_commands(&mut self) {
630        while let Ok(req) = self.conn_map_cmd_rx.try_recv() {
631            match req {
632                ConnectionMapCommand::MapCid {
633                    existing_cid,
634                    new_cid,
635                } => self.conns.map_cid(&existing_cid, &new_cid),
636                ConnectionMapCommand::UnmapCid(cid) => self.conns.unmap_cid(&cid),
637            }
638        }
639    }
640}
641
642// Quickly extract the connection id of a short quic packet without allocating
643fn short_dcid(buf: &[u8]) -> Option<ConnectionId<'_>> {
644    let is_short_dcid = buf.first()? >> 7 == 0;
645
646    if is_short_dcid {
647        buf.get(1..1 + MAX_CONN_ID_LEN).map(ConnectionId::from_ref)
648    } else {
649        None
650    }
651}
652
653/// Converts an [`Instant`] to a [`SystemTime`], based on the current delta
654/// between both clocks.
655fn instant_to_system(ts: Instant) -> SystemTime {
656    let now = Instant::now();
657    let system_now = SystemTime::now();
658    if let Some(delta) = now.checked_duration_since(ts) {
659        return system_now - delta;
660    }
661
662    let delta = ts.checked_duration_since(now).expect("now < ts");
663    system_now + delta
664}
665
666/// Determine if we should store the destination address for a packet, based on
667/// an address parsed from a
668/// [`ControlMessageOwned`](nix::sys::socket::ControlMessageOwned).
669///
670/// This is to prevent overriding the destination address if the packet was
671/// originally addressed to `local`, as that would cause us to incorrectly
672/// address packets when sending.
673///
674/// Returns the parsed address if it should be stored.
675#[cfg(target_os = "linux")]
676fn resolve_dst_addr(
677    local: &SocketAddr, parsed: &SocketAddr,
678) -> Option<SocketAddr> {
679    if local != parsed {
680        return Some(*parsed);
681    }
682
683    None
684}
685
686impl<Tx, Rx, M, I> Future for InboundPacketRouter<Tx, Rx, M, I>
687where
688    Tx: DatagramSocketSend + Send + Sync + 'static,
689    Rx: DatagramSocketRecv + Unpin,
690    M: Metrics,
691    I: InitialPacketHandler + Unpin,
692{
693    type Output = io::Result<()>;
694
695    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
696        let server_addr = self.local_addr;
697
698        loop {
699            if let Err(error) = self.incoming_packet_handler.update(cx) {
700                // This is so rare that it's easier to spawn a separate task
701                let sender = self.accept_sink.clone();
702                spawn_with_killswitch(async move {
703                    let _ = sender.send(Err(error)).await;
704                });
705            }
706
707            match self.poll_recv_and_rx_time(cx) {
708                Poll::Ready(Ok(PollRecvData {
709                    buf,
710                    src_addr: peer_addr,
711                    dst_addr_override,
712                    rx_time,
713                    gro,
714                    #[cfg(target_os = "linux")]
715                    so_mark_data,
716                })) => {
717                    let send_from = if let Some(dst_addr) = dst_addr_override {
718                        log::trace!("overriding local address"; "actual_local" => dst_addr, "configured_local" => server_addr);
719                        dst_addr
720                    } else {
721                        server_addr
722                    };
723
724                    let res = self.on_incoming(Incoming {
725                        peer_addr,
726                        local_addr: send_from,
727                        buf,
728                        rx_time,
729                        gro,
730                        #[cfg(target_os = "linux")]
731                        so_mark_data,
732                    });
733
734                    if let Err(e) = res {
735                        let err_type = initial_packet_error_type(&e);
736                        self.metrics
737                            .rejected_initial_packet_count(err_type.clone())
738                            .inc();
739
740                        if self.config.enable_expensive_packet_count_metrics {
741                            if let Some(peer_ip) =
742                                quic_expensive_metrics_ip_reduce(peer_addr.ip())
743                            {
744                                self.metrics
745                                    .expensive_rejected_initial_packet_count(
746                                        err_type.clone(),
747                                        peer_ip,
748                                    )
749                                    .inc();
750                            }
751                        }
752
753                        if matches!(
754                            err_type,
755                            labels::QuicInvalidInitialPacketError::Unexpected
756                        ) {
757                            // don't block packet routing on errors
758                            let _ = self.accept_sink.try_send(Err(e));
759                        }
760                    }
761                },
762
763                Poll::Ready(Err(e)) => {
764                    log::error!("Incoming packet router encountered recvmsg error"; "error" => e);
765                    continue;
766                },
767
768                Poll::Pending => {
769                    // Check whether any connections are still active
770                    if self.shutdown_tx.is_some() && self.accept_sink.is_closed()
771                    {
772                        self.shutdown_tx = None;
773                    }
774
775                    if self.shutdown_rx.poll_recv(cx).is_ready() {
776                        return Poll::Ready(Ok(()));
777                    }
778
779                    // Process any incoming connection map signals and handle them
780                    self.handle_conn_map_commands();
781
782                    return Poll::Pending;
783                },
784            }
785        }
786    }
787}
788
789/// Categorizes errors that are returned when handling packets which are not
790/// associated with an established connection. The purpose is to suppress
791/// logging of 'expected' errors (e.g. junk data sent to the UDP socket) to
792/// prevent DoS.
793fn initial_packet_error_type(
794    e: &io::Error,
795) -> labels::QuicInvalidInitialPacketError {
796    Some(e)
797        .filter(|e| e.kind() == io::ErrorKind::Other)
798        .and_then(io::Error::get_ref)
799        .and_then(|e| e.downcast_ref())
800        .map_or(
801            labels::QuicInvalidInitialPacketError::Unexpected,
802            Clone::clone,
803        )
804}
805
806/// An [`InitialPacketHandler`] handles unknown quic initials and processes
807/// them; generally accepting new connections (acting as a server), or
808/// establishing a connection to a server (acting as a client). An
809/// [`InboundPacketRouter`] holds an instance of this trait and routes
810/// [`Incoming`] packets to it when it receives initials.
811///
812/// The handler produces [`quiche::Connection`]s which are then turned into
813/// [`QuicConnection`](super::QuicConnection), IoWorker pair.
814pub trait InitialPacketHandler {
815    fn update(&mut self, _ctx: &mut Context<'_>) -> io::Result<()> {
816        Ok(())
817    }
818
819    fn handle_initials(
820        &mut self, incoming: Incoming, hdr: Header<'static>,
821        quiche_config: &mut quiche::Config,
822    ) -> io::Result<Option<NewConnection>>;
823}
824
825/// A [`NewConnection`] describes a new [`quiche::Connection`] that can be
826/// driven by an io worker.
827pub struct NewConnection {
828    /// See [`QuicConnectionParams::quiche_conn`].
829    conn: Box<QuicheConnection>,
830    pending_cid: Option<ConnectionId<'static>>,
831    initial_pkt: Option<Incoming>,
832    cid_generator: Option<SharedConnectionIdGenerator>,
833    /// When the handshake started. Should be called before [`quiche::accept`]
834    /// or [`quiche::connect`].
835    handshake_start_time: Instant,
836}
837
838// TODO: the router module is private so we can't move these to /tests
839// TODO: Rewrite tests to be Windows compatible
840#[cfg(all(test, unix))]
841mod tests {
842    use super::acceptor::ConnectionAcceptor;
843    use super::acceptor::ConnectionAcceptorConfig;
844    use super::*;
845
846    use crate::http3::settings::Http3Settings;
847    use crate::metrics::DefaultMetrics;
848    use crate::quic::connection::SimpleConnectionIdGenerator;
849    use crate::settings::Config;
850    use crate::settings::Hooks;
851    use crate::settings::QuicSettings;
852    use crate::settings::TlsCertificatePaths;
853    use crate::socket::SocketCapabilities;
854    use crate::ConnectionParams;
855    use crate::ServerH3Driver;
856
857    use datagram_socket::MAX_DATAGRAM_SIZE;
858    use h3i::actions::h3::Action;
859    use std::sync::Arc;
860    use std::time::Duration;
861    use tokio::net::UdpSocket;
862    use tokio::time;
863
864    const TEST_CERT_FILE: &str = concat!(
865        env!("CARGO_MANIFEST_DIR"),
866        "/",
867        "../quiche/examples/cert.crt"
868    );
869    const TEST_KEY_FILE: &str = concat!(
870        env!("CARGO_MANIFEST_DIR"),
871        "/",
872        "../quiche/examples/cert.key"
873    );
874
875    fn test_connect(host_port: String) {
876        let h3i_config = h3i::config::Config::new()
877            .with_host_port("test.com".to_string())
878            .with_idle_timeout(2000)
879            .with_connect_to(host_port)
880            .verify_peer(false)
881            .build()
882            .unwrap();
883
884        let conn_close = h3i::quiche::ConnectionError {
885            is_app: true,
886            error_code: h3i::quiche::WireErrorCode::NoError as _,
887            reason: Vec::new(),
888        };
889        let actions = vec![Action::ConnectionClose { error: conn_close }];
890
891        let _ = h3i::client::sync_client::connect(h3i_config, actions, None);
892    }
893
894    #[tokio::test]
895    async fn test_timeout() {
896        // Configure a short idle timeout to speed up connection reclamation as
897        // quiche doesn't support time mocking
898        let quic_settings = QuicSettings {
899            max_idle_timeout: Some(Duration::from_millis(1)),
900            max_recv_udp_payload_size: MAX_DATAGRAM_SIZE,
901            max_send_udp_payload_size: MAX_DATAGRAM_SIZE,
902            ..Default::default()
903        };
904
905        let tls_cert_settings = TlsCertificatePaths {
906            cert: TEST_CERT_FILE,
907            private_key: TEST_KEY_FILE,
908            kind: crate::settings::CertificateKind::X509,
909        };
910
911        let params = ConnectionParams::new_server(
912            quic_settings,
913            tls_cert_settings,
914            Hooks::default(),
915        );
916        let config = Config::new(&params, SocketCapabilities::default()).unwrap();
917
918        let socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
919        let local_addr = socket.local_addr().unwrap();
920        let host_port = local_addr.to_string();
921        let socket_tx = Arc::new(socket);
922        let socket_rx = Arc::clone(&socket_tx);
923
924        let acceptor = ConnectionAcceptor::new(
925            ConnectionAcceptorConfig {
926                disable_client_ip_validation: config.disable_client_ip_validation,
927                qlog_dir: config.qlog_dir.clone(),
928                qlog_compression: config.qlog_compression,
929                keylog_file: config
930                    .keylog_file
931                    .as_ref()
932                    .and_then(|f| f.try_clone().ok()),
933                #[cfg(target_os = "linux")]
934                with_pktinfo: false,
935            },
936            Arc::clone(&socket_tx),
937            Default::default(),
938            Arc::new(SimpleConnectionIdGenerator),
939            DefaultMetrics,
940        );
941
942        let (socket_driver, mut incoming) = InboundPacketRouter::new(
943            config,
944            socket_tx,
945            socket_rx,
946            local_addr,
947            acceptor,
948            DefaultMetrics,
949        );
950        tokio::spawn(socket_driver);
951
952        // Start a request and drop it after connection establishment
953        std::thread::spawn(move || test_connect(host_port));
954
955        // Wait for a new connection
956        time::pause();
957
958        let (h3_driver, _) = ServerH3Driver::new(Http3Settings::default());
959        let conn = incoming.recv().await.unwrap().unwrap();
960        let drop_check = conn.incoming_ev_sender.clone();
961        let _conn = conn.start(h3_driver);
962
963        // Poll the incoming until the connection is dropped
964        time::advance(Duration::new(30, 0)).await;
965        time::resume();
966
967        // NOTE: this is a smoke test - in case of issues `notified()` future will
968        // never resolve hanging the test.
969        drop_check.closed().await;
970    }
971}