tokio_quiche/quic/router/
connector.rs1use std::io;
28use std::mem;
29use std::sync::Arc;
30use std::task::Context;
31use std::task::Poll;
32use std::time::Instant;
33
34use datagram_socket::DatagramSocketSend;
35use datagram_socket::DatagramSocketSendExt;
36use datagram_socket::MaybeConnectedSocket;
37use datagram_socket::MAX_DATAGRAM_SIZE;
38use foundations::telemetry::log;
39use quiche::ConnectionId;
40use quiche::Header;
41use tokio_util::time::delay_queue::Key;
42use tokio_util::time::DelayQueue;
43
44use crate::quic::router::InitialPacketHandler;
45use crate::quic::router::NewConnection;
46use crate::quic::Incoming;
47use crate::quic::QuicheConnection;
48
49pub(crate) struct ClientConnector<Tx> {
53 socket_tx: MaybeConnectedSocket<Arc<Tx>>,
54 connection: ConnectionState,
55 timeout_queue: DelayQueue<ConnectionId<'static>>,
56}
57
58enum ConnectionState {
60 Queued(QuicheConnection),
62 Pending(PendingConnection),
64 Returned,
67}
68
69impl ConnectionState {
70 fn take_if_queued(&mut self) -> Option<QuicheConnection> {
71 match mem::replace(self, Self::Returned) {
72 Self::Queued(conn) => Some(conn),
73 state => {
74 *self = state;
75 None
76 },
77 }
78 }
79
80 fn take_if_pending_and_id_matches(
81 &mut self, scid: &ConnectionId<'static>,
82 ) -> Option<PendingConnection> {
83 match mem::replace(self, Self::Returned) {
84 Self::Pending(pending) if *scid == pending.conn.source_id() =>
85 Some(pending),
86 state => {
87 *self = state;
88 None
89 },
90 }
91 }
92}
93
94struct PendingConnection {
97 conn: QuicheConnection,
98 timeout_key: Option<Key>,
99 handshake_start_time: Instant,
100}
101
102impl<Tx> ClientConnector<Tx>
103where
104 Tx: DatagramSocketSend + Send + 'static,
105{
106 pub(crate) fn new(socket_tx: Arc<Tx>, connection: QuicheConnection) -> Self {
107 Self {
108 socket_tx: MaybeConnectedSocket::new(socket_tx),
109 connection: ConnectionState::Queued(connection),
110 timeout_queue: Default::default(),
111 }
112 }
113
114 fn set_connection_to_pending(
118 &mut self, mut conn: QuicheConnection,
119 ) -> io::Result<()> {
120 simple_conn_send(&self.socket_tx, &mut conn)?;
121
122 let timeout_key = conn.timeout_instant().map(|instant| {
123 self.timeout_queue
124 .insert_at(conn.source_id().into_owned(), instant.into())
125 });
126
127 self.connection = ConnectionState::Pending(PendingConnection {
128 conn,
129 timeout_key,
130 handshake_start_time: Instant::now(),
131 });
132
133 Ok(())
134 }
135
136 fn on_incoming(
141 &mut self, mut incoming: Incoming, hdr: Header<'static>,
142 ) -> io::Result<Option<NewConnection>> {
143 let Some(PendingConnection {
144 mut conn,
145 timeout_key,
146 handshake_start_time,
147 }) = self.connection.take_if_pending_and_id_matches(&hdr.dcid)
148 else {
149 log::debug!("Received Initial packet for unknown connection ID"; "scid" => ?hdr.dcid);
150 return Ok(None);
151 };
152
153 let recv_info = quiche::RecvInfo {
154 from: incoming.peer_addr,
155 to: incoming.local_addr,
156 };
157
158 if let Some(gro) = incoming.gro {
159 for dgram in incoming.buf.chunks_mut(gro as usize) {
160 let _ = conn.recv(dgram, recv_info);
162 }
163 } else {
164 let _ = conn.recv(&mut incoming.buf, recv_info);
166 }
167
168 if let Some(key) = timeout_key {
171 self.timeout_queue.remove(&key);
172 }
173
174 let scid = conn.source_id();
175 if conn.is_established() {
176 log::debug!("QUIC connection established"; "scid" => ?scid);
177
178 Ok(Some(NewConnection {
179 conn,
180 pending_cid: None,
181 initial_pkt: None,
182 handshake_start_time,
183 }))
184 } else if conn.is_closed() {
185 let scid = conn.source_id();
186 log::error!("QUIC connection closed on_incoming"; "scid" => ?scid);
187
188 Err(io::Error::new(
189 io::ErrorKind::TimedOut,
190 format!("connection {scid:?} timed out"),
191 ))
192 } else {
193 self.set_connection_to_pending(conn).map(|()| None)
194 }
195 }
196
197 fn on_timeout(&mut self, scid: ConnectionId<'static>) -> io::Result<()> {
201 log::debug!("connection timedout"; "scid" => ?scid);
202
203 let Some(mut pending) =
204 self.connection.take_if_pending_and_id_matches(&scid)
205 else {
206 log::debug!("timedout connection missing from pending map"; "scid" => ?scid);
207 return Ok(());
208 };
209
210 pending.conn.on_timeout();
211
212 if pending.conn.is_closed() {
213 log::error!("pending connection closed on_timeout"; "scid" => ?scid);
214
215 return Err(io::Error::new(
216 io::ErrorKind::TimedOut,
217 format!("connection {scid:?} timed out"),
218 ));
219 }
220
221 self.set_connection_to_pending(pending.conn)
222 }
223
224 fn update(&mut self, cx: &mut Context) -> io::Result<()> {
227 while let Poll::Ready(Some(expired)) = self.timeout_queue.poll_expired(cx)
228 {
229 let scid = expired.into_inner();
230 self.on_timeout(scid)?;
231 }
232
233 if let Some(conn) = self.connection.take_if_queued() {
234 self.set_connection_to_pending(conn)?;
235 }
236
237 Ok(())
238 }
239}
240
241impl<Tx> InitialPacketHandler for ClientConnector<Tx>
242where
243 Tx: DatagramSocketSend + Send + 'static,
244{
245 fn update(&mut self, ctx: &mut Context<'_>) -> io::Result<()> {
246 ClientConnector::update(self, ctx)
247 }
248
249 fn handle_initials(
250 &mut self, incoming: Incoming, hdr: Header<'static>,
251 _: &mut quiche::Config,
252 ) -> io::Result<Option<NewConnection>> {
253 self.on_incoming(incoming, hdr)
254 }
255}
256
257fn simple_conn_send<Tx: DatagramSocketSend + Send + Sync + 'static>(
263 socket_tx: &MaybeConnectedSocket<Arc<Tx>>, conn: &mut QuicheConnection,
264) -> io::Result<()> {
265 let scid = conn.source_id().into_owned();
266 log::debug!("sending client Initials to peer"; "scid" => ?scid);
267
268 loop {
269 let scid = scid.clone();
270 let mut buf = [0; MAX_DATAGRAM_SIZE];
271 let send_res = conn.send(&mut buf);
272
273 let socket_clone = socket_tx.clone();
274 match send_res {
275 Ok((n, send_info)) => {
276 tokio::spawn({
277 let buf = buf[0..n].to_vec();
278 async move {
279 socket_clone.send_to(&buf, send_info.to).await.inspect_err(|error| {
280 log::error!("error sending client Initial packets to peer"; "scid" => ?scid, "peer_addr" => send_info.to, "error" => error.to_string());
281 })
282 }
283 });
284 },
285 Err(quiche::Error::Done) => break Ok(()),
286 Err(error) => {
287 log::error!("error writing packets to quiche's internal buffer"; "scid" => ?scid, "error" => error.to_string());
288 break Err(std::io::Error::other(error));
289 },
290 }
291 }
292}