1use std::io::IoSlice;
28use std::io::{
29 self,
30};
31use std::os::fd::AsRawFd;
32use std::os::fd::BorrowedFd;
33
34use smallvec::SmallVec;
35use tokio::io::ReadBuf;
36
37const MAX_MMSG: usize = 16;
38
39pub fn recvmmsg(fd: BorrowedFd, bufs: &mut [ReadBuf<'_>]) -> io::Result<usize> {
40 let mut msgvec: SmallVec<[libc::mmsghdr; MAX_MMSG]> = SmallVec::new();
41 let mut slices: SmallVec<[IoSlice; MAX_MMSG]> = SmallVec::new();
42
43 let mut ret = 0;
44
45 for bufs in bufs.chunks_mut(MAX_MMSG) {
46 msgvec.clear();
47 slices.clear();
48
49 for buf in bufs.iter_mut() {
50 let b = unsafe {
52 &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>]
53 as *mut [u8])
54 };
55
56 slices.push(IoSlice::new(b));
57
58 msgvec.push(libc::mmsghdr {
59 msg_hdr: libc::msghdr {
60 msg_name: std::ptr::null_mut(),
61 msg_namelen: 0,
62 msg_iov: slices.last_mut().unwrap() as *mut _ as *mut _,
63 msg_iovlen: 1,
64 msg_control: std::ptr::null_mut(),
65 msg_controllen: 0,
66 msg_flags: 0,
67 },
68 msg_len: buf.capacity().try_into().unwrap(),
69 });
70 }
71
72 let result = unsafe {
73 libc::recvmmsg(
74 fd.as_raw_fd(),
75 msgvec.as_mut_ptr(),
76 msgvec.len() as _,
77 0,
78 std::ptr::null_mut(),
79 )
80 };
81
82 if result == -1 {
83 break;
84 }
85
86 for i in 0..result as usize {
87 let filled = msgvec[i].msg_len as usize;
88 unsafe { bufs[i].assume_init(filled) };
89 bufs[i].advance(filled);
90 ret += 1;
91 }
92
93 if (result as usize) < MAX_MMSG {
94 break;
95 }
96 }
97
98 if ret == 0 {
99 return Err(io::Error::last_os_error());
100 }
101
102 Ok(ret)
103}
104
105pub fn sendmmsg(fd: BorrowedFd, bufs: &[ReadBuf<'_>]) -> io::Result<usize> {
106 let mut msgvec: SmallVec<[libc::mmsghdr; MAX_MMSG]> = SmallVec::new();
107 let mut slices: SmallVec<[IoSlice; MAX_MMSG]> = SmallVec::new();
108
109 let mut ret = 0;
110
111 for bufs in bufs.chunks(MAX_MMSG) {
112 msgvec.clear();
113 slices.clear();
114
115 for buf in bufs.iter() {
116 slices.push(IoSlice::new(buf.filled()));
117
118 msgvec.push(libc::mmsghdr {
119 msg_hdr: libc::msghdr {
120 msg_name: std::ptr::null_mut(),
121 msg_namelen: 0,
122 msg_iov: slices.last_mut().unwrap() as *mut _ as *mut _,
123 msg_iovlen: 1,
124 msg_control: std::ptr::null_mut(),
125 msg_controllen: 0,
126 msg_flags: 0,
127 },
128 msg_len: buf.capacity().try_into().unwrap(),
129 });
130 }
131
132 let result = unsafe {
133 libc::sendmmsg(
134 fd.as_raw_fd(),
135 msgvec.as_mut_ptr(),
136 msgvec.len() as _,
137 0,
138 )
139 };
140
141 if result == -1 {
142 break;
143 }
144
145 ret += result as usize;
146
147 if (result as usize) < MAX_MMSG {
148 break;
149 }
150 }
151
152 if ret == 0 {
153 return Err(io::Error::last_os_error());
154 }
155
156 Ok(ret)
157}
158
159#[macro_export]
160macro_rules! poll_recvmmsg {
161 ($self: expr, $cx: ident, $bufs: ident) => {
162 loop {
163 match $self.poll_recv_ready($cx)? {
164 Poll::Ready(()) => {
165 match $self.try_io(tokio::io::Interest::READABLE, || {
166 $crate::mmsg::recvmmsg($self.as_fd(), $bufs)
167 }) {
168 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} res => break Poll::Ready(res),
170 }
171 }
172 Poll::Pending => break Poll::Pending,
173 }
174 }
175 };
176}
177
178#[macro_export]
179macro_rules! poll_sendmmsg {
180 ($self: expr, $cx: ident, $bufs: ident) => {
181 loop {
182 match $self.poll_send_ready($cx)? {
183 Poll::Ready(()) => {
184 match $self.try_io(tokio::io::Interest::WRITABLE, || {
185 $crate::mmsg::sendmmsg($self.as_fd(), $bufs)
186 }) {
187 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} res => break Poll::Ready(res),
189 }
190 }
191 Poll::Pending => break Poll::Pending,
192 }
193 }
194 };
195}
196
197#[cfg(test)]
198mod tests {
199 use std::io;
200
201 use tokio::io::ReadBuf;
202 use tokio::net::UnixDatagram;
203
204 use crate::DatagramSocketRecvExt;
205 use crate::DatagramSocketSendExt;
206
207 #[tokio::test]
208 async fn recvmmsg() -> io::Result<()> {
209 let (s, mut r) = UnixDatagram::pair()?;
210 let mut bufs = [[0u8; 128]; 128];
211
212 for i in 0..5 {
213 s.send(&[i; 128]).await?;
214 }
215
216 let mut rbufs: Vec<_> =
217 bufs.iter_mut().map(|s| ReadBuf::new(&mut s[..])).collect();
218 assert_eq!(r.recv_many(&mut rbufs).await?, 5);
219
220 for (i, buf) in rbufs[0..5].iter().enumerate() {
221 assert_eq!(buf.filled(), &[i as u8; 128]);
222 }
223
224 for i in 0..92 {
225 s.send(&[i; 128]).await?;
226 }
227
228 let mut rbufs: Vec<_> =
229 bufs.iter_mut().map(|s| ReadBuf::new(&mut s[..])).collect();
230 assert_eq!(r.recv_many(&mut rbufs).await?, 92);
231
232 for (i, buf) in rbufs[0..92].iter().enumerate() {
233 assert_eq!(buf.filled(), &[i as u8; 128]);
234 }
235
236 Ok(())
237 }
238
239 #[tokio::test]
240 async fn sendmmsg() -> io::Result<()> {
241 let (s, r) = UnixDatagram::pair()?;
242 let mut bufs: [_; 128] = std::array::from_fn(|i| [i as u8; 128]);
243
244 let wbufs: Vec<_> = bufs
245 .iter_mut()
246 .map(|s| {
247 let mut b = ReadBuf::new(&mut s[..]);
248 b.set_filled(128);
249 b
250 })
251 .collect();
252
253 assert_eq!(s.send_many(&wbufs[..5]).await?, 5);
254
255 let mut rbuf = [0u8; 128];
256
257 for i in 0..5 {
258 assert_eq!(r.recv(&mut rbuf).await?, 128);
259 assert_eq!(rbuf, [i as u8; 128]);
260 }
261
262 Ok(())
263 }
264}