datagram_socket/
mmsg.rs

1// Copyright (C) 2025, Cloudflare, Inc.
2// All rights reserved.
3//
4// Redistribution and use in source and binary forms, with or without
5// modification, are permitted provided that the following conditions are
6// met:
7//
8//     * Redistributions of source code must retain the above copyright notice,
9//       this list of conditions and the following disclaimer.
10//
11//     * Redistributions in binary form must reproduce the above copyright
12//       notice, this list of conditions and the following disclaimer in the
13//       documentation and/or other materials provided with the distribution.
14//
15// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
16// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
17// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
19// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
27use std::io::IoSlice;
28use std::io::{
29    self,
30};
31use std::os::fd::AsRawFd;
32use std::os::fd::BorrowedFd;
33
34use smallvec::SmallVec;
35use tokio::io::ReadBuf;
36
37const MAX_MMSG: usize = 16;
38
39pub fn recvmmsg(fd: BorrowedFd, bufs: &mut [ReadBuf<'_>]) -> io::Result<usize> {
40    let mut msgvec: SmallVec<[libc::mmsghdr; MAX_MMSG]> = SmallVec::new();
41    let mut slices: SmallVec<[IoSlice; MAX_MMSG]> = SmallVec::new();
42
43    let mut ret = 0;
44
45    for bufs in bufs.chunks_mut(MAX_MMSG) {
46        msgvec.clear();
47        slices.clear();
48
49        for buf in bufs.iter_mut() {
50            // Safety: will not read the maybe uninitialized bytes.
51            let b = unsafe {
52                &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>]
53                    as *mut [u8])
54            };
55
56            slices.push(IoSlice::new(b));
57
58            msgvec.push(libc::mmsghdr {
59                msg_hdr: libc::msghdr {
60                    msg_name: std::ptr::null_mut(),
61                    msg_namelen: 0,
62                    msg_iov: slices.last_mut().unwrap() as *mut _ as *mut _,
63                    msg_iovlen: 1,
64                    msg_control: std::ptr::null_mut(),
65                    msg_controllen: 0,
66                    msg_flags: 0,
67                },
68                msg_len: buf.capacity().try_into().unwrap(),
69            });
70        }
71
72        let result = unsafe {
73            libc::recvmmsg(
74                fd.as_raw_fd(),
75                msgvec.as_mut_ptr(),
76                msgvec.len() as _,
77                0,
78                std::ptr::null_mut(),
79            )
80        };
81
82        if result == -1 {
83            break;
84        }
85
86        for i in 0..result as usize {
87            let filled = msgvec[i].msg_len as usize;
88            unsafe { bufs[i].assume_init(filled) };
89            bufs[i].advance(filled);
90            ret += 1;
91        }
92
93        if (result as usize) < MAX_MMSG {
94            break;
95        }
96    }
97
98    if ret == 0 {
99        return Err(io::Error::last_os_error());
100    }
101
102    Ok(ret)
103}
104
105pub fn sendmmsg(fd: BorrowedFd, bufs: &[ReadBuf<'_>]) -> io::Result<usize> {
106    let mut msgvec: SmallVec<[libc::mmsghdr; MAX_MMSG]> = SmallVec::new();
107    let mut slices: SmallVec<[IoSlice; MAX_MMSG]> = SmallVec::new();
108
109    let mut ret = 0;
110
111    for bufs in bufs.chunks(MAX_MMSG) {
112        msgvec.clear();
113        slices.clear();
114
115        for buf in bufs.iter() {
116            slices.push(IoSlice::new(buf.filled()));
117
118            msgvec.push(libc::mmsghdr {
119                msg_hdr: libc::msghdr {
120                    msg_name: std::ptr::null_mut(),
121                    msg_namelen: 0,
122                    msg_iov: slices.last_mut().unwrap() as *mut _ as *mut _,
123                    msg_iovlen: 1,
124                    msg_control: std::ptr::null_mut(),
125                    msg_controllen: 0,
126                    msg_flags: 0,
127                },
128                msg_len: buf.capacity().try_into().unwrap(),
129            });
130        }
131
132        let result = unsafe {
133            libc::sendmmsg(
134                fd.as_raw_fd(),
135                msgvec.as_mut_ptr(),
136                msgvec.len() as _,
137                0,
138            )
139        };
140
141        if result == -1 {
142            break;
143        }
144
145        ret += result as usize;
146
147        if (result as usize) < MAX_MMSG {
148            break;
149        }
150    }
151
152    if ret == 0 {
153        return Err(io::Error::last_os_error());
154    }
155
156    Ok(ret)
157}
158
159#[macro_export]
160macro_rules! poll_recvmmsg {
161    ($self: expr, $cx: ident, $bufs: ident) => {
162        loop {
163            match $self.poll_recv_ready($cx)? {
164                Poll::Ready(()) => {
165                    match $self.try_io(tokio::io::Interest::READABLE, || {
166                        $crate::mmsg::recvmmsg($self.as_fd(), $bufs)
167                    }) {
168                        Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}  // Have to poll for recv ready
169                        res => break Poll::Ready(res),
170                    }
171                }
172                Poll::Pending => break Poll::Pending,
173            }
174        }
175    };
176}
177
178#[macro_export]
179macro_rules! poll_sendmmsg {
180    ($self: expr, $cx: ident, $bufs: ident) => {
181        loop {
182            match $self.poll_send_ready($cx)? {
183                Poll::Ready(()) => {
184                    match $self.try_io(tokio::io::Interest::WRITABLE, || {
185                        $crate::mmsg::sendmmsg($self.as_fd(), $bufs)
186                    }) {
187                        Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} // Have to poll for send ready
188                        res => break Poll::Ready(res),
189                    }
190                }
191                Poll::Pending => break Poll::Pending,
192            }
193        }
194    };
195}
196
197#[cfg(test)]
198mod tests {
199    use std::io;
200
201    use tokio::io::ReadBuf;
202    use tokio::net::UnixDatagram;
203
204    use crate::DatagramSocketRecvExt;
205    use crate::DatagramSocketSendExt;
206
207    #[tokio::test]
208    async fn recvmmsg() -> io::Result<()> {
209        let (s, mut r) = UnixDatagram::pair()?;
210        let mut bufs = [[0u8; 128]; 128];
211
212        for i in 0..5 {
213            s.send(&[i; 128]).await?;
214        }
215
216        let mut rbufs: Vec<_> =
217            bufs.iter_mut().map(|s| ReadBuf::new(&mut s[..])).collect();
218        assert_eq!(r.recv_many(&mut rbufs).await?, 5);
219
220        for (i, buf) in rbufs[0..5].iter().enumerate() {
221            assert_eq!(buf.filled(), &[i as u8; 128]);
222        }
223
224        for i in 0..92 {
225            s.send(&[i; 128]).await?;
226        }
227
228        let mut rbufs: Vec<_> =
229            bufs.iter_mut().map(|s| ReadBuf::new(&mut s[..])).collect();
230        assert_eq!(r.recv_many(&mut rbufs).await?, 92);
231
232        for (i, buf) in rbufs[0..92].iter().enumerate() {
233            assert_eq!(buf.filled(), &[i as u8; 128]);
234        }
235
236        Ok(())
237    }
238
239    #[tokio::test]
240    async fn sendmmsg() -> io::Result<()> {
241        let (s, r) = UnixDatagram::pair()?;
242        let mut bufs: [_; 128] = std::array::from_fn(|i| [i as u8; 128]);
243
244        let wbufs: Vec<_> = bufs
245            .iter_mut()
246            .map(|s| {
247                let mut b = ReadBuf::new(&mut s[..]);
248                b.set_filled(128);
249                b
250            })
251            .collect();
252
253        assert_eq!(s.send_many(&wbufs[..5]).await?, 5);
254
255        let mut rbuf = [0u8; 128];
256
257        for i in 0..5 {
258            assert_eq!(r.recv(&mut rbuf).await?, 128);
259            assert_eq!(rbuf, [i as u8; 128]);
260        }
261
262        Ok(())
263    }
264}