tokio_quiche/quic/
addr_validation_token.rs1use quiche::ConnectionId;
28use std::io::Write;
29use std::io::{
30 self,
31};
32use std::net::IpAddr;
33use std::net::SocketAddr;
34
35use crate::QuicResultExt;
36
37const HMAC_KEY_LEN: usize = 32;
38const HMAC_TAG_LEN: usize = 32;
39
40pub(crate) struct AddrValidationTokenManager {
41 sign_key: [u8; HMAC_KEY_LEN],
42}
43
44impl Default for AddrValidationTokenManager {
45 fn default() -> Self {
46 let mut key_bytes = [0; HMAC_KEY_LEN];
47 boring::rand::rand_bytes(&mut key_bytes).unwrap();
48
49 AddrValidationTokenManager {
50 sign_key: key_bytes,
51 }
52 }
53}
54
55impl AddrValidationTokenManager {
56 pub(super) fn gen(
57 &self, original_dcid: &[u8], client_addr: SocketAddr,
58 ) -> Vec<u8> {
59 let ip_bytes = match client_addr.ip() {
60 IpAddr::V4(addr) => addr.octets().to_vec(),
61 IpAddr::V6(addr) => addr.octets().to_vec(),
62 };
63
64 let token_len = HMAC_TAG_LEN + ip_bytes.len() + original_dcid.len();
65 let mut token = io::Cursor::new(vec![0u8; token_len]);
66
67 token.set_position(HMAC_TAG_LEN as u64);
68 token.write_all(&ip_bytes).unwrap();
69 token.write_all(original_dcid).unwrap();
70
71 let tag = boring::hash::hmac_sha256(
72 &self.sign_key,
73 &token.get_ref()[HMAC_TAG_LEN..],
74 )
75 .unwrap();
76
77 token.set_position(0);
78 token.write_all(tag.as_ref()).unwrap();
79
80 token.into_inner()
81 }
82
83 pub(super) fn validate_and_extract_original_dcid<'t>(
84 &self, token: &'t [u8], client_addr: SocketAddr,
85 ) -> io::Result<ConnectionId<'t>> {
86 let ip_bytes = match client_addr.ip() {
87 IpAddr::V4(addr) => addr.octets().to_vec(),
88 IpAddr::V6(addr) => addr.octets().to_vec(),
89 };
90
91 let hmac_and_ip_len = HMAC_TAG_LEN + ip_bytes.len();
92
93 if token.len() < hmac_and_ip_len {
94 return Err("token is too short").into_io();
95 }
96
97 let (tag, payload) = token.split_at(HMAC_TAG_LEN);
98
99 let expected_tag =
100 boring::hash::hmac_sha256(&self.sign_key, payload).unwrap();
101
102 if !boring::memcmp::eq(&expected_tag, tag) {
103 return Err("signature verification failed").into_io();
104 }
105
106 if payload[..ip_bytes.len()] != *ip_bytes {
107 return Err("IPs don't match").into_io();
108 }
109
110 Ok(ConnectionId::from_ref(&token[hmac_and_ip_len..]))
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 #[test]
119 fn generate() {
120 let manager = AddrValidationTokenManager::default();
121
122 let assert_tag_generated = |token: &[u8]| {
123 let tag = &token[..HMAC_TAG_LEN];
124 let all_nulls = tag.iter().all(|b| *b == 0u8);
125
126 assert!(!all_nulls);
127 };
128
129 let token = manager.gen(b"foo", "127.0.0.1:1337".parse().unwrap());
130
131 assert_tag_generated(&token);
132 assert_eq!(token[HMAC_TAG_LEN..HMAC_TAG_LEN + 4], [127, 0, 0, 1]);
133 assert_eq!(&token[HMAC_TAG_LEN + 4..], b"foo");
134
135 let token = manager.gen(b"bar", "[::1]:1338".parse().unwrap());
136
137 assert_tag_generated(&token);
138
139 assert_eq!(token[HMAC_TAG_LEN..HMAC_TAG_LEN + 16], [
140 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1
141 ]);
142
143 assert_eq!(&token[HMAC_TAG_LEN + 16..], b"bar");
144 }
145
146 #[test]
147 fn validate() {
148 let manager = AddrValidationTokenManager::default();
149
150 let addr = "127.0.0.1:1337".parse().unwrap();
151 let token = manager.gen(b"foo", addr);
152
153 assert_eq!(
154 manager
155 .validate_and_extract_original_dcid(&token, addr)
156 .unwrap(),
157 ConnectionId::from_ref(b"foo")
158 );
159
160 let addr = "[::1]:1338".parse().unwrap();
161 let token = manager.gen(b"barbaz", addr);
162
163 assert_eq!(
164 manager
165 .validate_and_extract_original_dcid(&token, addr)
166 .unwrap(),
167 ConnectionId::from_ref(b"barbaz")
168 );
169 }
170
171 #[test]
172 fn validate_err_short_token() {
173 let manager = AddrValidationTokenManager::default();
174 let v4_addr = "127.0.0.1:1337".parse().unwrap();
175 let v6_addr = "[::1]:1338".parse().unwrap();
176
177 for addr in &[v4_addr, v6_addr] {
178 assert!(manager
179 .validate_and_extract_original_dcid(b"", *addr)
180 .is_err());
181
182 assert!(manager
183 .validate_and_extract_original_dcid(&[1u8; HMAC_TAG_LEN], *addr)
184 .is_err());
185
186 assert!(manager
187 .validate_and_extract_original_dcid(
188 &[1u8; HMAC_TAG_LEN + 1],
189 *addr
190 )
191 .is_err());
192 }
193 }
194
195 #[test]
196 fn validate_err_ips_mismatch() {
197 let manager = AddrValidationTokenManager::default();
198
199 let token = manager.gen(b"foo", "127.0.0.1:1337".parse().unwrap());
200
201 assert!(manager
202 .validate_and_extract_original_dcid(
203 &token,
204 "127.0.0.2:1337".parse().unwrap()
205 )
206 .is_err());
207
208 let token = manager.gen(b"barbaz", "[::1]:1338".parse().unwrap());
209
210 assert!(manager
211 .validate_and_extract_original_dcid(
212 &token,
213 "[::2]:1338".parse().unwrap()
214 )
215 .is_err());
216 }
217
218 #[test]
219 fn validate_err_invalid_signature() {
220 let manager = AddrValidationTokenManager::default();
221
222 let addr = "127.0.0.1:1337".parse().unwrap();
223 let mut token = manager.gen(b"foo", addr);
224
225 token[..HMAC_TAG_LEN].copy_from_slice(&[1u8; HMAC_TAG_LEN]);
226
227 assert!(manager
228 .validate_and_extract_original_dcid(&token, addr)
229 .is_err());
230 }
231}