tokio_quiche/quic/router/
connector.rs1use std::io;
28use std::mem;
29use std::sync::Arc;
30use std::task::Context;
31use std::task::Poll;
32use std::time::Instant;
33
34use datagram_socket::DatagramSocketSend;
35use datagram_socket::DatagramSocketSendExt;
36use datagram_socket::MaybeConnectedSocket;
37use datagram_socket::MAX_DATAGRAM_SIZE;
38use foundations::telemetry::log;
39use quiche::ConnectionId;
40use quiche::Header;
41use tokio_util::time::delay_queue::Key;
42use tokio_util::time::DelayQueue;
43
44use crate::quic::router::InitialPacketHandler;
45use crate::quic::router::NewConnection;
46use crate::quic::Incoming;
47use crate::quic::QuicheConnection;
48
49pub(crate) struct ClientConnector<Tx> {
53 socket_tx: MaybeConnectedSocket<Arc<Tx>>,
54 connection: ConnectionState,
55 timeout_queue: DelayQueue<ConnectionId<'static>>,
56}
57
58enum ConnectionState {
60 Queued(QuicheConnection),
62 Pending(PendingConnection),
64 Returned,
67}
68
69impl ConnectionState {
70 fn take_if_queued(&mut self) -> Option<QuicheConnection> {
71 match mem::replace(self, Self::Returned) {
72 Self::Queued(conn) => Some(conn),
73 state => {
74 *self = state;
75 None
76 },
77 }
78 }
79
80 fn take_if_pending_and_id_matches(
81 &mut self, scid: &ConnectionId<'static>,
82 ) -> Option<PendingConnection> {
83 match mem::replace(self, Self::Returned) {
84 Self::Pending(pending) if *scid == pending.conn.source_id() =>
85 Some(pending),
86 state => {
87 *self = state;
88 None
89 },
90 }
91 }
92}
93
94struct PendingConnection {
97 conn: QuicheConnection,
98 timeout_key: Option<Key>,
99 handshake_start_time: Instant,
100}
101
102impl<Tx> ClientConnector<Tx>
103where
104 Tx: DatagramSocketSend + Send + 'static,
105{
106 pub(crate) fn new(socket_tx: Arc<Tx>, connection: QuicheConnection) -> Self {
107 Self {
108 socket_tx: MaybeConnectedSocket::new(socket_tx),
109 connection: ConnectionState::Queued(connection),
110 timeout_queue: Default::default(),
111 }
112 }
113
114 fn set_connection_to_pending(
118 &mut self, mut conn: QuicheConnection,
119 ) -> io::Result<()> {
120 simple_conn_send(&self.socket_tx, &mut conn)?;
121
122 let timeout_key = conn.timeout_instant().map(|instant| {
123 self.timeout_queue
124 .insert_at(conn.source_id().into_owned(), instant.into())
125 });
126
127 self.connection = ConnectionState::Pending(PendingConnection {
128 conn,
129 timeout_key,
130 handshake_start_time: Instant::now(),
131 });
132
133 Ok(())
134 }
135
136 fn on_incoming(
141 &mut self, mut incoming: Incoming, hdr: Header<'static>,
142 ) -> io::Result<Option<NewConnection>> {
143 let Some(PendingConnection {
144 mut conn,
145 timeout_key,
146 handshake_start_time,
147 }) = self.connection.take_if_pending_and_id_matches(&hdr.dcid)
148 else {
149 log::debug!("Received Initial packet for unknown connection ID"; "scid" => ?hdr.dcid);
150 return Ok(None);
151 };
152
153 let recv_info = quiche::RecvInfo {
154 from: incoming.peer_addr,
155 to: incoming.local_addr,
156 };
157
158 if let Some(gro) = incoming.gro {
159 for dgram in incoming.buf.chunks_mut(gro as usize) {
160 let _ = conn.recv(dgram, recv_info);
162 }
163 } else {
164 let _ = conn.recv(&mut incoming.buf, recv_info);
166 }
167
168 if let Some(key) = timeout_key {
171 self.timeout_queue.remove(&key);
172 }
173
174 let scid = conn.source_id();
175 if conn.is_established() {
176 log::debug!("QUIC connection established"; "scid" => ?scid);
177
178 Ok(Some(NewConnection {
179 conn,
180 pending_cid: None,
181 initial_pkt: None,
182 cid_generator: None,
183 handshake_start_time,
184 }))
185 } else if conn.is_closed() {
186 let scid = conn.source_id();
187 log::error!("QUIC connection closed on_incoming"; "scid" => ?scid);
188
189 Err(io::Error::new(
190 io::ErrorKind::TimedOut,
191 format!("connection {scid:?} timed out"),
192 ))
193 } else {
194 self.set_connection_to_pending(conn).map(|()| None)
195 }
196 }
197
198 fn on_timeout(&mut self, scid: ConnectionId<'static>) -> io::Result<()> {
202 log::debug!("connection timedout"; "scid" => ?scid);
203
204 let Some(mut pending) =
205 self.connection.take_if_pending_and_id_matches(&scid)
206 else {
207 log::debug!("timedout connection missing from pending map"; "scid" => ?scid);
208 return Ok(());
209 };
210
211 pending.conn.on_timeout();
212
213 if pending.conn.is_closed() {
214 log::error!("pending connection closed on_timeout"; "scid" => ?scid);
215
216 return Err(io::Error::new(
217 io::ErrorKind::TimedOut,
218 format!("connection {scid:?} timed out"),
219 ));
220 }
221
222 self.set_connection_to_pending(pending.conn)
223 }
224
225 fn update(&mut self, cx: &mut Context) -> io::Result<()> {
228 while let Poll::Ready(Some(expired)) = self.timeout_queue.poll_expired(cx)
229 {
230 let scid = expired.into_inner();
231 self.on_timeout(scid)?;
232 }
233
234 if let Some(conn) = self.connection.take_if_queued() {
235 self.set_connection_to_pending(conn)?;
236 }
237
238 Ok(())
239 }
240}
241
242impl<Tx> InitialPacketHandler for ClientConnector<Tx>
243where
244 Tx: DatagramSocketSend + Send + 'static,
245{
246 fn update(&mut self, ctx: &mut Context<'_>) -> io::Result<()> {
247 ClientConnector::update(self, ctx)
248 }
249
250 fn handle_initials(
251 &mut self, incoming: Incoming, hdr: Header<'static>,
252 _: &mut quiche::Config,
253 ) -> io::Result<Option<NewConnection>> {
254 self.on_incoming(incoming, hdr)
255 }
256}
257
258fn simple_conn_send<Tx: DatagramSocketSend + Send + Sync + 'static>(
264 socket_tx: &MaybeConnectedSocket<Arc<Tx>>, conn: &mut QuicheConnection,
265) -> io::Result<()> {
266 let scid = conn.source_id().into_owned();
267 log::debug!("sending client Initials to peer"; "scid" => ?scid);
268
269 loop {
270 let scid = scid.clone();
271 let mut buf = [0; MAX_DATAGRAM_SIZE];
272 let send_res = conn.send(&mut buf);
273
274 let socket_clone = socket_tx.clone();
275 match send_res {
276 Ok((n, send_info)) => {
277 tokio::spawn({
278 let buf = buf[0..n].to_vec();
279 async move {
280 socket_clone.send_to(&buf, send_info.to).await.inspect_err(|error| {
281 log::error!("error sending client Initial packets to peer"; "scid" => ?scid, "peer_addr" => send_info.to, "error" => error.to_string());
282 })
283 }
284 });
285 },
286 Err(quiche::Error::Done) => break Ok(()),
287 Err(error) => {
288 log::error!("error writing packets to quiche's internal buffer"; "scid" => ?scid, "error" => error.to_string());
289 break Err(std::io::Error::other(error));
290 },
291 }
292 }
293}