tokio_quiche/quic/router/
mod.rs

1// Copyright (C) 2025, Cloudflare, Inc.
2// All rights reserved.
3//
4// Redistribution and use in source and binary forms, with or without
5// modification, are permitted provided that the following conditions are
6// met:
7//
8//     * Redistributions of source code must retain the above copyright notice,
9//       this list of conditions and the following disclaimer.
10//
11//     * Redistributions in binary form must reproduce the above copyright
12//       notice, this list of conditions and the following disclaimer in the
13//       documentation and/or other materials provided with the distribution.
14//
15// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
16// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
17// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
19// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
27pub(crate) mod acceptor;
28pub(crate) mod connector;
29
30use super::connection::ConnectionMap;
31use super::connection::HandshakeInfo;
32use super::connection::Incoming;
33use super::connection::InitialQuicConnection;
34use super::connection::QuicConnectionParams;
35use super::io::worker::WriterConfig;
36use super::QuicheConnection;
37use crate::buf_factory::BufFactory;
38use crate::buf_factory::PooledBuf;
39use crate::metrics::labels;
40use crate::metrics::quic_expensive_metrics_ip_reduce;
41use crate::metrics::Metrics;
42use crate::settings::Config;
43use datagram_socket::DatagramSocketRecv;
44use datagram_socket::DatagramSocketSend;
45use foundations::telemetry::log;
46use quiche::ConnectionId;
47use quiche::Header;
48use quiche::MAX_CONN_ID_LEN;
49use std::default::Default;
50use std::future::Future;
51use std::io;
52use std::net::SocketAddr;
53use std::pin::Pin;
54use std::sync::Arc;
55use std::task::ready;
56use std::task::Context;
57use std::task::Poll;
58use std::time::Instant;
59use std::time::SystemTime;
60use task_killswitch::spawn_with_killswitch;
61use tokio::sync::mpsc;
62
63#[cfg(target_os = "linux")]
64use foundations::telemetry::metrics::Counter;
65#[cfg(target_os = "linux")]
66use foundations::telemetry::metrics::TimeHistogram;
67#[cfg(target_os = "linux")]
68use libc::sockaddr_in;
69#[cfg(target_os = "linux")]
70use libc::sockaddr_in6;
71
72type ConnStream<Tx, M> = mpsc::Receiver<io::Result<InitialQuicConnection<Tx, M>>>;
73
74#[cfg(feature = "perf-quic-listener-metrics")]
75mod listener_stage_timer {
76    use foundations::telemetry::metrics::TimeHistogram;
77    use std::time::Instant;
78
79    pub(super) struct ListenerStageTimer {
80        start: Instant,
81        time_hist: TimeHistogram,
82    }
83
84    impl ListenerStageTimer {
85        pub(super) fn new(
86            start: Instant, time_hist: TimeHistogram,
87        ) -> ListenerStageTimer {
88            ListenerStageTimer { start, time_hist }
89        }
90    }
91
92    impl Drop for ListenerStageTimer {
93        fn drop(&mut self) {
94            self.time_hist
95                .observe((Instant::now() - self.start).as_nanos() as u64);
96        }
97    }
98}
99
100#[derive(Debug)]
101struct PollRecvData {
102    bytes: usize,
103    // The packet's source, e.g., the peer's address
104    src_addr: SocketAddr,
105    // The packet's original destination. If the original destination is
106    // different from the local listening address, this will be `None`.
107    dst_addr_override: Option<SocketAddr>,
108    rx_time: Option<SystemTime>,
109    gro: Option<i32>,
110    #[cfg(target_os = "linux")]
111    so_mark_data: Option<[u8; 4]>,
112}
113
114/// A message to the listener notifiying a mapping for a connection should be
115/// removed.
116pub enum ConnectionMapCommand {
117    UnmapCid(ConnectionId<'static>),
118    RemoveScid(ConnectionId<'static>),
119}
120
121/// An `InboundPacketRouter` maintains a map of quic connections and routes
122/// [`Incoming`] packets from the [recv half][rh] of a datagram socket to those
123/// connections or some quic initials handler.
124///
125/// [rh]: datagram_socket::DatagramSocketRecv
126///
127/// When a packet (or batch of packets) is received, the router will either
128/// route those packets to an established
129/// [`QuicConnection`](super::QuicConnection) or have a them handled by a
130/// `InitialPacketHandler` which either acts as a quic listener or
131/// quic connector, a server or client respectively.
132///
133/// If you only have a single connection, or if you need more control over the
134/// socket, use `QuicConnection` directly instead.
135pub struct InboundPacketRouter<Tx, Rx, M, I>
136where
137    Tx: DatagramSocketSend + Send + 'static,
138    M: Metrics,
139{
140    socket_tx: Arc<Tx>,
141    socket_rx: Rx,
142    local_addr: SocketAddr,
143    config: Config,
144    conns: ConnectionMap,
145    incoming_packet_handler: I,
146    shutdown_tx: Option<mpsc::Sender<()>>,
147    shutdown_rx: mpsc::Receiver<()>,
148    conn_map_cmd_tx: mpsc::UnboundedSender<ConnectionMapCommand>,
149    conn_map_cmd_rx: mpsc::UnboundedReceiver<ConnectionMapCommand>,
150    accept_sink: mpsc::Sender<io::Result<InitialQuicConnection<Tx, M>>>,
151    metrics: M,
152    #[cfg(target_os = "linux")]
153    udp_drop_count: u32,
154
155    #[cfg(target_os = "linux")]
156    reusable_cmsg_space: Vec<u8>,
157
158    current_buf: PooledBuf,
159
160    // We keep the metrics in here, to avoid cloning them each packet
161    #[cfg(target_os = "linux")]
162    metrics_handshake_time_seconds: TimeHistogram,
163    #[cfg(target_os = "linux")]
164    metrics_udp_drop_count: Counter,
165}
166
167impl<Tx, Rx, M, I> InboundPacketRouter<Tx, Rx, M, I>
168where
169    Tx: DatagramSocketSend + Send + 'static,
170    Rx: DatagramSocketRecv,
171    M: Metrics,
172    I: InitialPacketHandler,
173{
174    pub(crate) fn new(
175        config: Config, socket_tx: Arc<Tx>, socket_rx: Rx,
176        local_addr: SocketAddr, incoming_packet_handler: I, metrics: M,
177    ) -> (Self, ConnStream<Tx, M>) {
178        let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
179        let (accept_sink, accept_stream) = mpsc::channel(config.listen_backlog);
180        let (conn_map_cmd_tx, conn_map_cmd_rx) = mpsc::unbounded_channel();
181
182        (
183            InboundPacketRouter {
184                local_addr,
185                socket_tx,
186                socket_rx,
187                conns: ConnectionMap::default(),
188                incoming_packet_handler,
189                shutdown_tx: Some(shutdown_tx),
190                shutdown_rx,
191                conn_map_cmd_tx,
192                conn_map_cmd_rx,
193                accept_sink,
194                #[cfg(target_os = "linux")]
195                udp_drop_count: 0,
196                #[cfg(target_os = "linux")]
197                // Specify CMSG space. Even if they're not all currently used, the cmsg buffer may
198                // have been configured by a previous version of Tokio-Quiche with the socket
199                // re-used on graceful restart. As such, this vector should _only grow_, and care
200                // should be taken when adding new cmsgs.
201                reusable_cmsg_space: nix::cmsg_space!(
202                    u32, // GRO
203                    nix::sys::time::TimeSpec, // timestamp
204                    u16, // drop count
205                    sockaddr_in, // IP_RECVORIGDSTADDR
206                    sockaddr_in6, // IPV6_RECVORIGDSTADDR
207                    u32 // SO_MARK
208                ),
209                config,
210
211                current_buf: BufFactory::get_max_buf(),
212
213                #[cfg(target_os = "linux")]
214                metrics_handshake_time_seconds: metrics.handshake_time_seconds(labels::QuicHandshakeStage::QueueWaiting),
215                #[cfg(target_os = "linux")]
216                metrics_udp_drop_count: metrics.udp_drop_count(),
217
218                metrics,
219
220            },
221            accept_stream,
222        )
223    }
224
225    fn on_incoming(&mut self, mut incoming: Incoming) -> io::Result<()> {
226        #[cfg(feature = "perf-quic-listener-metrics")]
227        let start = std::time::Instant::now();
228
229        if let Some(dcid) = short_dcid(&incoming.buf) {
230            if let Some(ev_sender) = self.conns.get(&dcid) {
231                let _ = ev_sender.try_send(incoming);
232                return Ok(());
233            }
234        }
235
236        let hdr = Header::from_slice(&mut incoming.buf, MAX_CONN_ID_LEN)
237            .map_err(|e| match e {
238                quiche::Error::BufferTooShort | quiche::Error::InvalidPacket =>
239                    labels::QuicInvalidInitialPacketError::FailedToParse.into(),
240                e => io::Error::other(e),
241            })?;
242
243        if let Some(ev_sender) = self.conns.get(&hdr.dcid) {
244            let _ = ev_sender.try_send(incoming);
245            return Ok(());
246        }
247
248        #[cfg(feature = "perf-quic-listener-metrics")]
249        let _timer = listener_stage_timer::ListenerStageTimer::new(
250            start,
251            self.metrics.handshake_time_seconds(
252                labels::QuicHandshakeStage::HandshakeProtocol,
253            ),
254        );
255
256        if self.shutdown_tx.is_none() {
257            return Ok(());
258        }
259
260        let local_addr = incoming.local_addr;
261        let peer_addr = incoming.peer_addr;
262
263        #[cfg(feature = "perf-quic-listener-metrics")]
264        let init_rx_time = incoming.rx_time;
265
266        let new_connection = self.incoming_packet_handler.handle_initials(
267            incoming,
268            hdr,
269            self.config.as_mut(),
270        )?;
271
272        match new_connection {
273            Some(new_connection) => self.spawn_new_connection(
274                new_connection,
275                local_addr,
276                peer_addr,
277                #[cfg(feature = "perf-quic-listener-metrics")]
278                init_rx_time,
279            ),
280            None => Ok(()),
281        }
282    }
283
284    /// Creates a new [`QuicConnection`](super::QuicConnection) and spawns an
285    /// associated io worker.
286    fn spawn_new_connection(
287        &mut self, new_connection: NewConnection, local_addr: SocketAddr,
288        peer_addr: SocketAddr,
289        #[cfg(feature = "perf-quic-listener-metrics")] init_rx_time: Option<
290            SystemTime,
291        >,
292    ) -> io::Result<()> {
293        let NewConnection {
294            conn,
295            pending_cid,
296            handshake_start_time,
297            initial_pkt,
298        } = new_connection;
299
300        let Some(ref shutdown_tx) = self.shutdown_tx else {
301            // don't create new connections if we're shutting down.
302            return Ok(());
303        };
304        let Ok(send_permit) = self.accept_sink.try_reserve() else {
305            // drop the connection if the backlog is full. the client will retry.
306            return Err(
307                labels::QuicInvalidInitialPacketError::AcceptQueueOverflow.into(),
308            );
309        };
310
311        let scid = conn.source_id().into_owned();
312        let writer_cfg = WriterConfig {
313            peer_addr,
314            local_addr,
315            pending_cid: pending_cid.clone(),
316            with_gso: self.config.has_gso,
317            pacing_offload: self.config.pacing_offload,
318            with_pktinfo: if self.local_addr.is_ipv4() {
319                self.config.has_ippktinfo
320            } else {
321                self.config.has_ipv6pktinfo
322            },
323        };
324
325        let handshake_info = HandshakeInfo::new(
326            handshake_start_time,
327            self.config.handshake_timeout,
328        );
329
330        let conn = InitialQuicConnection::new(QuicConnectionParams {
331            writer_cfg,
332            initial_pkt,
333            shutdown_tx: shutdown_tx.clone(),
334            conn_map_cmd_tx: self.conn_map_cmd_tx.clone(),
335            scid: scid.clone(),
336            metrics: self.metrics.clone(),
337            #[cfg(feature = "perf-quic-listener-metrics")]
338            init_rx_time,
339            handshake_info,
340            quiche_conn: conn,
341            socket: Arc::clone(&self.socket_tx),
342            local_addr,
343            peer_addr,
344        });
345
346        conn.audit_log_stats
347            .set_transport_handshake_start(instant_to_system(
348                handshake_start_time,
349            ));
350
351        self.conns.insert(scid, &conn);
352
353        // Add the client-generated "pending" connection ID to the map as well.
354        //
355        // This is only required when client address validation is disabled.
356        // When validation is enabled, the client is already using the
357        // server-generated connection ID by the time we get here.
358        if let Some(pending_cid) = pending_cid {
359            self.conns.map_cid(pending_cid, &conn);
360        }
361
362        self.metrics.accepted_initial_packet_count().inc();
363        if self.config.enable_expensive_packet_count_metrics {
364            if let Some(peer_ip) =
365                quic_expensive_metrics_ip_reduce(conn.peer_addr().ip())
366            {
367                self.metrics
368                    .expensive_accepted_initial_packet_count(peer_ip)
369                    .inc();
370            }
371        }
372
373        send_permit.send(Ok(conn));
374        Ok(())
375    }
376}
377
378impl<Tx, Rx, M, I> InboundPacketRouter<Tx, Rx, M, I>
379where
380    Tx: DatagramSocketSend + Send + Sync + 'static,
381    Rx: DatagramSocketRecv,
382    M: Metrics,
383    I: InitialPacketHandler,
384{
385    /// [`InboundPacketRouter::poll_recv_from`] should be used if the underlying
386    /// system or socket does not support rx_time nor GRO.
387    fn poll_recv_from(
388        &mut self, cx: &mut Context<'_>,
389    ) -> Poll<io::Result<PollRecvData>> {
390        let mut buf = tokio::io::ReadBuf::new(&mut self.current_buf);
391        let addr = ready!(self.socket_rx.poll_recv_from(cx, &mut buf))?;
392        Poll::Ready(Ok(PollRecvData {
393            bytes: buf.filled().len(),
394            src_addr: addr,
395            rx_time: None,
396            gro: None,
397            dst_addr_override: None,
398            #[cfg(target_os = "linux")]
399            so_mark_data: None,
400        }))
401    }
402
403    fn poll_recv_and_rx_time(
404        &mut self, cx: &mut Context<'_>,
405    ) -> Poll<io::Result<PollRecvData>> {
406        #[cfg(not(target_os = "linux"))]
407        {
408            self.poll_recv_from(cx)
409        }
410
411        #[cfg(target_os = "linux")]
412        {
413            use libc::SOL_SOCKET;
414            use libc::SO_MARK;
415            use nix::errno::Errno;
416            use nix::sys::socket::*;
417            use std::net::SocketAddrV4;
418            use std::net::SocketAddrV6;
419            use std::os::fd::AsRawFd;
420            use tokio::io::Interest;
421
422            let Some(udp_socket) = self.socket_rx.as_udp_socket() else {
423                // the given socket is not a UDP socket, fall back to the
424                // simple poll_recv_from.
425                return self.poll_recv_from(cx);
426            };
427
428            loop {
429                let iov_s = &mut [io::IoSliceMut::new(&mut self.current_buf)];
430                match udp_socket.try_io(Interest::READABLE, || {
431                    recvmsg::<SockaddrStorage>(
432                        udp_socket.as_raw_fd(),
433                        iov_s,
434                        Some(&mut self.reusable_cmsg_space),
435                        MsgFlags::empty(),
436                    )
437                    .map_err(|x| x.into())
438                }) {
439                    Ok(r) => {
440                        let bytes = r.bytes;
441
442                        let address = match r.address {
443                            Some(inner) => inner,
444                            _ => return Poll::Ready(Err(Errno::EINVAL.into())),
445                        };
446
447                        let peer_addr = match address.family() {
448                            Some(AddressFamily::Inet) => SocketAddrV4::from(
449                                *address.as_sockaddr_in().unwrap(),
450                            )
451                            .into(),
452                            Some(AddressFamily::Inet6) => SocketAddrV6::from(
453                                *address.as_sockaddr_in6().unwrap(),
454                            )
455                            .into(),
456                            _ => {
457                                return Poll::Ready(Err(Errno::EINVAL.into()));
458                            },
459                        };
460
461                        let mut rx_time = None;
462                        let mut gro = None;
463                        let mut dst_addr_override = None;
464                        let mut mark_bytes: Option<[u8; 4]> = None;
465
466                        let Ok(cmsgs) = r.cmsgs() else {
467                            // Best-effort if we can't read cmsgs.
468                            return Poll::Ready(Ok(PollRecvData {
469                                bytes,
470                                src_addr: peer_addr,
471                                dst_addr_override,
472                                rx_time,
473                                gro,
474                                so_mark_data: mark_bytes,
475                            }));
476                        };
477
478                        for cmsg in cmsgs {
479                            match cmsg {
480                                ControlMessageOwned::RxqOvfl(c) => {
481                                    if c != self.udp_drop_count {
482                                        self.metrics_udp_drop_count.inc_by(
483                                            (c - self.udp_drop_count) as u64,
484                                        );
485                                        self.udp_drop_count = c;
486                                    }
487                                },
488                                ControlMessageOwned::ScmTimestampns(val) => {
489                                    rx_time = SystemTime::UNIX_EPOCH
490                                        .checked_add(val.into());
491                                    if let Some(delta) =
492                                        rx_time.and_then(|rx_time| {
493                                            rx_time.elapsed().ok()
494                                        })
495                                    {
496                                        self.metrics_handshake_time_seconds
497                                            .observe(delta.as_nanos() as u64);
498                                    }
499                                },
500                                ControlMessageOwned::UdpGroSegments(val) =>
501                                    gro = Some(val),
502                                ControlMessageOwned::Ipv4OrigDstAddr(val) => {
503                                    let source_addr = std::net::Ipv4Addr::from(
504                                        u32::to_be(val.sin_addr.s_addr),
505                                    );
506                                    let source_port = u16::to_be(val.sin_port);
507
508                                    let parsed_addr =
509                                        SocketAddr::V4(SocketAddrV4::new(
510                                            source_addr,
511                                            source_port,
512                                        ));
513
514                                    dst_addr_override = resolve_dst_addr(
515                                        &self.local_addr,
516                                        &parsed_addr,
517                                    );
518                                },
519                                ControlMessageOwned::Ipv6OrigDstAddr(val) => {
520                                    // Don't have to flip IPv6 bytes since it's a
521                                    // byte array, not a
522                                    // series of bytes parsed as a u32 as in the
523                                    // IPv4 case
524                                    let source_addr = std::net::Ipv6Addr::from(
525                                        val.sin6_addr.s6_addr,
526                                    );
527                                    let source_port = u16::to_be(val.sin6_port);
528                                    let source_flowinfo =
529                                        u32::to_be(val.sin6_flowinfo);
530                                    let source_scope =
531                                        u32::to_be(val.sin6_scope_id);
532
533                                    let parsed_addr =
534                                        SocketAddr::V6(SocketAddrV6::new(
535                                            source_addr,
536                                            source_port,
537                                            source_flowinfo,
538                                            source_scope,
539                                        ));
540
541                                    dst_addr_override = resolve_dst_addr(
542                                        &self.local_addr,
543                                        &parsed_addr,
544                                    );
545                                },
546                                ControlMessageOwned::Ipv4PacketInfo(_) |
547                                ControlMessageOwned::Ipv6PacketInfo(_) => {
548                                    // We only want the destination address from
549                                    // IP_RECVORIGDSTADDR, but we'll get these
550                                    // messages because we set IP_PKTINFO on the
551                                    // socket.
552                                },
553                                ControlMessageOwned::Unknown(raw_cmsg) => {
554                                    let UnknownCmsg {
555                                        cmsg_header,
556                                        data_bytes,
557                                    } = raw_cmsg;
558
559                                    if cmsg_header.cmsg_level == SOL_SOCKET &&
560                                        cmsg_header.cmsg_type == SO_MARK
561                                    {
562                                        let Ok(arr) =
563                                            <[u8; 4]>::try_from(data_bytes)
564                                        else {
565                                            // Should be unreachable as SO_MARK is
566                                            // a u32: https://elixir.bootlin.com/linux/v6.17/source/include/net/sock.h#L487
567                                            continue;
568                                        };
569
570                                        let _ = mark_bytes.insert(arr);
571                                    }
572                                },
573                                _ => {
574                                    // Unrecognized cmsg received, just ignore
575                                    // it.
576                                },
577                            };
578                        }
579
580                        return Poll::Ready(Ok(PollRecvData {
581                            bytes,
582                            src_addr: peer_addr,
583                            dst_addr_override,
584                            rx_time,
585                            gro,
586                            so_mark_data: mark_bytes,
587                        }));
588                    },
589                    Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
590                        // NOTE: we manually poll the socket here to register
591                        // interest in the socket to become
592                        // writable for the given `cx`. Under the hood, tokio's
593                        // implementation just checks for
594                        // EWOULDBLOCK and if socket is busy registers provided
595                        // waker to be invoked when the
596                        // socket is free and consequently drive the event loop.
597                        ready!(udp_socket.poll_recv_ready(cx))?
598                    },
599                    Err(e) => return Poll::Ready(Err(e)),
600                }
601            }
602        }
603    }
604
605    fn handle_conn_map_commands(&mut self) {
606        while let Ok(req) = self.conn_map_cmd_rx.try_recv() {
607            match req {
608                ConnectionMapCommand::UnmapCid(cid) => self.conns.unmap_cid(&cid),
609                ConnectionMapCommand::RemoveScid(scid) =>
610                    self.conns.remove(&scid),
611            }
612        }
613    }
614}
615
616// Quickly extract the connection id of a short quic packet without allocating
617fn short_dcid(buf: &[u8]) -> Option<ConnectionId<'_>> {
618    let is_short_dcid = buf.first()? >> 7 == 0;
619
620    if is_short_dcid {
621        buf.get(1..1 + MAX_CONN_ID_LEN).map(ConnectionId::from_ref)
622    } else {
623        None
624    }
625}
626
627/// Converts an [`Instant`] to a [`SystemTime`], based on the current delta
628/// between both clocks.
629fn instant_to_system(ts: Instant) -> SystemTime {
630    let now = Instant::now();
631    let system_now = SystemTime::now();
632    if let Some(delta) = now.checked_duration_since(ts) {
633        return system_now - delta;
634    }
635
636    let delta = ts.checked_duration_since(now).expect("now < ts");
637    system_now + delta
638}
639
640/// Determine if we should store the destination address for a packet, based on
641/// an address parsed from a
642/// [`ControlMessageOwned`](nix::sys::socket::ControlMessageOwned).
643///
644/// This is to prevent overriding the destination address if the packet was
645/// originally addressed to `local`, as that would cause us to incorrectly
646/// address packets when sending.
647///
648/// Returns the parsed address if it should be stored.
649#[cfg(target_os = "linux")]
650fn resolve_dst_addr(
651    local: &SocketAddr, parsed: &SocketAddr,
652) -> Option<SocketAddr> {
653    if local != parsed {
654        return Some(*parsed);
655    }
656
657    None
658}
659
660impl<Tx, Rx, M, I> Future for InboundPacketRouter<Tx, Rx, M, I>
661where
662    Tx: DatagramSocketSend + Send + Sync + 'static,
663    Rx: DatagramSocketRecv + Unpin,
664    M: Metrics,
665    I: InitialPacketHandler + Unpin,
666{
667    type Output = io::Result<()>;
668
669    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
670        let server_addr = self.local_addr;
671
672        loop {
673            if let Err(error) = self.incoming_packet_handler.update(cx) {
674                // This is so rare that it's easier to spawn a separate task
675                let sender = self.accept_sink.clone();
676                spawn_with_killswitch(async move {
677                    let _ = sender.send(Err(error)).await;
678                });
679            }
680
681            match self.poll_recv_and_rx_time(cx) {
682                Poll::Ready(Ok(PollRecvData {
683                    bytes,
684                    src_addr: peer_addr,
685                    dst_addr_override,
686                    rx_time,
687                    gro,
688                    #[cfg(target_os = "linux")]
689                    so_mark_data,
690                })) => {
691                    let mut buf = std::mem::replace(
692                        &mut self.current_buf,
693                        BufFactory::get_max_buf(),
694                    );
695                    buf.truncate(bytes);
696
697                    let send_from = if let Some(dst_addr) = dst_addr_override {
698                        log::trace!("overriding local address"; "actual_local" => dst_addr, "configured_local" => server_addr);
699                        dst_addr
700                    } else {
701                        server_addr
702                    };
703
704                    let res = self.on_incoming(Incoming {
705                        peer_addr,
706                        local_addr: send_from,
707                        buf,
708                        rx_time,
709                        gro,
710                        #[cfg(target_os = "linux")]
711                        so_mark_data,
712                    });
713
714                    if let Err(e) = res {
715                        let err_type = initial_packet_error_type(&e);
716                        self.metrics
717                            .rejected_initial_packet_count(err_type.clone())
718                            .inc();
719
720                        if self.config.enable_expensive_packet_count_metrics {
721                            if let Some(peer_ip) =
722                                quic_expensive_metrics_ip_reduce(peer_addr.ip())
723                            {
724                                self.metrics
725                                    .expensive_rejected_initial_packet_count(
726                                        err_type.clone(),
727                                        peer_ip,
728                                    )
729                                    .inc();
730                            }
731                        }
732
733                        if matches!(
734                            err_type,
735                            labels::QuicInvalidInitialPacketError::Unexpected
736                        ) {
737                            // don't block packet routing on errors
738                            let _ = self.accept_sink.try_send(Err(e));
739                        }
740                    }
741                },
742
743                Poll::Ready(Err(e)) => {
744                    log::error!("Incoming packet router encountered recvmsg error"; "error" => e);
745                    continue;
746                },
747
748                Poll::Pending => {
749                    // Check whether any connections are still active
750                    if self.shutdown_tx.is_some() && self.accept_sink.is_closed()
751                    {
752                        self.shutdown_tx = None;
753                    }
754
755                    if self.shutdown_rx.poll_recv(cx).is_ready() {
756                        return Poll::Ready(Ok(()));
757                    }
758
759                    // Process any incoming connection map signals and handle them
760                    self.handle_conn_map_commands();
761
762                    return Poll::Pending;
763                },
764            }
765        }
766    }
767}
768
769/// Categorizes errors that are returned when handling packets which are not
770/// associated with an established connection. The purpose is to suppress
771/// logging of 'expected' errors (e.g. junk data sent to the UDP socket) to
772/// prevent DoS.
773fn initial_packet_error_type(
774    e: &io::Error,
775) -> labels::QuicInvalidInitialPacketError {
776    Some(e)
777        .filter(|e| e.kind() == io::ErrorKind::Other)
778        .and_then(io::Error::get_ref)
779        .and_then(|e| e.downcast_ref())
780        .map_or(
781            labels::QuicInvalidInitialPacketError::Unexpected,
782            Clone::clone,
783        )
784}
785
786/// An [`InitialPacketHandler`] handles unknown quic initials and processes
787/// them; generally accepting new connections (acting as a server), or
788/// establishing a connection to a server (acting as a client). An
789/// [`InboundPacketRouter`] holds an instance of this trait and routes
790/// [`Incoming`] packets to it when it receives initials.
791///
792/// The handler produces [`quiche::Connection`]s which are then turned into
793/// [`QuicConnection`](super::QuicConnection), IoWorker pair.
794pub trait InitialPacketHandler {
795    fn update(&mut self, _ctx: &mut Context<'_>) -> io::Result<()> {
796        Ok(())
797    }
798
799    fn handle_initials(
800        &mut self, incoming: Incoming, hdr: Header<'static>,
801        quiche_config: &mut quiche::Config,
802    ) -> io::Result<Option<NewConnection>>;
803}
804
805/// A [`NewConnection`] describes a new [`quiche::Connection`] that can be
806/// driven by an io worker.
807pub struct NewConnection {
808    conn: QuicheConnection,
809    pending_cid: Option<ConnectionId<'static>>,
810    initial_pkt: Option<Incoming>,
811    /// When the handshake started. Should be called before [`quiche::accept`]
812    /// or [`quiche::connect`].
813    handshake_start_time: Instant,
814}
815
816// TODO: the router module is private so we can't move these to /tests
817// TODO: Rewrite tests to be Windows compatible
818#[cfg(all(test, unix))]
819mod tests {
820    use super::acceptor::ConnectionAcceptor;
821    use super::acceptor::ConnectionAcceptorConfig;
822    use super::*;
823
824    use crate::http3::settings::Http3Settings;
825    use crate::metrics::DefaultMetrics;
826    use crate::quic::connection::SimpleConnectionIdGenerator;
827    use crate::settings::Config;
828    use crate::settings::Hooks;
829    use crate::settings::QuicSettings;
830    use crate::settings::TlsCertificatePaths;
831    use crate::socket::SocketCapabilities;
832    use crate::ConnectionParams;
833    use crate::ServerH3Driver;
834
835    use datagram_socket::MAX_DATAGRAM_SIZE;
836    use h3i::actions::h3::Action;
837    use std::sync::Arc;
838    use std::time::Duration;
839    use tokio::net::UdpSocket;
840    use tokio::time;
841
842    const TEST_CERT_FILE: &str = concat!(
843        env!("CARGO_MANIFEST_DIR"),
844        "/",
845        "../quiche/examples/cert.crt"
846    );
847    const TEST_KEY_FILE: &str = concat!(
848        env!("CARGO_MANIFEST_DIR"),
849        "/",
850        "../quiche/examples/cert.key"
851    );
852
853    fn test_connect(host_port: String) {
854        let h3i_config = h3i::config::Config::new()
855            .with_host_port("test.com".to_string())
856            .with_idle_timeout(2000)
857            .with_connect_to(host_port)
858            .verify_peer(false)
859            .build()
860            .unwrap();
861
862        let conn_close = h3i::quiche::ConnectionError {
863            is_app: true,
864            error_code: h3i::quiche::WireErrorCode::NoError as _,
865            reason: Vec::new(),
866        };
867        let actions = vec![Action::ConnectionClose { error: conn_close }];
868
869        let _ = h3i::client::sync_client::connect(h3i_config, actions, None);
870    }
871
872    #[tokio::test]
873    async fn test_timeout() {
874        // Configure a short idle timeout to speed up connection reclamation as
875        // quiche doesn't support time mocking
876        let quic_settings = QuicSettings {
877            max_idle_timeout: Some(Duration::from_millis(1)),
878            max_recv_udp_payload_size: MAX_DATAGRAM_SIZE,
879            max_send_udp_payload_size: MAX_DATAGRAM_SIZE,
880            ..Default::default()
881        };
882
883        let tls_cert_settings = TlsCertificatePaths {
884            cert: TEST_CERT_FILE,
885            private_key: TEST_KEY_FILE,
886            kind: crate::settings::CertificateKind::X509,
887        };
888
889        let params = ConnectionParams::new_server(
890            quic_settings,
891            tls_cert_settings,
892            Hooks::default(),
893        );
894        let config = Config::new(&params, SocketCapabilities::default()).unwrap();
895
896        let socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
897        let local_addr = socket.local_addr().unwrap();
898        let host_port = local_addr.to_string();
899        let socket_tx = Arc::new(socket);
900        let socket_rx = Arc::clone(&socket_tx);
901
902        let acceptor = ConnectionAcceptor::new(
903            ConnectionAcceptorConfig {
904                disable_client_ip_validation: config.disable_client_ip_validation,
905                qlog_dir: config.qlog_dir.clone(),
906                keylog_file: config
907                    .keylog_file
908                    .as_ref()
909                    .and_then(|f| f.try_clone().ok()),
910                #[cfg(target_os = "linux")]
911                with_pktinfo: false,
912            },
913            Arc::clone(&socket_tx),
914            0,
915            Default::default(),
916            Box::new(SimpleConnectionIdGenerator),
917            DefaultMetrics,
918        );
919
920        let (socket_driver, mut incoming) = InboundPacketRouter::new(
921            config,
922            socket_tx,
923            socket_rx,
924            local_addr,
925            acceptor,
926            DefaultMetrics,
927        );
928        tokio::spawn(socket_driver);
929
930        // Start a request and drop it after connection establishment
931        std::thread::spawn(move || test_connect(host_port));
932
933        // Wait for a new connection
934        time::pause();
935
936        let (h3_driver, _) = ServerH3Driver::new(Http3Settings::default());
937        let conn = incoming.recv().await.unwrap().unwrap();
938        let drop_check = conn.incoming_ev_sender.clone();
939        let _conn = conn.start(h3_driver);
940
941        // Poll the incoming until the connection is dropped
942        time::advance(Duration::new(30, 0)).await;
943        time::resume();
944
945        // NOTE: this is a smoke test - in case of issues `notified()` future will
946        // never resolve hanging the test.
947        drop_check.closed().await;
948    }
949}