From 7f69b97643ed9627f65a79dd6843eea192206f45 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Mon, 15 Aug 2022 06:17:27 -0700 Subject: [PATCH 01/34] add first pass implementation of windows uds --- src/sys/windows/mod.rs | 14 + src/sys/windows/uds/listener.rs | 62 +++ src/sys/windows/uds/mod.rs | 91 ++++ src/sys/windows/uds/stdnet/ext.rs | 684 +++++++++++++++++++++++++++ src/sys/windows/uds/stdnet/mod.rs | 134 ++++++ src/sys/windows/uds/stdnet/net.rs | 480 +++++++++++++++++++ src/sys/windows/uds/stdnet/socket.rs | 287 +++++++++++ src/sys/windows/uds/stream.rs | 46 ++ 8 files changed, 1798 insertions(+) create mode 100644 src/sys/windows/uds/listener.rs create mode 100644 src/sys/windows/uds/mod.rs create mode 100644 src/sys/windows/uds/stdnet/ext.rs create mode 100644 src/sys/windows/uds/stdnet/mod.rs create mode 100644 src/sys/windows/uds/stdnet/net.rs create mode 100644 src/sys/windows/uds/stdnet/socket.rs create mode 100644 src/sys/windows/uds/stream.rs diff --git a/src/sys/windows/mod.rs b/src/sys/windows/mod.rs index f8b72fc49..a243c7198 100644 --- a/src/sys/windows/mod.rs +++ b/src/sys/windows/mod.rs @@ -31,6 +31,20 @@ cfg_net! { }}; } + macro_rules! wsa_syscall { + ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{ + let res = unsafe { $fn($($arg, )*) }; + if $err_test(&res, &$err_value) { + use windows_sys::Win32::Networking::WinSock::WSAGetLastError; + Err(io::Error::from_raw_os_error(unsafe { + WSAGetLastError() + })) + } else { + Ok(res) + } + }}; + } + mod net; pub(crate) mod tcp; diff --git a/src/sys/windows/uds/listener.rs b/src/sys/windows/uds/listener.rs new file mode 100644 index 000000000..f93f584a8 --- /dev/null +++ b/src/sys/windows/uds/listener.rs @@ -0,0 +1,62 @@ +use std::{io, mem}; +use std::os::windows::io::{AsRawSocket, FromRawSocket}; +use std::path::Path; +use windows_sys::Win32::Networking::WinSock; + +use super::{stdnet as net, socket_addr}; +use crate::net::{SocketAddr, UnixStream}; +use crate::sys::windows::net::{init, new_socket}; + +pub(crate) fn bind(path: &Path) -> io::Result { + init(); + let socket = new_socket(WinSock::AF_UNIX, WinSock::SOCK_STREAM)?; + let (sockaddr, socklen) = socket_addr(path)?; + let sockaddr = &sockaddr as *const WinSock::sockaddr_un as *const WinSock::SOCKADDR; + + wsa_syscall!(bind(socket, sockaddr, socklen as _), PartialEq::eq, SOCKET_ERROR) + .and_then(|_| wsa_syscall!(listen(socket, 128), PartialEq::eq, SOCKET_ERROR)) + .map_err(|err| { + // Close the socket if we hit an error, ignoring the error from + // closing since we can't pass back two errors. + let _ = unsafe { WinSock::closesocket(socket) }; + err + }) + .map(|_| unsafe { net::UnixListener::from_raw_socket(socket) }) +} + +pub(crate) fn accept(listener: &net::UnixListener) -> io::Result<(UnixStream, SocketAddr)> { + let sockaddr = mem::MaybeUninit::::zeroed(); + + // This is safe to assume because a `WinSock::sockaddr_un` filled with `0` + // bytes is properly initialized. + // + // `0` is a valid value for `sockaddr_un::sun_family`; it is + // `WinSock::AF_UNSPEC`. + // + // `[0; 108]` is a valid value for `sockaddr_un::sun_path`; it begins an + // abstract path. + let mut sockaddr = unsafe { sockaddr.assume_init() }; + + sockaddr.sun_family = WinSock::AF_UNIX; + let mut socklen = mem::size_of_val(&sockaddr) as c_int; + + let socket = self.0.accept(&mut storage as *mut _ as *mut _, &mut len)?; + + let socket = wsa_syscall!( + accept( + listener.as_raw_socket(), + &sockaddr as *const WinSock::sockaddr_un as *const WinSock::SOCKADDR, + socklen as _ + ), + PartialEq::eq, + INVALID_SOCKET + )?; + + socket + .map(UnixStream::from_std) + .map(|stream| (stream, SocketAddr::from_parts(sockaddr, socklen))) +} + +pub(crate) fn local_addr(listener: &net::UnixListener) -> io::Result { + super::local_addr(listener.as_raw_socket()) +} diff --git a/src/sys/windows/uds/mod.rs b/src/sys/windows/uds/mod.rs new file mode 100644 index 000000000..95d5f4c67 --- /dev/null +++ b/src/sys/windows/uds/mod.rs @@ -0,0 +1,91 @@ +mod stdnet; +pub use self::stdnet::SocketAddr; + +fn path_offset(addr: &WinSock::sockaddr_un) -> usize { + // Work with an actual instance of the type since using a null pointer is UB + let base = addr as *const _ as usize; + let path = &addr.sun_path as *const _ as usize; + path - base +} + +cfg_os_poll! { + use windows_sys::Win32::Networking::WinSock; + use std::os::windows::io::RawSocket; + use std::path::Path; + use std::{io, mem}; + + pub(crate) mod listener; + pub(crate) mod stream; + + pub unsafe fn socket_addr(path: &Path) -> io::Result<(WinSock::sockaddr_un, c_int)> { + let sockaddr = mem::MaybeUninit::::zeroed(); + + // This is safe to assume because a `WinSock::sockaddr_un` filled with `0` + // bytes is properly initialized. + // + // `0` is a valid value for `sockaddr_un::sun_family`; it is + // `WinSock::AF_UNSPEC`. + // + // `[0; 108]` is a valid value for `sockaddr_un::sun_path`; it begins an + // abstract path. + let mut sockaddr = unsafe { sockaddr.assume_init() }; + sockaddr.sun_family = WinSock::AF_UNIX; + + // Winsock2 expects 'sun_path' to be a Win32 UTF-8 file system path + let bytes = path.to_str().map(|s| s.as_bytes()).ok_or(io::Error::new( + io::ErrorKind::InvalidInput, + "path contains invalid characters", + ))?; + + if bytes.contains(&0) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "paths may not contain interior null bytes", + )); + } + + if bytes.len() >= sockaddr.sun_path.len() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "path must be shorter than SUN_LEN", + )); + } + for (dst, src) in sockaddr.sun_path.iter_mut().zip(bytes.iter()) { + *dst = *src as c_char; + } + // null byte for pathname addresses is already there because we zeroed the + // struct + + let offset = path_offset(&sockaddr); + let mut socklen = offset + bytes.len(); + + match bytes.get(0) { + // The struct has already been zeroes so the null byte for pathname + // addresses is already there. + Some(&0) | None => {} + Some(_) => socklen += 1, + } + + Ok((sockaddr, socklen as c_int)) + } + + pub(crate) fn local_addr(socket: RawSocket) -> io::Result { + SocketAddr::new(|sockaddr, socklen| { + wsa_syscall!( + WinSock::getsockname(socket, sockaddr, socklen), + PartialEq::eq, + SOCKET_ERROR + ) + }) + } + + pub(crate) fn peer_addr(socket: RawSocket) -> io::Result { + SocketAddr::new(|sockaddr, socklen| { + wsa_syscall!( + WinSock::getpeername(socket, sockaddr, socklen), + PartialEq::eq, + SOCKET_ERROR + ) + }) + } +} diff --git a/src/sys/windows/uds/stdnet/ext.rs b/src/sys/windows/uds/stdnet/ext.rs new file mode 100644 index 000000000..d2e44e1d6 --- /dev/null +++ b/src/sys/windows/uds/stdnet/ext.rs @@ -0,0 +1,684 @@ +//! Extensions and types for Unix domain socket networking primitives. +//! +//! This module contains a number of extension traits for Windows-specific +//! functionality. + +use std::cmp; +use std::fmt; +use std::io; +use std::mem; +use std::os::windows::prelude::*; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use windows_sys::Win32::Networking::WinSock::{ + self, + SIO_GET_EXTENSION_FUNCTION_POINTER, SOCKADDR, SOCKADDR_STORAGE, SOL_SOCKET, WSABUF, + WSAGetLastError, WSAGetOverlappedResult, WSAIoctl, WSARecv, WSASend, + SOCKET, SOCKET_ERROR, WSA_IO_PENDING, setsockopt, bind +}; +use windows_sys::Win32::Foundation::BOOL; +use windows_sys::core::GUID; +use windows_sys::Win32::System::IO::OVERLAPPED; + +use super::net::{UnixListener, UnixStream}; +use super::{path_offset, SocketAddr}; + +// TODO +type DWORD = u32; +type INT = i32; +type u_long = u32; +type c_int = i32; +type PVOID = *mut c_void; +type LPINT = *mut INT; +type LPDWORD = *mut DWORD; +type LPOVERLAPPED = *mut OVERLAPPED; +type LPSOCKADDR = *mut SOCKADDR; + +/// A buffer in which an accepted socket's address will be stored +/// +/// This type is used with the `accept_overlapped` method on the +/// `UnixListenerExt` trait to provide space for the overlapped I/O operation to +/// fill in the socket addresses upon completion. +#[repr(C)] +pub struct AcceptAddrsBuf { + // For AcceptEx we've got the restriction that the addresses passed in that + // buffer need to be at least 16 bytes more than the maximum address length + // for the protocol in question, so add some extra here and there + local: SOCKADDR_STORAGE, + _pad1: [u8; 16], + remote: SOCKADDR_STORAGE, + _pad2: [u8; 16], +} + +impl fmt::Debug for AcceptAddrsBuf { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let local = unsafe { &*(&self.local as *const _ as *const WinSock::sockaddr_un) }; + let remote = unsafe { &*(&self.remote as *const _ as *const WinSock::sockaddr_un) }; + f.debug_struct("AcceptAddrsBuf") + .field("local", local) + .field("remote", remote) + .finish() + } +} + +/// The parsed return value of `AcceptAddrsBuf` +pub struct AcceptAddrs<'a> { + local: LPSOCKADDR, + local_len: c_int, + remote: LPSOCKADDR, + remote_len: c_int, + _data: &'a AcceptAddrsBuf, +} + +impl<'a> fmt::Debug for AcceptAddrs<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Debug::fmt(&self._data, f) + } +} + +struct WsaExtension { + guid: GUID, + val: AtomicUsize, +} + +/// Additional methods for the `UnixStream` type +pub trait UnixStreamExt { + /// Execute an overlapped read I/O operation on this Unix domain socket + /// stream. + /// + /// This function will issue an overlapped I/O read (via `WSARecv`) on this + /// socket. The provided buffer will be filled in when the operation + /// completes and the given `OVERLAPPED` instance is used to track the + /// overlapped operation. + /// + /// If the operation succeeds, `Ok(Some(n))` is returned indicating how + /// many bytes were read. If the operation returns an error indicating that + /// the I/O is currently pending, `Ok(None)` is returned. Otherwise, the + /// error associated with the operation is returned and no overlapped + /// operation is enqueued. + /// + /// The number of bytes read will be returned as part of the completion + /// notification when the I/O finishes. + /// + /// # Unsafety + /// + /// This function is unsafe because the kernel requires that the `buf` and + /// `overlapped` pointers are valid until the end of the I/O operation. The + /// kernel also requires that `overlapped` is unique for this I/O operation + /// and is not in use for any other I/O. + /// + /// To safely use this function callers must ensure that these two input + /// pointers are valid until the I/O operation is completed, typically via + /// completion ports and waiting to receive the completion notification on + /// the port. + unsafe fn read_overlapped( + &self, + buf: &mut [u8], + overlapped: *mut OVERLAPPED, + ) -> io::Result>; + + /// Execute an overlapped write I/O operation on this Unix domain socket + /// stream. + /// + /// This function will issue an overlapped I/O write (via `WSASend`) on this + /// socket. The provided buffer will be written when the operation completes + /// and the given `OVERLAPPED` instance is used to track the overlapped + /// operation. + /// + /// If the operation succeeds, `Ok(Some(n))` is returned where `n` is the + /// number of bytes that were written. If the operation returns an error + /// indicating that the I/O is currently pending, `Ok(None)` is returned. + /// Otherwise, the error associated with the operation is returned and no + /// overlapped operation is enqueued. + /// + /// The number of bytes written will be returned as part of the completion + /// notification when the I/O finishes. + /// + /// # Unsafety + /// + /// This function is unsafe because the kernel requires that the `buf` and + /// `overlapped` pointers are valid until the end of the I/O operation. The + /// kernel also requires that `overlapped` is unique for this I/O operation + /// and is not in use for any other I/O. + /// + /// To safely use this function callers must ensure that these two input + /// pointers are valid until the I/O operation is completed, typically via + /// completion ports and waiting to receive the completion notification on + /// the port. + unsafe fn write_overlapped( + &self, + buf: &[u8], + overlapped: *mut OVERLAPPED, + ) -> io::Result>; + + /// Attempt to consume the internal socket in this builder by executing an + /// overlapped connect operation. + /// + /// This function will issue a connect operation to the address specified on + /// the underlying socket, flagging it as an overlapped operation which will + /// complete asynchronously. If successful this function will return the + /// corresponding Unix domain socket stream. + /// + /// The `buf` argument provided is an initial buffer of data that should be + /// sent after the connection is initiated. It's acceptable to + /// pass an empty slice here. + /// + /// This function will also return whether the connect immediately + /// succeeded or not. If `Ok(None)` is returned then the I/O operation is + /// still pending and will complete later. If `Ok(Some(bytes))` is returned + /// then that many bytes were transferred. + /// + /// Note that to succeed this requires that the underlying socket has + /// previously been bound via a call to `bind` to a local path. + /// + /// # Unsafety + /// + /// This function is unsafe because the kernel requires that the + /// `overlapped` and `buf` pointers to be valid until the end of the I/O + /// operation. The kernel also requires that `overlapped` is unique for + /// this I/O operation and is not in use for any other I/O. + /// + /// To safely use this function callers must ensure that this pointer is + /// valid until the I/O operation is completed, typically via completion + /// ports and waiting to receive the completion notification on the port. + unsafe fn connect_overlapped( + &self, + addr: &SocketAddr, + buf: &[u8], + overlapped: *mut OVERLAPPED, + ) -> io::Result>; + + /// Once a `connect_overlapped` has finished, this function needs to be + /// called to finish the connect operation. + /// + /// Currently this just calls `setsockopt` with `SO_UPDATE_CONNECT_CONTEXT` + /// to ensure that further functions like `getpeername` and `getsockname` + /// work correctly. + fn connect_complete(&self) -> io::Result<()>; + + /// Calls the `GetOverlappedResult` function to get the result of an + /// overlapped operation for this handle. + /// + /// This function takes the `OVERLAPPED` argument which must have been used + /// to initiate an overlapped I/O operation, and returns either the + /// successful number of bytes transferred during the operation or an error + /// if one occurred, along with the results of the `lpFlags` parameter of + /// the relevant operation, if applicable. + /// + /// # Unsafety + /// + /// This function is unsafe as `overlapped` must have previously been used + /// to execute an operation for this handle, and it must also be a valid + /// pointer to an `OVERLAPPED` instance. + /// + /// # Panics + /// + /// This function will panic + unsafe fn result(&self, overlapped: *mut OVERLAPPED) -> io::Result<(usize, u32)>; +} + +/// Additional methods for the `UnixListener` type +pub trait UnixListenerExt { + /// Perform an accept operation on this listener, accepting a connection in + /// an overlapped fashion. + /// + /// This function will issue an I/O request to accept an incoming connection + /// with the specified overlapped instance. The `socket` provided must be + /// configured but not bound or connected. If successful this method will + /// consume the socket to return a Unix stream. + /// + /// The `addrs` buffer provided will be filled in with the local and remote + /// addresses of the connection upon completion. + /// + /// If the accept succeeds immediately, `Ok(true)` is returned. If the + /// connect indicates that the I/O is currently pending, `Ok(false)` is + /// returned. Otherwise, the error associated with the operation is returned + /// and no overlapped operation is enqueued. + /// + /// # Unsafety + /// + /// This function is unsafe because the kernel requires that the + /// `addrs` and `overlapped` pointers are valid until the end of the I/O + /// operation. The kernel also requires that `overlapped` is unique for this + /// I/O operation and is not in use for any other I/O. + /// + /// To safely use this function callers must ensure that the pointers are + /// valid until the I/O operation is completed, typically via completion + /// ports and waiting to receive the completion notification on the port. + unsafe fn accept_overlapped( + &self, + socket: &UnixStream, + addrs: &mut AcceptAddrsBuf, + overlapped: *mut OVERLAPPED, + ) -> io::Result; + + /// Once an `accept_overlapped` has finished, this function needs to be + /// called to finish the accept operation. + /// + /// Currently this just calls `setsockopt` with `SO_UPDATE_ACCEPT_CONTEXT` + /// to ensure that further functions like `getpeername` and `getsockname` + /// work correctly. + fn accept_complete(&self, socket: &UnixStream) -> io::Result<()>; + + /// Calls the `GetOverlappedResult` function to get the result of an + /// overlapped operation for this handle. + /// + /// This function takes the `OVERLAPPED` argument which must have been used + /// to initiate an overlapped I/O operation, and returns either the + /// successful number of bytes transferred during the operation or an error + /// if one occurred, along with the results of the `lpFlags` parameter of + /// the relevant operation, if applicable. + /// + /// # Unsafety + /// + /// This function is unsafe as `overlapped` must have previously been used + /// to execute an operation for this handle, and it must also be a valid + /// pointer to an `OVERLAPPED` instance. + /// + /// # Panics + /// + /// This function will panic + unsafe fn result(&self, overlapped: *mut OVERLAPPED) -> io::Result<(usize, u32)>; +} + +fn last_err() -> io::Result> { + let err = unsafe { WSAGetLastError() }; + if err == WSA_IO_PENDING as i32 { + Ok(None) + } else { + Err(io::Error::from_raw_os_error(err)) + } +} + +fn cvt(i: c_int, size: DWORD) -> io::Result> { + if i == SOCKET_ERROR { + last_err() + } else { + Ok(Some(size as usize)) + } +} + +fn socket_addr_to_ptrs(addr: &SocketAddr) -> (*const SOCKADDR, c_int) { + ( + &addr.addr as *const _ as *const _, + mem::size_of::() as c_int, + ) +} + +unsafe fn ptrs_to_socket_addr(ptr: *const SOCKADDR, len: c_int) -> Option { + if (len as usize) < mem::size_of::() { + return None; + } + match (*ptr).sa_family { + WinSock::AF_UNIX if len as usize >= mem::size_of::() => { + let b = &*(ptr as *const WinSock::sockaddr_un); + match b.sun_path.iter().position(|c| *c == 0) { + Some(0) => Some(SocketAddr::from_parts(b.clone(), len)), + Some(i) => { + let mut l = path_offset(b) + i; + match b.sun_path.get(0) { + Some(&0) | None => {} + Some(_) => l += 1, + } + Some(SocketAddr::from_parts(b.clone(), l as c_int)) + } + _ => None, // Invalid socket path, no terminating null byte + } + } + _ => None, // Invalid socket family, should be AF_UNIX + } +} + +unsafe fn slice2buf(slice: &[u8]) -> WSABUF { + WSABUF { + len: cmp::min(slice.len(), ::max_value() as usize) as u_long, + buf: slice.as_ptr() as *mut _, + } +} + +unsafe fn result(socket: SOCKET, overlapped: *mut OVERLAPPED) -> io::Result<(usize, u32)> { + let mut transferred = 0; + let mut flags = 0; + let r = WSAGetOverlappedResult(socket, overlapped, &mut transferred, 0, &mut flags); + if r == 0 { + Err(io::Error::last_os_error()) + } else { + Ok((transferred as usize, flags)) + } +} + +impl UnixStreamExt for UnixStream { + unsafe fn read_overlapped( + &self, + buf: &mut [u8], + overlapped: *mut OVERLAPPED, + ) -> io::Result> { + let mut buf = slice2buf(buf); + let mut flags = 0; + let mut bytes_read: DWORD = 0; + let r = WSARecv( + self.as_raw_socket() as SOCKET, + &mut buf, + 1, + &mut bytes_read, + &mut flags, + overlapped, + None, + ); + cvt(r, bytes_read) + } + + unsafe fn write_overlapped( + &self, + buf: &[u8], + overlapped: *mut OVERLAPPED, + ) -> io::Result> { + let mut buf = slice2buf(buf); + let mut bytes_written = 0; + + // Note here that we capture the number of bytes written. The + // documentation on MSDN, however, states: + // + // > Use NULL for this parameter if the lpOverlapped parameter is not + // > NULL to avoid potentially erroneous results. This parameter can be + // > NULL only if the lpOverlapped parameter is not NULL. + // + // If we're not passing a null overlapped pointer here, then why are we + // then capturing the number of bytes! Well so it turns out that this is + // clearly faster to learn the bytes here rather than later calling + // `WSAGetOverlappedResult`, and in practice almost all implementations + // use this anyway [1]. + // + // As a result we use this to and report back the result. + // + // [1]: https://github.com/carllerche/mio/pull/520#issuecomment-273983823 + let r = WSASend( + self.as_raw_socket() as SOCKET, + &mut buf, + 1, + &mut bytes_written, + 0, + overlapped, + None, + ); + cvt(r, bytes_written) + } + + unsafe fn connect_overlapped( + &self, + addr: &SocketAddr, + buf: &[u8], + overlapped: *mut OVERLAPPED, + ) -> io::Result> { + connect_overlapped(self.as_raw_socket() as SOCKET, addr, buf, overlapped) + } + + fn connect_complete(&self) -> io::Result<()> { + const SO_UPDATE_CONNECT_CONTEXT: c_int = 0x7010; + let result = unsafe { + setsockopt( + self.as_raw_socket() as SOCKET, + SOL_SOCKET, + SO_UPDATE_CONNECT_CONTEXT, + 0 as *const _, + 0, + ) + }; + if result == 0 { + Ok(()) + } else { + Err(io::Error::last_os_error()) + } + } + + unsafe fn result(&self, overlapped: *mut OVERLAPPED) -> io::Result<(usize, u32)> { + result(self.as_raw_socket() as SOCKET, overlapped) + } +} + +unsafe fn connect_overlapped( + socket: SOCKET, + addr: &SocketAddr, + buf: &[u8], + overlapped: *mut OVERLAPPED, +) -> io::Result> { + let anonaddr = WinSock::sockaddr_un { + sun_family: WinSock::AF_UNIX, + sun_path: [0; 108], + }; + let len = mem::size_of::() as c_int; + wsa_syscall!( + bind(socket as _, &anonaddr as *const _ as *const _, len as _), + PartialEq::eq, + SOCKET_ERROR + )?; + + static CONNECTEX: WsaExtension = WsaExtension { + guid: GUID { + Data1: 0x25a207b9, + Data2: 0xddf3, + Data3: 0x4660, + Data4: [0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e], + }, + val: AtomicUsize::new(0), + }; + type ConnectEx = unsafe extern "system" fn( + SOCKET, + *const SOCKADDR, + c_int, + PVOID, + DWORD, + LPDWORD, + LPOVERLAPPED, + ) -> BOOL; + + let ptr = CONNECTEX.get(socket)?; + assert!(ptr != 0); + let connect_ex = mem::transmute::<_, ConnectEx>(ptr); + + let (addr_buf, addr_len) = socket_addr_to_ptrs(addr); + let mut bytes_sent: DWORD = 0; + let r = connect_ex( + socket, + addr_buf, + addr_len, + buf.as_ptr() as *mut _, + buf.len() as u32, + &mut bytes_sent, + overlapped, + ); + if r == 1 { + Ok(Some(bytes_sent as usize)) + } else { + last_err() + } +} + +impl UnixListenerExt for UnixListener { + unsafe fn accept_overlapped( + &self, + socket: &UnixStream, + addrs: &mut AcceptAddrsBuf, + overlapped: *mut OVERLAPPED, + ) -> io::Result { + static ACCEPTEX: WsaExtension = WsaExtension { + guid: GUID { + Data1: 0xb5367df1, + Data2: 0xcbac, + Data3: 0x11cf, + Data4: [0x95, 0xca, 0x00, 0x80, 0x5f, 0x48, 0xa1, 0x92], + }, + val: AtomicUsize::new(0), + }; + type AcceptEx = unsafe extern "system" fn( + SOCKET, + SOCKET, + PVOID, + DWORD, + DWORD, + DWORD, + LPDWORD, + LPOVERLAPPED, + ) -> BOOL; + + let ptr = ACCEPTEX.get(self.as_raw_socket() as SOCKET)?; + assert!(ptr != 0); + let accept_ex = mem::transmute::<_, AcceptEx>(ptr); + + let mut bytes = 0; + let (a, b, c, d) = (*addrs).args(); + let r = accept_ex( + self.as_raw_socket() as SOCKET, + socket.as_raw_socket() as SOCKET, + a, + b, + c, + d, + &mut bytes, + overlapped, + ); + let succeeded = if r == 1 { + true + } else { + last_err()?; + false + }; + Ok(succeeded) + } + + fn accept_complete(&self, socket: &UnixStream) -> io::Result<()> { + const SO_UPDATE_ACCEPT_CONTEXT: c_int = 0x700B; + let me = self.as_raw_socket(); + let result = unsafe { + setsockopt( + socket.as_raw_socket() as SOCKET, + SOL_SOCKET, + SO_UPDATE_ACCEPT_CONTEXT, + &me as *const _ as *const _, + mem::size_of_val(&me) as c_int, + ) + }; + if result == 0 { + Ok(()) + } else { + Err(io::Error::last_os_error()) + } + } + + unsafe fn result(&self, overlapped: *mut OVERLAPPED) -> io::Result<(usize, u32)> { + result(self.as_raw_socket() as SOCKET, overlapped) + } +} + +static GETACCEPTEXSOCKADDRS: WsaExtension = WsaExtension { + guid: GUID { + Data1: 0xb5367df2, + Data2: 0xcbac, + Data3: 0x11cf, + Data4: [0x95, 0xca, 0x00, 0x80, 0x5f, 0x48, 0xa1, 0x92], + }, + val: AtomicUsize::new(0), +}; +type GetAcceptExSockaddrs = unsafe extern "system" fn( + PVOID, + DWORD, + DWORD, + DWORD, + *mut LPSOCKADDR, + LPINT, + *mut LPSOCKADDR, + LPINT, +); + +impl AcceptAddrsBuf { + /// Creates a new blank buffer ready to be passed to a call to + /// `accept_overlapped`. + pub fn new() -> AcceptAddrsBuf { + unsafe { mem::zeroed() } + } + + /// Parses the data contained in this address buffer, returning the parsed + /// result if successful. + /// + /// This function can be called after a call to `accept_overlapped` has + /// succeeded to parse out the data that was written in. + pub fn parse(&self, socket: &UnixListener) -> io::Result { + let mut ret = AcceptAddrs { + local: 0 as *mut _, + local_len: 0, + remote: 0 as *mut _, + remote_len: 0, + _data: self, + }; + let ptr = GETACCEPTEXSOCKADDRS.get(socket.as_raw_socket() as SOCKET)?; + assert!(ptr != 0); + unsafe { + let get_sockaddrs = mem::transmute::<_, GetAcceptExSockaddrs>(ptr); + let (a, b, c, d) = self.args(); + get_sockaddrs( + a, + b, + c, + d, + &mut ret.local, + &mut ret.local_len, + &mut ret.remote, + &mut ret.remote_len, + ); + Ok(ret) + } + } + + fn args(&self) -> (PVOID, DWORD, DWORD, DWORD) { + let remote_offset = unsafe { &(*(0 as *const AcceptAddrsBuf)).remote as *const _ as usize }; + ( + self as *const _ as *mut _, + 0, + remote_offset as DWORD, + (mem::size_of_val(self) - remote_offset) as DWORD, + ) + } +} + +impl<'a> AcceptAddrs<'a> { + /// Returns the local socket address contained in this buffer. + #[allow(dead_code)] + pub fn local(&self) -> Option { + unsafe { ptrs_to_socket_addr(self.local, self.local_len) } + } + + /// Returns the remote socket address contained in this buffer. + pub fn remote(&self) -> Option { + unsafe { ptrs_to_socket_addr(self.remote, self.remote_len) } + } +} + +impl WsaExtension { + fn get(&self, socket: SOCKET) -> io::Result { + let prev = self.val.load(Ordering::SeqCst); + if prev != 0 && !cfg!(debug_assertions) { + return Ok(prev); + } + let mut ret = 0 as usize; + let mut bytes = 0; + let r = unsafe { + WSAIoctl( + socket, + SIO_GET_EXTENSION_FUNCTION_POINTER, + &self.guid as *const _ as *mut _, + mem::size_of_val(&self.guid) as DWORD, + &mut ret as *mut _ as *mut _, + mem::size_of_val(&ret) as DWORD, + &mut bytes, + 0 as *mut _, + None, + ) + }; + cvt(r, 0).map(|_| { + debug_assert_eq!(bytes as usize, mem::size_of_val(&ret)); + debug_assert!(prev == 0 || prev == ret); + self.val.store(ret, Ordering::SeqCst); + ret + }) + } +} diff --git a/src/sys/windows/uds/stdnet/mod.rs b/src/sys/windows/uds/stdnet/mod.rs new file mode 100644 index 000000000..5a0a33047 --- /dev/null +++ b/src/sys/windows/uds/stdnet/mod.rs @@ -0,0 +1,134 @@ +use std::ascii; +use std::fmt; +use std::io; +use std::mem; +use std::os::raw::{c_char, c_int}; +use std::path::Path; + +use windows_sys::Win32::Networking::WinSock::{ + self, + SOCKADDR, + SOCKET_ERROR, + WSAGetLastError +}; + +mod ext; +mod net; +mod socket; + +enum AddressKind<'a> { + Unnamed, + Pathname(&'a Path), + Abstract(&'a [u8]), +} + +/// An address associated with a Unix socket +#[derive(Copy, Clone)] +pub struct SocketAddr { + addr: WinSock::sockaddr_un, + len: c_int, +} + +impl SocketAddr { + fn new(f: F) -> io::Result + where + F: FnOnce(*mut SOCKADDR, *mut c_int) -> c_int, + { + let mut sockaddr = { + let sockaddr = mem::MaybeUninit::::zeroed(); + unsafe { sockaddr.assume_init() } + }; + + let mut len = mem::size_of::() as c_int; + wsa_syscall!( + f(&mut sockaddr as *mut _ as *mut _, &mut len), + PartialEq::eq, + SOCKET_ERROR + )?; + Ok(SocketAddr::from_parts(sockaddr, len)) + } + + fn from_parts(addr: WinSock::sockaddr_un, mut len: c_int) -> SocketAddr { + if len == 0 { + // When there is a datagram from unnamed unix socket + // linux returns zero bytes of address + len = path_offset(&addr) as c_int; // i.e. zero-length address + } + SocketAddr { addr, len } + } + + /// Returns true if and only if the address is unnamed. + pub fn is_unnamed(&self) -> bool { + if let AddressKind::Unnamed = self.address() { + true + } else { + false + } + } + + /// Returns the contents of this address if it is a `pathname` address. + pub fn as_pathname(&self) -> Option<&Path> { + if let AddressKind::Pathname(path) = self.address() { + Some(path) + } else { + None + } + } + + fn address<'a>(&'a self) -> AddressKind<'a> { + let len = self.len as usize - path_offset(&self.addr); + // WinSock::sockaddr_un::sun_path on Windows is a Win32 UTF-8 file system path + let path = unsafe { mem::transmute::<&[c_char], &[u8]>(&self.addr.sun_path) }; + + // macOS seems to return a len of 16 and a zeroed sun_path for unnamed addresses + if len == 0 + || (cfg!(not(any(target_os = "linux", target_os = "android"))) + && self.addr.sun_path[0] == 0) + { + AddressKind::Unnamed + } else if self.addr.sun_path[0] == 0 { + AddressKind::Abstract(&path[1..len]) + } else { + use std::ffi::CStr; + let pathname = unsafe { CStr::from_bytes_with_nul_unchecked(&path[..len]) }; + AddressKind::Pathname(Path::new(pathname.to_str().unwrap())) + } + } +} + +impl fmt::Debug for SocketAddr { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + match self.address() { + AddressKind::Unnamed => write!(fmt, "(unnamed)"), + AddressKind::Abstract(name) => write!(fmt, "{} (abstract)", AsciiEscaped(name)), + AddressKind::Pathname(path) => write!(fmt, "{:?} (pathname)", path), + } + } +} + +impl PartialEq for SocketAddr { + fn eq(&self, other: &SocketAddr) -> bool { + let ita = self.addr.sun_path.iter(); + let itb = other.addr.sun_path.iter(); + + self.len == other.len + && self.addr.sun_family == other.addr.sun_family + && ita.zip(itb).all(|(a, b)| a == b) + } +} + +struct AsciiEscaped<'a>(&'a [u8]); + +impl<'a> fmt::Display for AsciiEscaped<'a> { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + write!(fmt, "\"")?; + for byte in self.0.iter().cloned().flat_map(ascii::escape_default) { + write!(fmt, "{}", byte as char)?; + } + write!(fmt, "\"") + } +} + +pub use self::ext::{AcceptAddrs, AcceptAddrsBuf, UnixListenerExt, UnixStreamExt}; +pub use self::net::{UnixListener, UnixStream}; +pub use self::socket::Socket; diff --git a/src/sys/windows/uds/stdnet/net.rs b/src/sys/windows/uds/stdnet/net.rs new file mode 100644 index 000000000..ca75c46e2 --- /dev/null +++ b/src/sys/windows/uds/stdnet/net.rs @@ -0,0 +1,480 @@ +use std::fmt; +use std::io; +use std::mem; +use std::net::Shutdown; +use std::os::raw::c_int; +use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; +use std::path::Path; +use std::time::Duration; + +use windows_sys::Win32::Networking::WinSock::{ + self, + bind, connect, getpeername, getsockname, listen, SO_RCVTIMEO, SO_SNDTIMEO, +}; + +use crate::sys::windows::net::init; +use super::socket::Socket; +use super::{socket_addr, SocketAddr}; + +/// A Unix stream socket +pub struct UnixStream(Socket); + +impl fmt::Debug for UnixStream { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + let mut builder = fmt.debug_struct("UnixStream"); + builder.field("socket", &self.0.as_raw_socket()); + if let Ok(addr) = self.local_addr() { + builder.field("local", &addr); + } + if let Ok(addr) = self.peer_addr() { + builder.field("peer", &addr); + } + builder.finish() + } +} + +impl UnixStream { + /// Connects to the socket named by `path`. + pub fn connect>(path: P) -> io::Result { + init(); + fn inner(path: &Path) -> io::Result { + unsafe { + let inner = Socket::new()?; + let (addr, len) = socket_addr(path)?; + + wsa_syscall!( + connect( + inner.as_raw_socket() as _, + &addr as *const _ as *const _, + len as i32, + ), + PartialEq::eq, + SOCKET_ERROR + )?; + Ok(UnixStream(inner)) + } + } + inner(path.as_ref()) + } + + /// Creates a new independently owned handle to the underlying socket. + /// + /// The returned `UnixStream` is a reference to the same stream that this + /// object references. Both handles will read and write the same stream of + /// data, and options set on one stream will be propagated to the other + /// stream. + pub fn try_clone(&self) -> io::Result { + self.0.duplicate().map(UnixStream) + } + + /// Returns the socket address of the local half of this connection. + pub fn local_addr(&self) -> io::Result { + SocketAddr::new(|addr, len| unsafe { getsockname(self.0.as_raw_socket() as _, addr, len) }) + } + + /// Returns the socket address of the remote half of this connection. + pub fn peer_addr(&self) -> io::Result { + SocketAddr::new(|addr, len| unsafe { getpeername(self.0.as_raw_socket() as _, addr, len) }) + } + + /// Moves the socket into or out of nonblocking mode. + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.0.set_nonblocking(nonblocking) + } + + /// Returns the value of the `SO_ERROR` option. + pub fn take_error(&self) -> io::Result> { + self.0.take_error() + } + + /// Shuts down the read, write, or both halves of this connection. + /// + /// This function will cause all pending and future I/O calls on the + /// specified portions to immediately return with an appropriate value + /// (see the documentation for `Shutdown`). + pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { + self.0.shutdown(how) + } + + pub fn pair() -> io::Result<(Self, Self)> { + use std::sync::{Arc, RwLock}; + use std::thread::spawn; + + let dir = tempfile::tempdir()?; + let file_path = dir.path().join("socket"); + let a: Arc>>> = Arc::new(RwLock::new(None)); + let ul = UnixListener::bind(&file_path).unwrap(); + let server = { + let a = a.clone(); + spawn(move || { + let mut store = a.write().unwrap(); + let stream0 = ul.accept().map(|s| s.0); + *store = Some(stream0); + }) + }; + let stream1 = UnixStream::connect(&file_path)?; + server + .join() + .map_err(|_| io::Error::from(io::ErrorKind::ConnectionRefused))?; + let stream0 = (*(a.write().unwrap())).take().unwrap()?; + return Ok((stream0, stream1)); + } + + /// Sets the read timeout to the timeout specified. + /// + /// If the value specified is `None`, then `read` calls will block + /// indefinitely. An `Err` is returned if the zero `Duration` is + /// passed to this method. + pub fn set_read_timeout(&self, dur: Option) -> io::Result<()> { + self.0.set_timeout(dur, SO_RCVTIMEO) + } + + /// Sets the write timeout to the timeout specified. + /// + /// If the value specified is `None`, then `write` calls will block + /// indefinitely. An `Err` is returned if the zero `Duration` is + /// passed to this method. + pub fn set_write_timeout(&self, dur: Option) -> io::Result<()> { + self.0.set_timeout(dur, SO_SNDTIMEO) + } + + /// Returns the read timeout of this socket. + pub fn read_timeout(&self) -> io::Result> { + self.0.timeout(SO_RCVTIMEO) + } + + /// Returns the write timeout of this socket. + pub fn write_timeout(&self) -> io::Result> { + self.0.timeout(SO_SNDTIMEO) + } +} + +impl io::Read for UnixStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + io::Read::read(&mut &*self, buf) + } +} + +impl<'a> io::Read for &'a UnixStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.read(buf) + } +} + +impl io::Write for UnixStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + io::Write::write(&mut &*self, buf) + } + + fn flush(&mut self) -> io::Result<()> { + io::Write::flush(&mut &*self) + } +} + +impl<'a> io::Write for &'a UnixStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +impl AsRawSocket for UnixStream { + fn as_raw_socket(&self) -> RawSocket { + self.0.as_raw_socket() + } +} + +impl FromRawSocket for UnixStream { + unsafe fn from_raw_socket(sock: RawSocket) -> Self { + UnixStream(Socket::from_raw_socket(sock)) + } +} + +impl IntoRawSocket for UnixStream { + fn into_raw_socket(self) -> RawSocket { + let ret = self.0.as_raw_socket(); + mem::forget(self); + ret + } +} + +/// A Unix domain socket server +pub struct UnixListener(Socket); + +impl fmt::Debug for UnixListener { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + let mut builder = fmt.debug_struct("UnixListener"); + builder.field("socket", &self.0.as_raw_socket()); + if let Ok(addr) = self.local_addr() { + builder.field("local", &addr); + } + builder.finish() + } +} + +impl UnixListener { + /// Creates a new `UnixListener` bound to the specified socket. + pub fn bind>(path: P) -> io::Result { + init(); + fn inner(path: &Path) -> io::Result { + unsafe { + let inner = Socket::new()?; + let (addr, len) = socket_addr(path)?; + + wsa_syscall!( + bind( + inner.as_raw_socket() as _, + &addr as *const _ as *const _, + len as _, + ), + PartialEq::eq, + SOCKET_ERROR + )?; + wsa_syscall!( + listen(inner.as_raw_socket() as _, 128), + PartialEq::eq, + SOCKET_ERROR + )?; + + Ok(UnixListener(inner)) + } + } + inner(path.as_ref()) + } + + /// Accepts a new incoming connection to this listener. + /// + /// This function will block the calling thread until a new Unix connection + /// is established. When established, the corresponding [`UnixStream`] and + /// the remote peer's address will be returned. + /// + /// [`UnixStream`]: struct.UnixStream.html + pub fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { + let mut storage: WinSock::sockaddr_un = unsafe { mem::zeroed() }; + let mut len = mem::size_of_val(&storage) as c_int; + let sock = self.0.accept(&mut storage as *mut _ as *mut _, &mut len)?; + let addr = SocketAddr::from_parts(storage, len); + Ok((UnixStream(sock), addr)) + } + + /// Creates a new independently owned handle to the underlying socket. + /// + /// The returned `UnixListener` is a reference to the same socket that this + /// object references. Both handles can be used to accept incoming + /// connections and options set on one listener will affect the other. + pub fn try_clone(&self) -> io::Result { + self.0.duplicate().map(UnixListener) + } + + /// Returns the local socket address of this listener. + pub fn local_addr(&self) -> io::Result { + SocketAddr::new(|addr, len| unsafe { getsockname(self.0.as_raw_socket() as _, addr, len) }) + } + + /// Moves the socket into or out of nonblocking mode. + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.0.set_nonblocking(nonblocking) + } + + /// Returns the value of the `SO_ERROR` option. + pub fn take_error(&self) -> io::Result> { + self.0.take_error() + } + + /// Returns an iterator over incoming connections. + /// + /// The iterator will never return `None` and will also not yield the + /// peer's [`SocketAddr`] structure. + /// + /// [`SocketAddr`]: struct.SocketAddr.html + pub fn incoming<'a>(&'a self) -> Incoming<'a> { + Incoming { listener: self } + } +} + +impl AsRawSocket for UnixListener { + fn as_raw_socket(&self) -> RawSocket { + self.0.as_raw_socket() + } +} + +impl FromRawSocket for UnixListener { + unsafe fn from_raw_socket(sock: RawSocket) -> Self { + UnixListener(Socket::from_raw_socket(sock)) + } +} + +impl IntoRawSocket for UnixListener { + fn into_raw_socket(self) -> RawSocket { + let ret = self.0.as_raw_socket(); + mem::forget(self); + ret + } +} + +impl<'a> IntoIterator for &'a UnixListener { + type Item = io::Result; + type IntoIter = Incoming<'a>; + + fn into_iter(self) -> Incoming<'a> { + self.incoming() + } +} + +/// An iterator over incoming connections to a [`UnixListener`]. +/// +/// It will never return `None`. +/// +/// [`UnixListener`]: struct.UnixListener.html +#[derive(Debug)] +pub struct Incoming<'a> { + listener: &'a UnixListener, +} + +impl<'a> Iterator for Incoming<'a> { + type Item = io::Result; + + fn next(&mut self) -> Option> { + Some(self.listener.accept().map(|s| s.0)) + } + + fn size_hint(&self) -> (usize, Option) { + (usize::max_value(), None) + } +} + +#[cfg(test)] +mod test { + extern crate tempfile; + + use std::io::{self, Read, Write}; + use std::path::PathBuf; + use std::thread; + + use self::tempfile::TempDir; + + use super::*; + + macro_rules! or_panic { + ($e:expr) => { + match $e { + Ok(e) => e, + Err(e) => panic!("{}", e), + } + }; + } + + fn tmpdir() -> Result<(TempDir, PathBuf), io::Error> { + let dir = tempfile::tempdir()?; + let path = dir.path().join("sock"); + Ok((dir, path)) + } + + #[test] + fn basic() { + let (_dir, socket_path) = or_panic!(tmpdir()); + let msg1 = b"hello"; + let msg2 = b"world!"; + + let listener = or_panic!(UnixListener::bind(&socket_path)); + let thread = thread::spawn(move || { + let mut stream = or_panic!(listener.accept()).0; + let mut buf = [0; 5]; + or_panic!(stream.read(&mut buf)); + assert_eq!(&msg1[..], &buf[..]); + or_panic!(stream.write_all(msg2)); + }); + + let mut stream = or_panic!(UnixStream::connect(&socket_path)); + assert_eq!( + Some(&*socket_path), + stream.peer_addr().unwrap().as_pathname() + ); + or_panic!(stream.write_all(msg1)); + let mut buf = vec![]; + or_panic!(stream.read_to_end(&mut buf)); + assert_eq!(&msg2[..], &buf[..]); + drop(stream); + + thread.join().unwrap(); + } + + #[test] + fn try_clone() { + let (_dir, socket_path) = or_panic!(tmpdir()); + let msg1 = b"hello"; + let msg2 = b"world"; + + let listener = or_panic!(UnixListener::bind(&socket_path)); + let thread = thread::spawn(move || { + #[allow(unused_mut)] + let mut stream = or_panic!(listener.accept()).0; + or_panic!(stream.write_all(msg1)); + or_panic!(stream.write_all(msg2)); + }); + + let mut stream = or_panic!(UnixStream::connect(&socket_path)); + let mut stream2 = or_panic!(stream.try_clone()); + assert_eq!( + Some(&*socket_path), + stream2.peer_addr().unwrap().as_pathname() + ); + + let mut buf = [0; 5]; + or_panic!(stream.read(&mut buf)); + assert_eq!(&msg1[..], &buf[..]); + or_panic!(stream2.read(&mut buf)); + assert_eq!(&msg2[..], &buf[..]); + + thread.join().unwrap(); + } + + #[test] + fn iter() { + let (_dir, socket_path) = or_panic!(tmpdir()); + + let listener = or_panic!(UnixListener::bind(&socket_path)); + let thread = thread::spawn(move || { + for stream in listener.incoming().take(2) { + let mut stream = or_panic!(stream); + let mut buf = [0]; + or_panic!(stream.read(&mut buf)); + } + }); + + for _ in 0..2 { + let mut stream = or_panic!(UnixStream::connect(&socket_path)); + or_panic!(stream.write_all(&[0])); + } + + thread.join().unwrap(); + } + + #[test] + fn long_path() { + let dir = or_panic!(tempfile::tempdir()); + let socket_path = dir.path().join( + "asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfa\ + sasdfasdfasdasdfasdfasdfadfasdfasdfasdfasdfasdf", + ); + match UnixStream::connect(&socket_path) { + Err(ref e) if e.kind() == io::ErrorKind::InvalidInput => {} + Err(e) => panic!("unexpected error {}", e), + Ok(_) => panic!("unexpected success"), + } + + match UnixListener::bind(&socket_path) { + Err(ref e) if e.kind() == io::ErrorKind::InvalidInput => {} + Err(e) => panic!("unexpected error {}", e), + Ok(_) => panic!("unexpected success"), + } + } + + #[test] + fn abstract_namespace_not_allowed() { + assert!(UnixStream::connect("\0asdf").is_err()); + } +} diff --git a/src/sys/windows/uds/stdnet/socket.rs b/src/sys/windows/uds/stdnet/socket.rs new file mode 100644 index 000000000..8d809d01f --- /dev/null +++ b/src/sys/windows/uds/stdnet/socket.rs @@ -0,0 +1,287 @@ +#![allow(non_camel_case_types)] + +use std::io; +use std::mem; +use std::net::Shutdown; +use std::os::raw::{c_int, c_ulong}; +use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; +use std::ptr; +use std::sync::Once; +use std::time::Duration; + +use windows_sys::Win32::Foundation::{ + HANDLE, + SetHandleInformation, + HANDLE_FLAG_INHERIT +}; +use windows_sys::Win32::System::Threading::GetCurrentProcessId; +use windows_sys::Win32::System::WindowsProgramming::INFINITE; +use windows_sys::Win32::Networking::WinSock::{ + self, + SOCKET_ERROR, + AF_UNIX, + SOCKADDR, + SOCK_STREAM, + SOL_SOCKET, + SO_ERROR, + accept, closesocket, ioctlsocket, recv, send, + setsockopt, shutdown, WSADuplicateSocketW, WSASocketW, FIONBIO, + INVALID_SOCKET, SOCKET, WSADATA, WSAPROTOCOL_INFOW, + WSA_FLAG_OVERLAPPED, + SD_RECEIVE, + SD_SEND, + SD_BOTH +}; + +// TODO +type socklen_t = i32; +type DWORD = u32; + +#[derive(Debug)] +pub struct Socket(SOCKET); + +impl Socket { + pub fn new() -> io::Result { + let socket = wsa_syscall!( + WSASocketW( + AF_UNIX, + SOCK_STREAM, + 0, + ptr::null_mut(), + 0, + WSA_FLAG_OVERLAPPED, + ) + PartialEq::eq, + INVALID_SOCKET + )?; + socket.set_no_inherit()?; + Ok(socket) + } + + pub fn accept(&self, storage: *mut SOCKADDR, len: *mut c_int) -> io::Result { + let socket = wsa_syscall!( + accept(self.0, storage, len), + PartialEq::eq, + INVALID_SOCKET + )?; + socket.set_no_inherit()?; + Ok(socket) + } + + pub fn duplicate(&self) -> io::Result { + let socket = unsafe { + let mut info: WSAPROTOCOL_INFOW = mem::zeroed(); + wsa_syscall!( + WSADuplicateSocketW( + self.0, + GetCurrentProcessId(), + &mut info, + ), + PartialEq::eq, + SOCKET_ERROR + )?; + let n = wsa_syscall!( + WSASocketW( + info.iAddressFamily, + info.iSocketType, + info.iProtocol, + &mut info, + 0, + WSA_FLAG_OVERLAPPED, + ) + PartialEq::eq, + INVALID_SOCKET + )?; + Socket(n) + }; + socket.set_no_inherit()?; + Ok(socket) + } + + fn recv_with_flags(&self, buf: &mut [u8], flags: c_int) -> io::Result { + let ret = wsa_syscall!( + recv( + self.0, + buf.as_mut_ptr() as *mut _, + buf.len() as c_int, + flags, + ), + PartialEq::eq, + SOCKET_ERROR + )?; + Ok(ret as usize) + } + + pub fn read(&self, buf: &mut [u8]) -> io::Result { + self.recv_with_flags(buf, 0) + } + + pub fn write(&self, buf: &[u8]) -> io::Result { + let ret = wsa_syscall!( + send(self.0, buf as *const _ as *const _, buf.len() as c_int, 0), + PartialEq::eq, + SOCKET_ERROR + )?; + Ok(ret as usize) + } + + fn set_no_inherit(&self) -> io::Result<()> { + syscall!( + SetHandleInformation(self.0 as HANDLE, HANDLE_FLAG_INHERIT, 0), + PartialEq::eq, + 0 + ) + } + + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + let mut nonblocking = nonblocking as c_ulong; + wsa_syscall!( + ioctlsocket(self.0, FIONBIO as c_int, &mut nonblocking), + PartialEq::eq, + SOCKET_ERROR + ) + } + + pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { + let how = match how { + Shutdown::Write => SD_SEND, + Shutdown::Read => SD_RECEIVE, + Shutdown::Both => SD_BOTH, + }; + wsa_syscall!( + shutdown(self.0, how), + PartialEq::eq, + SOCKET_ERROR + )?; + Ok(()) + } + + pub fn take_error(&self) -> io::Result> { + let raw: c_int = getsockopt(self, SOL_SOCKET, SO_ERROR)?; + if raw == 0 { + Ok(None) + } else { + Ok(Some(io::Error::from_raw_os_error(raw as i32))) + } + } + + pub fn set_timeout(&self, dur: Option, kind: c_int) -> io::Result<()> { + let timeout = match dur { + Some(dur) => { + let timeout = dur2timeout(dur); + if timeout == 0 { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "cannot set a 0 duration timeout", + )); + } + timeout + } + None => 0, + }; + setsockopt(self, SOL_SOCKET, kind, timeout) + } + + pub fn timeout(&self, kind: c_int) -> io::Result> { + let raw: DWORD = getsockopt(self, SOL_SOCKET, kind)?; + if raw == 0 { + Ok(None) + } else { + let secs = raw / 1000; + let nsec = (raw % 1000) * 1000000; + Ok(Some(Duration::new(secs as u64, nsec as u32))) + } + } +} + +pub fn setsockopt(sock: &Socket, opt: c_int, val: c_int, payload: T) -> io::Result<()> { + unsafe { + let payload = &payload as *const T as *const _; + wsa_syscall!( + WinSock::setsockopt( + sock.as_raw_socket() as usize, + opt, + val, + payload, + mem::size_of::() as socklen_t, + ), + PartialEq::eq, + SOCKET_ERROR + )?; + Ok(()) + } +} + +pub fn getsockopt(sock: &Socket, opt: c_int, val: c_int) -> io::Result { + unsafe { + let mut slot: T = mem::zeroed(); + let mut len = mem::size_of::() as socklen_t; + wsa_syscall!( + WinSock::getsockopt( + sock.as_raw_socket() as _, + opt, + val, + &mut slot as *mut _ as *mut _, + &mut len, + ), + PartialEq::eq, + SOCKET_ERROR + )?; + assert_eq!(len as usize, mem::size_of::()); + Ok(slot) + } +} + +fn dur2timeout(dur: Duration) -> DWORD { + // Note that a duration is a (u64, u32) (seconds, nanoseconds) pair, and the + // timeouts in windows APIs are typically u32 milliseconds. To translate, we + // have two pieces to take care of: + // + // * Nanosecond precision is rounded up + // * Greater than u32::MAX milliseconds (50 days) is rounded up to INFINITE + // (never time out). + dur.as_secs() + .checked_mul(1000) + .and_then(|ms| ms.checked_add((dur.subsec_nanos() as u64) / 1_000_000)) + .and_then(|ms| { + ms.checked_add(if dur.subsec_nanos() % 1_000_000 > 0 { + 1 + } else { + 0 + }) + }) + .map(|ms| { + if ms > ::max_value() as u64 { + INFINITE + } else { + ms as DWORD + } + }) + .unwrap_or(INFINITE) +} + +impl Drop for Socket { + fn drop(&mut self) { + let _ = unsafe { closesocket(self.0) }; + } +} + +impl AsRawSocket for Socket { + fn as_raw_socket(&self) -> RawSocket { + self.0 as RawSocket + } +} + +impl FromRawSocket for Socket { + unsafe fn from_raw_socket(sock: RawSocket) -> Self { + Socket(sock as SOCKET) + } +} + +impl IntoRawSocket for Socket { + fn into_raw_socket(self) -> RawSocket { + let ret = self.0 as RawSocket; + mem::forget(self); + ret + } +} diff --git a/src/sys/windows/uds/stream.rs b/src/sys/windows/uds/stream.rs new file mode 100644 index 000000000..92b241b23 --- /dev/null +++ b/src/sys/windows/uds/stream.rs @@ -0,0 +1,46 @@ +use std::io; +use std::os::windows::io::{AsRawSocket, FromRawSocket}; +use std::path::Path; +use windows_sys::Win32::Networking::WinSock; + +use super::{stdnet as net, socket_addr}; +use crate::net::SocketAddr; +use crate::sys::windows::net::{init, new_socket}; + +pub(crate) fn connect(path: &Path) -> io::Result { + init(); + let socket = new_socket(WinSock::AF_UNIX, WinSock::SOCK_STREAM)?; + let (sockaddr, socklen) = socket_addr(path)?; + let sockaddr = &sockaddr as *const WinSock::sockaddr_un as *const WinSock::SOCKADDR; + + wsa_syscall!( + connect(socket, sockaddr, socklen as _), + PartialEq::eq, + SOCKET_ERROR + )?; + match syscall!(connect(socket, sockaddr, socklen)) { + Ok(_) => {} + Err(ref err) if err.raw_os_error() == Some(WinSock::WSAEINPROGRESS) => {} + Err(e) => { + // Close the socket if we hit an error, ignoring the error + // from closing since we can't pass back two errors. + let _ = unsafe { WinSock::closesocket(socket) }; + + return Err(e); + } + } + + Ok(unsafe { net::UnixStream::from_raw_socket(socket) }) +} + +pub(crate) fn pair() -> io::Result<(net::UnixStream, net::UnixStream)> { + net::UnixStream::pair() +} + +pub(crate) fn local_addr(socket: &net::UnixStream) -> io::Result { + super::local_addr(socket.as_raw_socket()) +} + +pub(crate) fn peer_addr(socket: &net::UnixStream) -> io::Result { + super::peer_addr(socket.as_raw_socket()) +} From bec570b711a1c2b1d2dfa0f3ba1dcb9ac5640ea9 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Mon, 15 Aug 2022 06:18:18 -0700 Subject: [PATCH 02/34] modify src/net for windows compatibility --- src/net/mod.rs | 6 ++++-- src/net/uds/listener.rs | 30 ++++++++++++++++++++++++++++++ src/net/uds/mod.rs | 2 ++ src/net/uds/stream.rs | 30 ++++++++++++++++++++++++++++++ 4 files changed, 66 insertions(+), 2 deletions(-) diff --git a/src/net/mod.rs b/src/net/mod.rs index 7d714ca00..caadc0c8b 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -33,7 +33,9 @@ mod udp; #[cfg(not(target_os = "wasi"))] pub use self::udp::UdpSocket; -#[cfg(unix)] +#[cfg(any(unix, windows))] mod uds; +#[cfg(any(unix, windows))] +pub use self::uds::{SocketAddr, UnixListener, UnixStream}; #[cfg(unix)] -pub use self::uds::{SocketAddr, UnixDatagram, UnixListener, UnixStream}; +pub use self::uds::UnixDatagram; diff --git a/src/net/uds/listener.rs b/src/net/uds/listener.rs index 37e8106d8..20ea3323e 100644 --- a/src/net/uds/listener.rs +++ b/src/net/uds/listener.rs @@ -2,8 +2,14 @@ use crate::io_source::IoSource; use crate::net::{SocketAddr, UnixStream}; use crate::{event, sys, Interest, Registry, Token}; +#[cfg(unix)] use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; +#[cfg(unix)] use std::os::unix::net; +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; +#[cfg(windows)] +use crate::sys::windows::uds::{stdnet as net}; use std::path::Path; use std::{fmt, io}; @@ -79,18 +85,21 @@ impl fmt::Debug for UnixListener { } } +#[cfg(unix)] impl IntoRawFd for UnixListener { fn into_raw_fd(self) -> RawFd { self.inner.into_inner().into_raw_fd() } } +#[cfg(unix)] impl AsRawFd for UnixListener { fn as_raw_fd(&self) -> RawFd { self.inner.as_raw_fd() } } +#[cfg(unix)] impl FromRawFd for UnixListener { /// Converts a `RawFd` to a `UnixListener`. /// @@ -102,3 +111,24 @@ impl FromRawFd for UnixListener { UnixListener::from_std(FromRawFd::from_raw_fd(fd)) } } + +#[cfg(windows)] +impl IntoRawSocket for UnixListener { + fn into_raw_socket(self) -> RawSocket { + self.inner.into_inner().into_raw_socket() + } +} + +#[cfg(windows)] +impl AsRawSocket for UnixListener { + fn as_raw_socket(&self) -> RawSocket { + self.inner.as_raw_socket() + } +} + +#[cfg(windows)] +impl FromRawSocket for UnixListener { + unsafe fn from_raw_socket(sock: RawSocket) -> Self { + UnixListener::from_std(FromRawSocket::from_raw_socket(sock)) + } +} diff --git a/src/net/uds/mod.rs b/src/net/uds/mod.rs index 6b4ffdc43..c0a77bbf2 100644 --- a/src/net/uds/mod.rs +++ b/src/net/uds/mod.rs @@ -1,4 +1,6 @@ +#[cfg(unix)] mod datagram; +#[cfg(unix)] pub use self::datagram::UnixDatagram; mod listener; diff --git a/src/net/uds/stream.rs b/src/net/uds/stream.rs index b41ef9da3..7e7bd302f 100644 --- a/src/net/uds/stream.rs +++ b/src/net/uds/stream.rs @@ -4,8 +4,14 @@ use crate::{event, sys, Interest, Registry, Token}; use std::fmt; use std::io::{self, IoSlice, IoSliceMut, Read, Write}; use std::net::Shutdown; +#[cfg(unix)] use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; +#[cfg(unix)] use std::os::unix::net; +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; +#[cfg(windows)] +use crate::sys::windows::uds::{stdnet as net}; use std::path::Path; /// A non-blocking Unix stream socket. @@ -220,18 +226,21 @@ impl fmt::Debug for UnixStream { } } +#[cfg(unix)] impl IntoRawFd for UnixStream { fn into_raw_fd(self) -> RawFd { self.inner.into_inner().into_raw_fd() } } +#[cfg(unix)] impl AsRawFd for UnixStream { fn as_raw_fd(&self) -> RawFd { self.inner.as_raw_fd() } } +#[cfg(unix)] impl FromRawFd for UnixStream { /// Converts a `RawFd` to a `UnixStream`. /// @@ -243,3 +252,24 @@ impl FromRawFd for UnixStream { UnixStream::from_std(FromRawFd::from_raw_fd(fd)) } } + +#[cfg(windows)] +impl IntoRawSocket for UnixStream { + fn into_raw_socket(self) -> RawSocket { + self.inner.into_inner().into_raw_socket() + } +} + +#[cfg(windows)] +impl AsRawSocket for UnixStream { + fn as_raw_socket(&self) -> RawSocket { + self.inner.as_raw_socket() + } +} + +#[cfg(windows)] +impl FromRawSocket for UnixStream { + unsafe fn from_raw_socket(sock: RawSocket) -> Self { + UnixStream::from_std(FromRawSocket::from_raw_socket(sock)) + } +} From 63e50c35eae303d80aa74a61ec4cfd44f01a8d85 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Mon, 15 Aug 2022 19:54:48 -0700 Subject: [PATCH 03/34] fix tests --- Cargo.toml | 4 + src/net/mod.rs | 2 + src/net/uds/listener.rs | 2 +- src/net/uds/mod.rs | 3 + src/net/uds/stream.rs | 2 +- src/sys/mod.rs | 7 +- src/sys/windows/mod.rs | 4 + src/sys/windows/selector.rs | 14 +- src/sys/windows/uds/listener.rs | 30 +- src/sys/windows/uds/mod.rs | 77 +-- src/sys/windows/uds/stdnet/ext.rs | 684 --------------------------- src/sys/windows/uds/stdnet/mod.rs | 91 +++- src/sys/windows/uds/stdnet/net.rs | 137 ++++-- src/sys/windows/uds/stdnet/socket.rs | 211 +++++---- src/sys/windows/uds/stream.rs | 24 +- tests/unix_listener.rs | 5 +- tests/unix_stream.rs | 20 +- 17 files changed, 387 insertions(+), 930 deletions(-) delete mode 100644 src/sys/windows/uds/stdnet/ext.rs diff --git a/Cargo.toml b/Cargo.toml index 8433f91ca..635e9a3d6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,9 @@ log = "0.4.8" [target.'cfg(unix)'.dependencies] libc = "0.2.121" +[target.'cfg(windows)'.dependencies] +tempfile = "3" + [target.'cfg(windows)'.dependencies.windows-sys] version = "0.36" features = [ @@ -55,6 +58,7 @@ features = [ "Win32_Foundation", # Basic types eg HANDLE "Win32_Networking_WinSock", # winsock2 types/functions "Win32_System_IO", # IO types like OVERLAPPED etc + "Win32_System_Threading", # Process utilities "Win32_System_WindowsProgramming", # General future used for various types/funcs ] diff --git a/src/net/mod.rs b/src/net/mod.rs index caadc0c8b..51b47b9d2 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -39,3 +39,5 @@ mod uds; pub use self::uds::{SocketAddr, UnixListener, UnixStream}; #[cfg(unix)] pub use self::uds::UnixDatagram; +#[cfg(windows)] +pub use self::uds::stdnet; diff --git a/src/net/uds/listener.rs b/src/net/uds/listener.rs index 20ea3323e..87cdab73c 100644 --- a/src/net/uds/listener.rs +++ b/src/net/uds/listener.rs @@ -9,7 +9,7 @@ use std::os::unix::net; #[cfg(windows)] use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; #[cfg(windows)] -use crate::sys::windows::uds::{stdnet as net}; +use crate::sys::uds::{stdnet as net}; use std::path::Path; use std::{fmt, io}; diff --git a/src/net/uds/mod.rs b/src/net/uds/mod.rs index c0a77bbf2..332d389d4 100644 --- a/src/net/uds/mod.rs +++ b/src/net/uds/mod.rs @@ -10,3 +10,6 @@ mod stream; pub use self::stream::UnixStream; pub use crate::sys::SocketAddr; + +#[cfg(windows)] +pub use crate::sys::uds::stdnet; diff --git a/src/net/uds/stream.rs b/src/net/uds/stream.rs index 7e7bd302f..4463ad7e8 100644 --- a/src/net/uds/stream.rs +++ b/src/net/uds/stream.rs @@ -11,7 +11,7 @@ use std::os::unix::net; #[cfg(windows)] use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; #[cfg(windows)] -use crate::sys::windows::uds::{stdnet as net}; +use crate::sys::uds::{stdnet as net}; use std::path::Path; /// A non-blocking Unix stream socket. diff --git a/src/sys/mod.rs b/src/sys/mod.rs index 2a968b265..1c5e3ae84 100644 --- a/src/sys/mod.rs +++ b/src/sys/mod.rs @@ -59,7 +59,7 @@ cfg_os_poll! { #[cfg(windows)] cfg_os_poll! { - mod windows; + pub mod windows; pub use self::windows::*; } @@ -83,4 +83,9 @@ cfg_not_os_poll! { cfg_net! { pub use self::unix::SocketAddr; } + + #[cfg(windows)] + cfg_net! { + pub use self::windows::SocketAddr; + } } diff --git a/src/sys/windows/mod.rs b/src/sys/windows/mod.rs index a243c7198..0817ca4fa 100644 --- a/src/sys/windows/mod.rs +++ b/src/sys/windows/mod.rs @@ -49,6 +49,10 @@ cfg_net! { pub(crate) mod tcp; pub(crate) mod udp; + pub mod uds; + pub use self::uds::SocketAddr; + #[cfg(all(windows, test))] + pub use self::uds::stdnet; } cfg_os_ext! { diff --git a/src/sys/windows/selector.rs b/src/sys/windows/selector.rs index 9f3cf68dd..777d12413 100644 --- a/src/sys/windows/selector.rs +++ b/src/sys/windows/selector.rs @@ -197,6 +197,7 @@ impl SockState { // This is the function called from the overlapped using as Arc>. Watch out for reference counting. fn feed_event(&mut self) -> Option { + println!("Feed event..."); self.poll_status = SockPollStatus::Idle; self.pending_evts = 0; @@ -260,11 +261,14 @@ impl SockState { cfg_io_source! { impl SockState { fn new(raw_socket: RawSocket, afd: Arc) -> io::Result { + println!("init state: {raw_socket:?}"); + let base = get_base_socket(raw_socket)?; + println!("init state:bas {base:?}"); Ok(SockState { iosb: IoStatusBlock::zeroed(), poll_info: AfdPollInfo::zeroed(), afd, - base_socket: get_base_socket(raw_socket)?, + base_socket: base, user_evts: 0, pending_evts: 0, user_data: 0, @@ -614,7 +618,9 @@ cfg_io_source! { /// GetQueuedCompletionStatusEx() we tell the kernel about the registered /// socket event(s) immediately. unsafe fn update_sockets_events_if_polling(&self) -> io::Result<()> { + println!("POLLING"); if self.is_polling.load(Ordering::Acquire) { + println!("POLLING IMMEDIATELY"); self.update_sockets_events() } else { Ok(()) @@ -658,8 +664,10 @@ cfg_io_source! { } } + #[allow(dead_code)] fn get_base_socket(raw_socket: RawSocket) -> io::Result { let res = try_get_base_socket(raw_socket, SIO_BASE_HANDLE); + println!("FIRST {res:?}"); if let Ok(base_socket) = res { return Ok(base_socket); } @@ -674,7 +682,9 @@ cfg_io_source! { SIO_BSP_HANDLE_POLL, SIO_BSP_HANDLE, ] { - if let Ok(base_socket) = try_get_base_socket(raw_socket, ioctl) { + let r = try_get_base_socket(raw_socket, ioctl); + println!("OTHER {r:?}"); + if let Ok(base_socket) = r { // Since we know now that we're dealing with an LSP (otherwise // SIO_BASE_HANDLE would't have failed), only return any result // when it is different from the original `raw_socket`. diff --git a/src/sys/windows/uds/listener.rs b/src/sys/windows/uds/listener.rs index f93f584a8..e6759f4e0 100644 --- a/src/sys/windows/uds/listener.rs +++ b/src/sys/windows/uds/listener.rs @@ -1,7 +1,16 @@ use std::{io, mem}; +use std::convert::TryInto; use std::os::windows::io::{AsRawSocket, FromRawSocket}; use std::path::Path; -use windows_sys::Win32::Networking::WinSock; +use std::os::raw::c_int; +use windows_sys::Win32::Networking::WinSock::{ + self, + SOCKET_ERROR, + INVALID_SOCKET, + bind as sys_bind, + listen, + accept as sys_accept +}; use super::{stdnet as net, socket_addr}; use crate::net::{SocketAddr, UnixStream}; @@ -9,11 +18,11 @@ use crate::sys::windows::net::{init, new_socket}; pub(crate) fn bind(path: &Path) -> io::Result { init(); - let socket = new_socket(WinSock::AF_UNIX, WinSock::SOCK_STREAM)?; + let socket = new_socket(WinSock::AF_UNIX.into(), WinSock::SOCK_STREAM)?; let (sockaddr, socklen) = socket_addr(path)?; let sockaddr = &sockaddr as *const WinSock::sockaddr_un as *const WinSock::SOCKADDR; - wsa_syscall!(bind(socket, sockaddr, socklen as _), PartialEq::eq, SOCKET_ERROR) + wsa_syscall!(sys_bind(socket, sockaddr, socklen as _), PartialEq::eq, SOCKET_ERROR) .and_then(|_| wsa_syscall!(listen(socket, 128), PartialEq::eq, SOCKET_ERROR)) .map_err(|err| { // Close the socket if we hit an error, ignoring the error from @@ -21,7 +30,7 @@ pub(crate) fn bind(path: &Path) -> io::Result { let _ = unsafe { WinSock::closesocket(socket) }; err }) - .map(|_| unsafe { net::UnixListener::from_raw_socket(socket) }) + .map(|_| unsafe { net::UnixListener::from_raw_socket(socket.try_into().unwrap()) }) } pub(crate) fn accept(listener: &net::UnixListener) -> io::Result<(UnixStream, SocketAddr)> { @@ -40,19 +49,18 @@ pub(crate) fn accept(listener: &net::UnixListener) -> io::Result<(UnixStream, So sockaddr.sun_family = WinSock::AF_UNIX; let mut socklen = mem::size_of_val(&sockaddr) as c_int; - let socket = self.0.accept(&mut storage as *mut _ as *mut _, &mut len)?; - let socket = wsa_syscall!( - accept( - listener.as_raw_socket(), - &sockaddr as *const WinSock::sockaddr_un as *const WinSock::SOCKADDR, - socklen as _ + sys_accept( + listener.as_raw_socket().try_into().unwrap(), + &mut sockaddr as *mut WinSock::sockaddr_un as *mut WinSock::SOCKADDR, + &mut socklen as _ ), PartialEq::eq, INVALID_SOCKET - )?; + ); socket + .map(|socket| unsafe { net::UnixStream::from_raw_socket(socket.try_into().unwrap()) }) .map(UnixStream::from_std) .map(|stream| (stream, SocketAddr::from_parts(sockaddr, socklen))) } diff --git a/src/sys/windows/uds/mod.rs b/src/sys/windows/uds/mod.rs index 95d5f4c67..8e5ed0704 100644 --- a/src/sys/windows/uds/mod.rs +++ b/src/sys/windows/uds/mod.rs @@ -1,78 +1,25 @@ -mod stdnet; -pub use self::stdnet::SocketAddr; - -fn path_offset(addr: &WinSock::sockaddr_un) -> usize { - // Work with an actual instance of the type since using a null pointer is UB - let base = addr as *const _ as usize; - let path = &addr.sun_path as *const _ as usize; - path - base -} +pub mod stdnet; +pub use self::stdnet::{path_offset, SocketAddr}; cfg_os_poll! { - use windows_sys::Win32::Networking::WinSock; + use std::convert::TryInto; + use windows_sys::Win32::Networking::WinSock::{ + getsockname, + getpeername, + SOCKET_ERROR + }; use std::os::windows::io::RawSocket; - use std::path::Path; - use std::{io, mem}; + use std::io; pub(crate) mod listener; pub(crate) mod stream; - pub unsafe fn socket_addr(path: &Path) -> io::Result<(WinSock::sockaddr_un, c_int)> { - let sockaddr = mem::MaybeUninit::::zeroed(); - - // This is safe to assume because a `WinSock::sockaddr_un` filled with `0` - // bytes is properly initialized. - // - // `0` is a valid value for `sockaddr_un::sun_family`; it is - // `WinSock::AF_UNSPEC`. - // - // `[0; 108]` is a valid value for `sockaddr_un::sun_path`; it begins an - // abstract path. - let mut sockaddr = unsafe { sockaddr.assume_init() }; - sockaddr.sun_family = WinSock::AF_UNIX; - - // Winsock2 expects 'sun_path' to be a Win32 UTF-8 file system path - let bytes = path.to_str().map(|s| s.as_bytes()).ok_or(io::Error::new( - io::ErrorKind::InvalidInput, - "path contains invalid characters", - ))?; - - if bytes.contains(&0) { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "paths may not contain interior null bytes", - )); - } - - if bytes.len() >= sockaddr.sun_path.len() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "path must be shorter than SUN_LEN", - )); - } - for (dst, src) in sockaddr.sun_path.iter_mut().zip(bytes.iter()) { - *dst = *src as c_char; - } - // null byte for pathname addresses is already there because we zeroed the - // struct - - let offset = path_offset(&sockaddr); - let mut socklen = offset + bytes.len(); - - match bytes.get(0) { - // The struct has already been zeroes so the null byte for pathname - // addresses is already there. - Some(&0) | None => {} - Some(_) => socklen += 1, - } - - Ok((sockaddr, socklen as c_int)) - } + pub use self::stdnet::socket_addr; pub(crate) fn local_addr(socket: RawSocket) -> io::Result { SocketAddr::new(|sockaddr, socklen| { wsa_syscall!( - WinSock::getsockname(socket, sockaddr, socklen), + getsockname(socket.try_into().unwrap(), sockaddr, socklen), PartialEq::eq, SOCKET_ERROR ) @@ -82,7 +29,7 @@ cfg_os_poll! { pub(crate) fn peer_addr(socket: RawSocket) -> io::Result { SocketAddr::new(|sockaddr, socklen| { wsa_syscall!( - WinSock::getpeername(socket, sockaddr, socklen), + getpeername(socket.try_into().unwrap(), sockaddr, socklen), PartialEq::eq, SOCKET_ERROR ) diff --git a/src/sys/windows/uds/stdnet/ext.rs b/src/sys/windows/uds/stdnet/ext.rs deleted file mode 100644 index d2e44e1d6..000000000 --- a/src/sys/windows/uds/stdnet/ext.rs +++ /dev/null @@ -1,684 +0,0 @@ -//! Extensions and types for Unix domain socket networking primitives. -//! -//! This module contains a number of extension traits for Windows-specific -//! functionality. - -use std::cmp; -use std::fmt; -use std::io; -use std::mem; -use std::os::windows::prelude::*; -use std::sync::atomic::{AtomicUsize, Ordering}; - -use windows_sys::Win32::Networking::WinSock::{ - self, - SIO_GET_EXTENSION_FUNCTION_POINTER, SOCKADDR, SOCKADDR_STORAGE, SOL_SOCKET, WSABUF, - WSAGetLastError, WSAGetOverlappedResult, WSAIoctl, WSARecv, WSASend, - SOCKET, SOCKET_ERROR, WSA_IO_PENDING, setsockopt, bind -}; -use windows_sys::Win32::Foundation::BOOL; -use windows_sys::core::GUID; -use windows_sys::Win32::System::IO::OVERLAPPED; - -use super::net::{UnixListener, UnixStream}; -use super::{path_offset, SocketAddr}; - -// TODO -type DWORD = u32; -type INT = i32; -type u_long = u32; -type c_int = i32; -type PVOID = *mut c_void; -type LPINT = *mut INT; -type LPDWORD = *mut DWORD; -type LPOVERLAPPED = *mut OVERLAPPED; -type LPSOCKADDR = *mut SOCKADDR; - -/// A buffer in which an accepted socket's address will be stored -/// -/// This type is used with the `accept_overlapped` method on the -/// `UnixListenerExt` trait to provide space for the overlapped I/O operation to -/// fill in the socket addresses upon completion. -#[repr(C)] -pub struct AcceptAddrsBuf { - // For AcceptEx we've got the restriction that the addresses passed in that - // buffer need to be at least 16 bytes more than the maximum address length - // for the protocol in question, so add some extra here and there - local: SOCKADDR_STORAGE, - _pad1: [u8; 16], - remote: SOCKADDR_STORAGE, - _pad2: [u8; 16], -} - -impl fmt::Debug for AcceptAddrsBuf { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let local = unsafe { &*(&self.local as *const _ as *const WinSock::sockaddr_un) }; - let remote = unsafe { &*(&self.remote as *const _ as *const WinSock::sockaddr_un) }; - f.debug_struct("AcceptAddrsBuf") - .field("local", local) - .field("remote", remote) - .finish() - } -} - -/// The parsed return value of `AcceptAddrsBuf` -pub struct AcceptAddrs<'a> { - local: LPSOCKADDR, - local_len: c_int, - remote: LPSOCKADDR, - remote_len: c_int, - _data: &'a AcceptAddrsBuf, -} - -impl<'a> fmt::Debug for AcceptAddrs<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - fmt::Debug::fmt(&self._data, f) - } -} - -struct WsaExtension { - guid: GUID, - val: AtomicUsize, -} - -/// Additional methods for the `UnixStream` type -pub trait UnixStreamExt { - /// Execute an overlapped read I/O operation on this Unix domain socket - /// stream. - /// - /// This function will issue an overlapped I/O read (via `WSARecv`) on this - /// socket. The provided buffer will be filled in when the operation - /// completes and the given `OVERLAPPED` instance is used to track the - /// overlapped operation. - /// - /// If the operation succeeds, `Ok(Some(n))` is returned indicating how - /// many bytes were read. If the operation returns an error indicating that - /// the I/O is currently pending, `Ok(None)` is returned. Otherwise, the - /// error associated with the operation is returned and no overlapped - /// operation is enqueued. - /// - /// The number of bytes read will be returned as part of the completion - /// notification when the I/O finishes. - /// - /// # Unsafety - /// - /// This function is unsafe because the kernel requires that the `buf` and - /// `overlapped` pointers are valid until the end of the I/O operation. The - /// kernel also requires that `overlapped` is unique for this I/O operation - /// and is not in use for any other I/O. - /// - /// To safely use this function callers must ensure that these two input - /// pointers are valid until the I/O operation is completed, typically via - /// completion ports and waiting to receive the completion notification on - /// the port. - unsafe fn read_overlapped( - &self, - buf: &mut [u8], - overlapped: *mut OVERLAPPED, - ) -> io::Result>; - - /// Execute an overlapped write I/O operation on this Unix domain socket - /// stream. - /// - /// This function will issue an overlapped I/O write (via `WSASend`) on this - /// socket. The provided buffer will be written when the operation completes - /// and the given `OVERLAPPED` instance is used to track the overlapped - /// operation. - /// - /// If the operation succeeds, `Ok(Some(n))` is returned where `n` is the - /// number of bytes that were written. If the operation returns an error - /// indicating that the I/O is currently pending, `Ok(None)` is returned. - /// Otherwise, the error associated with the operation is returned and no - /// overlapped operation is enqueued. - /// - /// The number of bytes written will be returned as part of the completion - /// notification when the I/O finishes. - /// - /// # Unsafety - /// - /// This function is unsafe because the kernel requires that the `buf` and - /// `overlapped` pointers are valid until the end of the I/O operation. The - /// kernel also requires that `overlapped` is unique for this I/O operation - /// and is not in use for any other I/O. - /// - /// To safely use this function callers must ensure that these two input - /// pointers are valid until the I/O operation is completed, typically via - /// completion ports and waiting to receive the completion notification on - /// the port. - unsafe fn write_overlapped( - &self, - buf: &[u8], - overlapped: *mut OVERLAPPED, - ) -> io::Result>; - - /// Attempt to consume the internal socket in this builder by executing an - /// overlapped connect operation. - /// - /// This function will issue a connect operation to the address specified on - /// the underlying socket, flagging it as an overlapped operation which will - /// complete asynchronously. If successful this function will return the - /// corresponding Unix domain socket stream. - /// - /// The `buf` argument provided is an initial buffer of data that should be - /// sent after the connection is initiated. It's acceptable to - /// pass an empty slice here. - /// - /// This function will also return whether the connect immediately - /// succeeded or not. If `Ok(None)` is returned then the I/O operation is - /// still pending and will complete later. If `Ok(Some(bytes))` is returned - /// then that many bytes were transferred. - /// - /// Note that to succeed this requires that the underlying socket has - /// previously been bound via a call to `bind` to a local path. - /// - /// # Unsafety - /// - /// This function is unsafe because the kernel requires that the - /// `overlapped` and `buf` pointers to be valid until the end of the I/O - /// operation. The kernel also requires that `overlapped` is unique for - /// this I/O operation and is not in use for any other I/O. - /// - /// To safely use this function callers must ensure that this pointer is - /// valid until the I/O operation is completed, typically via completion - /// ports and waiting to receive the completion notification on the port. - unsafe fn connect_overlapped( - &self, - addr: &SocketAddr, - buf: &[u8], - overlapped: *mut OVERLAPPED, - ) -> io::Result>; - - /// Once a `connect_overlapped` has finished, this function needs to be - /// called to finish the connect operation. - /// - /// Currently this just calls `setsockopt` with `SO_UPDATE_CONNECT_CONTEXT` - /// to ensure that further functions like `getpeername` and `getsockname` - /// work correctly. - fn connect_complete(&self) -> io::Result<()>; - - /// Calls the `GetOverlappedResult` function to get the result of an - /// overlapped operation for this handle. - /// - /// This function takes the `OVERLAPPED` argument which must have been used - /// to initiate an overlapped I/O operation, and returns either the - /// successful number of bytes transferred during the operation or an error - /// if one occurred, along with the results of the `lpFlags` parameter of - /// the relevant operation, if applicable. - /// - /// # Unsafety - /// - /// This function is unsafe as `overlapped` must have previously been used - /// to execute an operation for this handle, and it must also be a valid - /// pointer to an `OVERLAPPED` instance. - /// - /// # Panics - /// - /// This function will panic - unsafe fn result(&self, overlapped: *mut OVERLAPPED) -> io::Result<(usize, u32)>; -} - -/// Additional methods for the `UnixListener` type -pub trait UnixListenerExt { - /// Perform an accept operation on this listener, accepting a connection in - /// an overlapped fashion. - /// - /// This function will issue an I/O request to accept an incoming connection - /// with the specified overlapped instance. The `socket` provided must be - /// configured but not bound or connected. If successful this method will - /// consume the socket to return a Unix stream. - /// - /// The `addrs` buffer provided will be filled in with the local and remote - /// addresses of the connection upon completion. - /// - /// If the accept succeeds immediately, `Ok(true)` is returned. If the - /// connect indicates that the I/O is currently pending, `Ok(false)` is - /// returned. Otherwise, the error associated with the operation is returned - /// and no overlapped operation is enqueued. - /// - /// # Unsafety - /// - /// This function is unsafe because the kernel requires that the - /// `addrs` and `overlapped` pointers are valid until the end of the I/O - /// operation. The kernel also requires that `overlapped` is unique for this - /// I/O operation and is not in use for any other I/O. - /// - /// To safely use this function callers must ensure that the pointers are - /// valid until the I/O operation is completed, typically via completion - /// ports and waiting to receive the completion notification on the port. - unsafe fn accept_overlapped( - &self, - socket: &UnixStream, - addrs: &mut AcceptAddrsBuf, - overlapped: *mut OVERLAPPED, - ) -> io::Result; - - /// Once an `accept_overlapped` has finished, this function needs to be - /// called to finish the accept operation. - /// - /// Currently this just calls `setsockopt` with `SO_UPDATE_ACCEPT_CONTEXT` - /// to ensure that further functions like `getpeername` and `getsockname` - /// work correctly. - fn accept_complete(&self, socket: &UnixStream) -> io::Result<()>; - - /// Calls the `GetOverlappedResult` function to get the result of an - /// overlapped operation for this handle. - /// - /// This function takes the `OVERLAPPED` argument which must have been used - /// to initiate an overlapped I/O operation, and returns either the - /// successful number of bytes transferred during the operation or an error - /// if one occurred, along with the results of the `lpFlags` parameter of - /// the relevant operation, if applicable. - /// - /// # Unsafety - /// - /// This function is unsafe as `overlapped` must have previously been used - /// to execute an operation for this handle, and it must also be a valid - /// pointer to an `OVERLAPPED` instance. - /// - /// # Panics - /// - /// This function will panic - unsafe fn result(&self, overlapped: *mut OVERLAPPED) -> io::Result<(usize, u32)>; -} - -fn last_err() -> io::Result> { - let err = unsafe { WSAGetLastError() }; - if err == WSA_IO_PENDING as i32 { - Ok(None) - } else { - Err(io::Error::from_raw_os_error(err)) - } -} - -fn cvt(i: c_int, size: DWORD) -> io::Result> { - if i == SOCKET_ERROR { - last_err() - } else { - Ok(Some(size as usize)) - } -} - -fn socket_addr_to_ptrs(addr: &SocketAddr) -> (*const SOCKADDR, c_int) { - ( - &addr.addr as *const _ as *const _, - mem::size_of::() as c_int, - ) -} - -unsafe fn ptrs_to_socket_addr(ptr: *const SOCKADDR, len: c_int) -> Option { - if (len as usize) < mem::size_of::() { - return None; - } - match (*ptr).sa_family { - WinSock::AF_UNIX if len as usize >= mem::size_of::() => { - let b = &*(ptr as *const WinSock::sockaddr_un); - match b.sun_path.iter().position(|c| *c == 0) { - Some(0) => Some(SocketAddr::from_parts(b.clone(), len)), - Some(i) => { - let mut l = path_offset(b) + i; - match b.sun_path.get(0) { - Some(&0) | None => {} - Some(_) => l += 1, - } - Some(SocketAddr::from_parts(b.clone(), l as c_int)) - } - _ => None, // Invalid socket path, no terminating null byte - } - } - _ => None, // Invalid socket family, should be AF_UNIX - } -} - -unsafe fn slice2buf(slice: &[u8]) -> WSABUF { - WSABUF { - len: cmp::min(slice.len(), ::max_value() as usize) as u_long, - buf: slice.as_ptr() as *mut _, - } -} - -unsafe fn result(socket: SOCKET, overlapped: *mut OVERLAPPED) -> io::Result<(usize, u32)> { - let mut transferred = 0; - let mut flags = 0; - let r = WSAGetOverlappedResult(socket, overlapped, &mut transferred, 0, &mut flags); - if r == 0 { - Err(io::Error::last_os_error()) - } else { - Ok((transferred as usize, flags)) - } -} - -impl UnixStreamExt for UnixStream { - unsafe fn read_overlapped( - &self, - buf: &mut [u8], - overlapped: *mut OVERLAPPED, - ) -> io::Result> { - let mut buf = slice2buf(buf); - let mut flags = 0; - let mut bytes_read: DWORD = 0; - let r = WSARecv( - self.as_raw_socket() as SOCKET, - &mut buf, - 1, - &mut bytes_read, - &mut flags, - overlapped, - None, - ); - cvt(r, bytes_read) - } - - unsafe fn write_overlapped( - &self, - buf: &[u8], - overlapped: *mut OVERLAPPED, - ) -> io::Result> { - let mut buf = slice2buf(buf); - let mut bytes_written = 0; - - // Note here that we capture the number of bytes written. The - // documentation on MSDN, however, states: - // - // > Use NULL for this parameter if the lpOverlapped parameter is not - // > NULL to avoid potentially erroneous results. This parameter can be - // > NULL only if the lpOverlapped parameter is not NULL. - // - // If we're not passing a null overlapped pointer here, then why are we - // then capturing the number of bytes! Well so it turns out that this is - // clearly faster to learn the bytes here rather than later calling - // `WSAGetOverlappedResult`, and in practice almost all implementations - // use this anyway [1]. - // - // As a result we use this to and report back the result. - // - // [1]: https://github.com/carllerche/mio/pull/520#issuecomment-273983823 - let r = WSASend( - self.as_raw_socket() as SOCKET, - &mut buf, - 1, - &mut bytes_written, - 0, - overlapped, - None, - ); - cvt(r, bytes_written) - } - - unsafe fn connect_overlapped( - &self, - addr: &SocketAddr, - buf: &[u8], - overlapped: *mut OVERLAPPED, - ) -> io::Result> { - connect_overlapped(self.as_raw_socket() as SOCKET, addr, buf, overlapped) - } - - fn connect_complete(&self) -> io::Result<()> { - const SO_UPDATE_CONNECT_CONTEXT: c_int = 0x7010; - let result = unsafe { - setsockopt( - self.as_raw_socket() as SOCKET, - SOL_SOCKET, - SO_UPDATE_CONNECT_CONTEXT, - 0 as *const _, - 0, - ) - }; - if result == 0 { - Ok(()) - } else { - Err(io::Error::last_os_error()) - } - } - - unsafe fn result(&self, overlapped: *mut OVERLAPPED) -> io::Result<(usize, u32)> { - result(self.as_raw_socket() as SOCKET, overlapped) - } -} - -unsafe fn connect_overlapped( - socket: SOCKET, - addr: &SocketAddr, - buf: &[u8], - overlapped: *mut OVERLAPPED, -) -> io::Result> { - let anonaddr = WinSock::sockaddr_un { - sun_family: WinSock::AF_UNIX, - sun_path: [0; 108], - }; - let len = mem::size_of::() as c_int; - wsa_syscall!( - bind(socket as _, &anonaddr as *const _ as *const _, len as _), - PartialEq::eq, - SOCKET_ERROR - )?; - - static CONNECTEX: WsaExtension = WsaExtension { - guid: GUID { - Data1: 0x25a207b9, - Data2: 0xddf3, - Data3: 0x4660, - Data4: [0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e], - }, - val: AtomicUsize::new(0), - }; - type ConnectEx = unsafe extern "system" fn( - SOCKET, - *const SOCKADDR, - c_int, - PVOID, - DWORD, - LPDWORD, - LPOVERLAPPED, - ) -> BOOL; - - let ptr = CONNECTEX.get(socket)?; - assert!(ptr != 0); - let connect_ex = mem::transmute::<_, ConnectEx>(ptr); - - let (addr_buf, addr_len) = socket_addr_to_ptrs(addr); - let mut bytes_sent: DWORD = 0; - let r = connect_ex( - socket, - addr_buf, - addr_len, - buf.as_ptr() as *mut _, - buf.len() as u32, - &mut bytes_sent, - overlapped, - ); - if r == 1 { - Ok(Some(bytes_sent as usize)) - } else { - last_err() - } -} - -impl UnixListenerExt for UnixListener { - unsafe fn accept_overlapped( - &self, - socket: &UnixStream, - addrs: &mut AcceptAddrsBuf, - overlapped: *mut OVERLAPPED, - ) -> io::Result { - static ACCEPTEX: WsaExtension = WsaExtension { - guid: GUID { - Data1: 0xb5367df1, - Data2: 0xcbac, - Data3: 0x11cf, - Data4: [0x95, 0xca, 0x00, 0x80, 0x5f, 0x48, 0xa1, 0x92], - }, - val: AtomicUsize::new(0), - }; - type AcceptEx = unsafe extern "system" fn( - SOCKET, - SOCKET, - PVOID, - DWORD, - DWORD, - DWORD, - LPDWORD, - LPOVERLAPPED, - ) -> BOOL; - - let ptr = ACCEPTEX.get(self.as_raw_socket() as SOCKET)?; - assert!(ptr != 0); - let accept_ex = mem::transmute::<_, AcceptEx>(ptr); - - let mut bytes = 0; - let (a, b, c, d) = (*addrs).args(); - let r = accept_ex( - self.as_raw_socket() as SOCKET, - socket.as_raw_socket() as SOCKET, - a, - b, - c, - d, - &mut bytes, - overlapped, - ); - let succeeded = if r == 1 { - true - } else { - last_err()?; - false - }; - Ok(succeeded) - } - - fn accept_complete(&self, socket: &UnixStream) -> io::Result<()> { - const SO_UPDATE_ACCEPT_CONTEXT: c_int = 0x700B; - let me = self.as_raw_socket(); - let result = unsafe { - setsockopt( - socket.as_raw_socket() as SOCKET, - SOL_SOCKET, - SO_UPDATE_ACCEPT_CONTEXT, - &me as *const _ as *const _, - mem::size_of_val(&me) as c_int, - ) - }; - if result == 0 { - Ok(()) - } else { - Err(io::Error::last_os_error()) - } - } - - unsafe fn result(&self, overlapped: *mut OVERLAPPED) -> io::Result<(usize, u32)> { - result(self.as_raw_socket() as SOCKET, overlapped) - } -} - -static GETACCEPTEXSOCKADDRS: WsaExtension = WsaExtension { - guid: GUID { - Data1: 0xb5367df2, - Data2: 0xcbac, - Data3: 0x11cf, - Data4: [0x95, 0xca, 0x00, 0x80, 0x5f, 0x48, 0xa1, 0x92], - }, - val: AtomicUsize::new(0), -}; -type GetAcceptExSockaddrs = unsafe extern "system" fn( - PVOID, - DWORD, - DWORD, - DWORD, - *mut LPSOCKADDR, - LPINT, - *mut LPSOCKADDR, - LPINT, -); - -impl AcceptAddrsBuf { - /// Creates a new blank buffer ready to be passed to a call to - /// `accept_overlapped`. - pub fn new() -> AcceptAddrsBuf { - unsafe { mem::zeroed() } - } - - /// Parses the data contained in this address buffer, returning the parsed - /// result if successful. - /// - /// This function can be called after a call to `accept_overlapped` has - /// succeeded to parse out the data that was written in. - pub fn parse(&self, socket: &UnixListener) -> io::Result { - let mut ret = AcceptAddrs { - local: 0 as *mut _, - local_len: 0, - remote: 0 as *mut _, - remote_len: 0, - _data: self, - }; - let ptr = GETACCEPTEXSOCKADDRS.get(socket.as_raw_socket() as SOCKET)?; - assert!(ptr != 0); - unsafe { - let get_sockaddrs = mem::transmute::<_, GetAcceptExSockaddrs>(ptr); - let (a, b, c, d) = self.args(); - get_sockaddrs( - a, - b, - c, - d, - &mut ret.local, - &mut ret.local_len, - &mut ret.remote, - &mut ret.remote_len, - ); - Ok(ret) - } - } - - fn args(&self) -> (PVOID, DWORD, DWORD, DWORD) { - let remote_offset = unsafe { &(*(0 as *const AcceptAddrsBuf)).remote as *const _ as usize }; - ( - self as *const _ as *mut _, - 0, - remote_offset as DWORD, - (mem::size_of_val(self) - remote_offset) as DWORD, - ) - } -} - -impl<'a> AcceptAddrs<'a> { - /// Returns the local socket address contained in this buffer. - #[allow(dead_code)] - pub fn local(&self) -> Option { - unsafe { ptrs_to_socket_addr(self.local, self.local_len) } - } - - /// Returns the remote socket address contained in this buffer. - pub fn remote(&self) -> Option { - unsafe { ptrs_to_socket_addr(self.remote, self.remote_len) } - } -} - -impl WsaExtension { - fn get(&self, socket: SOCKET) -> io::Result { - let prev = self.val.load(Ordering::SeqCst); - if prev != 0 && !cfg!(debug_assertions) { - return Ok(prev); - } - let mut ret = 0 as usize; - let mut bytes = 0; - let r = unsafe { - WSAIoctl( - socket, - SIO_GET_EXTENSION_FUNCTION_POINTER, - &self.guid as *const _ as *mut _, - mem::size_of_val(&self.guid) as DWORD, - &mut ret as *mut _ as *mut _, - mem::size_of_val(&ret) as DWORD, - &mut bytes, - 0 as *mut _, - None, - ) - }; - cvt(r, 0).map(|_| { - debug_assert_eq!(bytes as usize, mem::size_of_val(&ret)); - debug_assert!(prev == 0 || prev == ret); - self.val.store(ret, Ordering::SeqCst); - ret - }) - } -} diff --git a/src/sys/windows/uds/stdnet/mod.rs b/src/sys/windows/uds/stdnet/mod.rs index 5a0a33047..c4164c572 100644 --- a/src/sys/windows/uds/stdnet/mod.rs +++ b/src/sys/windows/uds/stdnet/mod.rs @@ -2,20 +2,73 @@ use std::ascii; use std::fmt; use std::io; use std::mem; -use std::os::raw::{c_char, c_int}; +use std::os::raw::c_int; use std::path::Path; -use windows_sys::Win32::Networking::WinSock::{ - self, - SOCKADDR, - SOCKET_ERROR, - WSAGetLastError -}; +use windows_sys::Win32::Networking::WinSock::{self, SOCKADDR}; -mod ext; mod net; mod socket; +pub fn path_offset(addr: &WinSock::sockaddr_un) -> usize { + // Work with an actual instance of the type since using a null pointer is UB + let base = addr as *const _ as usize; + let path = &addr.sun_path as *const _ as usize; + path - base +} + +pub fn socket_addr(path: &Path) -> io::Result<(WinSock::sockaddr_un, c_int)> { + let sockaddr = mem::MaybeUninit::::zeroed(); + + // This is safe to assume because a `WinSock::sockaddr_un` filled with `0` + // bytes is properly initialized. + // + // `0` is a valid value for `sockaddr_un::sun_family`; it is + // `WinSock::AF_UNSPEC`. + // + // `[0; 108]` is a valid value for `sockaddr_un::sun_path`; it begins an + // abstract path. + let mut sockaddr = unsafe { sockaddr.assume_init() }; + sockaddr.sun_family = WinSock::AF_UNIX; + + // Winsock2 expects 'sun_path' to be a Win32 UTF-8 file system path + let bytes = path.to_str().map(|s| s.as_bytes()).ok_or(io::Error::new( + io::ErrorKind::InvalidInput, + "path contains invalid characters", + ))?; + + if bytes.contains(&0) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "paths may not contain interior null bytes", + )); + } + + if bytes.len() >= sockaddr.sun_path.len() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "path must be shorter than SUN_LEN", + )); + } + for (dst, src) in sockaddr.sun_path.iter_mut().zip(bytes.iter()) { + *dst = *src as u8; + } + // null byte for pathname addresses is already there because we zeroed the + // struct + + let offset = path_offset(&sockaddr); + let mut socklen = offset + bytes.len(); + + match bytes.get(0) { + // The struct has already been zeroes so the null byte for pathname + // addresses is already there. + Some(&0) | None => {} + Some(_) => socklen += 1, + } + + Ok((sockaddr, socklen as c_int)) +} + enum AddressKind<'a> { Unnamed, Pathname(&'a Path), @@ -30,9 +83,9 @@ pub struct SocketAddr { } impl SocketAddr { - fn new(f: F) -> io::Result + pub(crate) fn new(f: F) -> io::Result where - F: FnOnce(*mut SOCKADDR, *mut c_int) -> c_int, + F: FnOnce(*mut SOCKADDR, *mut c_int) -> io::Result, { let mut sockaddr = { let sockaddr = mem::MaybeUninit::::zeroed(); @@ -40,15 +93,11 @@ impl SocketAddr { }; let mut len = mem::size_of::() as c_int; - wsa_syscall!( - f(&mut sockaddr as *mut _ as *mut _, &mut len), - PartialEq::eq, - SOCKET_ERROR - )?; + f(&mut sockaddr as *mut _ as *mut _, &mut len)?; Ok(SocketAddr::from_parts(sockaddr, len)) } - fn from_parts(addr: WinSock::sockaddr_un, mut len: c_int) -> SocketAddr { + pub(crate) fn from_parts(addr: WinSock::sockaddr_un, mut len: c_int) -> SocketAddr { if len == 0 { // When there is a datagram from unnamed unix socket // linux returns zero bytes of address @@ -78,7 +127,6 @@ impl SocketAddr { fn address<'a>(&'a self) -> AddressKind<'a> { let len = self.len as usize - path_offset(&self.addr); // WinSock::sockaddr_un::sun_path on Windows is a Win32 UTF-8 file system path - let path = unsafe { mem::transmute::<&[c_char], &[u8]>(&self.addr.sun_path) }; // macOS seems to return a len of 16 and a zeroed sun_path for unnamed addresses if len == 0 @@ -87,17 +135,17 @@ impl SocketAddr { { AddressKind::Unnamed } else if self.addr.sun_path[0] == 0 { - AddressKind::Abstract(&path[1..len]) + AddressKind::Abstract(&self.addr.sun_path[1..len]) } else { use std::ffi::CStr; - let pathname = unsafe { CStr::from_bytes_with_nul_unchecked(&path[..len]) }; + let pathname = unsafe { CStr::from_bytes_with_nul_unchecked(&self.addr.sun_path[..len]) }; AddressKind::Pathname(Path::new(pathname.to_str().unwrap())) } } } impl fmt::Debug for SocketAddr { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match self.address() { AddressKind::Unnamed => write!(fmt, "(unnamed)"), AddressKind::Abstract(name) => write!(fmt, "{} (abstract)", AsciiEscaped(name)), @@ -120,7 +168,7 @@ impl PartialEq for SocketAddr { struct AsciiEscaped<'a>(&'a [u8]); impl<'a> fmt::Display for AsciiEscaped<'a> { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { write!(fmt, "\"")?; for byte in self.0.iter().cloned().flat_map(ascii::escape_default) { write!(fmt, "{}", byte as char)?; @@ -129,6 +177,5 @@ impl<'a> fmt::Display for AsciiEscaped<'a> { } } -pub use self::ext::{AcceptAddrs, AcceptAddrsBuf, UnixListenerExt, UnixStreamExt}; pub use self::net::{UnixListener, UnixStream}; pub use self::socket::Socket; diff --git a/src/sys/windows/uds/stdnet/net.rs b/src/sys/windows/uds/stdnet/net.rs index ca75c46e2..39aade7d7 100644 --- a/src/sys/windows/uds/stdnet/net.rs +++ b/src/sys/windows/uds/stdnet/net.rs @@ -1,5 +1,6 @@ use std::fmt; -use std::io; +use std::io::{self, IoSlice, IoSliceMut}; +use std::convert::TryInto; use std::mem; use std::net::Shutdown; use std::os::raw::c_int; @@ -9,7 +10,14 @@ use std::time::Duration; use windows_sys::Win32::Networking::WinSock::{ self, - bind, connect, getpeername, getsockname, listen, SO_RCVTIMEO, SO_SNDTIMEO, + bind, + connect, + getpeername, + getsockname, + listen, + SO_RCVTIMEO, + SOCKET_ERROR, + SO_SNDTIMEO }; use crate::sys::windows::net::init; @@ -20,7 +28,7 @@ use super::{socket_addr, SocketAddr}; pub struct UnixStream(Socket); impl fmt::Debug for UnixStream { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { let mut builder = fmt.debug_struct("UnixStream"); builder.field("socket", &self.0.as_raw_socket()); if let Ok(addr) = self.local_addr() { @@ -38,21 +46,19 @@ impl UnixStream { pub fn connect>(path: P) -> io::Result { init(); fn inner(path: &Path) -> io::Result { - unsafe { - let inner = Socket::new()?; - let (addr, len) = socket_addr(path)?; - - wsa_syscall!( - connect( - inner.as_raw_socket() as _, - &addr as *const _ as *const _, - len as i32, - ), - PartialEq::eq, - SOCKET_ERROR - )?; - Ok(UnixStream(inner)) - } + let inner = Socket::new()?; + let (addr, len) = socket_addr(path)?; + + wsa_syscall!( + connect( + inner.as_raw_socket() as _, + &addr as *const _ as *const _, + len as i32, + ), + PartialEq::eq, + SOCKET_ERROR + )?; + Ok(UnixStream(inner)) } inner(path.as_ref()) } @@ -69,12 +75,24 @@ impl UnixStream { /// Returns the socket address of the local half of this connection. pub fn local_addr(&self) -> io::Result { - SocketAddr::new(|addr, len| unsafe { getsockname(self.0.as_raw_socket() as _, addr, len) }) + SocketAddr::new(|addr, len| { + wsa_syscall!( + getsockname(self.0.as_raw_socket() as _, addr, len), + PartialEq::eq, + SOCKET_ERROR + ) + }) } /// Returns the socket address of the remote half of this connection. pub fn peer_addr(&self) -> io::Result { - SocketAddr::new(|addr, len| unsafe { getpeername(self.0.as_raw_socket() as _, addr, len) }) + SocketAddr::new(|addr, len| { + wsa_syscall!( + getpeername(self.0.as_raw_socket() as _, addr, len), + PartialEq::eq, + SOCKET_ERROR + ) + }) } /// Moves the socket into or out of nonblocking mode. @@ -104,6 +122,7 @@ impl UnixStream { let file_path = dir.path().join("socket"); let a: Arc>>> = Arc::new(RwLock::new(None)); let ul = UnixListener::bind(&file_path).unwrap(); + ul.set_nonblocking(true)?; let server = { let a = a.clone(); spawn(move || { @@ -113,6 +132,7 @@ impl UnixStream { }) }; let stream1 = UnixStream::connect(&file_path)?; + stream1.set_nonblocking(true)?; server .join() .map_err(|_| io::Error::from(io::ErrorKind::ConnectionRefused))?; @@ -126,7 +146,7 @@ impl UnixStream { /// indefinitely. An `Err` is returned if the zero `Duration` is /// passed to this method. pub fn set_read_timeout(&self, dur: Option) -> io::Result<()> { - self.0.set_timeout(dur, SO_RCVTIMEO) + self.0.set_timeout(dur, SO_RCVTIMEO.try_into().unwrap()) } /// Sets the write timeout to the timeout specified. @@ -135,17 +155,17 @@ impl UnixStream { /// indefinitely. An `Err` is returned if the zero `Duration` is /// passed to this method. pub fn set_write_timeout(&self, dur: Option) -> io::Result<()> { - self.0.set_timeout(dur, SO_SNDTIMEO) + self.0.set_timeout(dur, SO_SNDTIMEO.try_into().unwrap()) } /// Returns the read timeout of this socket. pub fn read_timeout(&self) -> io::Result> { - self.0.timeout(SO_RCVTIMEO) + self.0.timeout(SO_RCVTIMEO.try_into().unwrap()) } /// Returns the write timeout of this socket. pub fn write_timeout(&self) -> io::Result> { - self.0.timeout(SO_SNDTIMEO) + self.0.timeout(SO_SNDTIMEO.try_into().unwrap()) } } @@ -153,12 +173,20 @@ impl io::Read for UnixStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { io::Read::read(&mut &*self, buf) } + + fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { + io::Read::read_vectored(&mut &*self, bufs) + } } impl<'a> io::Read for &'a UnixStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { self.0.read(buf) } + + fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { + self.0.read_vectored(bufs) + } } impl io::Write for UnixStream { @@ -166,6 +194,10 @@ impl io::Write for UnixStream { io::Write::write(&mut &*self, buf) } + fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { + io::Write::write_vectored(&mut &*self, bufs) + } + fn flush(&mut self) -> io::Result<()> { io::Write::flush(&mut &*self) } @@ -176,6 +208,11 @@ impl<'a> io::Write for &'a UnixStream { self.0.write(buf) } + + fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { + self.0.write_vectored(bufs) + } + fn flush(&mut self) -> io::Result<()> { Ok(()) } @@ -205,7 +242,7 @@ impl IntoRawSocket for UnixStream { pub struct UnixListener(Socket); impl fmt::Debug for UnixListener { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { let mut builder = fmt.debug_struct("UnixListener"); builder.field("socket", &self.0.as_raw_socket()); if let Ok(addr) = self.local_addr() { @@ -220,27 +257,25 @@ impl UnixListener { pub fn bind>(path: P) -> io::Result { init(); fn inner(path: &Path) -> io::Result { - unsafe { - let inner = Socket::new()?; - let (addr, len) = socket_addr(path)?; - - wsa_syscall!( - bind( - inner.as_raw_socket() as _, - &addr as *const _ as *const _, - len as _, - ), - PartialEq::eq, - SOCKET_ERROR - )?; - wsa_syscall!( - listen(inner.as_raw_socket() as _, 128), - PartialEq::eq, - SOCKET_ERROR - )?; - - Ok(UnixListener(inner)) - } + let inner = Socket::new()?; + let (addr, len) = socket_addr(path)?; + + wsa_syscall!( + bind( + inner.as_raw_socket() as _, + &addr as *const _ as *const _, + len as _, + ), + PartialEq::eq, + SOCKET_ERROR + )?; + wsa_syscall!( + listen(inner.as_raw_socket() as _, 128), + PartialEq::eq, + SOCKET_ERROR + )?; + + Ok(UnixListener(inner)) } inner(path.as_ref()) } @@ -271,7 +306,13 @@ impl UnixListener { /// Returns the local socket address of this listener. pub fn local_addr(&self) -> io::Result { - SocketAddr::new(|addr, len| unsafe { getsockname(self.0.as_raw_socket() as _, addr, len) }) + SocketAddr::new(|addr, len| { + wsa_syscall!( + getsockname(self.0.as_raw_socket() as _, addr, len), + PartialEq::eq, + SOCKET_ERROR + ) + }) } /// Moves the socket into or out of nonblocking mode. @@ -348,7 +389,7 @@ impl<'a> Iterator for Incoming<'a> { #[cfg(test)] mod test { - extern crate tempfile; + use tempfile; use std::io::{self, Read, Write}; use std::path::PathBuf; diff --git a/src/sys/windows/uds/stdnet/socket.rs b/src/sys/windows/uds/stdnet/socket.rs index 8d809d01f..5a1fccabb 100644 --- a/src/sys/windows/uds/stdnet/socket.rs +++ b/src/sys/windows/uds/stdnet/socket.rs @@ -1,12 +1,10 @@ -#![allow(non_camel_case_types)] - -use std::io; +use std::io::{self, IoSlice, IoSliceMut}; +use std::convert::TryInto; use std::mem; use std::net::Shutdown; use std::os::raw::{c_int, c_ulong}; use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; use std::ptr; -use std::sync::Once; use std::time::Duration; use windows_sys::Win32::Foundation::{ @@ -17,26 +15,33 @@ use windows_sys::Win32::Foundation::{ use windows_sys::Win32::System::Threading::GetCurrentProcessId; use windows_sys::Win32::System::WindowsProgramming::INFINITE; use windows_sys::Win32::Networking::WinSock::{ - self, - SOCKET_ERROR, + WSABUF, AF_UNIX, + FIONBIO, + INVALID_SOCKET, + SD_BOTH, + SD_RECEIVE, + SD_SEND, SOCKADDR, + SOCKET, + SOCKET_ERROR, SOCK_STREAM, SOL_SOCKET, SO_ERROR, - accept, closesocket, ioctlsocket, recv, send, - setsockopt, shutdown, WSADuplicateSocketW, WSASocketW, FIONBIO, - INVALID_SOCKET, SOCKET, WSADATA, WSAPROTOCOL_INFOW, + WSADuplicateSocketW, + WSAPROTOCOL_INFOW, + WSASocketW, WSA_FLAG_OVERLAPPED, - SD_RECEIVE, - SD_SEND, - SD_BOTH + accept, + closesocket, + getsockopt as c_getsockopt, + ioctlsocket, + recv, + send, + setsockopt as c_setsockopt, + shutdown, }; -// TODO -type socklen_t = i32; -type DWORD = u32; - #[derive(Debug)] pub struct Socket(SOCKET); @@ -44,16 +49,17 @@ impl Socket { pub fn new() -> io::Result { let socket = wsa_syscall!( WSASocketW( - AF_UNIX, - SOCK_STREAM, + AF_UNIX.into(), + SOCK_STREAM.into(), 0, ptr::null_mut(), 0, WSA_FLAG_OVERLAPPED, - ) + ), PartialEq::eq, INVALID_SOCKET )?; + let socket = Socket(socket); socket.set_no_inherit()?; Ok(socket) } @@ -64,36 +70,35 @@ impl Socket { PartialEq::eq, INVALID_SOCKET )?; + let socket = Socket(socket); socket.set_no_inherit()?; Ok(socket) } pub fn duplicate(&self) -> io::Result { - let socket = unsafe { - let mut info: WSAPROTOCOL_INFOW = mem::zeroed(); - wsa_syscall!( - WSADuplicateSocketW( - self.0, - GetCurrentProcessId(), - &mut info, - ), - PartialEq::eq, - SOCKET_ERROR - )?; - let n = wsa_syscall!( - WSASocketW( - info.iAddressFamily, - info.iSocketType, - info.iProtocol, - &mut info, - 0, - WSA_FLAG_OVERLAPPED, - ) - PartialEq::eq, - INVALID_SOCKET - )?; - Socket(n) - }; + let mut info: WSAPROTOCOL_INFOW = unsafe { mem::zeroed() }; + wsa_syscall!( + WSADuplicateSocketW( + self.0, + GetCurrentProcessId(), + &mut info, + ), + PartialEq::eq, + SOCKET_ERROR + )?; + let n = wsa_syscall!( + WSASocketW( + info.iAddressFamily, + info.iSocketType, + info.iProtocol, + &mut info, + 0, + WSA_FLAG_OVERLAPPED, + ), + PartialEq::eq, + INVALID_SOCKET + )?; + let socket = Socket(n); socket.set_no_inherit()?; Ok(socket) } @@ -116,6 +121,18 @@ impl Socket { self.recv_with_flags(buf, 0) } + pub fn read_vectored(&self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { + let mut total = 0; + for slice in &mut *bufs { + let wsa_buf = unsafe { *(slice as *const _ as *const WSABUF) }; + let len = wsa_buf.len; + let buf = unsafe { std::slice::from_raw_parts_mut(wsa_buf.buf, len.try_into().unwrap()) }; + total += self.recv_with_flags(buf, 0)?; + } + println!("Wrote vectored: {total:?}, {bufs:?}"); + Ok(total as usize) + } + pub fn write(&self, buf: &[u8]) -> io::Result { let ret = wsa_syscall!( send(self.0, buf as *const _ as *const _, buf.len() as c_int, 0), @@ -125,21 +142,41 @@ impl Socket { Ok(ret as usize) } + pub fn write_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result { + let mut total = 0; + for slice in bufs { + let wsa_buf = unsafe { *(slice as *const _ as *const WSABUF) }; + let len = wsa_buf.len; + let buf = unsafe { std::slice::from_raw_parts(wsa_buf.buf, len.try_into().unwrap()) }; + dbg!(buf); + let ret = wsa_syscall!( + send(self.0, buf as *const _ as *const _, len as c_int, 0), + PartialEq::eq, + SOCKET_ERROR + )?; + total += ret; + } + println!("Wrote vectored: {total:?}, {bufs:?}"); + Ok(total as usize) + } + fn set_no_inherit(&self) -> io::Result<()> { syscall!( SetHandleInformation(self.0 as HANDLE, HANDLE_FLAG_INHERIT, 0), PartialEq::eq, 0 - ) + )?; + Ok(()) } pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { - let mut nonblocking = nonblocking as c_ulong; + let mut nonblocking: c_ulong = if nonblocking { 1 } else { 0 }; wsa_syscall!( - ioctlsocket(self.0, FIONBIO as c_int, &mut nonblocking), + ioctlsocket(self.0, FIONBIO, &mut nonblocking), PartialEq::eq, SOCKET_ERROR - ) + )?; + Ok(()) } pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { @@ -149,7 +186,7 @@ impl Socket { Shutdown::Both => SD_BOTH, }; wsa_syscall!( - shutdown(self.0, how), + shutdown(self.0, how.try_into().unwrap()), PartialEq::eq, SOCKET_ERROR )?; @@ -157,7 +194,11 @@ impl Socket { } pub fn take_error(&self) -> io::Result> { - let raw: c_int = getsockopt(self, SOL_SOCKET, SO_ERROR)?; + let raw = getsockopt::( + self, + SOL_SOCKET.try_into().unwrap(), + SO_ERROR.try_into().unwrap() + )?; if raw == 0 { Ok(None) } else { @@ -179,11 +220,11 @@ impl Socket { } None => 0, }; - setsockopt(self, SOL_SOCKET, kind, timeout) + setsockopt(self, SOL_SOCKET.try_into().unwrap(), kind, timeout) } pub fn timeout(&self, kind: c_int) -> io::Result> { - let raw: DWORD = getsockopt(self, SOL_SOCKET, kind)?; + let raw: u32 = getsockopt(self, SOL_SOCKET.try_into().unwrap(), kind)?; if raw == 0 { Ok(None) } else { @@ -195,44 +236,40 @@ impl Socket { } pub fn setsockopt(sock: &Socket, opt: c_int, val: c_int, payload: T) -> io::Result<()> { - unsafe { - let payload = &payload as *const T as *const _; - wsa_syscall!( - WinSock::setsockopt( - sock.as_raw_socket() as usize, - opt, - val, - payload, - mem::size_of::() as socklen_t, - ), - PartialEq::eq, - SOCKET_ERROR - )?; - Ok(()) - } + let payload = &payload as *const T as *const _; + wsa_syscall!( + c_setsockopt( + sock.as_raw_socket() as usize, + opt, + val, + payload, + mem::size_of::() as i32, + ), + PartialEq::eq, + SOCKET_ERROR + )?; + Ok(()) } pub fn getsockopt(sock: &Socket, opt: c_int, val: c_int) -> io::Result { - unsafe { - let mut slot: T = mem::zeroed(); - let mut len = mem::size_of::() as socklen_t; - wsa_syscall!( - WinSock::getsockopt( - sock.as_raw_socket() as _, - opt, - val, - &mut slot as *mut _ as *mut _, - &mut len, - ), - PartialEq::eq, - SOCKET_ERROR - )?; - assert_eq!(len as usize, mem::size_of::()); - Ok(slot) - } + let mut slot: T = unsafe { mem::zeroed() }; + let mut len = mem::size_of::() as i32; + wsa_syscall!( + c_getsockopt( + sock.as_raw_socket() as _, + opt, + val, + &mut slot as *mut _ as *mut _, + &mut len, + ), + PartialEq::eq, + SOCKET_ERROR + )?; + assert_eq!(len as usize, mem::size_of::()); + Ok(slot) } -fn dur2timeout(dur: Duration) -> DWORD { +fn dur2timeout(dur: Duration) -> u32 { // Note that a duration is a (u64, u32) (seconds, nanoseconds) pair, and the // timeouts in windows APIs are typically u32 milliseconds. To translate, we // have two pieces to take care of: @@ -251,10 +288,10 @@ fn dur2timeout(dur: Duration) -> DWORD { }) }) .map(|ms| { - if ms > ::max_value() as u64 { + if ms > ::max_value() as u64 { INFINITE } else { - ms as DWORD + ms as u32 } }) .unwrap_or(INFINITE) diff --git a/src/sys/windows/uds/stream.rs b/src/sys/windows/uds/stream.rs index 92b241b23..7702e9ca3 100644 --- a/src/sys/windows/uds/stream.rs +++ b/src/sys/windows/uds/stream.rs @@ -1,24 +1,31 @@ use std::io; use std::os::windows::io::{AsRawSocket, FromRawSocket}; +use std::convert::TryInto; use std::path::Path; -use windows_sys::Win32::Networking::WinSock; +use windows_sys::Win32::Networking::WinSock::{self, SOCKET_ERROR, connect as sys_connect, ioctlsocket, FIONBIO}; -use super::{stdnet as net, socket_addr}; +use super::{stdnet::{self as net}, socket_addr}; use crate::net::SocketAddr; use crate::sys::windows::net::{init, new_socket}; pub(crate) fn connect(path: &Path) -> io::Result { init(); - let socket = new_socket(WinSock::AF_UNIX, WinSock::SOCK_STREAM)?; + let socket = new_socket(WinSock::AF_UNIX.into(), WinSock::SOCK_STREAM)?; let (sockaddr, socklen) = socket_addr(path)?; let sockaddr = &sockaddr as *const WinSock::sockaddr_un as *const WinSock::SOCKADDR; + // Put into blocking mode to connect. wsa_syscall!( - connect(socket, sockaddr, socklen as _), + ioctlsocket(socket, FIONBIO, &mut 0), PartialEq::eq, SOCKET_ERROR )?; - match syscall!(connect(socket, sockaddr, socklen)) { + + match wsa_syscall!( + sys_connect(socket, sockaddr, socklen as _), + PartialEq::eq, + SOCKET_ERROR + ) { Ok(_) => {} Err(ref err) if err.raw_os_error() == Some(WinSock::WSAEINPROGRESS) => {} Err(e) => { @@ -29,8 +36,13 @@ pub(crate) fn connect(path: &Path) -> io::Result { return Err(e); } } + wsa_syscall!( + ioctlsocket(socket, FIONBIO, &mut 1), + PartialEq::eq, + SOCKET_ERROR + )?; - Ok(unsafe { net::UnixStream::from_raw_socket(socket) }) + Ok(unsafe { net::UnixStream::from_raw_socket(socket.try_into().unwrap()) }) } pub(crate) fn pair() -> io::Result<(net::UnixStream, net::UnixStream)> { diff --git a/tests/unix_listener.rs b/tests/unix_listener.rs index 0aeda8153..186afbd6e 100644 --- a/tests/unix_listener.rs +++ b/tests/unix_listener.rs @@ -1,8 +1,11 @@ -#![cfg(all(unix, feature = "os-poll", feature = "net"))] +#![cfg(all(feature = "os-poll", feature = "net"))] use mio::net::UnixListener; use mio::{Interest, Token}; use std::io::{self, Read}; +#[cfg(windows)] +use mio::net::{stdnet as net}; +#[cfg(unix)] use std::os::unix::net; use std::path::{Path, PathBuf}; use std::sync::{Arc, Barrier}; diff --git a/tests/unix_stream.rs b/tests/unix_stream.rs index 79b7c3d4b..e5128ae3f 100644 --- a/tests/unix_stream.rs +++ b/tests/unix_stream.rs @@ -1,9 +1,12 @@ -#![cfg(all(unix, feature = "os-poll", feature = "net"))] +#![cfg(all(feature = "os-poll", feature = "net"))] use mio::net::UnixStream; use mio::{Interest, Token}; use std::io::{self, IoSlice, IoSliceMut, Read, Write}; use std::net::Shutdown; +#[cfg(windows)] +use mio::net::{stdnet as net}; +#[cfg(unix)] use std::os::unix::net; use std::path::Path; use std::sync::mpsc::channel; @@ -217,6 +220,11 @@ fn unix_stream_shutdown_write() { vec![ExpectEvent::new(TOKEN_1, Interest::WRITABLE)], ); + // TODO: have to re-register here to reset user_events + poll.registry() + .reregister(&mut stream, TOKEN_1, Interest::WRITABLE.add(Interest::READABLE)) + .unwrap(); + checked_write!(stream.write(DATA1)); expect_events( &mut poll, @@ -241,7 +249,13 @@ fn unix_stream_shutdown_write() { ); let err = stream.write(DATA2).unwrap_err(); + #[cfg(unix)] assert_eq!(err.kind(), io::ErrorKind::BrokenPipe); + #[cfg(windows)] + { + use windows_sys::Win32::Networking::WinSock::WSAESHUTDOWN; + assert_eq!(err.raw_os_error(), Some(WSAESHUTDOWN)); + } // Read should be ok let mut buf = [0; DEFAULT_BUF_SIZE]; @@ -444,6 +458,10 @@ where expect_read!(stream.read(&mut buf), DATA1); assert!(stream.take_error().unwrap().is_none()); + // TODO: have to re-register here to reset user_events + poll.registry() + .reregister(&mut stream, TOKEN_1, Interest::WRITABLE.add(Interest::READABLE)) + .unwrap(); let bufs = [IoSlice::new(DATA1), IoSlice::new(DATA2)]; let wrote = stream.write_vectored(&bufs).unwrap(); From 35912383e31f4bc3e1f11bd629d7657e954a5309 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Tue, 16 Aug 2022 00:23:58 -0700 Subject: [PATCH 04/34] add docs back in --- src/sys/windows/uds/stdnet/mod.rs | 2 +- src/sys/windows/uds/stdnet/net.rs | 276 +++++++++++++++++++++++++-- src/sys/windows/uds/stdnet/socket.rs | 46 +---- 3 files changed, 276 insertions(+), 48 deletions(-) diff --git a/src/sys/windows/uds/stdnet/mod.rs b/src/sys/windows/uds/stdnet/mod.rs index c4164c572..e3f15db01 100644 --- a/src/sys/windows/uds/stdnet/mod.rs +++ b/src/sys/windows/uds/stdnet/mod.rs @@ -17,7 +17,7 @@ pub fn path_offset(addr: &WinSock::sockaddr_un) -> usize { path - base } -pub fn socket_addr(path: &Path) -> io::Result<(WinSock::sockaddr_un, c_int)> { +fn socket_addr(path: &Path) -> io::Result<(WinSock::sockaddr_un, c_int)> { let sockaddr = mem::MaybeUninit::::zeroed(); // This is safe to assume because a `WinSock::sockaddr_un` filled with `0` diff --git a/src/sys/windows/uds/stdnet/net.rs b/src/sys/windows/uds/stdnet/net.rs index 39aade7d7..784ffccc5 100644 --- a/src/sys/windows/uds/stdnet/net.rs +++ b/src/sys/windows/uds/stdnet/net.rs @@ -8,23 +8,26 @@ use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket} use std::path::Path; use std::time::Duration; -use windows_sys::Win32::Networking::WinSock::{ - self, - bind, - connect, - getpeername, - getsockname, - listen, - SO_RCVTIMEO, - SOCKET_ERROR, - SO_SNDTIMEO -}; +use windows_sys::Win32::Networking::WinSock::{self, bind, connect, getpeername, getsockname, listen, SO_RCVTIMEO, SOCKET_ERROR, SO_SNDTIMEO}; use crate::sys::windows::net::init; use super::socket::Socket; use super::{socket_addr, SocketAddr}; /// A Unix stream socket +/// +/// # Examples +/// +/// ```no_run +/// use mio::net::stdnet::UnixStream; +/// use std::io::prelude::*; +/// +/// let mut stream = UnixStream::connect("/path/to/my/socket").unwrap(); +/// stream.write_all(b"hello world").unwrap(); +/// let mut response = String::new(); +/// stream.read_to_string(&mut response).unwrap(); +/// println!("{}", response); +/// ``` pub struct UnixStream(Socket); impl fmt::Debug for UnixStream { @@ -43,6 +46,20 @@ impl fmt::Debug for UnixStream { impl UnixStream { /// Connects to the socket named by `path`. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixStream; + /// + /// let socket = match UnixStream::connect("/tmp/sock") { + /// Ok(sock) => sock, + /// Err(e) => { + /// println!("Couldn't connect: {:?}", e); + /// return + /// } + /// }; + /// ``` pub fn connect>(path: P) -> io::Result { init(); fn inner(path: &Path) -> io::Result { @@ -69,11 +86,28 @@ impl UnixStream { /// object references. Both handles will read and write the same stream of /// data, and options set on one stream will be propagated to the other /// stream. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixStream; + /// + /// let socket = UnixStream::connect("/tmp/sock").unwrap(); + /// let sock_copy = socket.try_clone().expect("Couldn't clone socket"); + /// ``` pub fn try_clone(&self) -> io::Result { self.0.duplicate().map(UnixStream) } /// Returns the socket address of the local half of this connection. + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixStream; + /// + /// let socket = UnixStream::connect("/tmp/sock").unwrap(); + /// let addr = socket.local_addr().expect("Couldn't get local address"); + /// ``` pub fn local_addr(&self) -> io::Result { SocketAddr::new(|addr, len| { wsa_syscall!( @@ -85,6 +119,15 @@ impl UnixStream { } /// Returns the socket address of the remote half of this connection. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixStream; + /// + /// let socket = UnixStream::connect("/tmp/sock").unwrap(); + /// let addr = socket.peer_addr().expect("Couldn't get peer address"); + /// ``` pub fn peer_addr(&self) -> io::Result { SocketAddr::new(|addr, len| { wsa_syscall!( @@ -96,11 +139,31 @@ impl UnixStream { } /// Moves the socket into or out of nonblocking mode. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixStream; + /// + /// let socket = UnixStream::connect("/tmp/sock").unwrap(); + /// socket.set_nonblocking(true).expect("Couldn't set nonblocking"); + /// ``` pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { self.0.set_nonblocking(nonblocking) } /// Returns the value of the `SO_ERROR` option. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixStream; + /// + /// let socket = UnixStream::connect("/tmp/sock").unwrap(); + /// if let Ok(Some(err)) = socket.take_error() { + /// println!("Got error: {:?}", err); + /// } + /// ``` pub fn take_error(&self) -> io::Result> { self.0.take_error() } @@ -110,6 +173,16 @@ impl UnixStream { /// This function will cause all pending and future I/O calls on the /// specified portions to immediately return with an appropriate value /// (see the documentation for `Shutdown`). + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixStream; + /// use std::net::Shutdown; + /// + /// let socket = UnixStream::connect("/tmp/sock").unwrap(); + /// socket.shutdown(Shutdown::Both).expect("shutdown function failed"); + /// ``` pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { self.0.shutdown(how) } @@ -145,6 +218,15 @@ impl UnixStream { /// If the value specified is `None`, then `read` calls will block /// indefinitely. An `Err` is returned if the zero `Duration` is /// passed to this method. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixStream; + /// + /// let socket = UnixStream::connect("/tmp/sock").unwrap(); + /// socket.set_read_timeout(None).expect("Couldn't set read timeout"); + /// ``` pub fn set_read_timeout(&self, dur: Option) -> io::Result<()> { self.0.set_timeout(dur, SO_RCVTIMEO.try_into().unwrap()) } @@ -154,16 +236,44 @@ impl UnixStream { /// If the value specified is `None`, then `write` calls will block /// indefinitely. An `Err` is returned if the zero `Duration` is /// passed to this method. + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixStream; + /// + /// let socket = UnixStream::connect("/tmp/sock").unwrap(); + /// socket.set_write_timeout(None).expect("Couldn't set write timeout"); + /// ``` pub fn set_write_timeout(&self, dur: Option) -> io::Result<()> { self.0.set_timeout(dur, SO_SNDTIMEO.try_into().unwrap()) } /// Returns the read timeout of this socket. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixStream; + /// + /// let socket = UnixStream::connect("/tmp/sock").unwrap(); + /// socket.set_read_timeout(None).expect("Couldn't set read timeout"); + /// assert_eq!(socket.read_timeout().unwrap(), None); + /// ``` pub fn read_timeout(&self) -> io::Result> { self.0.timeout(SO_RCVTIMEO.try_into().unwrap()) } /// Returns the write timeout of this socket. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixStream; + /// + /// let socket = UnixStream::connect("/tmp/sock").unwrap(); + /// socket.set_write_timeout(None).expect("Couldn't set write timeout"); + /// assert_eq!(socket.write_timeout().unwrap(), None); + /// ``` pub fn write_timeout(&self) -> io::Result> { self.0.timeout(SO_SNDTIMEO.try_into().unwrap()) } @@ -239,6 +349,33 @@ impl IntoRawSocket for UnixStream { } /// A Unix domain socket server +/// +/// # Examples +/// +/// ```no_run +/// use std::thread; +/// use mio::net::stdnet::{UnixStream, UnixListener}; +/// +/// fn handle_client(stream: UnixStream) { +/// // ... +/// } +/// +/// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); +/// +/// // accept connections and process them, spawning a new thread for each one +/// for stream in listener.incoming() { +/// match stream { +/// Ok(stream) => { +/// /* connection succeeded */ +/// thread::spawn(|| handle_client(stream)); +/// } +/// Err(err) => { +/// /* connection failed */ +/// break; +/// } +/// } +/// } +/// ``` pub struct UnixListener(Socket); impl fmt::Debug for UnixListener { @@ -254,6 +391,20 @@ impl fmt::Debug for UnixListener { impl UnixListener { /// Creates a new `UnixListener` bound to the specified socket. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixListener; + /// + /// let listener = match UnixListener::bind("/path/to/the/socket") { + /// Ok(sock) => sock, + /// Err(e) => { + /// println!("Couldn't connect: {:?}", e); + /// return + /// } + /// }; + /// ``` pub fn bind>(path: P) -> io::Result { init(); fn inner(path: &Path) -> io::Result { @@ -287,6 +438,19 @@ impl UnixListener { /// the remote peer's address will be returned. /// /// [`UnixStream`]: struct.UnixStream.html + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixListener; + /// + /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); + /// + /// match listener.accept() { + /// Ok((socket, addr)) => println!("Got a client: {:?}", addr), + /// Err(e) => println!("accept function failed: {:?}", e), + /// } + /// ``` pub fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { let mut storage: WinSock::sockaddr_un = unsafe { mem::zeroed() }; let mut len = mem::size_of_val(&storage) as c_int; @@ -300,11 +464,31 @@ impl UnixListener { /// The returned `UnixListener` is a reference to the same socket that this /// object references. Both handles can be used to accept incoming /// connections and options set on one listener will affect the other. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixListener; + /// + /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); + /// + /// let listener_copy = listener.try_clone().expect("Couldn't clone socket"); + /// ``` pub fn try_clone(&self) -> io::Result { self.0.duplicate().map(UnixListener) } /// Returns the local socket address of this listener. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixListener; + /// + /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); + /// + /// let addr = listener.local_addr().expect("Couldn't get local address"); + /// ``` pub fn local_addr(&self) -> io::Result { SocketAddr::new(|addr, len| { wsa_syscall!( @@ -316,11 +500,33 @@ impl UnixListener { } /// Moves the socket into or out of nonblocking mode. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixListener; + /// + /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); + /// + /// listener.set_nonblocking(true).expect("Couldn't set nonblocking"); + /// ``` pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { self.0.set_nonblocking(nonblocking) } /// Returns the value of the `SO_ERROR` option. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixListener; + /// + /// let listener = UnixListener::bind("/tmp/sock").unwrap(); + /// + /// if let Ok(Some(err)) = listener.take_error() { + /// println!("Got error: {:?}", err); + /// } + /// ``` pub fn take_error(&self) -> io::Result> { self.0.take_error() } @@ -331,6 +537,30 @@ impl UnixListener { /// peer's [`SocketAddr`] structure. /// /// [`SocketAddr`]: struct.SocketAddr.html + /// + /// # Examples + /// + /// ```no_run + /// use std::thread; + /// use mio::net::stdnet::{UnixStream, UnixListener}; + /// + /// fn handle_client(stream: UnixStream) { + /// // ... + /// } + /// + /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); + /// + /// for stream in listener.incoming() { + /// match stream { + /// Ok(stream) => { + /// thread::spawn(|| handle_client(stream)); + /// } + /// Err(err) => { + /// break; + /// } + /// } + /// } + /// ``` pub fn incoming<'a>(&'a self) -> Incoming<'a> { Incoming { listener: self } } @@ -370,6 +600,30 @@ impl<'a> IntoIterator for &'a UnixListener { /// It will never return `None`. /// /// [`UnixListener`]: struct.UnixListener.html +/// +/// # Examples +/// +/// ```no_run +/// use std::thread; +/// use mio::net::stdnet::{UnixStream, UnixListener}; +/// +/// fn handle_client(stream: UnixStream) { +/// // ... +/// } +/// +/// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); +/// +/// for stream in listener.incoming() { +/// match stream { +/// Ok(stream) => { +/// thread::spawn(|| handle_client(stream)); +/// } +/// Err(err) => { +/// break; +/// } +/// } +/// } +/// ``` #[derive(Debug)] pub struct Incoming<'a> { listener: &'a UnixListener, diff --git a/src/sys/windows/uds/stdnet/socket.rs b/src/sys/windows/uds/stdnet/socket.rs index 5a1fccabb..ee0e8766f 100644 --- a/src/sys/windows/uds/stdnet/socket.rs +++ b/src/sys/windows/uds/stdnet/socket.rs @@ -14,33 +14,7 @@ use windows_sys::Win32::Foundation::{ }; use windows_sys::Win32::System::Threading::GetCurrentProcessId; use windows_sys::Win32::System::WindowsProgramming::INFINITE; -use windows_sys::Win32::Networking::WinSock::{ - WSABUF, - AF_UNIX, - FIONBIO, - INVALID_SOCKET, - SD_BOTH, - SD_RECEIVE, - SD_SEND, - SOCKADDR, - SOCKET, - SOCKET_ERROR, - SOCK_STREAM, - SOL_SOCKET, - SO_ERROR, - WSADuplicateSocketW, - WSAPROTOCOL_INFOW, - WSASocketW, - WSA_FLAG_OVERLAPPED, - accept, - closesocket, - getsockopt as c_getsockopt, - ioctlsocket, - recv, - send, - setsockopt as c_setsockopt, - shutdown, -}; +use windows_sys::Win32::Networking::WinSock::{INVALID_SOCKET, SOCKADDR, SOCKET, SOCKET_ERROR, SOCK_STREAM, SOL_SOCKET, SO_ERROR, WSADuplicateSocketW, WSAPROTOCOL_INFOW, WSASocketW, accept, closesocket, getsockopt as c_getsockopt, ioctlsocket, recv, send, setsockopt as c_setsockopt, shutdown}; #[derive(Debug)] pub struct Socket(SOCKET); @@ -49,12 +23,12 @@ impl Socket { pub fn new() -> io::Result { let socket = wsa_syscall!( WSASocketW( - AF_UNIX.into(), - SOCK_STREAM.into(), + WinSock::AF_UNIX.into(), + WinSock::SOCK_STREAM.into(), 0, ptr::null_mut(), 0, - WSA_FLAG_OVERLAPPED, + WinSock::WSA_FLAG_OVERLAPPED, ), PartialEq::eq, INVALID_SOCKET @@ -124,7 +98,7 @@ impl Socket { pub fn read_vectored(&self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { let mut total = 0; for slice in &mut *bufs { - let wsa_buf = unsafe { *(slice as *const _ as *const WSABUF) }; + let wsa_buf = unsafe { *(slice as *const _ as *const WinSock::WSABUF) }; let len = wsa_buf.len; let buf = unsafe { std::slice::from_raw_parts_mut(wsa_buf.buf, len.try_into().unwrap()) }; total += self.recv_with_flags(buf, 0)?; @@ -145,7 +119,7 @@ impl Socket { pub fn write_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result { let mut total = 0; for slice in bufs { - let wsa_buf = unsafe { *(slice as *const _ as *const WSABUF) }; + let wsa_buf = unsafe { *(slice as *const _ as *const WinSock::WSABUF) }; let len = wsa_buf.len; let buf = unsafe { std::slice::from_raw_parts(wsa_buf.buf, len.try_into().unwrap()) }; dbg!(buf); @@ -172,7 +146,7 @@ impl Socket { pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { let mut nonblocking: c_ulong = if nonblocking { 1 } else { 0 }; wsa_syscall!( - ioctlsocket(self.0, FIONBIO, &mut nonblocking), + ioctlsocket(self.0, WinSock::FIONBIO, &mut nonblocking), PartialEq::eq, SOCKET_ERROR )?; @@ -181,9 +155,9 @@ impl Socket { pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { let how = match how { - Shutdown::Write => SD_SEND, - Shutdown::Read => SD_RECEIVE, - Shutdown::Both => SD_BOTH, + Shutdown::Write => WinSock::SD_SEND, + Shutdown::Read => WinSock::SD_RECEIVE, + Shutdown::Both => WinSock::SD_BOTH, }; wsa_syscall!( shutdown(self.0, how.try_into().unwrap()), From 9015ca296f8ff8eb39a5e25fdcaf87b973e0b973 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Tue, 16 Aug 2022 00:40:23 -0700 Subject: [PATCH 05/34] cleanup --- src/sys/windows/uds/listener.rs | 2 +- src/sys/windows/uds/mod.rs | 4 +--- src/sys/windows/uds/stdnet/mod.rs | 7 ++++--- src/sys/windows/uds/stdnet/net.rs | 17 +++++++++++++++++ src/sys/windows/uds/stdnet/socket.rs | 4 ++-- src/sys/windows/uds/stream.rs | 2 +- tests/unix_listener.rs | 4 ++-- tests/unix_stream.rs | 20 ++++++++++++++------ 8 files changed, 42 insertions(+), 18 deletions(-) diff --git a/src/sys/windows/uds/listener.rs b/src/sys/windows/uds/listener.rs index e6759f4e0..39acc1e98 100644 --- a/src/sys/windows/uds/listener.rs +++ b/src/sys/windows/uds/listener.rs @@ -12,7 +12,7 @@ use windows_sys::Win32::Networking::WinSock::{ accept as sys_accept }; -use super::{stdnet as net, socket_addr}; +use super::stdnet::{self as net, socket_addr}; use crate::net::{SocketAddr, UnixStream}; use crate::sys::windows::net::{init, new_socket}; diff --git a/src/sys/windows/uds/mod.rs b/src/sys/windows/uds/mod.rs index 8e5ed0704..421b2bb2c 100644 --- a/src/sys/windows/uds/mod.rs +++ b/src/sys/windows/uds/mod.rs @@ -1,5 +1,5 @@ pub mod stdnet; -pub use self::stdnet::{path_offset, SocketAddr}; +pub use self::stdnet::SocketAddr; cfg_os_poll! { use std::convert::TryInto; @@ -14,8 +14,6 @@ cfg_os_poll! { pub(crate) mod listener; pub(crate) mod stream; - pub use self::stdnet::socket_addr; - pub(crate) fn local_addr(socket: RawSocket) -> io::Result { SocketAddr::new(|sockaddr, socklen| { wsa_syscall!( diff --git a/src/sys/windows/uds/stdnet/mod.rs b/src/sys/windows/uds/stdnet/mod.rs index e3f15db01..70c5eeda6 100644 --- a/src/sys/windows/uds/stdnet/mod.rs +++ b/src/sys/windows/uds/stdnet/mod.rs @@ -1,3 +1,5 @@ +//! Windows specific networking functionality + use std::ascii; use std::fmt; use std::io; @@ -10,14 +12,14 @@ use windows_sys::Win32::Networking::WinSock::{self, SOCKADDR}; mod net; mod socket; -pub fn path_offset(addr: &WinSock::sockaddr_un) -> usize { +pub(crate) fn path_offset(addr: &WinSock::sockaddr_un) -> usize { // Work with an actual instance of the type since using a null pointer is UB let base = addr as *const _ as usize; let path = &addr.sun_path as *const _ as usize; path - base } -fn socket_addr(path: &Path) -> io::Result<(WinSock::sockaddr_un, c_int)> { +pub(crate) fn socket_addr(path: &Path) -> io::Result<(WinSock::sockaddr_un, c_int)> { let sockaddr = mem::MaybeUninit::::zeroed(); // This is safe to assume because a `WinSock::sockaddr_un` filled with `0` @@ -178,4 +180,3 @@ impl<'a> fmt::Display for AsciiEscaped<'a> { } pub use self::net::{UnixListener, UnixStream}; -pub use self::socket::Socket; diff --git a/src/sys/windows/uds/stdnet/net.rs b/src/sys/windows/uds/stdnet/net.rs index 784ffccc5..e56f5a6f8 100644 --- a/src/sys/windows/uds/stdnet/net.rs +++ b/src/sys/windows/uds/stdnet/net.rs @@ -187,6 +187,23 @@ impl UnixStream { self.0.shutdown(how) } + /// Creates an unnamed pair of connected sockets. + /// + /// Returns two `UnixStream`s which are connected to each other. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixStream; + /// + /// let (sock1, sock2) = match UnixStream::pair() { + /// Ok((sock1, sock2)) => (sock1, sock2), + /// Err(e) => { + /// println!("Couldn't create a pair of sockets: {e:?}"); + /// return + /// } + /// } + /// ``` pub fn pair() -> io::Result<(Self, Self)> { use std::sync::{Arc, RwLock}; use std::thread::spawn; diff --git a/src/sys/windows/uds/stdnet/socket.rs b/src/sys/windows/uds/stdnet/socket.rs index ee0e8766f..0eb7a8d2e 100644 --- a/src/sys/windows/uds/stdnet/socket.rs +++ b/src/sys/windows/uds/stdnet/socket.rs @@ -14,7 +14,7 @@ use windows_sys::Win32::Foundation::{ }; use windows_sys::Win32::System::Threading::GetCurrentProcessId; use windows_sys::Win32::System::WindowsProgramming::INFINITE; -use windows_sys::Win32::Networking::WinSock::{INVALID_SOCKET, SOCKADDR, SOCKET, SOCKET_ERROR, SOCK_STREAM, SOL_SOCKET, SO_ERROR, WSADuplicateSocketW, WSAPROTOCOL_INFOW, WSASocketW, accept, closesocket, getsockopt as c_getsockopt, ioctlsocket, recv, send, setsockopt as c_setsockopt, shutdown}; +use windows_sys::Win32::Networking::WinSock::{self, INVALID_SOCKET, SOCKADDR, SOCKET, SOCKET_ERROR, SOL_SOCKET, SO_ERROR, WSADuplicateSocketW, WSAPROTOCOL_INFOW, WSASocketW, accept, closesocket, getsockopt as c_getsockopt, ioctlsocket, recv, send, setsockopt as c_setsockopt, shutdown}; #[derive(Debug)] pub struct Socket(SOCKET); @@ -67,7 +67,7 @@ impl Socket { info.iProtocol, &mut info, 0, - WSA_FLAG_OVERLAPPED, + WinSock::WSA_FLAG_OVERLAPPED, ), PartialEq::eq, INVALID_SOCKET diff --git a/src/sys/windows/uds/stream.rs b/src/sys/windows/uds/stream.rs index 7702e9ca3..6dd9f66ac 100644 --- a/src/sys/windows/uds/stream.rs +++ b/src/sys/windows/uds/stream.rs @@ -4,7 +4,7 @@ use std::convert::TryInto; use std::path::Path; use windows_sys::Win32::Networking::WinSock::{self, SOCKET_ERROR, connect as sys_connect, ioctlsocket, FIONBIO}; -use super::{stdnet::{self as net}, socket_addr}; +use super::stdnet::{self as net, socket_addr}; use crate::net::SocketAddr; use crate::sys::windows::net::{init, new_socket}; diff --git a/tests/unix_listener.rs b/tests/unix_listener.rs index 186afbd6e..4c8d9f072 100644 --- a/tests/unix_listener.rs +++ b/tests/unix_listener.rs @@ -1,10 +1,10 @@ #![cfg(all(feature = "os-poll", feature = "net"))] +#[cfg(windows)] +use mio::net::stdnet as net; use mio::net::UnixListener; use mio::{Interest, Token}; use std::io::{self, Read}; -#[cfg(windows)] -use mio::net::{stdnet as net}; #[cfg(unix)] use std::os::unix::net; use std::path::{Path, PathBuf}; diff --git a/tests/unix_stream.rs b/tests/unix_stream.rs index e5128ae3f..8edaf731e 100644 --- a/tests/unix_stream.rs +++ b/tests/unix_stream.rs @@ -1,11 +1,11 @@ #![cfg(all(feature = "os-poll", feature = "net"))] +#[cfg(windows)] +use mio::net::stdnet as net; use mio::net::UnixStream; use mio::{Interest, Token}; use std::io::{self, IoSlice, IoSliceMut, Read, Write}; use std::net::Shutdown; -#[cfg(windows)] -use mio::net::{stdnet as net}; #[cfg(unix)] use std::os::unix::net; use std::path::Path; @@ -222,7 +222,11 @@ fn unix_stream_shutdown_write() { // TODO: have to re-register here to reset user_events poll.registry() - .reregister(&mut stream, TOKEN_1, Interest::WRITABLE.add(Interest::READABLE)) + .reregister( + &mut stream, + TOKEN_1, + Interest::WRITABLE.add(Interest::READABLE), + ) .unwrap(); checked_write!(stream.write(DATA1)); @@ -318,8 +322,8 @@ fn unix_stream_shutdown_both() { let err = stream.write(DATA2).unwrap_err(); #[cfg(unix)] assert_eq!(err.kind(), io::ErrorKind::BrokenPipe); - #[cfg(window)] - assert_eq!(err.kind(), io::ErrorKind::ConnectionAbroted); + #[cfg(windows)] + assert_eq!(err.kind(), io::ErrorKind::ConnectionAborted); // Close the connection to allow the remote to shutdown drop(stream); @@ -460,7 +464,11 @@ where assert!(stream.take_error().unwrap().is_none()); // TODO: have to re-register here to reset user_events poll.registry() - .reregister(&mut stream, TOKEN_1, Interest::WRITABLE.add(Interest::READABLE)) + .reregister( + &mut stream, + TOKEN_1, + Interest::WRITABLE.add(Interest::READABLE), + ) .unwrap(); let bufs = [IoSlice::new(DATA1), IoSlice::new(DATA2)]; From 7265833fc2dbb712a5315efda77329b57e640eed Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Tue, 16 Aug 2022 11:48:35 -0700 Subject: [PATCH 06/34] remove log statements --- src/sys/windows/selector.rs | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/sys/windows/selector.rs b/src/sys/windows/selector.rs index 777d12413..cd4d5cb66 100644 --- a/src/sys/windows/selector.rs +++ b/src/sys/windows/selector.rs @@ -197,7 +197,6 @@ impl SockState { // This is the function called from the overlapped using as Arc>. Watch out for reference counting. fn feed_event(&mut self) -> Option { - println!("Feed event..."); self.poll_status = SockPollStatus::Idle; self.pending_evts = 0; @@ -261,9 +260,7 @@ impl SockState { cfg_io_source! { impl SockState { fn new(raw_socket: RawSocket, afd: Arc) -> io::Result { - println!("init state: {raw_socket:?}"); let base = get_base_socket(raw_socket)?; - println!("init state:bas {base:?}"); Ok(SockState { iosb: IoStatusBlock::zeroed(), poll_info: AfdPollInfo::zeroed(), @@ -618,9 +615,7 @@ cfg_io_source! { /// GetQueuedCompletionStatusEx() we tell the kernel about the registered /// socket event(s) immediately. unsafe fn update_sockets_events_if_polling(&self) -> io::Result<()> { - println!("POLLING"); if self.is_polling.load(Ordering::Acquire) { - println!("POLLING IMMEDIATELY"); self.update_sockets_events() } else { Ok(()) @@ -667,7 +662,6 @@ cfg_io_source! { #[allow(dead_code)] fn get_base_socket(raw_socket: RawSocket) -> io::Result { let res = try_get_base_socket(raw_socket, SIO_BASE_HANDLE); - println!("FIRST {res:?}"); if let Ok(base_socket) = res { return Ok(base_socket); } @@ -683,7 +677,6 @@ cfg_io_source! { SIO_BSP_HANDLE, ] { let r = try_get_base_socket(raw_socket, ioctl); - println!("OTHER {r:?}"); if let Ok(base_socket) = r { // Since we know now that we're dealing with an LSP (otherwise // SIO_BASE_HANDLE would't have failed), only return any result From 3884bb58f9d489cff5cd398931c80d6c6c078f22 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Tue, 16 Aug 2022 12:17:36 -0700 Subject: [PATCH 07/34] clean up selector --- src/sys/windows/selector.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/sys/windows/selector.rs b/src/sys/windows/selector.rs index cd4d5cb66..9f3cf68dd 100644 --- a/src/sys/windows/selector.rs +++ b/src/sys/windows/selector.rs @@ -260,12 +260,11 @@ impl SockState { cfg_io_source! { impl SockState { fn new(raw_socket: RawSocket, afd: Arc) -> io::Result { - let base = get_base_socket(raw_socket)?; Ok(SockState { iosb: IoStatusBlock::zeroed(), poll_info: AfdPollInfo::zeroed(), afd, - base_socket: base, + base_socket: get_base_socket(raw_socket)?, user_evts: 0, pending_evts: 0, user_data: 0, @@ -659,7 +658,6 @@ cfg_io_source! { } } - #[allow(dead_code)] fn get_base_socket(raw_socket: RawSocket) -> io::Result { let res = try_get_base_socket(raw_socket, SIO_BASE_HANDLE); if let Ok(base_socket) = res { @@ -676,8 +674,7 @@ cfg_io_source! { SIO_BSP_HANDLE_POLL, SIO_BSP_HANDLE, ] { - let r = try_get_base_socket(raw_socket, ioctl); - if let Ok(base_socket) = r { + if let Ok(base_socket) = try_get_base_socket(raw_socket, ioctl) { // Since we know now that we're dealing with an LSP (otherwise // SIO_BASE_HANDLE would't have failed), only return any result // when it is different from the original `raw_socket`. From 5ca8952715781f81d866af9b38a1194558b65560 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Wed, 17 Aug 2022 22:08:20 -0700 Subject: [PATCH 08/34] clean up stream and listener sys logic --- src/sys/windows/uds/listener.rs | 66 +++---------------- src/sys/windows/uds/stdnet/net.rs | 97 +++++++++++++++------------- src/sys/windows/uds/stdnet/socket.rs | 20 +++--- src/sys/windows/uds/stream.rs | 50 +++----------- 4 files changed, 78 insertions(+), 155 deletions(-) diff --git a/src/sys/windows/uds/listener.rs b/src/sys/windows/uds/listener.rs index 39acc1e98..f3aff5b6a 100644 --- a/src/sys/windows/uds/listener.rs +++ b/src/sys/windows/uds/listener.rs @@ -1,68 +1,20 @@ -use std::{io, mem}; -use std::convert::TryInto; -use std::os::windows::io::{AsRawSocket, FromRawSocket}; +use std::io; +use std::os::windows::io::AsRawSocket; use std::path::Path; -use std::os::raw::c_int; -use windows_sys::Win32::Networking::WinSock::{ - self, - SOCKET_ERROR, - INVALID_SOCKET, - bind as sys_bind, - listen, - accept as sys_accept -}; -use super::stdnet::{self as net, socket_addr}; +use super::{stdnet as net}; use crate::net::{SocketAddr, UnixStream}; -use crate::sys::windows::net::{init, new_socket}; pub(crate) fn bind(path: &Path) -> io::Result { - init(); - let socket = new_socket(WinSock::AF_UNIX.into(), WinSock::SOCK_STREAM)?; - let (sockaddr, socklen) = socket_addr(path)?; - let sockaddr = &sockaddr as *const WinSock::sockaddr_un as *const WinSock::SOCKADDR; - - wsa_syscall!(sys_bind(socket, sockaddr, socklen as _), PartialEq::eq, SOCKET_ERROR) - .and_then(|_| wsa_syscall!(listen(socket, 128), PartialEq::eq, SOCKET_ERROR)) - .map_err(|err| { - // Close the socket if we hit an error, ignoring the error from - // closing since we can't pass back two errors. - let _ = unsafe { WinSock::closesocket(socket) }; - err - }) - .map(|_| unsafe { net::UnixListener::from_raw_socket(socket.try_into().unwrap()) }) + let listener = net::UnixListener::bind(path)?; + listener.set_nonblocking(true)?; + Ok(listener) } pub(crate) fn accept(listener: &net::UnixListener) -> io::Result<(UnixStream, SocketAddr)> { - let sockaddr = mem::MaybeUninit::::zeroed(); - - // This is safe to assume because a `WinSock::sockaddr_un` filled with `0` - // bytes is properly initialized. - // - // `0` is a valid value for `sockaddr_un::sun_family`; it is - // `WinSock::AF_UNSPEC`. - // - // `[0; 108]` is a valid value for `sockaddr_un::sun_path`; it begins an - // abstract path. - let mut sockaddr = unsafe { sockaddr.assume_init() }; - - sockaddr.sun_family = WinSock::AF_UNIX; - let mut socklen = mem::size_of_val(&sockaddr) as c_int; - - let socket = wsa_syscall!( - sys_accept( - listener.as_raw_socket().try_into().unwrap(), - &mut sockaddr as *mut WinSock::sockaddr_un as *mut WinSock::SOCKADDR, - &mut socklen as _ - ), - PartialEq::eq, - INVALID_SOCKET - ); - - socket - .map(|socket| unsafe { net::UnixStream::from_raw_socket(socket.try_into().unwrap()) }) - .map(UnixStream::from_std) - .map(|stream| (stream, SocketAddr::from_parts(sockaddr, socklen))) + listener.set_nonblocking(true)?; + let es = listener.accept().map(|(stream, addr)| (UnixStream::from_std(stream), addr)); + es } pub(crate) fn local_addr(listener: &net::UnixListener) -> io::Result { diff --git a/src/sys/windows/uds/stdnet/net.rs b/src/sys/windows/uds/stdnet/net.rs index e56f5a6f8..dc2458a3f 100644 --- a/src/sys/windows/uds/stdnet/net.rs +++ b/src/sys/windows/uds/stdnet/net.rs @@ -10,7 +10,6 @@ use std::time::Duration; use windows_sys::Win32::Networking::WinSock::{self, bind, connect, getpeername, getsockname, listen, SO_RCVTIMEO, SOCKET_ERROR, SO_SNDTIMEO}; -use crate::sys::windows::net::init; use super::socket::Socket; use super::{socket_addr, SocketAddr}; @@ -61,23 +60,23 @@ impl UnixStream { /// }; /// ``` pub fn connect>(path: P) -> io::Result { - init(); - fn inner(path: &Path) -> io::Result { - let inner = Socket::new()?; - let (addr, len) = socket_addr(path)?; - - wsa_syscall!( - connect( - inner.as_raw_socket() as _, - &addr as *const _ as *const _, - len as i32, - ), - PartialEq::eq, - SOCKET_ERROR - )?; - Ok(UnixStream(inner)) + let inner = Socket::new()?; + let (addr, len) = socket_addr(path.as_ref())?; + + match wsa_syscall!( + connect( + inner.as_raw_socket() as _, + &addr as *const _ as *const _, + len as i32, + ), + PartialEq::eq, + SOCKET_ERROR + ) { + Ok(_) => {}, + Err(ref err) if err.raw_os_error() == Some(WinSock::WSAEINPROGRESS) => {}, + Err(e) => return Err(e) } - inner(path.as_ref()) + Ok(UnixStream(inner)) } /// Creates a new independently owned handle to the underlying socket. @@ -212,7 +211,6 @@ impl UnixStream { let file_path = dir.path().join("socket"); let a: Arc>>> = Arc::new(RwLock::new(None)); let ul = UnixListener::bind(&file_path).unwrap(); - ul.set_nonblocking(true)?; let server = { let a = a.clone(); spawn(move || { @@ -222,7 +220,6 @@ impl UnixStream { }) }; let stream1 = UnixStream::connect(&file_path)?; - stream1.set_nonblocking(true)?; server .join() .map_err(|_| io::Error::from(io::ErrorKind::ConnectionRefused))?; @@ -423,29 +420,24 @@ impl UnixListener { /// }; /// ``` pub fn bind>(path: P) -> io::Result { - init(); - fn inner(path: &Path) -> io::Result { - let inner = Socket::new()?; - let (addr, len) = socket_addr(path)?; - - wsa_syscall!( - bind( - inner.as_raw_socket() as _, - &addr as *const _ as *const _, - len as _, - ), - PartialEq::eq, - SOCKET_ERROR - )?; - wsa_syscall!( - listen(inner.as_raw_socket() as _, 128), - PartialEq::eq, - SOCKET_ERROR - )?; - - Ok(UnixListener(inner)) - } - inner(path.as_ref()) + let inner = Socket::new()?; + let (addr, len) = socket_addr(path.as_ref())?; + + wsa_syscall!( + bind( + inner.as_raw_socket() as _, + &addr as *const _ as *const _, + len as _, + ), + PartialEq::eq, + SOCKET_ERROR + )?; + wsa_syscall!( + listen(inner.as_raw_socket() as _, 128), + PartialEq::eq, + SOCKET_ERROR + )?; + Ok(UnixListener(inner)) } /// Accepts a new incoming connection to this listener. @@ -469,10 +461,23 @@ impl UnixListener { /// } /// ``` pub fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { - let mut storage: WinSock::sockaddr_un = unsafe { mem::zeroed() }; - let mut len = mem::size_of_val(&storage) as c_int; - let sock = self.0.accept(&mut storage as *mut _ as *mut _, &mut len)?; - let addr = SocketAddr::from_parts(storage, len); + let sockaddr = mem::MaybeUninit::::zeroed(); + + // This is safe to assume because a `WinSock::sockaddr_un` filled with `0` + // bytes is properly initialized. + // + // `0` is a valid value for `sockaddr_un::sun_family`; it is + // `WinSock::AF_UNSPEC`. + // + // `[0; 108]` is a valid value for `sockaddr_un::sun_path`; it begins an + // abstract path. + let mut sockaddr = unsafe { sockaddr.assume_init() }; + + sockaddr.sun_family = WinSock::AF_UNIX; + let mut socklen = mem::size_of_val(&sockaddr) as c_int; + + let sock = self.0.accept(&mut sockaddr as *mut _ as *mut _, &mut socklen)?; + let addr = SocketAddr::from_parts(sockaddr, socklen); Ok((UnixStream(sock), addr)) } diff --git a/src/sys/windows/uds/stdnet/socket.rs b/src/sys/windows/uds/stdnet/socket.rs index 0eb7a8d2e..6620e1b26 100644 --- a/src/sys/windows/uds/stdnet/socket.rs +++ b/src/sys/windows/uds/stdnet/socket.rs @@ -16,11 +16,14 @@ use windows_sys::Win32::System::Threading::GetCurrentProcessId; use windows_sys::Win32::System::WindowsProgramming::INFINITE; use windows_sys::Win32::Networking::WinSock::{self, INVALID_SOCKET, SOCKADDR, SOCKET, SOCKET_ERROR, SOL_SOCKET, SO_ERROR, WSADuplicateSocketW, WSAPROTOCOL_INFOW, WSASocketW, accept, closesocket, getsockopt as c_getsockopt, ioctlsocket, recv, send, setsockopt as c_setsockopt, shutdown}; +use crate::sys::windows::net::init; + #[derive(Debug)] pub struct Socket(SOCKET); impl Socket { pub fn new() -> io::Result { + init(); let socket = wsa_syscall!( WSASocketW( WinSock::AF_UNIX.into(), @@ -28,14 +31,12 @@ impl Socket { 0, ptr::null_mut(), 0, - WinSock::WSA_FLAG_OVERLAPPED, + WinSock::WSA_FLAG_OVERLAPPED | WinSock::WSA_FLAG_NO_HANDLE_INHERIT, ), PartialEq::eq, INVALID_SOCKET )?; - let socket = Socket(socket); - socket.set_no_inherit()?; - Ok(socket) + Ok(Socket(socket)) } pub fn accept(&self, storage: *mut SOCKADDR, len: *mut c_int) -> io::Result { @@ -60,21 +61,19 @@ impl Socket { PartialEq::eq, SOCKET_ERROR )?; - let n = wsa_syscall!( + let socket = wsa_syscall!( WSASocketW( info.iAddressFamily, info.iSocketType, info.iProtocol, &mut info, 0, - WinSock::WSA_FLAG_OVERLAPPED, + WinSock::WSA_FLAG_OVERLAPPED | WinSock::WSA_FLAG_NO_HANDLE_INHERIT, ), PartialEq::eq, INVALID_SOCKET )?; - let socket = Socket(n); - socket.set_no_inherit()?; - Ok(socket) + Ok(Socket(socket)) } fn recv_with_flags(&self, buf: &mut [u8], flags: c_int) -> io::Result { @@ -103,7 +102,6 @@ impl Socket { let buf = unsafe { std::slice::from_raw_parts_mut(wsa_buf.buf, len.try_into().unwrap()) }; total += self.recv_with_flags(buf, 0)?; } - println!("Wrote vectored: {total:?}, {bufs:?}"); Ok(total as usize) } @@ -122,7 +120,6 @@ impl Socket { let wsa_buf = unsafe { *(slice as *const _ as *const WinSock::WSABUF) }; let len = wsa_buf.len; let buf = unsafe { std::slice::from_raw_parts(wsa_buf.buf, len.try_into().unwrap()) }; - dbg!(buf); let ret = wsa_syscall!( send(self.0, buf as *const _ as *const _, len as c_int, 0), PartialEq::eq, @@ -130,7 +127,6 @@ impl Socket { )?; total += ret; } - println!("Wrote vectored: {total:?}, {bufs:?}"); Ok(total as usize) } diff --git a/src/sys/windows/uds/stream.rs b/src/sys/windows/uds/stream.rs index 6dd9f66ac..c70bc1375 100644 --- a/src/sys/windows/uds/stream.rs +++ b/src/sys/windows/uds/stream.rs @@ -1,52 +1,22 @@ use std::io; -use std::os::windows::io::{AsRawSocket, FromRawSocket}; -use std::convert::TryInto; +use std::os::windows::io::{AsRawSocket}; use std::path::Path; -use windows_sys::Win32::Networking::WinSock::{self, SOCKET_ERROR, connect as sys_connect, ioctlsocket, FIONBIO}; - -use super::stdnet::{self as net, socket_addr}; +use super::stdnet::{self as net}; use crate::net::SocketAddr; -use crate::sys::windows::net::{init, new_socket}; +use crate::sys::windows::net::init; pub(crate) fn connect(path: &Path) -> io::Result { init(); - let socket = new_socket(WinSock::AF_UNIX.into(), WinSock::SOCK_STREAM)?; - let (sockaddr, socklen) = socket_addr(path)?; - let sockaddr = &sockaddr as *const WinSock::sockaddr_un as *const WinSock::SOCKADDR; - - // Put into blocking mode to connect. - wsa_syscall!( - ioctlsocket(socket, FIONBIO, &mut 0), - PartialEq::eq, - SOCKET_ERROR - )?; - - match wsa_syscall!( - sys_connect(socket, sockaddr, socklen as _), - PartialEq::eq, - SOCKET_ERROR - ) { - Ok(_) => {} - Err(ref err) if err.raw_os_error() == Some(WinSock::WSAEINPROGRESS) => {} - Err(e) => { - // Close the socket if we hit an error, ignoring the error - // from closing since we can't pass back two errors. - let _ = unsafe { WinSock::closesocket(socket) }; - - return Err(e); - } - } - wsa_syscall!( - ioctlsocket(socket, FIONBIO, &mut 1), - PartialEq::eq, - SOCKET_ERROR - )?; - - Ok(unsafe { net::UnixStream::from_raw_socket(socket.try_into().unwrap()) }) + let socket = net::UnixStream::connect(path)?; + socket.set_nonblocking(true)?; + Ok(socket) } pub(crate) fn pair() -> io::Result<(net::UnixStream, net::UnixStream)> { - net::UnixStream::pair() + let (stream0, stream1) = net::UnixStream::pair()?; + stream0.set_nonblocking(true)?; + stream1.set_nonblocking(true)?; + Ok((stream0, stream1)) } pub(crate) fn local_addr(socket: &net::UnixStream) -> io::Result { From 985a145d6fd5b2fb42ef9276635ba1747b702a3e Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Wed, 17 Aug 2022 22:09:00 -0700 Subject: [PATCH 09/34] fix re-registration --- src/sys/windows/mod.rs | 14 +++++--------- tests/unix_stream.rs | 17 ----------------- 2 files changed, 5 insertions(+), 26 deletions(-) diff --git a/src/sys/windows/mod.rs b/src/sys/windows/mod.rs index 0817ca4fa..fb09926a2 100644 --- a/src/sys/windows/mod.rs +++ b/src/sys/windows/mod.rs @@ -102,15 +102,11 @@ cfg_io_source! { F: FnOnce(&T) -> io::Result, { let result = f(io); - if let Err(ref e) = result { - if e.kind() == io::ErrorKind::WouldBlock { - self.inner.as_ref().map_or(Ok(()), |state| { - state - .selector - .reregister(state.sock_state.clone(), state.token, state.interests) - })?; - } - } + self.inner.as_ref().map_or(Ok(()), |state| { + state + .selector + .reregister(state.sock_state.clone(), state.token, state.interests) + })?; result } diff --git a/tests/unix_stream.rs b/tests/unix_stream.rs index 8edaf731e..a29376f96 100644 --- a/tests/unix_stream.rs +++ b/tests/unix_stream.rs @@ -220,15 +220,6 @@ fn unix_stream_shutdown_write() { vec![ExpectEvent::new(TOKEN_1, Interest::WRITABLE)], ); - // TODO: have to re-register here to reset user_events - poll.registry() - .reregister( - &mut stream, - TOKEN_1, - Interest::WRITABLE.add(Interest::READABLE), - ) - .unwrap(); - checked_write!(stream.write(DATA1)); expect_events( &mut poll, @@ -462,14 +453,6 @@ where expect_read!(stream.read(&mut buf), DATA1); assert!(stream.take_error().unwrap().is_none()); - // TODO: have to re-register here to reset user_events - poll.registry() - .reregister( - &mut stream, - TOKEN_1, - Interest::WRITABLE.add(Interest::READABLE), - ) - .unwrap(); let bufs = [IoSlice::new(DATA1), IoSlice::new(DATA2)]; let wrote = stream.write_vectored(&bufs).unwrap(); From f5ec8cebd2d81cda795fbff06fdb101e62920794 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Wed, 17 Aug 2022 23:11:15 -0700 Subject: [PATCH 10/34] add test for serial calls to listener.accept --- tests/unix_listener.rs | 53 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 52 insertions(+), 1 deletion(-) diff --git a/tests/unix_listener.rs b/tests/unix_listener.rs index 4c8d9f072..d7ddf11d5 100644 --- a/tests/unix_listener.rs +++ b/tests/unix_listener.rs @@ -138,7 +138,7 @@ fn unix_listener_deregister() { #[cfg(target_os = "linux")] #[test] -fn unix_listener_abstract_namesapce() { +fn unix_listener_abstract_namespace() { use rand::Rng; let num: u64 = rand::thread_rng().gen(); let name = format!("\u{0000}-mio-abstract-uds-{}", num); @@ -190,6 +190,57 @@ where handle.join().unwrap(); } +#[test] +fn unix_listener_multiple_accepts() { + let (mut poll, mut events) = init_with_poll(); + let barrier = Arc::new(Barrier::new(2)); + let path = temp_file("unix_listener_multiple_accepts"); + let mut buf = [0; DEFAULT_BUF_SIZE]; + + let mut listener = UnixListener::bind(&path).unwrap(); + + assert_socket_non_blocking(&listener); + assert_socket_close_on_exec(&listener); + + poll.registry() + .register( + &mut listener, + TOKEN_1, + Interest::WRITABLE.add(Interest::READABLE), + ) + .unwrap(); + expect_no_events(&mut poll, &mut events); + + let handle = open_connections(path, 2, barrier.clone()); + + // First connection is opened, try to accept and read. + expect_events( + &mut poll, + &mut events, + vec![ExpectEvent::new(TOKEN_1, Interest::READABLE)], + ); + + let (mut stream1, _) = listener.accept().unwrap(); + assert_would_block(stream1.read(&mut buf)); + barrier.wait(); + + // Second connection is opened, try to accept and read. + expect_events( + &mut poll, + &mut events, + vec![ExpectEvent::new(TOKEN_1, Interest::READABLE)], + ); + + let (mut stream1, _) = listener.accept().unwrap(); + assert_would_block(stream1.read(&mut buf)); + barrier.wait(); + + // We don't expect any more connections. + assert_would_block(listener.accept()); + assert!(listener.take_error().unwrap().is_none()); + handle.join().unwrap(); +} + fn open_connections( path: PathBuf, n_connections: usize, From cee5c6bb4e952463f1fe5acbd651d34bff96bb2b Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Wed, 17 Aug 2022 23:11:36 -0700 Subject: [PATCH 11/34] fix serial calls to accept --- src/net/uds/listener.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/net/uds/listener.rs b/src/net/uds/listener.rs index 87cdab73c..7129401e0 100644 --- a/src/net/uds/listener.rs +++ b/src/net/uds/listener.rs @@ -41,7 +41,9 @@ impl UnixListener { /// The call is responsible for ensuring that the listening socket is in /// non-blocking mode. pub fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { - sys::uds::listener::accept(&self.inner) + self.inner.do_io(|inner| { + sys::uds::listener::accept(&*inner) + }) } /// Returns the local socket address of this listener. From 488254d2b509b4b1152ee118aabf8533616b5d47 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Thu, 18 Aug 2022 15:29:04 -0700 Subject: [PATCH 12/34] remove tempfile dependency and fix doc tests --- Cargo.toml | 2 +- src/net/uds/stream.rs | 70 ++++++++++++++++++++++- src/sys/windows/uds/stdnet/net.rs | 93 +++++++++++++++++++++++-------- 3 files changed, 139 insertions(+), 26 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 635e9a3d6..c2a776fe8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,7 +49,7 @@ log = "0.4.8" libc = "0.2.121" [target.'cfg(windows)'.dependencies] -tempfile = "3" +rand = "0.8" [target.'cfg(windows)'.dependencies.windows-sys] version = "0.36" diff --git a/src/net/uds/stream.rs b/src/net/uds/stream.rs index 4463ad7e8..8cae0cefa 100644 --- a/src/net/uds/stream.rs +++ b/src/net/uds/stream.rs @@ -92,7 +92,8 @@ impl UnixStream { /// /// # Examples /// - /// ``` + #[cfg_attr(unix, doc = "```")] + #[cfg_attr(windows, doc = "```ignore")] /// # use std::error::Error; /// # /// # fn main() -> Result<(), Box> { @@ -140,6 +141,73 @@ impl UnixStream { /// # Ok(()) /// # } /// ``` + /// + #[cfg_attr(windows, doc = "```")] + #[cfg_attr(unix, doc = "```ignore")] + /// # use std::error::Error; + /// # + /// # fn main() -> Result<(), Box> { + /// use std::io; + /// use std::os::windows::io::AsRawSocket; + /// use std::os::raw::c_int; + /// use mio::net::UnixStream; + /// use windows_sys::Win32::Networking::WinSock; + /// use std::convert::TryInto; + /// + /// let (stream1, stream2) = UnixStream::pair()?; + /// + /// // Wait until the stream is writable... + /// + /// // Write to the stream using a direct WinSock call, of course the + /// // `io::Write` implementation would be easier to use. + /// let buf = b"hello"; + /// let n = stream1.try_io(|| { + /// let res = unsafe { + /// WinSock::send( + /// stream1.as_raw_socket().try_into().unwrap(), + /// &buf as *const _ as *const _, + /// buf.len() as c_int, + /// 0 + /// ) + /// }; + /// if res != WinSock::SOCKET_ERROR { + /// Ok(res as usize) + /// } else { + /// // If EAGAIN or EWOULDBLOCK is set by WinSock::send, the closure + /// // should return `WouldBlock` error. + /// Err(io::Error::from_raw_os_error(unsafe { + /// WinSock::WSAGetLastError() + /// })) + /// } + /// })?; + /// eprintln!("write {} bytes", n); + /// + /// // Wait until the stream is readable... + /// + /// // Read from the stream using a direct WinSock call, of course the + /// // `io::Read` implementation would be easier to use. + /// let mut buf = [0; 512]; + /// let n = stream2.try_io(|| { + /// let res = unsafe { + /// WinSock::recv( + /// stream2.as_raw_socket().try_into().unwrap(), + /// &mut buf as *mut _ as *mut _, + /// buf.len() as c_int, + /// 0 + /// ) + /// }; + /// if res != WinSock::SOCKET_ERROR { + /// Ok(res as usize) + /// } else { + /// // If EAGAIN or EWOULDBLOCK is set by WinSock::recv, the closure + /// // should return `WouldBlock` error. + /// Err(io::Error::last_os_error()) + /// } + /// })?; + /// eprintln!("read {} bytes", n); + /// # Ok(()) + /// # } + /// ``` pub fn try_io(&self, f: F) -> io::Result where F: FnOnce() -> io::Result, diff --git a/src/sys/windows/uds/stdnet/net.rs b/src/sys/windows/uds/stdnet/net.rs index dc2458a3f..22dda64e3 100644 --- a/src/sys/windows/uds/stdnet/net.rs +++ b/src/sys/windows/uds/stdnet/net.rs @@ -5,13 +5,59 @@ use std::mem; use std::net::Shutdown; use std::os::raw::c_int; use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; -use std::path::Path; +use std::path::{Path, PathBuf}; use std::time::Duration; use windows_sys::Win32::Networking::WinSock::{self, bind, connect, getpeername, getsockname, listen, SO_RCVTIMEO, SOCKET_ERROR, SO_SNDTIMEO}; use super::socket::Socket; use super::{socket_addr, SocketAddr}; +use rand::{distributions::Alphanumeric, Rng}; + +struct TempPath(PathBuf); + +impl TempPath { + fn new(random_len: usize) -> io::Result { + let dir = std::env::temp_dir(); + // Retry a few times in case of collisions + for _ in 0..10 { + let rand_str: String = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(random_len) + .map(char::from) + .collect(); + let filename = format!(".tmp-{rand_str}.socket"); + let path = dir.join(filename); + if !path.exists() { + return Ok(Self(path)); + } + } + + Err(io::Error::new( + io::ErrorKind::AlreadyExists, + "too many temporary files exist", + )) + } +} + +impl Drop for TempPath { + fn drop(&mut self) { + let _ = std::fs::remove_file(&self.0); + } +} + +impl AsRef for TempPath { + fn as_ref(&self) -> &Path { + &self.0 + } +} + +impl std::ops::Deref for TempPath { + type Target = Path; + fn deref(&self) -> &Path { + Path::new(&self.0) + } +} /// A Unix stream socket /// @@ -58,6 +104,7 @@ impl UnixStream { /// return /// } /// }; + /// # drop(socket); // Silence unused variable warning. /// ``` pub fn connect>(path: P) -> io::Result { let inner = Socket::new()?; @@ -93,6 +140,7 @@ impl UnixStream { /// /// let socket = UnixStream::connect("/tmp/sock").unwrap(); /// let sock_copy = socket.try_clone().expect("Couldn't clone socket"); + /// # drop(sock_copy); // Silence unused variable warning. /// ``` pub fn try_clone(&self) -> io::Result { self.0.duplicate().map(UnixStream) @@ -106,6 +154,7 @@ impl UnixStream { /// /// let socket = UnixStream::connect("/tmp/sock").unwrap(); /// let addr = socket.local_addr().expect("Couldn't get local address"); + /// # drop(addr); // Silence unused variable warning. /// ``` pub fn local_addr(&self) -> io::Result { SocketAddr::new(|addr, len| { @@ -126,6 +175,7 @@ impl UnixStream { /// /// let socket = UnixStream::connect("/tmp/sock").unwrap(); /// let addr = socket.peer_addr().expect("Couldn't get peer address"); + /// # drop(addr); // Silence unused variable warning. /// ``` pub fn peer_addr(&self) -> io::Result { SocketAddr::new(|addr, len| { @@ -201,14 +251,15 @@ impl UnixStream { /// println!("Couldn't create a pair of sockets: {e:?}"); /// return /// } - /// } + /// }; + /// # drop(sock1); // Silence unused variable warning. + /// # drop(sock2); // Silence unused variable warning. /// ``` pub fn pair() -> io::Result<(Self, Self)> { use std::sync::{Arc, RwLock}; use std::thread::spawn; - let dir = tempfile::tempdir()?; - let file_path = dir.path().join("socket"); + let file_path = TempPath::new(10)?; let a: Arc>>> = Arc::new(RwLock::new(None)); let ul = UnixListener::bind(&file_path).unwrap(); let server = { @@ -372,6 +423,7 @@ impl IntoRawSocket for UnixStream { /// /// fn handle_client(stream: UnixStream) { /// // ... +/// # drop(stream); // Silence unused variable warning. /// } /// /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); @@ -385,6 +437,7 @@ impl IntoRawSocket for UnixStream { /// } /// Err(err) => { /// /* connection failed */ +/// eprintln!("connection failed: {err}"); /// break; /// } /// } @@ -418,6 +471,7 @@ impl UnixListener { /// return /// } /// }; + /// # drop(listener); // Silence unused variable warning. /// ``` pub fn bind>(path: P) -> io::Result { let inner = Socket::new()?; @@ -456,7 +510,7 @@ impl UnixListener { /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); /// /// match listener.accept() { - /// Ok((socket, addr)) => println!("Got a client: {:?}", addr), + /// Ok((_socket, addr)) => println!("Got a client: {:?}", addr), /// Err(e) => println!("accept function failed: {:?}", e), /// } /// ``` @@ -495,6 +549,7 @@ impl UnixListener { /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); /// /// let listener_copy = listener.try_clone().expect("Couldn't clone socket"); + /// # drop(listener_copy); // Silence unused variable warning. /// ``` pub fn try_clone(&self) -> io::Result { self.0.duplicate().map(UnixListener) @@ -510,6 +565,7 @@ impl UnixListener { /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); /// /// let addr = listener.local_addr().expect("Couldn't get local address"); + /// # drop(addr); // Silence unused variable warning. /// ``` pub fn local_addr(&self) -> io::Result { SocketAddr::new(|addr, len| { @@ -568,6 +624,7 @@ impl UnixListener { /// /// fn handle_client(stream: UnixStream) { /// // ... + /// # drop(stream); // Silence unused variable warning. /// } /// /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); @@ -578,6 +635,7 @@ impl UnixListener { /// thread::spawn(|| handle_client(stream)); /// } /// Err(err) => { + /// eprintln!("connection failed: {err}"); /// break; /// } /// } @@ -631,6 +689,7 @@ impl<'a> IntoIterator for &'a UnixListener { /// /// fn handle_client(stream: UnixStream) { /// // ... +/// # drop(stream); // Silence unused variable warning. /// } /// /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); @@ -641,6 +700,7 @@ impl<'a> IntoIterator for &'a UnixListener { /// thread::spawn(|| handle_client(stream)); /// } /// Err(err) => { +/// eprintln!("connection failed: {err}"); /// break; /// } /// } @@ -665,14 +725,9 @@ impl<'a> Iterator for Incoming<'a> { #[cfg(test)] mod test { - use tempfile; - use std::io::{self, Read, Write}; - use std::path::PathBuf; use std::thread; - use self::tempfile::TempDir; - use super::*; macro_rules! or_panic { @@ -684,15 +739,9 @@ mod test { }; } - fn tmpdir() -> Result<(TempDir, PathBuf), io::Error> { - let dir = tempfile::tempdir()?; - let path = dir.path().join("sock"); - Ok((dir, path)) - } - #[test] fn basic() { - let (_dir, socket_path) = or_panic!(tmpdir()); + let socket_path = TempPath::new(10).unwrap(); let msg1 = b"hello"; let msg2 = b"world!"; @@ -721,7 +770,7 @@ mod test { #[test] fn try_clone() { - let (_dir, socket_path) = or_panic!(tmpdir()); + let socket_path = TempPath::new(10).unwrap(); let msg1 = b"hello"; let msg2 = b"world"; @@ -751,7 +800,7 @@ mod test { #[test] fn iter() { - let (_dir, socket_path) = or_panic!(tmpdir()); + let socket_path = TempPath::new(10).unwrap(); let listener = or_panic!(UnixListener::bind(&socket_path)); let thread = thread::spawn(move || { @@ -772,11 +821,7 @@ mod test { #[test] fn long_path() { - let dir = or_panic!(tempfile::tempdir()); - let socket_path = dir.path().join( - "asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfa\ - sasdfasdfasdasdfasdfasdfadfasdfasdfasdfasdfasdf", - ); + let socket_path = TempPath::new(100).unwrap(); match UnixStream::connect(&socket_path) { Err(ref e) if e.kind() == io::ErrorKind::InvalidInput => {} Err(e) => panic!("unexpected error {}", e), From b6bae73d975f5e9f6e657d7a0b4193c1bc4497b6 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Sat, 20 Aug 2022 16:42:58 -0700 Subject: [PATCH 13/34] revert change in draining behavior --- src/sys/windows/mod.rs | 14 +++++++++----- tests/unix_listener.rs | 1 + tests/unix_stream.rs | 2 ++ 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/sys/windows/mod.rs b/src/sys/windows/mod.rs index fb09926a2..0817ca4fa 100644 --- a/src/sys/windows/mod.rs +++ b/src/sys/windows/mod.rs @@ -102,11 +102,15 @@ cfg_io_source! { F: FnOnce(&T) -> io::Result, { let result = f(io); - self.inner.as_ref().map_or(Ok(()), |state| { - state - .selector - .reregister(state.sock_state.clone(), state.token, state.interests) - })?; + if let Err(ref e) = result { + if e.kind() == io::ErrorKind::WouldBlock { + self.inner.as_ref().map_or(Ok(()), |state| { + state + .selector + .reregister(state.sock_state.clone(), state.token, state.interests) + })?; + } + } result } diff --git a/tests/unix_listener.rs b/tests/unix_listener.rs index d7ddf11d5..b2de8c9eb 100644 --- a/tests/unix_listener.rs +++ b/tests/unix_listener.rs @@ -222,6 +222,7 @@ fn unix_listener_multiple_accepts() { let (mut stream1, _) = listener.accept().unwrap(); assert_would_block(stream1.read(&mut buf)); + assert_would_block(listener.accept()); barrier.wait(); // Second connection is opened, try to accept and read. diff --git a/tests/unix_stream.rs b/tests/unix_stream.rs index a29376f96..205a91a98 100644 --- a/tests/unix_stream.rs +++ b/tests/unix_stream.rs @@ -454,6 +454,8 @@ where assert!(stream.take_error().unwrap().is_none()); + assert_would_block(stream.read(&mut buf)); + let bufs = [IoSlice::new(DATA1), IoSlice::new(DATA2)]; let wrote = stream.write_vectored(&bufs).unwrap(); assert_eq!(wrote, DATA1_LEN + DATA2_LEN); From 2113b9fc8c74a34fedb032c78ecfaf4b6d9e22a1 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Sat, 20 Aug 2022 19:38:32 -0700 Subject: [PATCH 14/34] re-organize stdnet files to mirror std::os::unix::net --- src/net/mod.rs | 9 +- src/sys/mod.rs | 2 +- src/sys/windows/mod.rs | 5 +- src/sys/windows/uds/listener.rs | 3 +- src/sys/windows/uds/mod.rs | 6 +- src/sys/windows/uds/stdnet/addr.rs | 158 +++++ src/sys/windows/uds/stdnet/listener.rs | 318 ++++++++++ src/sys/windows/uds/stdnet/mod.rs | 186 +----- src/sys/windows/uds/stdnet/net.rs | 842 ------------------------- src/sys/windows/uds/stdnet/socket.rs | 15 +- src/sys/windows/uds/stdnet/stream.rs | 411 ++++++++++++ src/sys/windows/uds/stream.rs | 7 +- 12 files changed, 913 insertions(+), 1049 deletions(-) create mode 100644 src/sys/windows/uds/stdnet/addr.rs create mode 100644 src/sys/windows/uds/stdnet/listener.rs delete mode 100644 src/sys/windows/uds/stdnet/net.rs create mode 100644 src/sys/windows/uds/stdnet/stream.rs diff --git a/src/net/mod.rs b/src/net/mod.rs index 51b47b9d2..a985bf64b 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -32,12 +32,13 @@ pub use self::tcp::{TcpListener, TcpStream}; mod udp; #[cfg(not(target_os = "wasi"))] pub use self::udp::UdpSocket; - -#[cfg(any(unix, windows))] +#[cfg(not(target_os = "wasi"))] mod uds; -#[cfg(any(unix, windows))] +#[cfg(not(target_os = "wasi"))] pub use self::uds::{SocketAddr, UnixListener, UnixStream}; + #[cfg(unix)] pub use self::uds::UnixDatagram; + #[cfg(windows)] -pub use self::uds::stdnet; +pub use crate::sys::uds::stdnet; diff --git a/src/sys/mod.rs b/src/sys/mod.rs index 1c5e3ae84..872529c4d 100644 --- a/src/sys/mod.rs +++ b/src/sys/mod.rs @@ -59,7 +59,7 @@ cfg_os_poll! { #[cfg(windows)] cfg_os_poll! { - pub mod windows; + mod windows; pub use self::windows::*; } diff --git a/src/sys/windows/mod.rs b/src/sys/windows/mod.rs index 0817ca4fa..ea4f4a409 100644 --- a/src/sys/windows/mod.rs +++ b/src/sys/windows/mod.rs @@ -33,11 +33,10 @@ cfg_net! { macro_rules! wsa_syscall { ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{ - let res = unsafe { $fn($($arg, )*) }; + let res = unsafe { windows_sys::Win32::Networking::WinSock::$fn($($arg, )*) }; if $err_test(&res, &$err_value) { - use windows_sys::Win32::Networking::WinSock::WSAGetLastError; Err(io::Error::from_raw_os_error(unsafe { - WSAGetLastError() + windows_sys::Win32::Networking::WinSock::WSAGetLastError() })) } else { Ok(res) diff --git a/src/sys/windows/uds/listener.rs b/src/sys/windows/uds/listener.rs index f3aff5b6a..4bd5a9f8e 100644 --- a/src/sys/windows/uds/listener.rs +++ b/src/sys/windows/uds/listener.rs @@ -13,8 +13,7 @@ pub(crate) fn bind(path: &Path) -> io::Result { pub(crate) fn accept(listener: &net::UnixListener) -> io::Result<(UnixStream, SocketAddr)> { listener.set_nonblocking(true)?; - let es = listener.accept().map(|(stream, addr)| (UnixStream::from_std(stream), addr)); - es + listener.accept().map(|(stream, addr)| (UnixStream::from_std(stream), addr)) } pub(crate) fn local_addr(listener: &net::UnixListener) -> io::Result { diff --git a/src/sys/windows/uds/mod.rs b/src/sys/windows/uds/mod.rs index 421b2bb2c..de48e24b1 100644 --- a/src/sys/windows/uds/mod.rs +++ b/src/sys/windows/uds/mod.rs @@ -3,11 +3,7 @@ pub use self::stdnet::SocketAddr; cfg_os_poll! { use std::convert::TryInto; - use windows_sys::Win32::Networking::WinSock::{ - getsockname, - getpeername, - SOCKET_ERROR - }; + use windows_sys::Win32::Networking::WinSock::SOCKET_ERROR; use std::os::windows::io::RawSocket; use std::io; diff --git a/src/sys/windows/uds/stdnet/addr.rs b/src/sys/windows/uds/stdnet/addr.rs new file mode 100644 index 000000000..2bc228dde --- /dev/null +++ b/src/sys/windows/uds/stdnet/addr.rs @@ -0,0 +1,158 @@ + +//! Windows specific networking functionality + +use std::ascii; +use std::fmt; +use std::io; +use std::mem; +use std::os::raw::c_int; +use std::path::Path; + +use windows_sys::Win32::Networking::WinSock::{sockaddr_un, AF_UNIX, SOCKADDR}; + +pub(super) fn path_offset(addr: &sockaddr_un) -> usize { + // Work with an actual instance of the type since using a null pointer is UB + let base = addr as *const _ as usize; + let path = &addr.sun_path as *const _ as usize; + path - base +} + +pub(super) fn socket_addr(path: &Path) -> io::Result<(sockaddr_un, c_int)> { + let sockaddr = mem::MaybeUninit::::zeroed(); + + // This is safe to assume because a `sockaddr_un` filled with `0` + // bytes is properly initialized. + // + // `0` is a valid value for `sockaddr_un::sun_family`; it is + // `WinSock::AF_UNSPEC`. + // + // `[0; 108]` is a valid value for `sockaddr_un::sun_path`; it begins an + // abstract path. + let mut sockaddr = unsafe { sockaddr.assume_init() }; + sockaddr.sun_family = AF_UNIX; + + // Winsock2 expects 'sun_path' to be a Win32 UTF-8 file system path + let bytes = path.to_str().map(|s| s.as_bytes()).ok_or(io::Error::new( + io::ErrorKind::InvalidInput, + "path contains invalid characters", + ))?; + + if bytes.contains(&0) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "paths may not contain interior null bytes", + )); + } + + if bytes.len() >= sockaddr.sun_path.len() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "path must be shorter than SUN_LEN", + )); + } + for (dst, src) in sockaddr.sun_path.iter_mut().zip(bytes.iter()) { + *dst = *src as u8; + } + // null byte for pathname addresses is already there because we zeroed the + // struct + + let offset = path_offset(&sockaddr); + let mut socklen = offset + bytes.len(); + + match bytes.get(0) { + // The struct has already been zeroes so the null byte for pathname + // addresses is already there. + Some(&0) | None => {} + Some(_) => socklen += 1, + } + + Ok((sockaddr, socklen as c_int)) +} + +enum AddressKind<'a> { + Unnamed, + Pathname(&'a Path), + // Note: Windows does not support Abstract addresses + // https://github.com/microsoft/WSL/issues/4240#issuecomment-620805115/ + Abstract(&'a [u8]), +} + +struct AsciiEscaped<'a>(&'a [u8]); + +impl<'a> fmt::Display for AsciiEscaped<'a> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "\"")?; + for byte in self.0.iter().cloned().flat_map(ascii::escape_default) { + write!(fmt, "{}", byte as char)?; + } + write!(fmt, "\"") + } +} + +/// An address associated with a Unix socket +#[derive(Copy, Clone)] +pub struct SocketAddr { + addr: sockaddr_un, + len: c_int, +} + +impl SocketAddr { + pub(crate) fn new(f: F) -> io::Result + where + F: FnOnce(*mut SOCKADDR, *mut c_int) -> io::Result, + { + let mut sockaddr = { + let sockaddr = mem::MaybeUninit::::zeroed(); + unsafe { sockaddr.assume_init() } + }; + + let mut len = mem::size_of::() as c_int; + f(&mut sockaddr as *mut _ as *mut _, &mut len)?; + Ok(SocketAddr::from_parts(sockaddr, len)) + } + + pub(crate) fn from_parts(addr: sockaddr_un, mut len: c_int) -> SocketAddr { + if len == 0 { + // When there is a datagram from unnamed unix socket + // linux returns zero bytes of address + len = path_offset(&addr) as c_int; // i.e. zero-length address + } + SocketAddr { addr, len } + } + + /// Returns true if and only if the address is unnamed. + pub fn is_unnamed(&self) -> bool { + matches!(self.address(), AddressKind::Unnamed) + } + + /// Returns the contents of this address if it is a `pathname` address. + pub fn as_pathname(&self) -> Option<&Path> { + if let AddressKind::Pathname(path) = self.address() { Some(path) } else { None } + } + + fn address(&self) -> AddressKind<'_> { + let len = self.len as usize - path_offset(&self.addr); + // sockaddr_un::sun_path on Windows is a Win32 UTF-8 file system path + + // macOS seems to return a len of 16 and a zeroed sun_path for unnamed addresses + if len == 0 { + AddressKind::Unnamed + } else if self.addr.sun_path[0] == 0 { + AddressKind::Abstract(&self.addr.sun_path[1..len]) + } else { + use std::ffi::CStr; + let pathname = unsafe { CStr::from_bytes_with_nul_unchecked(&self.addr.sun_path[..len]) }; + AddressKind::Pathname(Path::new(pathname.to_str().unwrap())) + } + } +} + +impl fmt::Debug for SocketAddr { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.address() { + AddressKind::Unnamed => write!(fmt, "(unnamed)"), + AddressKind::Abstract(name) => write!(fmt, "{} (abstract)", AsciiEscaped(name)), + AddressKind::Pathname(path) => write!(fmt, "{:?} (pathname)", path), + } + } +} diff --git a/src/sys/windows/uds/stdnet/listener.rs b/src/sys/windows/uds/stdnet/listener.rs new file mode 100644 index 000000000..199661887 --- /dev/null +++ b/src/sys/windows/uds/stdnet/listener.rs @@ -0,0 +1,318 @@ +use std::{mem, io, fmt}; +use std::os::raw::c_int; +use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; +use std::path::Path; + +use windows_sys::Win32::Networking::WinSock::{sockaddr_un, AF_UNIX, SOCKET_ERROR}; + +use super::{socket_addr, SocketAddr, socket::Socket, UnixStream}; + +/// A Unix domain socket server +/// +/// # Examples +/// +/// ```no_run +/// use std::thread; +/// use mio::net::stdnet::{UnixStream, UnixListener}; +/// +/// fn handle_client(stream: UnixStream) { +/// // ... +/// # drop(stream); // Silence unused variable warning. +/// } +/// +/// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); +/// +/// // accept connections and process them, spawning a new thread for each one +/// for stream in listener.incoming() { +/// match stream { +/// Ok(stream) => { +/// /* connection succeeded */ +/// thread::spawn(|| handle_client(stream)); +/// } +/// Err(err) => { +/// /* connection failed */ +/// eprintln!("connection failed: {err}"); +/// break; +/// } +/// } +/// } +/// ``` +pub struct UnixListener(Socket); + +impl fmt::Debug for UnixListener { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = fmt.debug_struct("UnixListener"); + builder.field("socket", &self.0.as_raw_socket()); + if let Ok(addr) = self.local_addr() { + builder.field("local", &addr); + } + builder.finish() + } +} + +impl UnixListener { + /// Creates a new `UnixListener` bound to the specified socket. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixListener; + /// + /// let listener = match UnixListener::bind("/path/to/the/socket") { + /// Ok(sock) => sock, + /// Err(e) => { + /// println!("Couldn't connect: {:?}", e); + /// return + /// } + /// }; + /// # drop(listener); // Silence unused variable warning. + /// ``` + pub fn bind>(path: P) -> io::Result { + let inner = Socket::new()?; + let (addr, len) = socket_addr(path.as_ref())?; + + wsa_syscall!( + bind( + inner.as_raw_socket() as _, + &addr as *const _ as *const _, + len as _, + ), + PartialEq::eq, + SOCKET_ERROR + )?; + wsa_syscall!( + listen(inner.as_raw_socket() as _, 128), + PartialEq::eq, + SOCKET_ERROR + )?; + Ok(UnixListener(inner)) + } + + /// Accepts a new incoming connection to this listener. + /// + /// This function will block the calling thread until a new Unix connection + /// is established. When established, the corresponding [`UnixStream`] and + /// the remote peer's address will be returned. + /// + /// [`UnixStream`]: struct.UnixStream.html + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixListener; + /// + /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); + /// + /// match listener.accept() { + /// Ok((_socket, addr)) => println!("Got a client: {:?}", addr), + /// Err(e) => println!("accept function failed: {:?}", e), + /// } + /// ``` + pub fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { + let sockaddr = mem::MaybeUninit::::zeroed(); + + // This is safe to assume because a `sockaddr_un` filled with `0` + // bytes is properly initialized. + // + // `0` is a valid value for `sockaddr_un::sun_family`; it is + // `WinSock::AF_UNSPEC`. + // + // `[0; 108]` is a valid value for `sockaddr_un::sun_path`; it begins an + // abstract path. + let mut sockaddr = unsafe { sockaddr.assume_init() }; + + sockaddr.sun_family = AF_UNIX; + let mut socklen = mem::size_of_val(&sockaddr) as c_int; + + let sock = self.0.accept(&mut sockaddr as *mut _ as *mut _, &mut socklen)?; + let addr = SocketAddr::from_parts(sockaddr, socklen); + Ok((UnixStream(sock), addr)) + } + + /// Creates a new independently owned handle to the underlying socket. + /// + /// The returned `UnixListener` is a reference to the same socket that this + /// object references. Both handles can be used to accept incoming + /// connections and options set on one listener will affect the other. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixListener; + /// + /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); + /// + /// let listener_copy = listener.try_clone().expect("Couldn't clone socket"); + /// # drop(listener_copy); // Silence unused variable warning. + /// ``` + pub fn try_clone(&self) -> io::Result { + self.0.duplicate().map(UnixListener) + } + + /// Returns the local socket address of this listener. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixListener; + /// + /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); + /// + /// let addr = listener.local_addr().expect("Couldn't get local address"); + /// # drop(addr); // Silence unused variable warning. + /// ``` + pub fn local_addr(&self) -> io::Result { + SocketAddr::new(|addr, len| { + wsa_syscall!( + getsockname(self.0.as_raw_socket() as _, addr, len), + PartialEq::eq, + SOCKET_ERROR + ) + }) + } + + /// Moves the socket into or out of nonblocking mode. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixListener; + /// + /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); + /// + /// listener.set_nonblocking(true).expect("Couldn't set nonblocking"); + /// ``` + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.0.set_nonblocking(nonblocking) + } + + /// Returns the value of the `SO_ERROR` option. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixListener; + /// + /// let listener = UnixListener::bind("/tmp/sock").unwrap(); + /// + /// if let Ok(Some(err)) = listener.take_error() { + /// println!("Got error: {:?}", err); + /// } + /// ``` + pub fn take_error(&self) -> io::Result> { + self.0.take_error() + } + + /// Returns an iterator over incoming connections. + /// + /// The iterator will never return `None` and will also not yield the + /// peer's [`SocketAddr`] structure. + /// + /// [`SocketAddr`]: struct.SocketAddr.html + /// + /// # Examples + /// + /// ```no_run + /// use std::thread; + /// use mio::net::stdnet::{UnixStream, UnixListener}; + /// + /// fn handle_client(stream: UnixStream) { + /// // ... + /// # drop(stream); // Silence unused variable warning. + /// } + /// + /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); + /// + /// for stream in listener.incoming() { + /// match stream { + /// Ok(stream) => { + /// thread::spawn(|| handle_client(stream)); + /// } + /// Err(err) => { + /// eprintln!("connection failed: {err}"); + /// break; + /// } + /// } + /// } + /// ``` + pub fn incoming<'a>(&'a self) -> Incoming<'a> { + Incoming { listener: self } + } +} + +impl AsRawSocket for UnixListener { + fn as_raw_socket(&self) -> RawSocket { + self.0.as_raw_socket() + } +} + +impl FromRawSocket for UnixListener { + unsafe fn from_raw_socket(sock: RawSocket) -> Self { + UnixListener(Socket::from_raw_socket(sock)) + } +} + +impl IntoRawSocket for UnixListener { + fn into_raw_socket(self) -> RawSocket { + let ret = self.0.as_raw_socket(); + mem::forget(self); + ret + } +} + +impl<'a> IntoIterator for &'a UnixListener { + type Item = io::Result; + type IntoIter = Incoming<'a>; + + fn into_iter(self) -> Incoming<'a> { + self.incoming() + } +} + +/// An iterator over incoming connections to a [`UnixListener`]. +/// +/// It will never return `None`. +/// +/// [`UnixListener`]: struct.UnixListener.html +/// +/// # Examples +/// +/// ```no_run +/// use std::thread; +/// use mio::net::stdnet::{UnixStream, UnixListener}; +/// +/// fn handle_client(stream: UnixStream) { +/// // ... +/// # drop(stream); // Silence unused variable warning. +/// } +/// +/// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); +/// +/// for stream in listener.incoming() { +/// match stream { +/// Ok(stream) => { +/// thread::spawn(|| handle_client(stream)); +/// } +/// Err(err) => { +/// eprintln!("connection failed: {err}"); +/// break; +/// } +/// } +/// } +/// ``` +#[derive(Debug)] +pub struct Incoming<'a> { + listener: &'a UnixListener, +} + +impl<'a> Iterator for Incoming<'a> { + type Item = io::Result; + + fn next(&mut self) -> Option> { + Some(self.listener.accept().map(|s| s.0)) + } + + fn size_hint(&self) -> (usize, Option) { + (usize::max_value(), None) + } +} diff --git a/src/sys/windows/uds/stdnet/mod.rs b/src/sys/windows/uds/stdnet/mod.rs index 70c5eeda6..0b2edbbf3 100644 --- a/src/sys/windows/uds/stdnet/mod.rs +++ b/src/sys/windows/uds/stdnet/mod.rs @@ -1,182 +1,10 @@ -//! Windows specific networking functionality +//! Windows specific networking functionality. Mirrors std::os::unix::net. -use std::ascii; -use std::fmt; -use std::io; -use std::mem; -use std::os::raw::c_int; -use std::path::Path; - -use windows_sys::Win32::Networking::WinSock::{self, SOCKADDR}; - -mod net; +mod addr; mod socket; +mod stream; +mod listener; -pub(crate) fn path_offset(addr: &WinSock::sockaddr_un) -> usize { - // Work with an actual instance of the type since using a null pointer is UB - let base = addr as *const _ as usize; - let path = &addr.sun_path as *const _ as usize; - path - base -} - -pub(crate) fn socket_addr(path: &Path) -> io::Result<(WinSock::sockaddr_un, c_int)> { - let sockaddr = mem::MaybeUninit::::zeroed(); - - // This is safe to assume because a `WinSock::sockaddr_un` filled with `0` - // bytes is properly initialized. - // - // `0` is a valid value for `sockaddr_un::sun_family`; it is - // `WinSock::AF_UNSPEC`. - // - // `[0; 108]` is a valid value for `sockaddr_un::sun_path`; it begins an - // abstract path. - let mut sockaddr = unsafe { sockaddr.assume_init() }; - sockaddr.sun_family = WinSock::AF_UNIX; - - // Winsock2 expects 'sun_path' to be a Win32 UTF-8 file system path - let bytes = path.to_str().map(|s| s.as_bytes()).ok_or(io::Error::new( - io::ErrorKind::InvalidInput, - "path contains invalid characters", - ))?; - - if bytes.contains(&0) { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "paths may not contain interior null bytes", - )); - } - - if bytes.len() >= sockaddr.sun_path.len() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "path must be shorter than SUN_LEN", - )); - } - for (dst, src) in sockaddr.sun_path.iter_mut().zip(bytes.iter()) { - *dst = *src as u8; - } - // null byte for pathname addresses is already there because we zeroed the - // struct - - let offset = path_offset(&sockaddr); - let mut socklen = offset + bytes.len(); - - match bytes.get(0) { - // The struct has already been zeroes so the null byte for pathname - // addresses is already there. - Some(&0) | None => {} - Some(_) => socklen += 1, - } - - Ok((sockaddr, socklen as c_int)) -} - -enum AddressKind<'a> { - Unnamed, - Pathname(&'a Path), - Abstract(&'a [u8]), -} - -/// An address associated with a Unix socket -#[derive(Copy, Clone)] -pub struct SocketAddr { - addr: WinSock::sockaddr_un, - len: c_int, -} - -impl SocketAddr { - pub(crate) fn new(f: F) -> io::Result - where - F: FnOnce(*mut SOCKADDR, *mut c_int) -> io::Result, - { - let mut sockaddr = { - let sockaddr = mem::MaybeUninit::::zeroed(); - unsafe { sockaddr.assume_init() } - }; - - let mut len = mem::size_of::() as c_int; - f(&mut sockaddr as *mut _ as *mut _, &mut len)?; - Ok(SocketAddr::from_parts(sockaddr, len)) - } - - pub(crate) fn from_parts(addr: WinSock::sockaddr_un, mut len: c_int) -> SocketAddr { - if len == 0 { - // When there is a datagram from unnamed unix socket - // linux returns zero bytes of address - len = path_offset(&addr) as c_int; // i.e. zero-length address - } - SocketAddr { addr, len } - } - - /// Returns true if and only if the address is unnamed. - pub fn is_unnamed(&self) -> bool { - if let AddressKind::Unnamed = self.address() { - true - } else { - false - } - } - - /// Returns the contents of this address if it is a `pathname` address. - pub fn as_pathname(&self) -> Option<&Path> { - if let AddressKind::Pathname(path) = self.address() { - Some(path) - } else { - None - } - } - - fn address<'a>(&'a self) -> AddressKind<'a> { - let len = self.len as usize - path_offset(&self.addr); - // WinSock::sockaddr_un::sun_path on Windows is a Win32 UTF-8 file system path - - // macOS seems to return a len of 16 and a zeroed sun_path for unnamed addresses - if len == 0 - || (cfg!(not(any(target_os = "linux", target_os = "android"))) - && self.addr.sun_path[0] == 0) - { - AddressKind::Unnamed - } else if self.addr.sun_path[0] == 0 { - AddressKind::Abstract(&self.addr.sun_path[1..len]) - } else { - use std::ffi::CStr; - let pathname = unsafe { CStr::from_bytes_with_nul_unchecked(&self.addr.sun_path[..len]) }; - AddressKind::Pathname(Path::new(pathname.to_str().unwrap())) - } - } -} - -impl fmt::Debug for SocketAddr { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - match self.address() { - AddressKind::Unnamed => write!(fmt, "(unnamed)"), - AddressKind::Abstract(name) => write!(fmt, "{} (abstract)", AsciiEscaped(name)), - AddressKind::Pathname(path) => write!(fmt, "{:?} (pathname)", path), - } - } -} - -impl PartialEq for SocketAddr { - fn eq(&self, other: &SocketAddr) -> bool { - let ita = self.addr.sun_path.iter(); - let itb = other.addr.sun_path.iter(); - - self.len == other.len - && self.addr.sun_family == other.addr.sun_family - && ita.zip(itb).all(|(a, b)| a == b) - } -} - -struct AsciiEscaped<'a>(&'a [u8]); - -impl<'a> fmt::Display for AsciiEscaped<'a> { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "\"")?; - for byte in self.0.iter().cloned().flat_map(ascii::escape_default) { - write!(fmt, "{}", byte as char)?; - } - write!(fmt, "\"") - } -} - -pub use self::net::{UnixListener, UnixStream}; +pub use self::addr::*; +pub use self::listener::*; +pub use self::stream::*; diff --git a/src/sys/windows/uds/stdnet/net.rs b/src/sys/windows/uds/stdnet/net.rs deleted file mode 100644 index 22dda64e3..000000000 --- a/src/sys/windows/uds/stdnet/net.rs +++ /dev/null @@ -1,842 +0,0 @@ -use std::fmt; -use std::io::{self, IoSlice, IoSliceMut}; -use std::convert::TryInto; -use std::mem; -use std::net::Shutdown; -use std::os::raw::c_int; -use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; -use std::path::{Path, PathBuf}; -use std::time::Duration; - -use windows_sys::Win32::Networking::WinSock::{self, bind, connect, getpeername, getsockname, listen, SO_RCVTIMEO, SOCKET_ERROR, SO_SNDTIMEO}; - -use super::socket::Socket; -use super::{socket_addr, SocketAddr}; -use rand::{distributions::Alphanumeric, Rng}; - -struct TempPath(PathBuf); - -impl TempPath { - fn new(random_len: usize) -> io::Result { - let dir = std::env::temp_dir(); - // Retry a few times in case of collisions - for _ in 0..10 { - let rand_str: String = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(random_len) - .map(char::from) - .collect(); - let filename = format!(".tmp-{rand_str}.socket"); - let path = dir.join(filename); - if !path.exists() { - return Ok(Self(path)); - } - } - - Err(io::Error::new( - io::ErrorKind::AlreadyExists, - "too many temporary files exist", - )) - } -} - -impl Drop for TempPath { - fn drop(&mut self) { - let _ = std::fs::remove_file(&self.0); - } -} - -impl AsRef for TempPath { - fn as_ref(&self) -> &Path { - &self.0 - } -} - -impl std::ops::Deref for TempPath { - type Target = Path; - fn deref(&self) -> &Path { - Path::new(&self.0) - } -} - -/// A Unix stream socket -/// -/// # Examples -/// -/// ```no_run -/// use mio::net::stdnet::UnixStream; -/// use std::io::prelude::*; -/// -/// let mut stream = UnixStream::connect("/path/to/my/socket").unwrap(); -/// stream.write_all(b"hello world").unwrap(); -/// let mut response = String::new(); -/// stream.read_to_string(&mut response).unwrap(); -/// println!("{}", response); -/// ``` -pub struct UnixStream(Socket); - -impl fmt::Debug for UnixStream { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut builder = fmt.debug_struct("UnixStream"); - builder.field("socket", &self.0.as_raw_socket()); - if let Ok(addr) = self.local_addr() { - builder.field("local", &addr); - } - if let Ok(addr) = self.peer_addr() { - builder.field("peer", &addr); - } - builder.finish() - } -} - -impl UnixStream { - /// Connects to the socket named by `path`. - /// - /// # Examples - /// - /// ```no_run - /// use mio::net::stdnet::UnixStream; - /// - /// let socket = match UnixStream::connect("/tmp/sock") { - /// Ok(sock) => sock, - /// Err(e) => { - /// println!("Couldn't connect: {:?}", e); - /// return - /// } - /// }; - /// # drop(socket); // Silence unused variable warning. - /// ``` - pub fn connect>(path: P) -> io::Result { - let inner = Socket::new()?; - let (addr, len) = socket_addr(path.as_ref())?; - - match wsa_syscall!( - connect( - inner.as_raw_socket() as _, - &addr as *const _ as *const _, - len as i32, - ), - PartialEq::eq, - SOCKET_ERROR - ) { - Ok(_) => {}, - Err(ref err) if err.raw_os_error() == Some(WinSock::WSAEINPROGRESS) => {}, - Err(e) => return Err(e) - } - Ok(UnixStream(inner)) - } - - /// Creates a new independently owned handle to the underlying socket. - /// - /// The returned `UnixStream` is a reference to the same stream that this - /// object references. Both handles will read and write the same stream of - /// data, and options set on one stream will be propagated to the other - /// stream. - /// - /// # Examples - /// - /// ```no_run - /// use mio::net::stdnet::UnixStream; - /// - /// let socket = UnixStream::connect("/tmp/sock").unwrap(); - /// let sock_copy = socket.try_clone().expect("Couldn't clone socket"); - /// # drop(sock_copy); // Silence unused variable warning. - /// ``` - pub fn try_clone(&self) -> io::Result { - self.0.duplicate().map(UnixStream) - } - - /// Returns the socket address of the local half of this connection. - /// # Examples - /// - /// ```no_run - /// use mio::net::stdnet::UnixStream; - /// - /// let socket = UnixStream::connect("/tmp/sock").unwrap(); - /// let addr = socket.local_addr().expect("Couldn't get local address"); - /// # drop(addr); // Silence unused variable warning. - /// ``` - pub fn local_addr(&self) -> io::Result { - SocketAddr::new(|addr, len| { - wsa_syscall!( - getsockname(self.0.as_raw_socket() as _, addr, len), - PartialEq::eq, - SOCKET_ERROR - ) - }) - } - - /// Returns the socket address of the remote half of this connection. - /// - /// # Examples - /// - /// ```no_run - /// use mio::net::stdnet::UnixStream; - /// - /// let socket = UnixStream::connect("/tmp/sock").unwrap(); - /// let addr = socket.peer_addr().expect("Couldn't get peer address"); - /// # drop(addr); // Silence unused variable warning. - /// ``` - pub fn peer_addr(&self) -> io::Result { - SocketAddr::new(|addr, len| { - wsa_syscall!( - getpeername(self.0.as_raw_socket() as _, addr, len), - PartialEq::eq, - SOCKET_ERROR - ) - }) - } - - /// Moves the socket into or out of nonblocking mode. - /// - /// # Examples - /// - /// ```no_run - /// use mio::net::stdnet::UnixStream; - /// - /// let socket = UnixStream::connect("/tmp/sock").unwrap(); - /// socket.set_nonblocking(true).expect("Couldn't set nonblocking"); - /// ``` - pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { - self.0.set_nonblocking(nonblocking) - } - - /// Returns the value of the `SO_ERROR` option. - /// - /// # Examples - /// - /// ```no_run - /// use mio::net::stdnet::UnixStream; - /// - /// let socket = UnixStream::connect("/tmp/sock").unwrap(); - /// if let Ok(Some(err)) = socket.take_error() { - /// println!("Got error: {:?}", err); - /// } - /// ``` - pub fn take_error(&self) -> io::Result> { - self.0.take_error() - } - - /// Shuts down the read, write, or both halves of this connection. - /// - /// This function will cause all pending and future I/O calls on the - /// specified portions to immediately return with an appropriate value - /// (see the documentation for `Shutdown`). - /// - /// # Examples - /// - /// ```no_run - /// use mio::net::stdnet::UnixStream; - /// use std::net::Shutdown; - /// - /// let socket = UnixStream::connect("/tmp/sock").unwrap(); - /// socket.shutdown(Shutdown::Both).expect("shutdown function failed"); - /// ``` - pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { - self.0.shutdown(how) - } - - /// Creates an unnamed pair of connected sockets. - /// - /// Returns two `UnixStream`s which are connected to each other. - /// - /// # Examples - /// - /// ```no_run - /// use mio::net::stdnet::UnixStream; - /// - /// let (sock1, sock2) = match UnixStream::pair() { - /// Ok((sock1, sock2)) => (sock1, sock2), - /// Err(e) => { - /// println!("Couldn't create a pair of sockets: {e:?}"); - /// return - /// } - /// }; - /// # drop(sock1); // Silence unused variable warning. - /// # drop(sock2); // Silence unused variable warning. - /// ``` - pub fn pair() -> io::Result<(Self, Self)> { - use std::sync::{Arc, RwLock}; - use std::thread::spawn; - - let file_path = TempPath::new(10)?; - let a: Arc>>> = Arc::new(RwLock::new(None)); - let ul = UnixListener::bind(&file_path).unwrap(); - let server = { - let a = a.clone(); - spawn(move || { - let mut store = a.write().unwrap(); - let stream0 = ul.accept().map(|s| s.0); - *store = Some(stream0); - }) - }; - let stream1 = UnixStream::connect(&file_path)?; - server - .join() - .map_err(|_| io::Error::from(io::ErrorKind::ConnectionRefused))?; - let stream0 = (*(a.write().unwrap())).take().unwrap()?; - return Ok((stream0, stream1)); - } - - /// Sets the read timeout to the timeout specified. - /// - /// If the value specified is `None`, then `read` calls will block - /// indefinitely. An `Err` is returned if the zero `Duration` is - /// passed to this method. - /// - /// # Examples - /// - /// ```no_run - /// use mio::net::stdnet::UnixStream; - /// - /// let socket = UnixStream::connect("/tmp/sock").unwrap(); - /// socket.set_read_timeout(None).expect("Couldn't set read timeout"); - /// ``` - pub fn set_read_timeout(&self, dur: Option) -> io::Result<()> { - self.0.set_timeout(dur, SO_RCVTIMEO.try_into().unwrap()) - } - - /// Sets the write timeout to the timeout specified. - /// - /// If the value specified is `None`, then `write` calls will block - /// indefinitely. An `Err` is returned if the zero `Duration` is - /// passed to this method. - /// # Examples - /// - /// ```no_run - /// use mio::net::stdnet::UnixStream; - /// - /// let socket = UnixStream::connect("/tmp/sock").unwrap(); - /// socket.set_write_timeout(None).expect("Couldn't set write timeout"); - /// ``` - pub fn set_write_timeout(&self, dur: Option) -> io::Result<()> { - self.0.set_timeout(dur, SO_SNDTIMEO.try_into().unwrap()) - } - - /// Returns the read timeout of this socket. - /// - /// # Examples - /// - /// ```no_run - /// use mio::net::stdnet::UnixStream; - /// - /// let socket = UnixStream::connect("/tmp/sock").unwrap(); - /// socket.set_read_timeout(None).expect("Couldn't set read timeout"); - /// assert_eq!(socket.read_timeout().unwrap(), None); - /// ``` - pub fn read_timeout(&self) -> io::Result> { - self.0.timeout(SO_RCVTIMEO.try_into().unwrap()) - } - - /// Returns the write timeout of this socket. - /// - /// # Examples - /// - /// ```no_run - /// use mio::net::stdnet::UnixStream; - /// - /// let socket = UnixStream::connect("/tmp/sock").unwrap(); - /// socket.set_write_timeout(None).expect("Couldn't set write timeout"); - /// assert_eq!(socket.write_timeout().unwrap(), None); - /// ``` - pub fn write_timeout(&self) -> io::Result> { - self.0.timeout(SO_SNDTIMEO.try_into().unwrap()) - } -} - -impl io::Read for UnixStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - io::Read::read(&mut &*self, buf) - } - - fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - io::Read::read_vectored(&mut &*self, bufs) - } -} - -impl<'a> io::Read for &'a UnixStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.0.read(buf) - } - - fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.0.read_vectored(bufs) - } -} - -impl io::Write for UnixStream { - fn write(&mut self, buf: &[u8]) -> io::Result { - io::Write::write(&mut &*self, buf) - } - - fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - io::Write::write_vectored(&mut &*self, bufs) - } - - fn flush(&mut self) -> io::Result<()> { - io::Write::flush(&mut &*self) - } -} - -impl<'a> io::Write for &'a UnixStream { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.0.write(buf) - } - - - fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.0.write_vectored(bufs) - } - - fn flush(&mut self) -> io::Result<()> { - Ok(()) - } -} - -impl AsRawSocket for UnixStream { - fn as_raw_socket(&self) -> RawSocket { - self.0.as_raw_socket() - } -} - -impl FromRawSocket for UnixStream { - unsafe fn from_raw_socket(sock: RawSocket) -> Self { - UnixStream(Socket::from_raw_socket(sock)) - } -} - -impl IntoRawSocket for UnixStream { - fn into_raw_socket(self) -> RawSocket { - let ret = self.0.as_raw_socket(); - mem::forget(self); - ret - } -} - -/// A Unix domain socket server -/// -/// # Examples -/// -/// ```no_run -/// use std::thread; -/// use mio::net::stdnet::{UnixStream, UnixListener}; -/// -/// fn handle_client(stream: UnixStream) { -/// // ... -/// # drop(stream); // Silence unused variable warning. -/// } -/// -/// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); -/// -/// // accept connections and process them, spawning a new thread for each one -/// for stream in listener.incoming() { -/// match stream { -/// Ok(stream) => { -/// /* connection succeeded */ -/// thread::spawn(|| handle_client(stream)); -/// } -/// Err(err) => { -/// /* connection failed */ -/// eprintln!("connection failed: {err}"); -/// break; -/// } -/// } -/// } -/// ``` -pub struct UnixListener(Socket); - -impl fmt::Debug for UnixListener { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut builder = fmt.debug_struct("UnixListener"); - builder.field("socket", &self.0.as_raw_socket()); - if let Ok(addr) = self.local_addr() { - builder.field("local", &addr); - } - builder.finish() - } -} - -impl UnixListener { - /// Creates a new `UnixListener` bound to the specified socket. - /// - /// # Examples - /// - /// ```no_run - /// use mio::net::stdnet::UnixListener; - /// - /// let listener = match UnixListener::bind("/path/to/the/socket") { - /// Ok(sock) => sock, - /// Err(e) => { - /// println!("Couldn't connect: {:?}", e); - /// return - /// } - /// }; - /// # drop(listener); // Silence unused variable warning. - /// ``` - pub fn bind>(path: P) -> io::Result { - let inner = Socket::new()?; - let (addr, len) = socket_addr(path.as_ref())?; - - wsa_syscall!( - bind( - inner.as_raw_socket() as _, - &addr as *const _ as *const _, - len as _, - ), - PartialEq::eq, - SOCKET_ERROR - )?; - wsa_syscall!( - listen(inner.as_raw_socket() as _, 128), - PartialEq::eq, - SOCKET_ERROR - )?; - Ok(UnixListener(inner)) - } - - /// Accepts a new incoming connection to this listener. - /// - /// This function will block the calling thread until a new Unix connection - /// is established. When established, the corresponding [`UnixStream`] and - /// the remote peer's address will be returned. - /// - /// [`UnixStream`]: struct.UnixStream.html - /// - /// # Examples - /// - /// ```no_run - /// use mio::net::stdnet::UnixListener; - /// - /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); - /// - /// match listener.accept() { - /// Ok((_socket, addr)) => println!("Got a client: {:?}", addr), - /// Err(e) => println!("accept function failed: {:?}", e), - /// } - /// ``` - pub fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { - let sockaddr = mem::MaybeUninit::::zeroed(); - - // This is safe to assume because a `WinSock::sockaddr_un` filled with `0` - // bytes is properly initialized. - // - // `0` is a valid value for `sockaddr_un::sun_family`; it is - // `WinSock::AF_UNSPEC`. - // - // `[0; 108]` is a valid value for `sockaddr_un::sun_path`; it begins an - // abstract path. - let mut sockaddr = unsafe { sockaddr.assume_init() }; - - sockaddr.sun_family = WinSock::AF_UNIX; - let mut socklen = mem::size_of_val(&sockaddr) as c_int; - - let sock = self.0.accept(&mut sockaddr as *mut _ as *mut _, &mut socklen)?; - let addr = SocketAddr::from_parts(sockaddr, socklen); - Ok((UnixStream(sock), addr)) - } - - /// Creates a new independently owned handle to the underlying socket. - /// - /// The returned `UnixListener` is a reference to the same socket that this - /// object references. Both handles can be used to accept incoming - /// connections and options set on one listener will affect the other. - /// - /// # Examples - /// - /// ```no_run - /// use mio::net::stdnet::UnixListener; - /// - /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); - /// - /// let listener_copy = listener.try_clone().expect("Couldn't clone socket"); - /// # drop(listener_copy); // Silence unused variable warning. - /// ``` - pub fn try_clone(&self) -> io::Result { - self.0.duplicate().map(UnixListener) - } - - /// Returns the local socket address of this listener. - /// - /// # Examples - /// - /// ```no_run - /// use mio::net::stdnet::UnixListener; - /// - /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); - /// - /// let addr = listener.local_addr().expect("Couldn't get local address"); - /// # drop(addr); // Silence unused variable warning. - /// ``` - pub fn local_addr(&self) -> io::Result { - SocketAddr::new(|addr, len| { - wsa_syscall!( - getsockname(self.0.as_raw_socket() as _, addr, len), - PartialEq::eq, - SOCKET_ERROR - ) - }) - } - - /// Moves the socket into or out of nonblocking mode. - /// - /// # Examples - /// - /// ```no_run - /// use mio::net::stdnet::UnixListener; - /// - /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); - /// - /// listener.set_nonblocking(true).expect("Couldn't set nonblocking"); - /// ``` - pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { - self.0.set_nonblocking(nonblocking) - } - - /// Returns the value of the `SO_ERROR` option. - /// - /// # Examples - /// - /// ```no_run - /// use mio::net::stdnet::UnixListener; - /// - /// let listener = UnixListener::bind("/tmp/sock").unwrap(); - /// - /// if let Ok(Some(err)) = listener.take_error() { - /// println!("Got error: {:?}", err); - /// } - /// ``` - pub fn take_error(&self) -> io::Result> { - self.0.take_error() - } - - /// Returns an iterator over incoming connections. - /// - /// The iterator will never return `None` and will also not yield the - /// peer's [`SocketAddr`] structure. - /// - /// [`SocketAddr`]: struct.SocketAddr.html - /// - /// # Examples - /// - /// ```no_run - /// use std::thread; - /// use mio::net::stdnet::{UnixStream, UnixListener}; - /// - /// fn handle_client(stream: UnixStream) { - /// // ... - /// # drop(stream); // Silence unused variable warning. - /// } - /// - /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); - /// - /// for stream in listener.incoming() { - /// match stream { - /// Ok(stream) => { - /// thread::spawn(|| handle_client(stream)); - /// } - /// Err(err) => { - /// eprintln!("connection failed: {err}"); - /// break; - /// } - /// } - /// } - /// ``` - pub fn incoming<'a>(&'a self) -> Incoming<'a> { - Incoming { listener: self } - } -} - -impl AsRawSocket for UnixListener { - fn as_raw_socket(&self) -> RawSocket { - self.0.as_raw_socket() - } -} - -impl FromRawSocket for UnixListener { - unsafe fn from_raw_socket(sock: RawSocket) -> Self { - UnixListener(Socket::from_raw_socket(sock)) - } -} - -impl IntoRawSocket for UnixListener { - fn into_raw_socket(self) -> RawSocket { - let ret = self.0.as_raw_socket(); - mem::forget(self); - ret - } -} - -impl<'a> IntoIterator for &'a UnixListener { - type Item = io::Result; - type IntoIter = Incoming<'a>; - - fn into_iter(self) -> Incoming<'a> { - self.incoming() - } -} - -/// An iterator over incoming connections to a [`UnixListener`]. -/// -/// It will never return `None`. -/// -/// [`UnixListener`]: struct.UnixListener.html -/// -/// # Examples -/// -/// ```no_run -/// use std::thread; -/// use mio::net::stdnet::{UnixStream, UnixListener}; -/// -/// fn handle_client(stream: UnixStream) { -/// // ... -/// # drop(stream); // Silence unused variable warning. -/// } -/// -/// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); -/// -/// for stream in listener.incoming() { -/// match stream { -/// Ok(stream) => { -/// thread::spawn(|| handle_client(stream)); -/// } -/// Err(err) => { -/// eprintln!("connection failed: {err}"); -/// break; -/// } -/// } -/// } -/// ``` -#[derive(Debug)] -pub struct Incoming<'a> { - listener: &'a UnixListener, -} - -impl<'a> Iterator for Incoming<'a> { - type Item = io::Result; - - fn next(&mut self) -> Option> { - Some(self.listener.accept().map(|s| s.0)) - } - - fn size_hint(&self) -> (usize, Option) { - (usize::max_value(), None) - } -} - -#[cfg(test)] -mod test { - use std::io::{self, Read, Write}; - use std::thread; - - use super::*; - - macro_rules! or_panic { - ($e:expr) => { - match $e { - Ok(e) => e, - Err(e) => panic!("{}", e), - } - }; - } - - #[test] - fn basic() { - let socket_path = TempPath::new(10).unwrap(); - let msg1 = b"hello"; - let msg2 = b"world!"; - - let listener = or_panic!(UnixListener::bind(&socket_path)); - let thread = thread::spawn(move || { - let mut stream = or_panic!(listener.accept()).0; - let mut buf = [0; 5]; - or_panic!(stream.read(&mut buf)); - assert_eq!(&msg1[..], &buf[..]); - or_panic!(stream.write_all(msg2)); - }); - - let mut stream = or_panic!(UnixStream::connect(&socket_path)); - assert_eq!( - Some(&*socket_path), - stream.peer_addr().unwrap().as_pathname() - ); - or_panic!(stream.write_all(msg1)); - let mut buf = vec![]; - or_panic!(stream.read_to_end(&mut buf)); - assert_eq!(&msg2[..], &buf[..]); - drop(stream); - - thread.join().unwrap(); - } - - #[test] - fn try_clone() { - let socket_path = TempPath::new(10).unwrap(); - let msg1 = b"hello"; - let msg2 = b"world"; - - let listener = or_panic!(UnixListener::bind(&socket_path)); - let thread = thread::spawn(move || { - #[allow(unused_mut)] - let mut stream = or_panic!(listener.accept()).0; - or_panic!(stream.write_all(msg1)); - or_panic!(stream.write_all(msg2)); - }); - - let mut stream = or_panic!(UnixStream::connect(&socket_path)); - let mut stream2 = or_panic!(stream.try_clone()); - assert_eq!( - Some(&*socket_path), - stream2.peer_addr().unwrap().as_pathname() - ); - - let mut buf = [0; 5]; - or_panic!(stream.read(&mut buf)); - assert_eq!(&msg1[..], &buf[..]); - or_panic!(stream2.read(&mut buf)); - assert_eq!(&msg2[..], &buf[..]); - - thread.join().unwrap(); - } - - #[test] - fn iter() { - let socket_path = TempPath::new(10).unwrap(); - - let listener = or_panic!(UnixListener::bind(&socket_path)); - let thread = thread::spawn(move || { - for stream in listener.incoming().take(2) { - let mut stream = or_panic!(stream); - let mut buf = [0]; - or_panic!(stream.read(&mut buf)); - } - }); - - for _ in 0..2 { - let mut stream = or_panic!(UnixStream::connect(&socket_path)); - or_panic!(stream.write_all(&[0])); - } - - thread.join().unwrap(); - } - - #[test] - fn long_path() { - let socket_path = TempPath::new(100).unwrap(); - match UnixStream::connect(&socket_path) { - Err(ref e) if e.kind() == io::ErrorKind::InvalidInput => {} - Err(e) => panic!("unexpected error {}", e), - Ok(_) => panic!("unexpected success"), - } - - match UnixListener::bind(&socket_path) { - Err(ref e) if e.kind() == io::ErrorKind::InvalidInput => {} - Err(e) => panic!("unexpected error {}", e), - Ok(_) => panic!("unexpected success"), - } - } - - #[test] - fn abstract_namespace_not_allowed() { - assert!(UnixStream::connect("\0asdf").is_err()); - } -} diff --git a/src/sys/windows/uds/stdnet/socket.rs b/src/sys/windows/uds/stdnet/socket.rs index 6620e1b26..14f8cb727 100644 --- a/src/sys/windows/uds/stdnet/socket.rs +++ b/src/sys/windows/uds/stdnet/socket.rs @@ -14,7 +14,7 @@ use windows_sys::Win32::Foundation::{ }; use windows_sys::Win32::System::Threading::GetCurrentProcessId; use windows_sys::Win32::System::WindowsProgramming::INFINITE; -use windows_sys::Win32::Networking::WinSock::{self, INVALID_SOCKET, SOCKADDR, SOCKET, SOCKET_ERROR, SOL_SOCKET, SO_ERROR, WSADuplicateSocketW, WSAPROTOCOL_INFOW, WSASocketW, accept, closesocket, getsockopt as c_getsockopt, ioctlsocket, recv, send, setsockopt as c_setsockopt, shutdown}; +use windows_sys::Win32::Networking::WinSock::{self, INVALID_SOCKET, SOCKADDR, SOCKET, SOCKET_ERROR, SOL_SOCKET, SO_ERROR, closesocket}; use crate::sys::windows::net::init; @@ -51,7 +51,7 @@ impl Socket { } pub fn duplicate(&self) -> io::Result { - let mut info: WSAPROTOCOL_INFOW = unsafe { mem::zeroed() }; + let mut info: WinSock::WSAPROTOCOL_INFOW = unsafe { mem::zeroed() }; wsa_syscall!( WSADuplicateSocketW( self.0, @@ -205,14 +205,13 @@ impl Socket { } } -pub fn setsockopt(sock: &Socket, opt: c_int, val: c_int, payload: T) -> io::Result<()> { - let payload = &payload as *const T as *const _; +fn setsockopt(sock: &Socket, opt: c_int, val: c_int, payload: T) -> io::Result<()> { wsa_syscall!( - c_setsockopt( + setsockopt( sock.as_raw_socket() as usize, opt, val, - payload, + &payload as *const T as *const _, mem::size_of::() as i32, ), PartialEq::eq, @@ -221,11 +220,11 @@ pub fn setsockopt(sock: &Socket, opt: c_int, val: c_int, payload: T) -> io::R Ok(()) } -pub fn getsockopt(sock: &Socket, opt: c_int, val: c_int) -> io::Result { +fn getsockopt(sock: &Socket, opt: c_int, val: c_int) -> io::Result { let mut slot: T = unsafe { mem::zeroed() }; let mut len = mem::size_of::() as i32; wsa_syscall!( - c_getsockopt( + getsockopt( sock.as_raw_socket() as _, opt, val, diff --git a/src/sys/windows/uds/stdnet/stream.rs b/src/sys/windows/uds/stdnet/stream.rs new file mode 100644 index 000000000..4c9d8efeb --- /dev/null +++ b/src/sys/windows/uds/stdnet/stream.rs @@ -0,0 +1,411 @@ +use std::{fmt, mem}; +use std::io::{self, IoSlice, IoSliceMut}; +use std::convert::TryInto; +use std::net::Shutdown; +use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; +use std::path::{Path, PathBuf}; +use std::time::Duration; + +use windows_sys::Win32::Networking::WinSock::{WSAEINPROGRESS, SO_RCVTIMEO, SOCKET_ERROR, SO_SNDTIMEO}; + +use super::{socket_addr, SocketAddr, socket::Socket, UnixListener}; +use rand::{distributions::Alphanumeric, Rng}; + +/// A Unix stream socket +/// +/// # Examples +/// +/// ```no_run +/// use mio::net::stdnet::UnixStream; +/// use std::io::prelude::*; +/// +/// let mut stream = UnixStream::connect("/path/to/my/socket").unwrap(); +/// stream.write_all(b"hello world").unwrap(); +/// let mut response = String::new(); +/// stream.read_to_string(&mut response).unwrap(); +/// println!("{}", response); +/// ``` +pub struct UnixStream(pub(super) Socket); + +impl fmt::Debug for UnixStream { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = fmt.debug_struct("UnixStream"); + builder.field("socket", &self.0.as_raw_socket()); + if let Ok(addr) = self.local_addr() { + builder.field("local", &addr); + } + if let Ok(addr) = self.peer_addr() { + builder.field("peer", &addr); + } + builder.finish() + } +} + +impl UnixStream { + /// Connects to the socket named by `path`. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixStream; + /// + /// let socket = match UnixStream::connect("/tmp/sock") { + /// Ok(sock) => sock, + /// Err(e) => { + /// println!("Couldn't connect: {:?}", e); + /// return + /// } + /// }; + /// # drop(socket); // Silence unused variable warning. + /// ``` + pub fn connect>(path: P) -> io::Result { + let inner = Socket::new()?; + let (addr, len) = socket_addr(path.as_ref())?; + + match wsa_syscall!( + connect( + inner.as_raw_socket() as _, + &addr as *const _ as *const _, + len as i32, + ), + PartialEq::eq, + SOCKET_ERROR + ) { + Ok(_) => {}, + Err(ref err) if err.raw_os_error() == Some(WSAEINPROGRESS) => {}, + Err(e) => return Err(e) + } + Ok(UnixStream(inner)) + } + + /// Creates a new independently owned handle to the underlying socket. + /// + /// The returned `UnixStream` is a reference to the same stream that this + /// object references. Both handles will read and write the same stream of + /// data, and options set on one stream will be propagated to the other + /// stream. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixStream; + /// + /// let socket = UnixStream::connect("/tmp/sock").unwrap(); + /// let sock_copy = socket.try_clone().expect("Couldn't clone socket"); + /// # drop(sock_copy); // Silence unused variable warning. + /// ``` + pub fn try_clone(&self) -> io::Result { + self.0.duplicate().map(UnixStream) + } + + /// Returns the socket address of the local half of this connection. + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixStream; + /// + /// let socket = UnixStream::connect("/tmp/sock").unwrap(); + /// let addr = socket.local_addr().expect("Couldn't get local address"); + /// # drop(addr); // Silence unused variable warning. + /// ``` + pub fn local_addr(&self) -> io::Result { + SocketAddr::new(|addr, len| { + wsa_syscall!( + getsockname(self.0.as_raw_socket() as _, addr, len), + PartialEq::eq, + SOCKET_ERROR + ) + }) + } + + /// Returns the socket address of the remote half of this connection. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixStream; + /// + /// let socket = UnixStream::connect("/tmp/sock").unwrap(); + /// let addr = socket.peer_addr().expect("Couldn't get peer address"); + /// # drop(addr); // Silence unused variable warning. + /// ``` + pub fn peer_addr(&self) -> io::Result { + SocketAddr::new(|addr, len| { + wsa_syscall!( + getpeername(self.0.as_raw_socket() as _, addr, len), + PartialEq::eq, + SOCKET_ERROR + ) + }) + } + + /// Moves the socket into or out of nonblocking mode. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixStream; + /// + /// let socket = UnixStream::connect("/tmp/sock").unwrap(); + /// socket.set_nonblocking(true).expect("Couldn't set nonblocking"); + /// ``` + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.0.set_nonblocking(nonblocking) + } + + /// Returns the value of the `SO_ERROR` option. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixStream; + /// + /// let socket = UnixStream::connect("/tmp/sock").unwrap(); + /// if let Ok(Some(err)) = socket.take_error() { + /// println!("Got error: {:?}", err); + /// } + /// ``` + pub fn take_error(&self) -> io::Result> { + self.0.take_error() + } + + /// Shuts down the read, write, or both halves of this connection. + /// + /// This function will cause all pending and future I/O calls on the + /// specified portions to immediately return with an appropriate value + /// (see the documentation for `Shutdown`). + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixStream; + /// use std::net::Shutdown; + /// + /// let socket = UnixStream::connect("/tmp/sock").unwrap(); + /// socket.shutdown(Shutdown::Both).expect("shutdown function failed"); + /// ``` + pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { + self.0.shutdown(how) + } + + /// Creates an unnamed pair of connected sockets. + /// + /// Returns two `UnixStream`s which are connected to each other. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixStream; + /// + /// let (sock1, sock2) = match UnixStream::pair() { + /// Ok((sock1, sock2)) => (sock1, sock2), + /// Err(e) => { + /// println!("Couldn't create a pair of sockets: {e:?}"); + /// return + /// } + /// }; + /// # drop(sock1); // Silence unused variable warning. + /// # drop(sock2); // Silence unused variable warning. + /// ``` + pub fn pair() -> io::Result<(Self, Self)> { + use std::sync::{Arc, RwLock}; + use std::thread::spawn; + + let file_path = TempPath::new(10)?; + let a: Arc>>> = Arc::new(RwLock::new(None)); + let ul = UnixListener::bind(&file_path).unwrap(); + let server = { + let a = a.clone(); + spawn(move || { + let mut store = a.write().unwrap(); + let stream0 = ul.accept().map(|s| s.0); + *store = Some(stream0); + }) + }; + let stream1 = UnixStream::connect(&file_path)?; + server + .join() + .map_err(|_| io::Error::from(io::ErrorKind::ConnectionRefused))?; + let stream0 = (*(a.write().unwrap())).take().unwrap()?; + return Ok((stream0, stream1)); + } + + /// Sets the read timeout to the timeout specified. + /// + /// If the value specified is `None`, then `read` calls will block + /// indefinitely. An `Err` is returned if the zero `Duration` is + /// passed to this method. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixStream; + /// + /// let socket = UnixStream::connect("/tmp/sock").unwrap(); + /// socket.set_read_timeout(None).expect("Couldn't set read timeout"); + /// ``` + pub fn set_read_timeout(&self, dur: Option) -> io::Result<()> { + self.0.set_timeout(dur, SO_RCVTIMEO.try_into().unwrap()) + } + + /// Sets the write timeout to the timeout specified. + /// + /// If the value specified is `None`, then `write` calls will block + /// indefinitely. An `Err` is returned if the zero `Duration` is + /// passed to this method. + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixStream; + /// + /// let socket = UnixStream::connect("/tmp/sock").unwrap(); + /// socket.set_write_timeout(None).expect("Couldn't set write timeout"); + /// ``` + pub fn set_write_timeout(&self, dur: Option) -> io::Result<()> { + self.0.set_timeout(dur, SO_SNDTIMEO.try_into().unwrap()) + } + + /// Returns the read timeout of this socket. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixStream; + /// + /// let socket = UnixStream::connect("/tmp/sock").unwrap(); + /// socket.set_read_timeout(None).expect("Couldn't set read timeout"); + /// assert_eq!(socket.read_timeout().unwrap(), None); + /// ``` + pub fn read_timeout(&self) -> io::Result> { + self.0.timeout(SO_RCVTIMEO.try_into().unwrap()) + } + + /// Returns the write timeout of this socket. + /// + /// # Examples + /// + /// ```no_run + /// use mio::net::stdnet::UnixStream; + /// + /// let socket = UnixStream::connect("/tmp/sock").unwrap(); + /// socket.set_write_timeout(None).expect("Couldn't set write timeout"); + /// assert_eq!(socket.write_timeout().unwrap(), None); + /// ``` + pub fn write_timeout(&self) -> io::Result> { + self.0.timeout(SO_SNDTIMEO.try_into().unwrap()) + } +} + +impl io::Read for UnixStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + io::Read::read(&mut &*self, buf) + } + + fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { + io::Read::read_vectored(&mut &*self, bufs) + } +} + +impl<'a> io::Read for &'a UnixStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.read(buf) + } + + fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { + self.0.read_vectored(bufs) + } +} + +impl io::Write for UnixStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + io::Write::write(&mut &*self, buf) + } + + fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { + io::Write::write_vectored(&mut &*self, bufs) + } + + fn flush(&mut self) -> io::Result<()> { + io::Write::flush(&mut &*self) + } +} + +impl<'a> io::Write for &'a UnixStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.write(buf) + } + + + fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { + self.0.write_vectored(bufs) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +impl AsRawSocket for UnixStream { + fn as_raw_socket(&self) -> RawSocket { + self.0.as_raw_socket() + } +} + +impl FromRawSocket for UnixStream { + unsafe fn from_raw_socket(sock: RawSocket) -> Self { + UnixStream(Socket::from_raw_socket(sock)) + } +} + +impl IntoRawSocket for UnixStream { + fn into_raw_socket(self) -> RawSocket { + let ret = self.0.as_raw_socket(); + mem::forget(self); + ret + } +} + +struct TempPath(PathBuf); + +impl TempPath { + fn new(random_len: usize) -> io::Result { + let dir = std::env::temp_dir(); + // Retry a few times in case of collisions + for _ in 0..10 { + let rand_str: String = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(random_len) + .map(char::from) + .collect(); + let filename = format!(".tmp-{rand_str}.socket"); + let path = dir.join(filename); + if !path.exists() { + return Ok(Self(path)); + } + } + + Err(io::Error::new( + io::ErrorKind::AlreadyExists, + "too many temporary files exist", + )) + } +} + +impl Drop for TempPath { + fn drop(&mut self) { + let _ = std::fs::remove_file(&self.0); + } +} + +impl AsRef for TempPath { + fn as_ref(&self) -> &Path { + &self.0 + } +} + +impl std::ops::Deref for TempPath { + type Target = Path; + fn deref(&self) -> &Path { + Path::new(&self.0) + } +} diff --git a/src/sys/windows/uds/stream.rs b/src/sys/windows/uds/stream.rs index c70bc1375..4df1aa396 100644 --- a/src/sys/windows/uds/stream.rs +++ b/src/sys/windows/uds/stream.rs @@ -2,11 +2,8 @@ use std::io; use std::os::windows::io::{AsRawSocket}; use std::path::Path; use super::stdnet::{self as net}; -use crate::net::SocketAddr; -use crate::sys::windows::net::init; pub(crate) fn connect(path: &Path) -> io::Result { - init(); let socket = net::UnixStream::connect(path)?; socket.set_nonblocking(true)?; Ok(socket) @@ -19,10 +16,10 @@ pub(crate) fn pair() -> io::Result<(net::UnixStream, net::UnixStream)> { Ok((stream0, stream1)) } -pub(crate) fn local_addr(socket: &net::UnixStream) -> io::Result { +pub(crate) fn local_addr(socket: &net::UnixStream) -> io::Result { super::local_addr(socket.as_raw_socket()) } -pub(crate) fn peer_addr(socket: &net::UnixStream) -> io::Result { +pub(crate) fn peer_addr(socket: &net::UnixStream) -> io::Result { super::peer_addr(socket.as_raw_socket()) } From 86c4c9ab608738fe5ef25d7969eb5513b1fd8d30 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Sat, 20 Aug 2022 20:12:30 -0700 Subject: [PATCH 15/34] use single syscall vectored approach from rust-lang/socket2 --- src/sys/windows/uds/stdnet/socket.rs | 106 +++++++++++++++++++-------- src/sys/windows/uds/stdnet/stream.rs | 8 +- 2 files changed, 78 insertions(+), 36 deletions(-) diff --git a/src/sys/windows/uds/stdnet/socket.rs b/src/sys/windows/uds/stdnet/socket.rs index 14f8cb727..188c3a035 100644 --- a/src/sys/windows/uds/stdnet/socket.rs +++ b/src/sys/windows/uds/stdnet/socket.rs @@ -1,4 +1,5 @@ use std::io::{self, IoSlice, IoSliceMut}; +use std::cmp::min; use std::convert::TryInto; use std::mem; use std::net::Shutdown; @@ -14,10 +15,24 @@ use windows_sys::Win32::Foundation::{ }; use windows_sys::Win32::System::Threading::GetCurrentProcessId; use windows_sys::Win32::System::WindowsProgramming::INFINITE; -use windows_sys::Win32::Networking::WinSock::{self, INVALID_SOCKET, SOCKADDR, SOCKET, SOCKET_ERROR, SOL_SOCKET, SO_ERROR, closesocket}; +use windows_sys::Win32::Networking::WinSock::{ + self, + WSABUF, + INVALID_SOCKET, + SOCKADDR, + SOCKET, + SOCKET_ERROR, + SOL_SOCKET, + SO_ERROR, + closesocket, + WSAESHUTDOWN +}; use crate::sys::windows::net::init; +/// Maximum size of a buffer passed to system call like `recv` and `send`. +const MAX_BUF_LEN: usize = c_int::MAX as usize; + #[derive(Debug)] pub struct Socket(SOCKET); @@ -76,13 +91,13 @@ impl Socket { Ok(Socket(socket)) } - fn recv_with_flags(&self, buf: &mut [u8], flags: c_int) -> io::Result { + pub fn recv(&self, buf: &mut [u8]) -> io::Result { let ret = wsa_syscall!( recv( self.0, buf.as_mut_ptr() as *mut _, buf.len() as c_int, - flags, + 0, ), PartialEq::eq, SOCKET_ERROR @@ -90,44 +105,71 @@ impl Socket { Ok(ret as usize) } - pub fn read(&self, buf: &mut [u8]) -> io::Result { - self.recv_with_flags(buf, 0) - } - - pub fn read_vectored(&self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { + pub fn recv_vectored(&self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { let mut total = 0; - for slice in &mut *bufs { - let wsa_buf = unsafe { *(slice as *const _ as *const WinSock::WSABUF) }; - let len = wsa_buf.len; - let buf = unsafe { std::slice::from_raw_parts_mut(wsa_buf.buf, len.try_into().unwrap()) }; - total += self.recv_with_flags(buf, 0)?; + let mut flags: u32 = 0; + let bufs = unsafe { &mut *(bufs as *mut [IoSliceMut<'_>] as *mut [WSABUF]) }; + let res = wsa_syscall!( + WSARecv( + self.0, + bufs.as_mut_ptr().cast(), + min(bufs.len(), u32::MAX as usize) as u32, + &mut total, + &mut flags, + ptr::null_mut(), + None, + ), + PartialEq::eq, + SOCKET_ERROR + ); + match res { + Ok(_) => Ok(total as usize), + Err(ref err) if err.raw_os_error() == Some(WSAESHUTDOWN as i32) => Ok(0), + Err(err) => Err(err), } - Ok(total as usize) } - pub fn write(&self, buf: &[u8]) -> io::Result { - let ret = wsa_syscall!( - send(self.0, buf as *const _ as *const _, buf.len() as c_int, 0), + pub fn send(&self, buf: &[u8]) -> io::Result { + wsa_syscall!( + send( + self.0, + buf.as_ptr().cast(), + min(buf.len(), MAX_BUF_LEN) as c_int, + 0, + ), PartialEq::eq, SOCKET_ERROR - )?; - Ok(ret as usize) + ) + .map(|n| n as usize) } - pub fn write_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result { + pub fn send_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result { let mut total = 0; - for slice in bufs { - let wsa_buf = unsafe { *(slice as *const _ as *const WinSock::WSABUF) }; - let len = wsa_buf.len; - let buf = unsafe { std::slice::from_raw_parts(wsa_buf.buf, len.try_into().unwrap()) }; - let ret = wsa_syscall!( - send(self.0, buf as *const _ as *const _, len as c_int, 0), - PartialEq::eq, - SOCKET_ERROR - )?; - total += ret; - } - Ok(total as usize) + wsa_syscall!( + WSASend( + self.0, + // FIXME: From the `WSASend` docs [1]: + // > For a Winsock application, once the WSASend function is called, + // > the system owns these buffers and the application may not + // > access them. + // + // So what we're doing is actually UB as `bufs` needs to be `&mut + // [IoSlice<'_>]`. + // + // See: https://github.com/rust-lang/socket2-rs/issues/129. + // + // [1] https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasend + bufs.as_ptr() as *mut _, + min(bufs.len(), u32::MAX as usize) as u32, + &mut total, + 0, + std::ptr::null_mut(), + None, + ), + PartialEq::eq, + SOCKET_ERROR + ) + .map(|_| total as usize) } fn set_no_inherit(&self) -> io::Result<()> { diff --git a/src/sys/windows/uds/stdnet/stream.rs b/src/sys/windows/uds/stdnet/stream.rs index 4c9d8efeb..7ed07d7a9 100644 --- a/src/sys/windows/uds/stdnet/stream.rs +++ b/src/sys/windows/uds/stdnet/stream.rs @@ -308,11 +308,11 @@ impl io::Read for UnixStream { impl<'a> io::Read for &'a UnixStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.0.read(buf) + self.0.recv(buf) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.0.read_vectored(bufs) + self.0.recv_vectored(bufs) } } @@ -332,12 +332,12 @@ impl io::Write for UnixStream { impl<'a> io::Write for &'a UnixStream { fn write(&mut self, buf: &[u8]) -> io::Result { - self.0.write(buf) + self.0.send(buf) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.0.write_vectored(bufs) + self.0.send_vectored(bufs) } fn flush(&mut self) -> io::Result<()> { From 9f2628638358bd7e1f520b7b9221bdd550452ff3 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Sat, 20 Aug 2022 20:18:24 -0700 Subject: [PATCH 16/34] lint --- src/sys/windows/uds/stdnet/addr.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/sys/windows/uds/stdnet/addr.rs b/src/sys/windows/uds/stdnet/addr.rs index 2bc228dde..e381a4165 100644 --- a/src/sys/windows/uds/stdnet/addr.rs +++ b/src/sys/windows/uds/stdnet/addr.rs @@ -1,6 +1,3 @@ - -//! Windows specific networking functionality - use std::ascii; use std::fmt; use std::io; From 09a9b7945d1f6c97d33a392ee2a6ffc093950d39 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Sun, 21 Aug 2022 14:51:20 -0700 Subject: [PATCH 17/34] improve support across feature matrix --- src/lib.rs | 2 + src/net/mod.rs | 3 - src/net/uds/listener.rs | 9 +- src/net/uds/mod.rs | 3 - src/net/uds/stream.rs | 4 +- src/sys/mod.rs | 7 +- src/sys/shell/mod.rs | 1 - src/sys/shell/uds.rs | 7 + src/sys/windows/mod.rs | 293 ++++++++++--------- src/sys/windows/net.rs | 12 - src/sys/windows/{uds => }/stdnet/addr.rs | 19 +- src/sys/windows/{uds => }/stdnet/listener.rs | 10 +- src/sys/windows/stdnet/mod.rs | 23 ++ src/sys/windows/{uds => }/stdnet/socket.rs | 48 +-- src/sys/windows/{uds => }/stdnet/stream.rs | 19 +- src/sys/windows/tcp.rs | 3 +- src/sys/windows/udp.rs | 3 +- src/sys/windows/uds/listener.rs | 6 +- src/sys/windows/uds/mod.rs | 3 +- src/sys/windows/uds/stdnet/mod.rs | 10 - src/sys/windows/uds/stream.rs | 4 +- tests/unix_listener.rs | 4 +- tests/unix_stream.rs | 4 +- 23 files changed, 256 insertions(+), 241 deletions(-) rename src/sys/windows/{uds => }/stdnet/addr.rs (91%) rename src/sys/windows/{uds => }/stdnet/listener.rs (97%) create mode 100644 src/sys/windows/stdnet/mod.rs rename src/sys/windows/{uds => }/stdnet/socket.rs (91%) rename src/sys/windows/{uds => }/stdnet/stream.rs (97%) delete mode 100644 src/sys/windows/uds/stdnet/mod.rs diff --git a/src/lib.rs b/src/lib.rs index 56a7160be..e1c8b47e3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -91,6 +91,8 @@ pub mod windows { //! Windows only extensions. pub use crate::sys::named_pipe::NamedPipe; + + pub use crate::sys::windows::std; } pub mod features { diff --git a/src/net/mod.rs b/src/net/mod.rs index a985bf64b..dc5d4b388 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -39,6 +39,3 @@ pub use self::uds::{SocketAddr, UnixListener, UnixStream}; #[cfg(unix)] pub use self::uds::UnixDatagram; - -#[cfg(windows)] -pub use crate::sys::uds::stdnet; diff --git a/src/net/uds/listener.rs b/src/net/uds/listener.rs index 7129401e0..f6b03e405 100644 --- a/src/net/uds/listener.rs +++ b/src/net/uds/listener.rs @@ -2,14 +2,14 @@ use crate::io_source::IoSource; use crate::net::{SocketAddr, UnixStream}; use crate::{event, sys, Interest, Registry, Token}; +#[cfg(windows)] +use crate::sys::windows::std::net; #[cfg(unix)] use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; #[cfg(unix)] use std::os::unix::net; #[cfg(windows)] use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; -#[cfg(windows)] -use crate::sys::uds::{stdnet as net}; use std::path::Path; use std::{fmt, io}; @@ -41,9 +41,8 @@ impl UnixListener { /// The call is responsible for ensuring that the listening socket is in /// non-blocking mode. pub fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { - self.inner.do_io(|inner| { - sys::uds::listener::accept(&*inner) - }) + self.inner + .do_io(|inner| sys::uds::listener::accept(&*inner)) } /// Returns the local socket address of this listener. diff --git a/src/net/uds/mod.rs b/src/net/uds/mod.rs index 332d389d4..c0a77bbf2 100644 --- a/src/net/uds/mod.rs +++ b/src/net/uds/mod.rs @@ -10,6 +10,3 @@ mod stream; pub use self::stream::UnixStream; pub use crate::sys::SocketAddr; - -#[cfg(windows)] -pub use crate::sys::uds::stdnet; diff --git a/src/net/uds/stream.rs b/src/net/uds/stream.rs index 8cae0cefa..7172c0d74 100644 --- a/src/net/uds/stream.rs +++ b/src/net/uds/stream.rs @@ -1,6 +1,8 @@ use crate::io_source::IoSource; use crate::{event, sys, Interest, Registry, Token}; +#[cfg(windows)] +use crate::sys::windows::std::net; use std::fmt; use std::io::{self, IoSlice, IoSliceMut, Read, Write}; use std::net::Shutdown; @@ -10,8 +12,6 @@ use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; use std::os::unix::net; #[cfg(windows)] use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; -#[cfg(windows)] -use crate::sys::uds::{stdnet as net}; use std::path::Path; /// A non-blocking Unix stream socket. diff --git a/src/sys/mod.rs b/src/sys/mod.rs index 872529c4d..ac9365263 100644 --- a/src/sys/mod.rs +++ b/src/sys/mod.rs @@ -59,7 +59,7 @@ cfg_os_poll! { #[cfg(windows)] cfg_os_poll! { - mod windows; + pub(crate) mod windows; pub use self::windows::*; } @@ -84,6 +84,11 @@ cfg_not_os_poll! { pub use self::unix::SocketAddr; } + #[cfg(windows)] + cfg_any_os_ext! { + pub(crate) mod windows; + } + #[cfg(windows)] cfg_net! { pub use self::windows::SocketAddr; diff --git a/src/sys/shell/mod.rs b/src/sys/shell/mod.rs index 8a3175f76..c29bcc9f6 100644 --- a/src/sys/shell/mod.rs +++ b/src/sys/shell/mod.rs @@ -15,7 +15,6 @@ pub(crate) use self::waker::Waker; cfg_net! { pub(crate) mod tcp; pub(crate) mod udp; - #[cfg(unix)] pub(crate) mod uds; } diff --git a/src/sys/shell/uds.rs b/src/sys/shell/uds.rs index c18aca042..caa23b9cb 100644 --- a/src/sys/shell/uds.rs +++ b/src/sys/shell/uds.rs @@ -1,3 +1,4 @@ +#[cfg(unix)] pub(crate) mod datagram { use crate::net::SocketAddr; use std::io; @@ -34,7 +35,10 @@ pub(crate) mod datagram { pub(crate) mod listener { use crate::net::{SocketAddr, UnixStream}; + #[cfg(windows)] + use crate::sys::windows::std::net; use std::io; + #[cfg(unix)] use std::os::unix::net; use std::path::Path; @@ -53,7 +57,10 @@ pub(crate) mod listener { pub(crate) mod stream { use crate::net::SocketAddr; + #[cfg(windows)] + use crate::sys::windows::std::net; use std::io; + #[cfg(unix)] use std::os::unix::net; use std::path::Path; diff --git a/src/sys/windows/mod.rs b/src/sys/windows/mod.rs index ea4f4a409..ffb967f09 100644 --- a/src/sys/windows/mod.rs +++ b/src/sys/windows/mod.rs @@ -1,169 +1,190 @@ -mod afd; - -pub mod event; -pub use event::{Event, Events}; - -mod handle; -use handle::Handle; - -mod io_status_block; -mod iocp; - -mod overlapped; -use overlapped::Overlapped; - -mod selector; -pub use selector::{Selector, SelectorInner, SockState}; - -// Macros must be defined before the modules that use them -cfg_net! { - /// Helper macro to execute a system call that returns an `io::Result`. - // - // Macro must be defined before any modules that uses them. - macro_rules! syscall { - ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{ - let res = unsafe { $fn($($arg, )*) }; - if $err_test(&res, &$err_value) { - Err(io::Error::last_os_error()) - } else { - Ok(res) - } - }}; - } +/// Helper macro to execute a system call that returns an `io::Result`. +// +// Macro must be defined before any modules that uses them. +#[allow(unused_macros)] +macro_rules! syscall { + ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{ + let res = unsafe { $fn($($arg, )*) }; + if $err_test(&res, &$err_value) { + Err(io::Error::last_os_error()) + } else { + Ok(res) + } + }}; +} - macro_rules! wsa_syscall { - ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{ - let res = unsafe { windows_sys::Win32::Networking::WinSock::$fn($($arg, )*) }; - if $err_test(&res, &$err_value) { - Err(io::Error::from_raw_os_error(unsafe { - windows_sys::Win32::Networking::WinSock::WSAGetLastError() - })) - } else { - Ok(res) - } - }}; - } +/// Helper macro to execute a WinSock system call that returns an `io::Result`. +#[allow(unused_macros)] +macro_rules! wsa_syscall { + ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{ + let res = unsafe { windows_sys::Win32::Networking::WinSock::$fn($($arg, )*) }; + if $err_test(&res, &$err_value) { + Err(io::Error::from_raw_os_error(unsafe { + windows_sys::Win32::Networking::WinSock::WSAGetLastError() + })) + } else { + Ok(res) + } + }}; +} - mod net; +cfg_any_os_ext! { + mod stdnet; - pub(crate) mod tcp; - pub(crate) mod udp; - pub mod uds; - pub use self::uds::SocketAddr; - #[cfg(all(windows, test))] - pub use self::uds::stdnet; + pub mod std { + //! Windows only std lib modules that cannot be upstreamed. + pub mod net { + //! Internal Windows std net implementation. + pub use crate::sys::windows::stdnet::*; + } + } } -cfg_os_ext! { - pub(crate) mod named_pipe; -} +cfg_os_poll! { + mod afd; -mod waker; -pub(crate) use waker::Waker; + pub mod event; + pub use event::{Event, Events}; -cfg_io_source! { - use std::io; - use std::os::windows::io::RawSocket; - use std::pin::Pin; - use std::sync::{Arc, Mutex}; + mod handle; + use handle::Handle; - use crate::{Interest, Registry, Token}; + mod io_status_block; + mod iocp; - struct InternalState { - selector: Arc, - token: Token, - interests: Interest, - sock_state: Pin>>, - } + mod overlapped; + use overlapped::Overlapped; - impl Drop for InternalState { - fn drop(&mut self) { - let mut sock_state = self.sock_state.lock().unwrap(); - sock_state.mark_delete(); - } + mod selector; + pub use selector::{Selector, SelectorInner, SockState}; + + // Macros must be defined before the modules that use them + cfg_net! { + mod net; + pub(crate) mod tcp; + pub(crate) mod udp; + pub(crate) mod uds; + pub use self::uds::SocketAddr; } - pub struct IoSourceState { - // This is `None` if the socket has not yet been registered. - // - // We box the internal state to not increase the size on the stack as the - // type might move around a lot. - inner: Option>, + cfg_os_ext! { + pub(crate) mod named_pipe; } - impl IoSourceState { - pub fn new() -> IoSourceState { - IoSourceState { inner: None } - } + mod waker; + pub(crate) use waker::Waker; - pub fn do_io(&self, f: F, io: &T) -> io::Result - where - F: FnOnce(&T) -> io::Result, - { - let result = f(io); - if let Err(ref e) = result { - if e.kind() == io::ErrorKind::WouldBlock { - self.inner.as_ref().map_or(Ok(()), |state| { - state - .selector - .reregister(state.sock_state.clone(), state.token, state.interests) - })?; - } - } - result - } + cfg_io_source! { + use ::std::io; + use ::std::os::windows::io::RawSocket; + use ::std::pin::Pin; + use ::std::sync::{Arc, Mutex}; - pub fn register( - &mut self, - registry: &Registry, + use crate::{Interest, Registry, Token}; + + struct InternalState { + selector: Arc, token: Token, interests: Interest, - socket: RawSocket, - ) -> io::Result<()> { - if self.inner.is_some() { - Err(io::ErrorKind::AlreadyExists.into()) - } else { - registry - .selector() - .register(socket, token, interests) - .map(|state| { - self.inner = Some(Box::new(state)); - }) + sock_state: Pin>>, + } + + impl Drop for InternalState { + fn drop(&mut self) { + let mut sock_state = self.sock_state.lock().unwrap(); + sock_state.mark_delete(); } } - pub fn reregister( - &mut self, - registry: &Registry, - token: Token, - interests: Interest, - ) -> io::Result<()> { - match self.inner.as_mut() { - Some(state) => { + pub struct IoSourceState { + // This is `None` if the socket has not yet been registered. + // + // We box the internal state to not increase the size on the stack as the + // type might move around a lot. + inner: Option>, + } + + impl IoSourceState { + pub fn new() -> IoSourceState { + IoSourceState { inner: None } + } + + pub fn do_io(&self, f: F, io: &T) -> io::Result + where + F: FnOnce(&T) -> io::Result, + { + let result = f(io); + if let Err(ref e) = result { + if e.kind() == io::ErrorKind::WouldBlock { + self.inner.as_ref().map_or(Ok(()), |state| { + state + .selector + .reregister(state.sock_state.clone(), state.token, state.interests) + })?; + } + } + result + } + + pub fn register( + &mut self, + registry: &Registry, + token: Token, + interests: Interest, + socket: RawSocket, + ) -> io::Result<()> { + if self.inner.is_some() { + Err(io::ErrorKind::AlreadyExists.into()) + } else { registry .selector() - .reregister(state.sock_state.clone(), token, interests) - .map(|()| { - state.token = token; - state.interests = interests; + .register(socket, token, interests) + .map(|state| { + self.inner = Some(Box::new(state)); }) } - None => Err(io::ErrorKind::NotFound.into()), } - } - pub fn deregister(&mut self) -> io::Result<()> { - match self.inner.as_mut() { - Some(state) => { - { - let mut sock_state = state.sock_state.lock().unwrap(); - sock_state.mark_delete(); + pub fn reregister( + &mut self, + registry: &Registry, + token: Token, + interests: Interest, + ) -> io::Result<()> { + match self.inner.as_mut() { + Some(state) => { + registry + .selector() + .reregister(state.sock_state.clone(), token, interests) + .map(|()| { + state.token = token; + state.interests = interests; + }) } - self.inner = None; - Ok(()) + None => Err(io::ErrorKind::NotFound.into()), + } + } + + pub fn deregister(&mut self) -> io::Result<()> { + match self.inner.as_mut() { + Some(state) => { + { + let mut sock_state = state.sock_state.lock().unwrap(); + sock_state.mark_delete(); + } + self.inner = None; + Ok(()) + } + None => Err(io::ErrorKind::NotFound.into()), } - None => Err(io::ErrorKind::NotFound.into()), } } } } + +cfg_not_os_poll! { + cfg_net! { + mod uds; + pub use self::uds::SocketAddr; + } +} diff --git a/src/sys/windows/net.rs b/src/sys/windows/net.rs index 102ba7979..d114da408 100644 --- a/src/sys/windows/net.rs +++ b/src/sys/windows/net.rs @@ -1,24 +1,12 @@ use std::io; use std::mem; use std::net::SocketAddr; -use std::sync::Once; use windows_sys::Win32::Networking::WinSock::{ ioctlsocket, socket, AF_INET, AF_INET6, FIONBIO, IN6_ADDR, IN6_ADDR_0, INVALID_SOCKET, IN_ADDR, IN_ADDR_0, SOCKADDR, SOCKADDR_IN, SOCKADDR_IN6, SOCKADDR_IN6_0, SOCKET, }; -/// Initialise the network stack for Windows. -pub(crate) fn init() { - static INIT: Once = Once::new(); - INIT.call_once(|| { - // Let standard library call `WSAStartup` for us, we can't do it - // ourselves because otherwise using any type in `std::net` would panic - // when it tries to call `WSAStartup` a second time. - drop(std::net::UdpSocket::bind("127.0.0.1:0")); - }); -} - /// Create a new non-blocking socket. pub(crate) fn new_ip_socket(addr: SocketAddr, socket_type: u16) -> io::Result { let domain = match addr { diff --git a/src/sys/windows/uds/stdnet/addr.rs b/src/sys/windows/stdnet/addr.rs similarity index 91% rename from src/sys/windows/uds/stdnet/addr.rs rename to src/sys/windows/stdnet/addr.rs index e381a4165..d737e8430 100644 --- a/src/sys/windows/uds/stdnet/addr.rs +++ b/src/sys/windows/stdnet/addr.rs @@ -29,10 +29,12 @@ pub(super) fn socket_addr(path: &Path) -> io::Result<(sockaddr_un, c_int)> { sockaddr.sun_family = AF_UNIX; // Winsock2 expects 'sun_path' to be a Win32 UTF-8 file system path - let bytes = path.to_str().map(|s| s.as_bytes()).ok_or(io::Error::new( - io::ErrorKind::InvalidInput, - "path contains invalid characters", - ))?; + let bytes = path.to_str().map(|s| s.as_bytes()).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "path contains invalid characters", + ) + })?; if bytes.contains(&0) { return Err(io::Error::new( @@ -124,7 +126,11 @@ impl SocketAddr { /// Returns the contents of this address if it is a `pathname` address. pub fn as_pathname(&self) -> Option<&Path> { - if let AddressKind::Pathname(path) = self.address() { Some(path) } else { None } + if let AddressKind::Pathname(path) = self.address() { + Some(path) + } else { + None + } } fn address(&self) -> AddressKind<'_> { @@ -138,7 +144,8 @@ impl SocketAddr { AddressKind::Abstract(&self.addr.sun_path[1..len]) } else { use std::ffi::CStr; - let pathname = unsafe { CStr::from_bytes_with_nul_unchecked(&self.addr.sun_path[..len]) }; + let pathname = + unsafe { CStr::from_bytes_with_nul_unchecked(&self.addr.sun_path[..len]) }; AddressKind::Pathname(Path::new(pathname.to_str().unwrap())) } } diff --git a/src/sys/windows/uds/stdnet/listener.rs b/src/sys/windows/stdnet/listener.rs similarity index 97% rename from src/sys/windows/uds/stdnet/listener.rs rename to src/sys/windows/stdnet/listener.rs index 199661887..f4ad91c5f 100644 --- a/src/sys/windows/uds/stdnet/listener.rs +++ b/src/sys/windows/stdnet/listener.rs @@ -1,11 +1,11 @@ -use std::{mem, io, fmt}; use std::os::raw::c_int; use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; use std::path::Path; +use std::{fmt, io, mem}; use windows_sys::Win32::Networking::WinSock::{sockaddr_un, AF_UNIX, SOCKET_ERROR}; -use super::{socket_addr, SocketAddr, socket::Socket, UnixStream}; +use super::{socket::Socket, socket_addr, SocketAddr, UnixStream}; /// A Unix domain socket server /// @@ -124,7 +124,9 @@ impl UnixListener { sockaddr.sun_family = AF_UNIX; let mut socklen = mem::size_of_val(&sockaddr) as c_int; - let sock = self.0.accept(&mut sockaddr as *mut _ as *mut _, &mut socklen)?; + let sock = self + .0 + .accept(&mut sockaddr as *mut _ as *mut _, &mut socklen)?; let addr = SocketAddr::from_parts(sockaddr, socklen); Ok((UnixStream(sock), addr)) } @@ -235,7 +237,7 @@ impl UnixListener { /// } /// } /// ``` - pub fn incoming<'a>(&'a self) -> Incoming<'a> { + pub fn incoming(&self) -> Incoming<'_> { Incoming { listener: self } } } diff --git a/src/sys/windows/stdnet/mod.rs b/src/sys/windows/stdnet/mod.rs new file mode 100644 index 000000000..9dbaf719f --- /dev/null +++ b/src/sys/windows/stdnet/mod.rs @@ -0,0 +1,23 @@ +//! Windows specific networking functionality. Mirrors std::os::unix::net. + +mod addr; +mod listener; +mod socket; +mod stream; + +pub use self::addr::*; +pub use self::listener::*; +pub use self::stream::*; + +use std::sync::Once; + +/// Initialise the network stack for Windows. +pub(crate) fn init() { + static INIT: Once = Once::new(); + INIT.call_once(|| { + // Let standard library call `WSAStartup` for us, we can't do it + // ourselves because otherwise using any type in `std::net` would panic + // when it tries to call `WSAStartup` a second time. + drop(std::net::UdpSocket::bind("127.0.0.1:0")); + }); +} diff --git a/src/sys/windows/uds/stdnet/socket.rs b/src/sys/windows/stdnet/socket.rs similarity index 91% rename from src/sys/windows/uds/stdnet/socket.rs rename to src/sys/windows/stdnet/socket.rs index 188c3a035..ef97cea0c 100644 --- a/src/sys/windows/uds/stdnet/socket.rs +++ b/src/sys/windows/stdnet/socket.rs @@ -1,6 +1,6 @@ -use std::io::{self, IoSlice, IoSliceMut}; use std::cmp::min; use std::convert::TryInto; +use std::io::{self, IoSlice, IoSliceMut}; use std::mem; use std::net::Shutdown; use std::os::raw::{c_int, c_ulong}; @@ -8,27 +8,14 @@ use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket} use std::ptr; use std::time::Duration; -use windows_sys::Win32::Foundation::{ - HANDLE, - SetHandleInformation, - HANDLE_FLAG_INHERIT +use super::init; +use windows_sys::Win32::Foundation::{SetHandleInformation, HANDLE, HANDLE_FLAG_INHERIT}; +use windows_sys::Win32::Networking::WinSock::{ + self, closesocket, INVALID_SOCKET, SOCKADDR, SOCKET, SOCKET_ERROR, SOL_SOCKET, SO_ERROR, + WSABUF, WSAESHUTDOWN, }; use windows_sys::Win32::System::Threading::GetCurrentProcessId; use windows_sys::Win32::System::WindowsProgramming::INFINITE; -use windows_sys::Win32::Networking::WinSock::{ - self, - WSABUF, - INVALID_SOCKET, - SOCKADDR, - SOCKET, - SOCKET_ERROR, - SOL_SOCKET, - SO_ERROR, - closesocket, - WSAESHUTDOWN -}; - -use crate::sys::windows::net::init; /// Maximum size of a buffer passed to system call like `recv` and `send`. const MAX_BUF_LEN: usize = c_int::MAX as usize; @@ -55,11 +42,7 @@ impl Socket { } pub fn accept(&self, storage: *mut SOCKADDR, len: *mut c_int) -> io::Result { - let socket = wsa_syscall!( - accept(self.0, storage, len), - PartialEq::eq, - INVALID_SOCKET - )?; + let socket = wsa_syscall!(accept(self.0, storage, len), PartialEq::eq, INVALID_SOCKET)?; let socket = Socket(socket); socket.set_no_inherit()?; Ok(socket) @@ -68,11 +51,7 @@ impl Socket { pub fn duplicate(&self) -> io::Result { let mut info: WinSock::WSAPROTOCOL_INFOW = unsafe { mem::zeroed() }; wsa_syscall!( - WSADuplicateSocketW( - self.0, - GetCurrentProcessId(), - &mut info, - ), + WSADuplicateSocketW(self.0, GetCurrentProcessId(), &mut info,), PartialEq::eq, SOCKET_ERROR )?; @@ -81,7 +60,7 @@ impl Socket { info.iAddressFamily, info.iSocketType, info.iProtocol, - &mut info, + &info, 0, WinSock::WSA_FLAG_OVERLAPPED | WinSock::WSA_FLAG_NO_HANDLE_INHERIT, ), @@ -93,12 +72,7 @@ impl Socket { pub fn recv(&self, buf: &mut [u8]) -> io::Result { let ret = wsa_syscall!( - recv( - self.0, - buf.as_mut_ptr() as *mut _, - buf.len() as c_int, - 0, - ), + recv(self.0, buf.as_mut_ptr() as *mut _, buf.len() as c_int, 0,), PartialEq::eq, SOCKET_ERROR )?; @@ -209,7 +183,7 @@ impl Socket { let raw = getsockopt::( self, SOL_SOCKET.try_into().unwrap(), - SO_ERROR.try_into().unwrap() + SO_ERROR.try_into().unwrap(), )?; if raw == 0 { Ok(None) diff --git a/src/sys/windows/uds/stdnet/stream.rs b/src/sys/windows/stdnet/stream.rs similarity index 97% rename from src/sys/windows/uds/stdnet/stream.rs rename to src/sys/windows/stdnet/stream.rs index 7ed07d7a9..49e4ff525 100644 --- a/src/sys/windows/uds/stdnet/stream.rs +++ b/src/sys/windows/stdnet/stream.rs @@ -1,14 +1,16 @@ -use std::{fmt, mem}; -use std::io::{self, IoSlice, IoSliceMut}; use std::convert::TryInto; +use std::io::{self, IoSlice, IoSliceMut}; use std::net::Shutdown; use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; use std::path::{Path, PathBuf}; use std::time::Duration; +use std::{fmt, mem}; -use windows_sys::Win32::Networking::WinSock::{WSAEINPROGRESS, SO_RCVTIMEO, SOCKET_ERROR, SO_SNDTIMEO}; +use windows_sys::Win32::Networking::WinSock::{ + SOCKET_ERROR, SO_RCVTIMEO, SO_SNDTIMEO, WSAEINPROGRESS, +}; -use super::{socket_addr, SocketAddr, socket::Socket, UnixListener}; +use super::{socket::Socket, socket_addr, SocketAddr, UnixListener}; use rand::{distributions::Alphanumeric, Rng}; /// A Unix stream socket @@ -71,9 +73,9 @@ impl UnixStream { PartialEq::eq, SOCKET_ERROR ) { - Ok(_) => {}, - Err(ref err) if err.raw_os_error() == Some(WSAEINPROGRESS) => {}, - Err(e) => return Err(e) + Ok(_) => {} + Err(ref err) if err.raw_os_error() == Some(WSAEINPROGRESS) => {} + Err(e) => return Err(e), } Ok(UnixStream(inner)) } @@ -227,7 +229,7 @@ impl UnixStream { .join() .map_err(|_| io::Error::from(io::ErrorKind::ConnectionRefused))?; let stream0 = (*(a.write().unwrap())).take().unwrap()?; - return Ok((stream0, stream1)); + Ok((stream0, stream1)) } /// Sets the read timeout to the timeout specified. @@ -335,7 +337,6 @@ impl<'a> io::Write for &'a UnixStream { self.0.send(buf) } - fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { self.0.send_vectored(bufs) } diff --git a/src/sys/windows/tcp.rs b/src/sys/windows/tcp.rs index 533074be9..a2f123e36 100644 --- a/src/sys/windows/tcp.rs +++ b/src/sys/windows/tcp.rs @@ -6,7 +6,8 @@ use windows_sys::Win32::Networking::WinSock::{ self, AF_INET, AF_INET6, SOCKET, SOCKET_ERROR, SOCK_STREAM, }; -use crate::sys::windows::net::{init, new_socket, socket_addr}; +use crate::sys::windows::net::{new_socket, socket_addr}; +use crate::sys::windows::std::net::init; pub(crate) fn new_for_addr(address: SocketAddr) -> io::Result { init(); diff --git a/src/sys/windows/udp.rs b/src/sys/windows/udp.rs index 91516ccc2..e07b967c4 100644 --- a/src/sys/windows/udp.rs +++ b/src/sys/windows/udp.rs @@ -4,7 +4,8 @@ use std::net::{self, SocketAddr}; use std::os::windows::io::{AsRawSocket, FromRawSocket}; use std::os::windows::raw::SOCKET as StdSocket; // windows-sys uses usize, stdlib uses u32/u64. -use crate::sys::windows::net::{init, new_ip_socket, socket_addr}; +use crate::sys::windows::net::{new_ip_socket, socket_addr}; +use crate::sys::windows::std::net::init; use windows_sys::Win32::Networking::WinSock::{ bind as win_bind, closesocket, getsockopt, IPPROTO_IPV6, IPV6_V6ONLY, SOCKET_ERROR, SOCK_DGRAM, }; diff --git a/src/sys/windows/uds/listener.rs b/src/sys/windows/uds/listener.rs index 4bd5a9f8e..4cc393ffc 100644 --- a/src/sys/windows/uds/listener.rs +++ b/src/sys/windows/uds/listener.rs @@ -2,8 +2,8 @@ use std::io; use std::os::windows::io::AsRawSocket; use std::path::Path; -use super::{stdnet as net}; use crate::net::{SocketAddr, UnixStream}; +use crate::sys::windows::std::net; pub(crate) fn bind(path: &Path) -> io::Result { let listener = net::UnixListener::bind(path)?; @@ -13,7 +13,9 @@ pub(crate) fn bind(path: &Path) -> io::Result { pub(crate) fn accept(listener: &net::UnixListener) -> io::Result<(UnixStream, SocketAddr)> { listener.set_nonblocking(true)?; - listener.accept().map(|(stream, addr)| (UnixStream::from_std(stream), addr)) + listener + .accept() + .map(|(stream, addr)| (UnixStream::from_std(stream), addr)) } pub(crate) fn local_addr(listener: &net::UnixListener) -> io::Result { diff --git a/src/sys/windows/uds/mod.rs b/src/sys/windows/uds/mod.rs index de48e24b1..d0b56eb9d 100644 --- a/src/sys/windows/uds/mod.rs +++ b/src/sys/windows/uds/mod.rs @@ -1,5 +1,4 @@ -pub mod stdnet; -pub use self::stdnet::SocketAddr; +pub use super::stdnet::SocketAddr; cfg_os_poll! { use std::convert::TryInto; diff --git a/src/sys/windows/uds/stdnet/mod.rs b/src/sys/windows/uds/stdnet/mod.rs deleted file mode 100644 index 0b2edbbf3..000000000 --- a/src/sys/windows/uds/stdnet/mod.rs +++ /dev/null @@ -1,10 +0,0 @@ -//! Windows specific networking functionality. Mirrors std::os::unix::net. - -mod addr; -mod socket; -mod stream; -mod listener; - -pub use self::addr::*; -pub use self::listener::*; -pub use self::stream::*; diff --git a/src/sys/windows/uds/stream.rs b/src/sys/windows/uds/stream.rs index 4df1aa396..c59a1f95c 100644 --- a/src/sys/windows/uds/stream.rs +++ b/src/sys/windows/uds/stream.rs @@ -1,7 +1,7 @@ +use crate::sys::windows::std::net; use std::io; -use std::os::windows::io::{AsRawSocket}; +use std::os::windows::io::AsRawSocket; use std::path::Path; -use super::stdnet::{self as net}; pub(crate) fn connect(path: &Path) -> io::Result { let socket = net::UnixStream::connect(path)?; diff --git a/tests/unix_listener.rs b/tests/unix_listener.rs index b2de8c9eb..3314d424a 100644 --- a/tests/unix_listener.rs +++ b/tests/unix_listener.rs @@ -1,7 +1,7 @@ -#![cfg(all(feature = "os-poll", feature = "net"))] +#![cfg(all(feature = "os-poll", feature = "net", any(unix, feature = "os-ext")))] #[cfg(windows)] -use mio::net::stdnet as net; +use mio::net::windows::std::net; use mio::net::UnixListener; use mio::{Interest, Token}; use std::io::{self, Read}; diff --git a/tests/unix_stream.rs b/tests/unix_stream.rs index 205a91a98..b62bcc147 100644 --- a/tests/unix_stream.rs +++ b/tests/unix_stream.rs @@ -1,7 +1,7 @@ -#![cfg(all(feature = "os-poll", feature = "net"))] +#![cfg(all(feature = "os-poll", feature = "net", any(unix, feature = "os-ext")))] #[cfg(windows)] -use mio::net::stdnet as net; +use mio::net::windows::std::net; use mio::net::UnixStream; use mio::{Interest, Token}; use std::io::{self, IoSlice, IoSliceMut, Read, Write}; From bb914db2d129edba9c5b005ac57bf53fc3d963f0 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Sun, 21 Aug 2022 18:27:06 -0700 Subject: [PATCH 18/34] fix doc tests --- src/sys/windows/stdnet/listener.rs | 18 +++++++++--------- src/sys/windows/stdnet/stream.rs | 26 +++++++++++++------------- tests/unix_listener.rs | 2 +- tests/unix_stream.rs | 2 +- 4 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/sys/windows/stdnet/listener.rs b/src/sys/windows/stdnet/listener.rs index f4ad91c5f..611c86d15 100644 --- a/src/sys/windows/stdnet/listener.rs +++ b/src/sys/windows/stdnet/listener.rs @@ -13,7 +13,7 @@ use super::{socket::Socket, socket_addr, SocketAddr, UnixStream}; /// /// ```no_run /// use std::thread; -/// use mio::net::stdnet::{UnixStream, UnixListener}; +/// use mio::windows::std::net::{UnixStream, UnixListener}; /// /// fn handle_client(stream: UnixStream) { /// // ... @@ -56,7 +56,7 @@ impl UnixListener { /// # Examples /// /// ```no_run - /// use mio::net::stdnet::UnixListener; + /// use mio::windows::std::net::UnixListener; /// /// let listener = match UnixListener::bind("/path/to/the/socket") { /// Ok(sock) => sock, @@ -99,7 +99,7 @@ impl UnixListener { /// # Examples /// /// ```no_run - /// use mio::net::stdnet::UnixListener; + /// use mio::windows::std::net::UnixListener; /// /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); /// @@ -140,7 +140,7 @@ impl UnixListener { /// # Examples /// /// ```no_run - /// use mio::net::stdnet::UnixListener; + /// use mio::windows::std::net::UnixListener; /// /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); /// @@ -156,7 +156,7 @@ impl UnixListener { /// # Examples /// /// ```no_run - /// use mio::net::stdnet::UnixListener; + /// use mio::windows::std::net::UnixListener; /// /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); /// @@ -178,7 +178,7 @@ impl UnixListener { /// # Examples /// /// ```no_run - /// use mio::net::stdnet::UnixListener; + /// use mio::windows::std::net::UnixListener; /// /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); /// @@ -193,7 +193,7 @@ impl UnixListener { /// # Examples /// /// ```no_run - /// use mio::net::stdnet::UnixListener; + /// use mio::windows::std::net::UnixListener; /// /// let listener = UnixListener::bind("/tmp/sock").unwrap(); /// @@ -216,7 +216,7 @@ impl UnixListener { /// /// ```no_run /// use std::thread; - /// use mio::net::stdnet::{UnixStream, UnixListener}; + /// use mio::windows::std::net::{UnixStream, UnixListener}; /// /// fn handle_client(stream: UnixStream) { /// // ... @@ -281,7 +281,7 @@ impl<'a> IntoIterator for &'a UnixListener { /// /// ```no_run /// use std::thread; -/// use mio::net::stdnet::{UnixStream, UnixListener}; +/// use mio::windows::std::net::{UnixStream, UnixListener}; /// /// fn handle_client(stream: UnixStream) { /// // ... diff --git a/src/sys/windows/stdnet/stream.rs b/src/sys/windows/stdnet/stream.rs index 49e4ff525..9946e27fe 100644 --- a/src/sys/windows/stdnet/stream.rs +++ b/src/sys/windows/stdnet/stream.rs @@ -18,7 +18,7 @@ use rand::{distributions::Alphanumeric, Rng}; /// # Examples /// /// ```no_run -/// use mio::net::stdnet::UnixStream; +/// use mio::windows::std::net::UnixStream; /// use std::io::prelude::*; /// /// let mut stream = UnixStream::connect("/path/to/my/socket").unwrap(); @@ -49,7 +49,7 @@ impl UnixStream { /// # Examples /// /// ```no_run - /// use mio::net::stdnet::UnixStream; + /// use mio::windows::std::net::UnixStream; /// /// let socket = match UnixStream::connect("/tmp/sock") { /// Ok(sock) => sock, @@ -90,7 +90,7 @@ impl UnixStream { /// # Examples /// /// ```no_run - /// use mio::net::stdnet::UnixStream; + /// use mio::windows::std::net::UnixStream; /// /// let socket = UnixStream::connect("/tmp/sock").unwrap(); /// let sock_copy = socket.try_clone().expect("Couldn't clone socket"); @@ -104,7 +104,7 @@ impl UnixStream { /// # Examples /// /// ```no_run - /// use mio::net::stdnet::UnixStream; + /// use mio::windows::std::net::UnixStream; /// /// let socket = UnixStream::connect("/tmp/sock").unwrap(); /// let addr = socket.local_addr().expect("Couldn't get local address"); @@ -125,7 +125,7 @@ impl UnixStream { /// # Examples /// /// ```no_run - /// use mio::net::stdnet::UnixStream; + /// use mio::windows::std::net::UnixStream; /// /// let socket = UnixStream::connect("/tmp/sock").unwrap(); /// let addr = socket.peer_addr().expect("Couldn't get peer address"); @@ -146,7 +146,7 @@ impl UnixStream { /// # Examples /// /// ```no_run - /// use mio::net::stdnet::UnixStream; + /// use mio::windows::std::net::UnixStream; /// /// let socket = UnixStream::connect("/tmp/sock").unwrap(); /// socket.set_nonblocking(true).expect("Couldn't set nonblocking"); @@ -160,7 +160,7 @@ impl UnixStream { /// # Examples /// /// ```no_run - /// use mio::net::stdnet::UnixStream; + /// use mio::windows::std::net::UnixStream; /// /// let socket = UnixStream::connect("/tmp/sock").unwrap(); /// if let Ok(Some(err)) = socket.take_error() { @@ -180,7 +180,7 @@ impl UnixStream { /// # Examples /// /// ```no_run - /// use mio::net::stdnet::UnixStream; + /// use mio::windows::std::net::UnixStream; /// use std::net::Shutdown; /// /// let socket = UnixStream::connect("/tmp/sock").unwrap(); @@ -197,7 +197,7 @@ impl UnixStream { /// # Examples /// /// ```no_run - /// use mio::net::stdnet::UnixStream; + /// use mio::windows::std::net::UnixStream; /// /// let (sock1, sock2) = match UnixStream::pair() { /// Ok((sock1, sock2)) => (sock1, sock2), @@ -241,7 +241,7 @@ impl UnixStream { /// # Examples /// /// ```no_run - /// use mio::net::stdnet::UnixStream; + /// use mio::windows::std::net::UnixStream; /// /// let socket = UnixStream::connect("/tmp/sock").unwrap(); /// socket.set_read_timeout(None).expect("Couldn't set read timeout"); @@ -258,7 +258,7 @@ impl UnixStream { /// # Examples /// /// ```no_run - /// use mio::net::stdnet::UnixStream; + /// use mio::windows::std::net::UnixStream; /// /// let socket = UnixStream::connect("/tmp/sock").unwrap(); /// socket.set_write_timeout(None).expect("Couldn't set write timeout"); @@ -272,7 +272,7 @@ impl UnixStream { /// # Examples /// /// ```no_run - /// use mio::net::stdnet::UnixStream; + /// use mio::windows::std::net::UnixStream; /// /// let socket = UnixStream::connect("/tmp/sock").unwrap(); /// socket.set_read_timeout(None).expect("Couldn't set read timeout"); @@ -287,7 +287,7 @@ impl UnixStream { /// # Examples /// /// ```no_run - /// use mio::net::stdnet::UnixStream; + /// use mio::windows::std::net::UnixStream; /// /// let socket = UnixStream::connect("/tmp/sock").unwrap(); /// socket.set_write_timeout(None).expect("Couldn't set write timeout"); diff --git a/tests/unix_listener.rs b/tests/unix_listener.rs index 3314d424a..734b471f1 100644 --- a/tests/unix_listener.rs +++ b/tests/unix_listener.rs @@ -1,7 +1,7 @@ #![cfg(all(feature = "os-poll", feature = "net", any(unix, feature = "os-ext")))] #[cfg(windows)] -use mio::net::windows::std::net; +use mio::windows::std::net; use mio::net::UnixListener; use mio::{Interest, Token}; use std::io::{self, Read}; diff --git a/tests/unix_stream.rs b/tests/unix_stream.rs index b62bcc147..b4e4a0243 100644 --- a/tests/unix_stream.rs +++ b/tests/unix_stream.rs @@ -1,7 +1,7 @@ #![cfg(all(feature = "os-poll", feature = "net", any(unix, feature = "os-ext")))] #[cfg(windows)] -use mio::net::windows::std::net; +use mio::windows::std::net; use mio::net::UnixStream; use mio::{Interest, Token}; use std::io::{self, IoSlice, IoSliceMut, Read, Write}; From 648855dda3b652a3f6278db64d88bd4d052b1c76 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Sun, 21 Aug 2022 18:53:15 -0700 Subject: [PATCH 19/34] use bcrypt instead of rand --- Cargo.toml | 4 +-- src/sys/windows/stdnet/stream.rs | 42 +++++++++++++++++++++++++++----- 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c2a776fe8..ba09d9d19 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,15 +48,13 @@ log = "0.4.8" [target.'cfg(unix)'.dependencies] libc = "0.2.121" -[target.'cfg(windows)'.dependencies] -rand = "0.8" - [target.'cfg(windows)'.dependencies.windows-sys] version = "0.36" features = [ "Win32_Storage_FileSystem", # Enables NtCreateFile "Win32_Foundation", # Basic types eg HANDLE "Win32_Networking_WinSock", # winsock2 types/functions + "Win32_Security_Cryptography", # Random number generation "Win32_System_IO", # IO types like OVERLAPPED etc "Win32_System_Threading", # Process utilities "Win32_System_WindowsProgramming", # General future used for various types/funcs diff --git a/src/sys/windows/stdnet/stream.rs b/src/sys/windows/stdnet/stream.rs index 9946e27fe..23352fa53 100644 --- a/src/sys/windows/stdnet/stream.rs +++ b/src/sys/windows/stdnet/stream.rs @@ -9,9 +9,13 @@ use std::{fmt, mem}; use windows_sys::Win32::Networking::WinSock::{ SOCKET_ERROR, SO_RCVTIMEO, SO_SNDTIMEO, WSAEINPROGRESS, }; +use windows_sys::Win32::Security::Cryptography::{ + BCryptGenRandom, + BCRYPT_USE_SYSTEM_PREFERRED_RNG +}; +use windows_sys::Win32::Foundation::STATUS_SUCCESS; use super::{socket::Socket, socket_addr, SocketAddr, UnixListener}; -use rand::{distributions::Alphanumeric, Rng}; /// A Unix stream socket /// @@ -368,16 +372,42 @@ impl IntoRawSocket for UnixStream { struct TempPath(PathBuf); +fn sample_ascii_string(len: usize) -> io::Result { + const RANGE: u32 = 26 + 26 + 10; + const GEN_ASCII_STR_CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\ + abcdefghijklmnopqrstuvwxyz\ + 0123456789"; + let mut result = String::with_capacity(len); + for _ in 0..len { + // We pick from 62 characters. This is so close to a power of 2, 64, + // that we can efficiently use a simple bitshift and rejection sampling. + let mut var = RANGE; + while var >= RANGE { + let mut buf = [0; 4]; + syscall!( + BCryptGenRandom( + 0, + &mut buf as *mut _, + buf.len() as u32, + BCRYPT_USE_SYSTEM_PREFERRED_RNG, + ), + PartialEq::ne, + STATUS_SUCCESS + )?; + var = u32::from_le_bytes(buf) >> (32 - 6); + } + let c = char::from(GEN_ASCII_STR_CHARSET[var as usize]); + result.push(c); + } + Ok(result) +} + impl TempPath { fn new(random_len: usize) -> io::Result { let dir = std::env::temp_dir(); // Retry a few times in case of collisions for _ in 0..10 { - let rand_str: String = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(random_len) - .map(char::from) - .collect(); + let rand_str = sample_ascii_string(random_len)?; let filename = format!(".tmp-{rand_str}.socket"); let path = dir.join(filename); if !path.exists() { From 3dd3c0f9695c6a3745a5351fc4d869b3df6069d5 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Mon, 22 Aug 2022 10:02:55 -0700 Subject: [PATCH 20/34] add -_ to random char set to avoid rejection sampling --- src/sys/windows/stdnet/stream.rs | 36 +++++++++++++------------------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/src/sys/windows/stdnet/stream.rs b/src/sys/windows/stdnet/stream.rs index 23352fa53..5937f342d 100644 --- a/src/sys/windows/stdnet/stream.rs +++ b/src/sys/windows/stdnet/stream.rs @@ -373,31 +373,25 @@ impl IntoRawSocket for UnixStream { struct TempPath(PathBuf); fn sample_ascii_string(len: usize) -> io::Result { - const RANGE: u32 = 26 + 26 + 10; const GEN_ASCII_STR_CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\ abcdefghijklmnopqrstuvwxyz\ - 0123456789"; + 0123456789-_"; let mut result = String::with_capacity(len); + let mut buf = [0; 4]; for _ in 0..len { - // We pick from 62 characters. This is so close to a power of 2, 64, - // that we can efficiently use a simple bitshift and rejection sampling. - let mut var = RANGE; - while var >= RANGE { - let mut buf = [0; 4]; - syscall!( - BCryptGenRandom( - 0, - &mut buf as *mut _, - buf.len() as u32, - BCRYPT_USE_SYSTEM_PREFERRED_RNG, - ), - PartialEq::ne, - STATUS_SUCCESS - )?; - var = u32::from_le_bytes(buf) >> (32 - 6); - } - let c = char::from(GEN_ASCII_STR_CHARSET[var as usize]); - result.push(c); + syscall!( + BCryptGenRandom( + 0, + &mut buf as *mut _, + buf.len() as u32, + BCRYPT_USE_SYSTEM_PREFERRED_RNG, + ), + PartialEq::ne, + STATUS_SUCCESS + )?; + // We pick from 64=2^6 characters so we can use a simple bitshift. + let var = u32::from_le_bytes(buf) >> (32 - 6); + result.push(char::from(GEN_ASCII_STR_CHARSET[var as usize])); } Ok(result) } From 2283d391f8adc97c56abff490bcb585753f6a1c1 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Mon, 22 Aug 2022 10:40:50 -0700 Subject: [PATCH 21/34] optimize rng syscall logic --- src/sys/windows/stdnet/stream.rs | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/sys/windows/stdnet/stream.rs b/src/sys/windows/stdnet/stream.rs index 5937f342d..799eb2e8c 100644 --- a/src/sys/windows/stdnet/stream.rs +++ b/src/sys/windows/stdnet/stream.rs @@ -376,23 +376,27 @@ fn sample_ascii_string(len: usize) -> io::Result { const GEN_ASCII_STR_CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\ abcdefghijklmnopqrstuvwxyz\ 0123456789-_"; - let mut result = String::with_capacity(len); - let mut buf = [0; 4]; - for _ in 0..len { + let mut buf: Vec = vec![0; len]; + for chunk in buf.chunks_mut(u32::max_value() as usize) { syscall!( BCryptGenRandom( 0, - &mut buf as *mut _, - buf.len() as u32, + chunk.as_mut_ptr(), + chunk.len() as u32, BCRYPT_USE_SYSTEM_PREFERRED_RNG, ), PartialEq::ne, STATUS_SUCCESS )?; - // We pick from 64=2^6 characters so we can use a simple bitshift. - let var = u32::from_le_bytes(buf) >> (32 - 6); - result.push(char::from(GEN_ASCII_STR_CHARSET[var as usize])); } + let result: String = buf + .into_iter() + .map(|r| { + // We pick from 64=2^6 characters so we can use a simple bitshift. + let idx = r >> (8 - 6); + char::from(GEN_ASCII_STR_CHARSET[idx as usize]) + }) + .collect(); Ok(result) } From 569de7288f6c436c7b5dffe9fd96efa5642583e1 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Mon, 22 Aug 2022 11:52:42 -0700 Subject: [PATCH 22/34] fix lint and fmt --- src/net/tcp/stream.rs | 20 ++++++++++---------- src/net/uds/listener.rs | 3 +-- src/net/uds/stream.rs | 20 ++++++++++---------- src/sys/unix/pipe.rs | 20 ++++++++++---------- src/sys/unix/uds/mod.rs | 4 ++-- src/sys/windows/iocp.rs | 2 +- src/sys/windows/stdnet/stream.rs | 5 ++--- tests/close_on_drop.rs | 2 +- tests/unix_listener.rs | 2 +- tests/unix_pipe.rs | 4 ++-- tests/unix_stream.rs | 2 +- 11 files changed, 41 insertions(+), 43 deletions(-) diff --git a/src/net/tcp/stream.rs b/src/net/tcp/stream.rs index 532e7d9b6..a7a9aa1ba 100644 --- a/src/net/tcp/stream.rs +++ b/src/net/tcp/stream.rs @@ -269,49 +269,49 @@ impl TcpStream { impl Read for TcpStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|inner| (&*inner).read(buf)) + self.inner.do_io(|mut inner| inner.read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|inner| (&*inner).read_vectored(bufs)) + self.inner.do_io(|mut inner| inner.read_vectored(bufs)) } } impl<'a> Read for &'a TcpStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|inner| (&*inner).read(buf)) + self.inner.do_io(|mut inner| inner.read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|inner| (&*inner).read_vectored(bufs)) + self.inner.do_io(|mut inner| inner.read_vectored(bufs)) } } impl Write for TcpStream { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|inner| (&*inner).write(buf)) + self.inner.do_io(|mut inner| inner.write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|inner| (&*inner).write_vectored(bufs)) + self.inner.do_io(|mut inner| inner.write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|inner| (&*inner).flush()) + self.inner.do_io(|mut inner| inner.flush()) } } impl<'a> Write for &'a TcpStream { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|inner| (&*inner).write(buf)) + self.inner.do_io(|mut inner| inner.write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|inner| (&*inner).write_vectored(bufs)) + self.inner.do_io(|mut inner| inner.write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|inner| (&*inner).flush()) + self.inner.do_io(|mut inner| inner.flush()) } } diff --git a/src/net/uds/listener.rs b/src/net/uds/listener.rs index f6b03e405..e26d5fffd 100644 --- a/src/net/uds/listener.rs +++ b/src/net/uds/listener.rs @@ -41,8 +41,7 @@ impl UnixListener { /// The call is responsible for ensuring that the listening socket is in /// non-blocking mode. pub fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { - self.inner - .do_io(|inner| sys::uds::listener::accept(&*inner)) + self.inner.do_io(sys::uds::listener::accept) } /// Returns the local socket address of this listener. diff --git a/src/net/uds/stream.rs b/src/net/uds/stream.rs index 7172c0d74..963997655 100644 --- a/src/net/uds/stream.rs +++ b/src/net/uds/stream.rs @@ -218,49 +218,49 @@ impl UnixStream { impl Read for UnixStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|inner| (&*inner).read(buf)) + self.inner.do_io(|mut inner| inner.read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|inner| (&*inner).read_vectored(bufs)) + self.inner.do_io(|mut inner| inner.read_vectored(bufs)) } } impl<'a> Read for &'a UnixStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|inner| (&*inner).read(buf)) + self.inner.do_io(|mut inner| inner.read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|inner| (&*inner).read_vectored(bufs)) + self.inner.do_io(|mut inner| inner.read_vectored(bufs)) } } impl Write for UnixStream { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|inner| (&*inner).write(buf)) + self.inner.do_io(|mut inner| inner.write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|inner| (&*inner).write_vectored(bufs)) + self.inner.do_io(|mut inner| inner.write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|inner| (&*inner).flush()) + self.inner.do_io(|mut inner| inner.flush()) } } impl<'a> Write for &'a UnixStream { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|inner| (&*inner).write(buf)) + self.inner.do_io(|mut inner| inner.write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|inner| (&*inner).write_vectored(bufs)) + self.inner.do_io(|mut inner| inner.write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|inner| (&*inner).flush()) + self.inner.do_io(|mut inner| inner.flush()) } } diff --git a/src/sys/unix/pipe.rs b/src/sys/unix/pipe.rs index b2865cda7..7a95b9697 100644 --- a/src/sys/unix/pipe.rs +++ b/src/sys/unix/pipe.rs @@ -313,29 +313,29 @@ impl event::Source for Sender { impl Write for Sender { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|sender| (&*sender).write(buf)) + self.inner.do_io(|mut sender| sender.write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|sender| (&*sender).write_vectored(bufs)) + self.inner.do_io(|mut sender| sender.write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|sender| (&*sender).flush()) + self.inner.do_io(|mut sender| sender.flush()) } } impl Write for &Sender { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|sender| (&*sender).write(buf)) + self.inner.do_io(|mut sender| sender.write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|sender| (&*sender).write_vectored(bufs)) + self.inner.do_io(|mut sender| sender.write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|sender| (&*sender).flush()) + self.inner.do_io(|mut sender| sender.flush()) } } @@ -478,21 +478,21 @@ impl event::Source for Receiver { impl Read for Receiver { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|sender| (&*sender).read(buf)) + self.inner.do_io(|mut sender| sender.read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|sender| (&*sender).read_vectored(bufs)) + self.inner.do_io(|mut sender| sender.read_vectored(bufs)) } } impl Read for &Receiver { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|sender| (&*sender).read(buf)) + self.inner.do_io(|mut sender| sender.read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|sender| (&*sender).read_vectored(bufs)) + self.inner.do_io(|mut sender| sender.read_vectored(bufs)) } } diff --git a/src/sys/unix/uds/mod.rs b/src/sys/unix/uds/mod.rs index 8e28a9573..526bbdfd0 100644 --- a/src/sys/unix/uds/mod.rs +++ b/src/sys/unix/uds/mod.rs @@ -40,7 +40,7 @@ cfg_os_poll! { sockaddr.sun_family = libc::AF_UNIX as libc::sa_family_t; let bytes = path.as_os_str().as_bytes(); - match (bytes.get(0), bytes.len().cmp(&sockaddr.sun_path.len())) { + match (bytes.first(), bytes.len().cmp(&sockaddr.sun_path.len())) { // Abstract paths don't need a null terminator (Some(&0), Ordering::Greater) => { return Err(io::Error::new( @@ -64,7 +64,7 @@ cfg_os_poll! { let offset = path_offset(&sockaddr); let mut socklen = offset + bytes.len(); - match bytes.get(0) { + match bytes.first() { // The struct has already been zeroes so the null byte for pathname // addresses is already there. Some(&0) | None => {} diff --git a/src/sys/windows/iocp.rs b/src/sys/windows/iocp.rs index d75f3826e..f7651daa8 100644 --- a/src/sys/windows/iocp.rs +++ b/src/sys/windows/iocp.rs @@ -260,6 +260,6 @@ mod tests { } assert_eq!(s[2].bytes_transferred(), 0); assert_eq!(s[2].token(), 0); - assert_eq!(s[2].overlapped(), 0 as *mut _); + assert_eq!(s[2].overlapped(), std::ptr::null_mut()); } } diff --git a/src/sys/windows/stdnet/stream.rs b/src/sys/windows/stdnet/stream.rs index 799eb2e8c..9958d6711 100644 --- a/src/sys/windows/stdnet/stream.rs +++ b/src/sys/windows/stdnet/stream.rs @@ -6,14 +6,13 @@ use std::path::{Path, PathBuf}; use std::time::Duration; use std::{fmt, mem}; +use windows_sys::Win32::Foundation::STATUS_SUCCESS; use windows_sys::Win32::Networking::WinSock::{ SOCKET_ERROR, SO_RCVTIMEO, SO_SNDTIMEO, WSAEINPROGRESS, }; use windows_sys::Win32::Security::Cryptography::{ - BCryptGenRandom, - BCRYPT_USE_SYSTEM_PREFERRED_RNG + BCryptGenRandom, BCRYPT_USE_SYSTEM_PREFERRED_RNG, }; -use windows_sys::Win32::Foundation::STATUS_SUCCESS; use super::{socket::Socket, socket_addr, SocketAddr, UnixListener}; diff --git a/tests/close_on_drop.rs b/tests/close_on_drop.rs index 8d9eefcca..a2e88d9de 100644 --- a/tests/close_on_drop.rs +++ b/tests/close_on_drop.rs @@ -58,7 +58,7 @@ impl TestHandler { AfterRead => {} } - let mut buf = Vec::with_capacity(1024); + let mut buf = vec![0; 1024]; match self.cli.read(&mut buf) { Ok(0) => self.shutdown = true, diff --git a/tests/unix_listener.rs b/tests/unix_listener.rs index 734b471f1..2874845c7 100644 --- a/tests/unix_listener.rs +++ b/tests/unix_listener.rs @@ -1,8 +1,8 @@ #![cfg(all(feature = "os-poll", feature = "net", any(unix, feature = "os-ext")))] +use mio::net::UnixListener; #[cfg(windows)] use mio::windows::std::net; -use mio::net::UnixListener; use mio::{Interest, Token}; use std::io::{self, Read}; #[cfg(unix)] diff --git a/tests/unix_pipe.rs b/tests/unix_pipe.rs index a83e3833b..f8e6464c9 100644 --- a/tests/unix_pipe.rs +++ b/tests/unix_pipe.rs @@ -49,7 +49,7 @@ fn smoke() { ); let n = receiver.read(&mut buf).unwrap(); assert_eq!(n, DATA1.len()); - assert_eq!(&buf[..n], &*DATA1); + assert_eq!(&buf[..n], DATA1); } #[test] @@ -162,7 +162,7 @@ fn from_child_process_io() { let mut buf = [0; 20]; let n = receiver.read(&mut buf).unwrap(); assert_eq!(n, DATA1.len()); - assert_eq!(&buf[..n], &*DATA1); + assert_eq!(&buf[..n], DATA1); drop(sender); diff --git a/tests/unix_stream.rs b/tests/unix_stream.rs index b4e4a0243..3930d608a 100644 --- a/tests/unix_stream.rs +++ b/tests/unix_stream.rs @@ -1,8 +1,8 @@ #![cfg(all(feature = "os-poll", feature = "net", any(unix, feature = "os-ext")))] +use mio::net::UnixStream; #[cfg(windows)] use mio::windows::std::net; -use mio::net::UnixStream; use mio::{Interest, Token}; use std::io::{self, IoSlice, IoSliceMut, Read, Write}; use std::net::Shutdown; From bdc6933ae58be016fd5249aa3c75606a04816377 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Mon, 22 Aug 2022 13:20:16 -0700 Subject: [PATCH 23/34] remove unused functions --- Cargo.toml | 1 - src/lib.rs | 2 - src/net/uds/listener.rs | 8 + src/net/uds/stream.rs | 8 + src/sys/windows/mod.rs | 12 +- src/sys/windows/stdnet/addr.rs | 25 +- src/sys/windows/stdnet/listener.rs | 283 ++-------------------- src/sys/windows/stdnet/mod.rs | 12 +- src/sys/windows/stdnet/socket.rs | 242 ++++++------------- src/sys/windows/stdnet/stream.rs | 361 ++++++----------------------- src/sys/windows/uds/mod.rs | 2 - tests/unix_listener.rs | 3 +- tests/unix_stream.rs | 37 ++- 13 files changed, 245 insertions(+), 751 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ba09d9d19..9f62ac3b4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,7 +56,6 @@ features = [ "Win32_Networking_WinSock", # winsock2 types/functions "Win32_Security_Cryptography", # Random number generation "Win32_System_IO", # IO types like OVERLAPPED etc - "Win32_System_Threading", # Process utilities "Win32_System_WindowsProgramming", # General future used for various types/funcs ] diff --git a/src/lib.rs b/src/lib.rs index e1c8b47e3..56a7160be 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -91,8 +91,6 @@ pub mod windows { //! Windows only extensions. pub use crate::sys::named_pipe::NamedPipe; - - pub use crate::sys::windows::std; } pub mod features { diff --git a/src/net/uds/listener.rs b/src/net/uds/listener.rs index e26d5fffd..928a85dfc 100644 --- a/src/net/uds/listener.rs +++ b/src/net/uds/listener.rs @@ -30,12 +30,20 @@ impl UnixListener { /// standard library in the Mio equivalent. The conversion assumes nothing /// about the underlying listener; it is left up to the user to set it in /// non-blocking mode. + #[cfg(unix)] pub fn from_std(listener: net::UnixListener) -> UnixListener { UnixListener { inner: IoSource::new(listener), } } + #[cfg(windows)] + pub(crate) fn from_std(listener: net::UnixListener) -> UnixListener { + UnixListener { + inner: IoSource::new(listener), + } + } + /// Accepts a new incoming connection to this listener. /// /// The call is responsible for ensuring that the listening socket is in diff --git a/src/net/uds/stream.rs b/src/net/uds/stream.rs index 963997655..375cb4d2d 100644 --- a/src/net/uds/stream.rs +++ b/src/net/uds/stream.rs @@ -40,11 +40,19 @@ impl UnixStream { /// The Unix stream here will not have `connect` called on it, so it /// should already be connected via some other means (be it manually, or /// the standard library). + #[cfg(unix)] pub fn from_std(stream: net::UnixStream) -> UnixStream { UnixStream { inner: IoSource::new(stream), } } + + #[cfg(windows)] + pub(crate) fn from_std(stream: net::UnixStream) -> UnixStream { + UnixStream { + inner: IoSource::new(stream), + } + } /// Creates an unnamed pair of connected sockets. /// diff --git a/src/sys/windows/mod.rs b/src/sys/windows/mod.rs index ffb967f09..09bb1c928 100644 --- a/src/sys/windows/mod.rs +++ b/src/sys/windows/mod.rs @@ -16,9 +16,9 @@ macro_rules! syscall { /// Helper macro to execute a WinSock system call that returns an `io::Result`. #[allow(unused_macros)] macro_rules! wsa_syscall { - ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{ + ($fn: ident ( $($arg: expr),* $(,)* ), $err_value: expr) => {{ let res = unsafe { windows_sys::Win32::Networking::WinSock::$fn($($arg, )*) }; - if $err_test(&res, &$err_value) { + if PartialEq::eq(&res, &$err_value) { Err(io::Error::from_raw_os_error(unsafe { windows_sys::Win32::Networking::WinSock::WSAGetLastError() })) @@ -28,14 +28,14 @@ macro_rules! wsa_syscall { }}; } -cfg_any_os_ext! { +cfg_net! { mod stdnet; - pub mod std { + pub(crate) mod std { //! Windows only std lib modules that cannot be upstreamed. - pub mod net { + pub(crate) mod net { //! Internal Windows std net implementation. - pub use crate::sys::windows::stdnet::*; + pub(crate) use crate::sys::windows::stdnet::*; } } } diff --git a/src/sys/windows/stdnet/addr.rs b/src/sys/windows/stdnet/addr.rs index d737e8430..8f52562f3 100644 --- a/src/sys/windows/stdnet/addr.rs +++ b/src/sys/windows/stdnet/addr.rs @@ -5,15 +5,18 @@ use std::mem; use std::os::raw::c_int; use std::path::Path; -use windows_sys::Win32::Networking::WinSock::{sockaddr_un, AF_UNIX, SOCKADDR}; +use windows_sys::Win32::Networking::WinSock::{sockaddr_un, SOCKADDR}; -pub(super) fn path_offset(addr: &sockaddr_un) -> usize { +fn path_offset(addr: &sockaddr_un) -> usize { // Work with an actual instance of the type since using a null pointer is UB let base = addr as *const _ as usize; let path = &addr.sun_path as *const _ as usize; path - base } +cfg_os_poll! { +use windows_sys::Win32::Networking::WinSock::AF_UNIX; + pub(super) fn socket_addr(path: &Path) -> io::Result<(sockaddr_un, c_int)> { let sockaddr = mem::MaybeUninit::::zeroed(); @@ -52,8 +55,6 @@ pub(super) fn socket_addr(path: &Path) -> io::Result<(sockaddr_un, c_int)> { for (dst, src) in sockaddr.sun_path.iter_mut().zip(bytes.iter()) { *dst = *src as u8; } - // null byte for pathname addresses is already there because we zeroed the - // struct let offset = path_offset(&sockaddr); let mut socklen = offset + bytes.len(); @@ -67,6 +68,7 @@ pub(super) fn socket_addr(path: &Path) -> io::Result<(sockaddr_un, c_int)> { Ok((sockaddr, socklen as c_int)) } +} enum AddressKind<'a> { Unnamed, @@ -96,9 +98,9 @@ pub struct SocketAddr { } impl SocketAddr { - pub(crate) fn new(f: F) -> io::Result + pub(crate) fn init(f: F) -> io::Result<(T, SocketAddr)> where - F: FnOnce(*mut SOCKADDR, *mut c_int) -> io::Result, + F: FnOnce(*mut SOCKADDR, *mut c_int) -> io::Result, { let mut sockaddr = { let sockaddr = mem::MaybeUninit::::zeroed(); @@ -106,8 +108,15 @@ impl SocketAddr { }; let mut len = mem::size_of::() as c_int; - f(&mut sockaddr as *mut _ as *mut _, &mut len)?; - Ok(SocketAddr::from_parts(sockaddr, len)) + let result = f(&mut sockaddr as *mut _ as *mut _, &mut len)?; + Ok((result, SocketAddr::from_parts(sockaddr, len))) + } + + pub(crate) fn new(f: F) -> io::Result + where + F: FnOnce(*mut SOCKADDR, *mut c_int) -> io::Result, + { + SocketAddr::init(f).map(|(_, addr)| addr) } pub(crate) fn from_parts(addr: sockaddr_un, mut len: c_int) -> SocketAddr { diff --git a/src/sys/windows/stdnet/listener.rs b/src/sys/windows/stdnet/listener.rs index 611c86d15..2d8965631 100644 --- a/src/sys/windows/stdnet/listener.rs +++ b/src/sys/windows/stdnet/listener.rs @@ -1,43 +1,11 @@ -use std::os::raw::c_int; use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; -use std::path::Path; use std::{fmt, io, mem}; -use windows_sys::Win32::Networking::WinSock::{sockaddr_un, AF_UNIX, SOCKET_ERROR}; +use windows_sys::Win32::Networking::WinSock::SOCKET_ERROR; -use super::{socket::Socket, socket_addr, SocketAddr, UnixStream}; +use super::{socket::Socket, SocketAddr}; -/// A Unix domain socket server -/// -/// # Examples -/// -/// ```no_run -/// use std::thread; -/// use mio::windows::std::net::{UnixStream, UnixListener}; -/// -/// fn handle_client(stream: UnixStream) { -/// // ... -/// # drop(stream); // Silence unused variable warning. -/// } -/// -/// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); -/// -/// // accept connections and process them, spawning a new thread for each one -/// for stream in listener.incoming() { -/// match stream { -/// Ok(stream) => { -/// /* connection succeeded */ -/// thread::spawn(|| handle_client(stream)); -/// } -/// Err(err) => { -/// /* connection failed */ -/// eprintln!("connection failed: {err}"); -/// break; -/// } -/// } -/// } -/// ``` -pub struct UnixListener(Socket); +pub(crate) struct UnixListener(Socket); impl fmt::Debug for UnixListener { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -51,195 +19,18 @@ impl fmt::Debug for UnixListener { } impl UnixListener { - /// Creates a new `UnixListener` bound to the specified socket. - /// - /// # Examples - /// - /// ```no_run - /// use mio::windows::std::net::UnixListener; - /// - /// let listener = match UnixListener::bind("/path/to/the/socket") { - /// Ok(sock) => sock, - /// Err(e) => { - /// println!("Couldn't connect: {:?}", e); - /// return - /// } - /// }; - /// # drop(listener); // Silence unused variable warning. - /// ``` - pub fn bind>(path: P) -> io::Result { - let inner = Socket::new()?; - let (addr, len) = socket_addr(path.as_ref())?; - - wsa_syscall!( - bind( - inner.as_raw_socket() as _, - &addr as *const _ as *const _, - len as _, - ), - PartialEq::eq, - SOCKET_ERROR - )?; - wsa_syscall!( - listen(inner.as_raw_socket() as _, 128), - PartialEq::eq, - SOCKET_ERROR - )?; - Ok(UnixListener(inner)) - } - - /// Accepts a new incoming connection to this listener. - /// - /// This function will block the calling thread until a new Unix connection - /// is established. When established, the corresponding [`UnixStream`] and - /// the remote peer's address will be returned. - /// - /// [`UnixStream`]: struct.UnixStream.html - /// - /// # Examples - /// - /// ```no_run - /// use mio::windows::std::net::UnixListener; - /// - /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); - /// - /// match listener.accept() { - /// Ok((_socket, addr)) => println!("Got a client: {:?}", addr), - /// Err(e) => println!("accept function failed: {:?}", e), - /// } - /// ``` - pub fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { - let sockaddr = mem::MaybeUninit::::zeroed(); - - // This is safe to assume because a `sockaddr_un` filled with `0` - // bytes is properly initialized. - // - // `0` is a valid value for `sockaddr_un::sun_family`; it is - // `WinSock::AF_UNSPEC`. - // - // `[0; 108]` is a valid value for `sockaddr_un::sun_path`; it begins an - // abstract path. - let mut sockaddr = unsafe { sockaddr.assume_init() }; - - sockaddr.sun_family = AF_UNIX; - let mut socklen = mem::size_of_val(&sockaddr) as c_int; - - let sock = self - .0 - .accept(&mut sockaddr as *mut _ as *mut _, &mut socklen)?; - let addr = SocketAddr::from_parts(sockaddr, socklen); - Ok((UnixStream(sock), addr)) - } - - /// Creates a new independently owned handle to the underlying socket. - /// - /// The returned `UnixListener` is a reference to the same socket that this - /// object references. Both handles can be used to accept incoming - /// connections and options set on one listener will affect the other. - /// - /// # Examples - /// - /// ```no_run - /// use mio::windows::std::net::UnixListener; - /// - /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); - /// - /// let listener_copy = listener.try_clone().expect("Couldn't clone socket"); - /// # drop(listener_copy); // Silence unused variable warning. - /// ``` - pub fn try_clone(&self) -> io::Result { - self.0.duplicate().map(UnixListener) - } - - /// Returns the local socket address of this listener. - /// - /// # Examples - /// - /// ```no_run - /// use mio::windows::std::net::UnixListener; - /// - /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); - /// - /// let addr = listener.local_addr().expect("Couldn't get local address"); - /// # drop(addr); // Silence unused variable warning. - /// ``` pub fn local_addr(&self) -> io::Result { SocketAddr::new(|addr, len| { wsa_syscall!( getsockname(self.0.as_raw_socket() as _, addr, len), - PartialEq::eq, SOCKET_ERROR ) }) } - /// Moves the socket into or out of nonblocking mode. - /// - /// # Examples - /// - /// ```no_run - /// use mio::windows::std::net::UnixListener; - /// - /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); - /// - /// listener.set_nonblocking(true).expect("Couldn't set nonblocking"); - /// ``` - pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { - self.0.set_nonblocking(nonblocking) - } - - /// Returns the value of the `SO_ERROR` option. - /// - /// # Examples - /// - /// ```no_run - /// use mio::windows::std::net::UnixListener; - /// - /// let listener = UnixListener::bind("/tmp/sock").unwrap(); - /// - /// if let Ok(Some(err)) = listener.take_error() { - /// println!("Got error: {:?}", err); - /// } - /// ``` pub fn take_error(&self) -> io::Result> { self.0.take_error() } - - /// Returns an iterator over incoming connections. - /// - /// The iterator will never return `None` and will also not yield the - /// peer's [`SocketAddr`] structure. - /// - /// [`SocketAddr`]: struct.SocketAddr.html - /// - /// # Examples - /// - /// ```no_run - /// use std::thread; - /// use mio::windows::std::net::{UnixStream, UnixListener}; - /// - /// fn handle_client(stream: UnixStream) { - /// // ... - /// # drop(stream); // Silence unused variable warning. - /// } - /// - /// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); - /// - /// for stream in listener.incoming() { - /// match stream { - /// Ok(stream) => { - /// thread::spawn(|| handle_client(stream)); - /// } - /// Err(err) => { - /// eprintln!("connection failed: {err}"); - /// break; - /// } - /// } - /// } - /// ``` - pub fn incoming(&self) -> Incoming<'_> { - Incoming { listener: self } - } } impl AsRawSocket for UnixListener { @@ -262,59 +53,31 @@ impl IntoRawSocket for UnixListener { } } -impl<'a> IntoIterator for &'a UnixListener { - type Item = io::Result; - type IntoIter = Incoming<'a>; +cfg_os_poll! { +use std::path::Path; - fn into_iter(self) -> Incoming<'a> { - self.incoming() - } -} +use super::{socket_addr, UnixStream}; -/// An iterator over incoming connections to a [`UnixListener`]. -/// -/// It will never return `None`. -/// -/// [`UnixListener`]: struct.UnixListener.html -/// -/// # Examples -/// -/// ```no_run -/// use std::thread; -/// use mio::windows::std::net::{UnixStream, UnixListener}; -/// -/// fn handle_client(stream: UnixStream) { -/// // ... -/// # drop(stream); // Silence unused variable warning. -/// } -/// -/// let listener = UnixListener::bind("/path/to/the/socket").unwrap(); -/// -/// for stream in listener.incoming() { -/// match stream { -/// Ok(stream) => { -/// thread::spawn(|| handle_client(stream)); -/// } -/// Err(err) => { -/// eprintln!("connection failed: {err}"); -/// break; -/// } -/// } -/// } -/// ``` -#[derive(Debug)] -pub struct Incoming<'a> { - listener: &'a UnixListener, -} +impl UnixListener { + pub fn bind>(path: P) -> io::Result { + let inner = Socket::new()?; + let (addr, len) = socket_addr(path.as_ref())?; -impl<'a> Iterator for Incoming<'a> { - type Item = io::Result; + wsa_syscall!( + bind(inner.as_raw_socket() as _, &addr as *const _ as *const _, len as _), + SOCKET_ERROR + )?; + wsa_syscall!(listen(inner.as_raw_socket() as _, 128), SOCKET_ERROR)?; + Ok(UnixListener(inner)) + } - fn next(&mut self) -> Option> { - Some(self.listener.accept().map(|s| s.0)) + pub fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { + SocketAddr::init(|addr, len| self.0.accept(addr, len)) + .map(|(sock, addr)| (UnixStream(sock), addr)) } - fn size_hint(&self) -> (usize, Option) { - (usize::max_value(), None) + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.0.set_nonblocking(nonblocking) } } +} diff --git a/src/sys/windows/stdnet/mod.rs b/src/sys/windows/stdnet/mod.rs index 9dbaf719f..62cf5cba1 100644 --- a/src/sys/windows/stdnet/mod.rs +++ b/src/sys/windows/stdnet/mod.rs @@ -1,13 +1,14 @@ -//! Windows specific networking functionality. Mirrors std::os::unix::net. - mod addr; mod listener; mod socket; mod stream; -pub use self::addr::*; -pub use self::listener::*; -pub use self::stream::*; +pub use self::addr::SocketAddr; +pub(crate) use self::listener::UnixListener; +pub(crate) use self::stream::UnixStream; + +cfg_os_poll! { +pub(self) use self::addr::socket_addr; use std::sync::Once; @@ -21,3 +22,4 @@ pub(crate) fn init() { drop(std::net::UdpSocket::bind("127.0.0.1:0")); }); } +} diff --git a/src/sys/windows/stdnet/socket.rs b/src/sys/windows/stdnet/socket.rs index ef97cea0c..deac3c031 100644 --- a/src/sys/windows/stdnet/socket.rs +++ b/src/sys/windows/stdnet/socket.rs @@ -3,77 +3,25 @@ use std::convert::TryInto; use std::io::{self, IoSlice, IoSliceMut}; use std::mem; use std::net::Shutdown; -use std::os::raw::{c_int, c_ulong}; +use std::os::raw::c_int; use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; use std::ptr; -use std::time::Duration; -use super::init; -use windows_sys::Win32::Foundation::{SetHandleInformation, HANDLE, HANDLE_FLAG_INHERIT}; use windows_sys::Win32::Networking::WinSock::{ - self, closesocket, INVALID_SOCKET, SOCKADDR, SOCKET, SOCKET_ERROR, SOL_SOCKET, SO_ERROR, - WSABUF, WSAESHUTDOWN, + self, closesocket, SOCKET, SOCKET_ERROR, WSABUF, }; -use windows_sys::Win32::System::Threading::GetCurrentProcessId; -use windows_sys::Win32::System::WindowsProgramming::INFINITE; /// Maximum size of a buffer passed to system call like `recv` and `send`. const MAX_BUF_LEN: usize = c_int::MAX as usize; #[derive(Debug)] -pub struct Socket(SOCKET); +pub(crate) struct Socket(SOCKET); impl Socket { - pub fn new() -> io::Result { - init(); - let socket = wsa_syscall!( - WSASocketW( - WinSock::AF_UNIX.into(), - WinSock::SOCK_STREAM.into(), - 0, - ptr::null_mut(), - 0, - WinSock::WSA_FLAG_OVERLAPPED | WinSock::WSA_FLAG_NO_HANDLE_INHERIT, - ), - PartialEq::eq, - INVALID_SOCKET - )?; - Ok(Socket(socket)) - } - - pub fn accept(&self, storage: *mut SOCKADDR, len: *mut c_int) -> io::Result { - let socket = wsa_syscall!(accept(self.0, storage, len), PartialEq::eq, INVALID_SOCKET)?; - let socket = Socket(socket); - socket.set_no_inherit()?; - Ok(socket) - } - - pub fn duplicate(&self) -> io::Result { - let mut info: WinSock::WSAPROTOCOL_INFOW = unsafe { mem::zeroed() }; - wsa_syscall!( - WSADuplicateSocketW(self.0, GetCurrentProcessId(), &mut info,), - PartialEq::eq, - SOCKET_ERROR - )?; - let socket = wsa_syscall!( - WSASocketW( - info.iAddressFamily, - info.iSocketType, - info.iProtocol, - &info, - 0, - WinSock::WSA_FLAG_OVERLAPPED | WinSock::WSA_FLAG_NO_HANDLE_INHERIT, - ), - PartialEq::eq, - INVALID_SOCKET - )?; - Ok(Socket(socket)) - } pub fn recv(&self, buf: &mut [u8]) -> io::Result { let ret = wsa_syscall!( recv(self.0, buf.as_mut_ptr() as *mut _, buf.len() as c_int, 0,), - PartialEq::eq, SOCKET_ERROR )?; Ok(ret as usize) @@ -93,12 +41,11 @@ impl Socket { ptr::null_mut(), None, ), - PartialEq::eq, SOCKET_ERROR ); match res { Ok(_) => Ok(total as usize), - Err(ref err) if err.raw_os_error() == Some(WSAESHUTDOWN as i32) => Ok(0), + Err(ref err) if err.raw_os_error() == Some(WinSock::WSAESHUTDOWN as i32) => Ok(0), Err(err) => Err(err), } } @@ -111,7 +58,6 @@ impl Socket { min(buf.len(), MAX_BUF_LEN) as c_int, 0, ), - PartialEq::eq, SOCKET_ERROR ) .map(|n| n as usize) @@ -140,31 +86,11 @@ impl Socket { std::ptr::null_mut(), None, ), - PartialEq::eq, SOCKET_ERROR ) .map(|_| total as usize) } - fn set_no_inherit(&self) -> io::Result<()> { - syscall!( - SetHandleInformation(self.0 as HANDLE, HANDLE_FLAG_INHERIT, 0), - PartialEq::eq, - 0 - )?; - Ok(()) - } - - pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { - let mut nonblocking: c_ulong = if nonblocking { 1 } else { 0 }; - wsa_syscall!( - ioctlsocket(self.0, WinSock::FIONBIO, &mut nonblocking), - PartialEq::eq, - SOCKET_ERROR - )?; - Ok(()) - } - pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { let how = match how { Shutdown::Write => WinSock::SD_SEND, @@ -173,115 +99,34 @@ impl Socket { }; wsa_syscall!( shutdown(self.0, how.try_into().unwrap()), - PartialEq::eq, SOCKET_ERROR )?; Ok(()) } pub fn take_error(&self) -> io::Result> { - let raw = getsockopt::( - self, - SOL_SOCKET.try_into().unwrap(), - SO_ERROR.try_into().unwrap(), + let mut val: mem::MaybeUninit = mem::MaybeUninit::uninit(); + let mut len = mem::size_of::() as i32; + wsa_syscall!( + getsockopt( + self.0 as _, + WinSock::SOL_SOCKET.try_into().unwrap(), + WinSock::SO_ERROR.try_into().unwrap(), + &mut val as *mut _ as *mut _, + &mut len, + ), + SOCKET_ERROR )?; - if raw == 0 { - Ok(None) - } else { - Ok(Some(io::Error::from_raw_os_error(raw as i32))) - } - } - - pub fn set_timeout(&self, dur: Option, kind: c_int) -> io::Result<()> { - let timeout = match dur { - Some(dur) => { - let timeout = dur2timeout(dur); - if timeout == 0 { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "cannot set a 0 duration timeout", - )); - } - timeout - } - None => 0, - }; - setsockopt(self, SOL_SOCKET.try_into().unwrap(), kind, timeout) - } - - pub fn timeout(&self, kind: c_int) -> io::Result> { - let raw: u32 = getsockopt(self, SOL_SOCKET.try_into().unwrap(), kind)?; - if raw == 0 { + assert_eq!(len as usize, mem::size_of::()); + let val = unsafe { val.assume_init() }; + if val == 0 { Ok(None) } else { - let secs = raw / 1000; - let nsec = (raw % 1000) * 1000000; - Ok(Some(Duration::new(secs as u64, nsec as u32))) + Ok(Some(io::Error::from_raw_os_error(val as i32))) } } } -fn setsockopt(sock: &Socket, opt: c_int, val: c_int, payload: T) -> io::Result<()> { - wsa_syscall!( - setsockopt( - sock.as_raw_socket() as usize, - opt, - val, - &payload as *const T as *const _, - mem::size_of::() as i32, - ), - PartialEq::eq, - SOCKET_ERROR - )?; - Ok(()) -} - -fn getsockopt(sock: &Socket, opt: c_int, val: c_int) -> io::Result { - let mut slot: T = unsafe { mem::zeroed() }; - let mut len = mem::size_of::() as i32; - wsa_syscall!( - getsockopt( - sock.as_raw_socket() as _, - opt, - val, - &mut slot as *mut _ as *mut _, - &mut len, - ), - PartialEq::eq, - SOCKET_ERROR - )?; - assert_eq!(len as usize, mem::size_of::()); - Ok(slot) -} - -fn dur2timeout(dur: Duration) -> u32 { - // Note that a duration is a (u64, u32) (seconds, nanoseconds) pair, and the - // timeouts in windows APIs are typically u32 milliseconds. To translate, we - // have two pieces to take care of: - // - // * Nanosecond precision is rounded up - // * Greater than u32::MAX milliseconds (50 days) is rounded up to INFINITE - // (never time out). - dur.as_secs() - .checked_mul(1000) - .and_then(|ms| ms.checked_add((dur.subsec_nanos() as u64) / 1_000_000)) - .and_then(|ms| { - ms.checked_add(if dur.subsec_nanos() % 1_000_000 > 0 { - 1 - } else { - 0 - }) - }) - .map(|ms| { - if ms > ::max_value() as u64 { - INFINITE - } else { - ms as u32 - } - }) - .unwrap_or(INFINITE) -} - impl Drop for Socket { fn drop(&mut self) { let _ = unsafe { closesocket(self.0) }; @@ -307,3 +152,52 @@ impl IntoRawSocket for Socket { ret } } + +cfg_os_poll! { +use windows_sys::Win32::Networking::WinSock::{INVALID_SOCKET, SOCKADDR}; +use windows_sys::Win32::Foundation::{SetHandleInformation, HANDLE, HANDLE_FLAG_INHERIT}; +use super::init; + +impl Socket { + pub fn new() -> io::Result { + init(); + let socket = wsa_syscall!( + WSASocketW( + WinSock::AF_UNIX.into(), + WinSock::SOCK_STREAM.into(), + 0, + ptr::null_mut(), + 0, + WinSock::WSA_FLAG_OVERLAPPED | WinSock::WSA_FLAG_NO_HANDLE_INHERIT, + ), + INVALID_SOCKET + )?; + Ok(Socket(socket)) + } + + pub fn accept(&self, storage: *mut SOCKADDR, len: *mut c_int) -> io::Result { + let socket = wsa_syscall!(accept(self.0, storage, len), INVALID_SOCKET)?; + let socket = Socket(socket); + socket.set_no_inherit()?; + Ok(socket) + } + + fn set_no_inherit(&self) -> io::Result<()> { + syscall!( + SetHandleInformation(self.0 as HANDLE, HANDLE_FLAG_INHERIT, 0), + PartialEq::eq, + 0 + )?; + Ok(()) + } + + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + let mut nonblocking = if nonblocking { 1 } else { 0 }; + wsa_syscall!( + ioctlsocket(self.0, WinSock::FIONBIO, &mut nonblocking), + SOCKET_ERROR + )?; + Ok(()) + } +} +} diff --git a/src/sys/windows/stdnet/stream.rs b/src/sys/windows/stdnet/stream.rs index 9958d6711..ed8b3314d 100644 --- a/src/sys/windows/stdnet/stream.rs +++ b/src/sys/windows/stdnet/stream.rs @@ -1,36 +1,13 @@ -use std::convert::TryInto; use std::io::{self, IoSlice, IoSliceMut}; use std::net::Shutdown; use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; -use std::path::{Path, PathBuf}; -use std::time::Duration; use std::{fmt, mem}; -use windows_sys::Win32::Foundation::STATUS_SUCCESS; -use windows_sys::Win32::Networking::WinSock::{ - SOCKET_ERROR, SO_RCVTIMEO, SO_SNDTIMEO, WSAEINPROGRESS, -}; -use windows_sys::Win32::Security::Cryptography::{ - BCryptGenRandom, BCRYPT_USE_SYSTEM_PREFERRED_RNG, -}; +use windows_sys::Win32::Networking::WinSock::SOCKET_ERROR; -use super::{socket::Socket, socket_addr, SocketAddr, UnixListener}; +use super::{socket::Socket, SocketAddr}; -/// A Unix stream socket -/// -/// # Examples -/// -/// ```no_run -/// use mio::windows::std::net::UnixStream; -/// use std::io::prelude::*; -/// -/// let mut stream = UnixStream::connect("/path/to/my/socket").unwrap(); -/// stream.write_all(b"hello world").unwrap(); -/// let mut response = String::new(); -/// stream.read_to_string(&mut response).unwrap(); -/// println!("{}", response); -/// ``` -pub struct UnixStream(pub(super) Socket); +pub(crate) struct UnixStream(pub(super) Socket); impl fmt::Debug for UnixStream { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -47,258 +24,31 @@ impl fmt::Debug for UnixStream { } impl UnixStream { - /// Connects to the socket named by `path`. - /// - /// # Examples - /// - /// ```no_run - /// use mio::windows::std::net::UnixStream; - /// - /// let socket = match UnixStream::connect("/tmp/sock") { - /// Ok(sock) => sock, - /// Err(e) => { - /// println!("Couldn't connect: {:?}", e); - /// return - /// } - /// }; - /// # drop(socket); // Silence unused variable warning. - /// ``` - pub fn connect>(path: P) -> io::Result { - let inner = Socket::new()?; - let (addr, len) = socket_addr(path.as_ref())?; - - match wsa_syscall!( - connect( - inner.as_raw_socket() as _, - &addr as *const _ as *const _, - len as i32, - ), - PartialEq::eq, - SOCKET_ERROR - ) { - Ok(_) => {} - Err(ref err) if err.raw_os_error() == Some(WSAEINPROGRESS) => {} - Err(e) => return Err(e), - } - Ok(UnixStream(inner)) - } - - /// Creates a new independently owned handle to the underlying socket. - /// - /// The returned `UnixStream` is a reference to the same stream that this - /// object references. Both handles will read and write the same stream of - /// data, and options set on one stream will be propagated to the other - /// stream. - /// - /// # Examples - /// - /// ```no_run - /// use mio::windows::std::net::UnixStream; - /// - /// let socket = UnixStream::connect("/tmp/sock").unwrap(); - /// let sock_copy = socket.try_clone().expect("Couldn't clone socket"); - /// # drop(sock_copy); // Silence unused variable warning. - /// ``` - pub fn try_clone(&self) -> io::Result { - self.0.duplicate().map(UnixStream) - } - - /// Returns the socket address of the local half of this connection. - /// # Examples - /// - /// ```no_run - /// use mio::windows::std::net::UnixStream; - /// - /// let socket = UnixStream::connect("/tmp/sock").unwrap(); - /// let addr = socket.local_addr().expect("Couldn't get local address"); - /// # drop(addr); // Silence unused variable warning. - /// ``` pub fn local_addr(&self) -> io::Result { SocketAddr::new(|addr, len| { wsa_syscall!( getsockname(self.0.as_raw_socket() as _, addr, len), - PartialEq::eq, SOCKET_ERROR ) }) } - /// Returns the socket address of the remote half of this connection. - /// - /// # Examples - /// - /// ```no_run - /// use mio::windows::std::net::UnixStream; - /// - /// let socket = UnixStream::connect("/tmp/sock").unwrap(); - /// let addr = socket.peer_addr().expect("Couldn't get peer address"); - /// # drop(addr); // Silence unused variable warning. - /// ``` pub fn peer_addr(&self) -> io::Result { SocketAddr::new(|addr, len| { wsa_syscall!( getpeername(self.0.as_raw_socket() as _, addr, len), - PartialEq::eq, SOCKET_ERROR ) }) } - /// Moves the socket into or out of nonblocking mode. - /// - /// # Examples - /// - /// ```no_run - /// use mio::windows::std::net::UnixStream; - /// - /// let socket = UnixStream::connect("/tmp/sock").unwrap(); - /// socket.set_nonblocking(true).expect("Couldn't set nonblocking"); - /// ``` - pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { - self.0.set_nonblocking(nonblocking) - } - - /// Returns the value of the `SO_ERROR` option. - /// - /// # Examples - /// - /// ```no_run - /// use mio::windows::std::net::UnixStream; - /// - /// let socket = UnixStream::connect("/tmp/sock").unwrap(); - /// if let Ok(Some(err)) = socket.take_error() { - /// println!("Got error: {:?}", err); - /// } - /// ``` pub fn take_error(&self) -> io::Result> { self.0.take_error() } - /// Shuts down the read, write, or both halves of this connection. - /// - /// This function will cause all pending and future I/O calls on the - /// specified portions to immediately return with an appropriate value - /// (see the documentation for `Shutdown`). - /// - /// # Examples - /// - /// ```no_run - /// use mio::windows::std::net::UnixStream; - /// use std::net::Shutdown; - /// - /// let socket = UnixStream::connect("/tmp/sock").unwrap(); - /// socket.shutdown(Shutdown::Both).expect("shutdown function failed"); - /// ``` pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { self.0.shutdown(how) } - - /// Creates an unnamed pair of connected sockets. - /// - /// Returns two `UnixStream`s which are connected to each other. - /// - /// # Examples - /// - /// ```no_run - /// use mio::windows::std::net::UnixStream; - /// - /// let (sock1, sock2) = match UnixStream::pair() { - /// Ok((sock1, sock2)) => (sock1, sock2), - /// Err(e) => { - /// println!("Couldn't create a pair of sockets: {e:?}"); - /// return - /// } - /// }; - /// # drop(sock1); // Silence unused variable warning. - /// # drop(sock2); // Silence unused variable warning. - /// ``` - pub fn pair() -> io::Result<(Self, Self)> { - use std::sync::{Arc, RwLock}; - use std::thread::spawn; - - let file_path = TempPath::new(10)?; - let a: Arc>>> = Arc::new(RwLock::new(None)); - let ul = UnixListener::bind(&file_path).unwrap(); - let server = { - let a = a.clone(); - spawn(move || { - let mut store = a.write().unwrap(); - let stream0 = ul.accept().map(|s| s.0); - *store = Some(stream0); - }) - }; - let stream1 = UnixStream::connect(&file_path)?; - server - .join() - .map_err(|_| io::Error::from(io::ErrorKind::ConnectionRefused))?; - let stream0 = (*(a.write().unwrap())).take().unwrap()?; - Ok((stream0, stream1)) - } - - /// Sets the read timeout to the timeout specified. - /// - /// If the value specified is `None`, then `read` calls will block - /// indefinitely. An `Err` is returned if the zero `Duration` is - /// passed to this method. - /// - /// # Examples - /// - /// ```no_run - /// use mio::windows::std::net::UnixStream; - /// - /// let socket = UnixStream::connect("/tmp/sock").unwrap(); - /// socket.set_read_timeout(None).expect("Couldn't set read timeout"); - /// ``` - pub fn set_read_timeout(&self, dur: Option) -> io::Result<()> { - self.0.set_timeout(dur, SO_RCVTIMEO.try_into().unwrap()) - } - - /// Sets the write timeout to the timeout specified. - /// - /// If the value specified is `None`, then `write` calls will block - /// indefinitely. An `Err` is returned if the zero `Duration` is - /// passed to this method. - /// # Examples - /// - /// ```no_run - /// use mio::windows::std::net::UnixStream; - /// - /// let socket = UnixStream::connect("/tmp/sock").unwrap(); - /// socket.set_write_timeout(None).expect("Couldn't set write timeout"); - /// ``` - pub fn set_write_timeout(&self, dur: Option) -> io::Result<()> { - self.0.set_timeout(dur, SO_SNDTIMEO.try_into().unwrap()) - } - - /// Returns the read timeout of this socket. - /// - /// # Examples - /// - /// ```no_run - /// use mio::windows::std::net::UnixStream; - /// - /// let socket = UnixStream::connect("/tmp/sock").unwrap(); - /// socket.set_read_timeout(None).expect("Couldn't set read timeout"); - /// assert_eq!(socket.read_timeout().unwrap(), None); - /// ``` - pub fn read_timeout(&self) -> io::Result> { - self.0.timeout(SO_RCVTIMEO.try_into().unwrap()) - } - - /// Returns the write timeout of this socket. - /// - /// # Examples - /// - /// ```no_run - /// use mio::windows::std::net::UnixStream; - /// - /// let socket = UnixStream::connect("/tmp/sock").unwrap(); - /// socket.set_write_timeout(None).expect("Couldn't set write timeout"); - /// assert_eq!(socket.write_timeout().unwrap(), None); - /// ``` - pub fn write_timeout(&self) -> io::Result> { - self.0.timeout(SO_SNDTIMEO.try_into().unwrap()) - } } impl io::Read for UnixStream { @@ -369,7 +119,64 @@ impl IntoRawSocket for UnixStream { } } -struct TempPath(PathBuf); +cfg_os_poll! { +use std::path::{Path, PathBuf}; +use windows_sys::Win32::Foundation::STATUS_SUCCESS; +use windows_sys::Win32::Networking::WinSock::WSAEINPROGRESS; +use windows_sys::Win32::Security::Cryptography::{ + BCryptGenRandom, BCRYPT_USE_SYSTEM_PREFERRED_RNG, +}; + +use super::{socket_addr, UnixListener}; + +impl UnixStream { + pub fn connect>(path: P) -> io::Result { + let inner = Socket::new()?; + let (addr, len) = socket_addr(path.as_ref())?; + + match wsa_syscall!( + connect( + inner.as_raw_socket() as _, + &addr as *const _ as *const _, + len as i32, + ), + SOCKET_ERROR + ) { + Ok(_) => {} + Err(ref err) if err.raw_os_error() == Some(WSAEINPROGRESS) => {} + Err(e) => return Err(e), + } + Ok(UnixStream(inner)) + } + + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.0.set_nonblocking(nonblocking) + } + + pub fn pair() -> io::Result<(Self, Self)> { + use std::sync::{Arc, RwLock}; + use std::thread::spawn; + + let file_path = temp_path(10)?; + let a: Arc>>> = Arc::new(RwLock::new(None)); + let ul = UnixListener::bind(&file_path).unwrap(); + let server = { + let a = a.clone(); + spawn(move || { + let mut store = a.write().unwrap(); + let stream0 = ul.accept().map(|s| s.0); + *store = Some(stream0); + }) + }; + let stream1 = UnixStream::connect(&file_path)?; + server + .join() + .map_err(|_| io::Error::from(io::ErrorKind::ConnectionRefused))?; + let stream0 = (*(a.write().unwrap())).take().unwrap()?; + let _ = std::fs::remove_file(&file_path); + Ok((stream0, stream1)) + } +} fn sample_ascii_string(len: usize) -> io::Result { const GEN_ASCII_STR_CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\ @@ -399,41 +206,21 @@ fn sample_ascii_string(len: usize) -> io::Result { Ok(result) } -impl TempPath { - fn new(random_len: usize) -> io::Result { - let dir = std::env::temp_dir(); - // Retry a few times in case of collisions - for _ in 0..10 { - let rand_str = sample_ascii_string(random_len)?; - let filename = format!(".tmp-{rand_str}.socket"); - let path = dir.join(filename); - if !path.exists() { - return Ok(Self(path)); - } +fn temp_path(len: usize) -> io::Result { + let dir = std::env::temp_dir(); + // Retry a few times in case of collisions + for _ in 0..10 { + let rand_str = sample_ascii_string(len)?; + let filename = format!(".tmp-{rand_str}.socket"); + let path = dir.join(filename); + if !path.exists() { + return Ok(path); } - - Err(io::Error::new( - io::ErrorKind::AlreadyExists, - "too many temporary files exist", - )) } -} -impl Drop for TempPath { - fn drop(&mut self) { - let _ = std::fs::remove_file(&self.0); - } + Err(io::Error::new( + io::ErrorKind::AlreadyExists, + "too many temporary files exist", + )) } - -impl AsRef for TempPath { - fn as_ref(&self) -> &Path { - &self.0 - } -} - -impl std::ops::Deref for TempPath { - type Target = Path; - fn deref(&self) -> &Path { - Path::new(&self.0) - } } diff --git a/src/sys/windows/uds/mod.rs b/src/sys/windows/uds/mod.rs index d0b56eb9d..13569e104 100644 --- a/src/sys/windows/uds/mod.rs +++ b/src/sys/windows/uds/mod.rs @@ -13,7 +13,6 @@ cfg_os_poll! { SocketAddr::new(|sockaddr, socklen| { wsa_syscall!( getsockname(socket.try_into().unwrap(), sockaddr, socklen), - PartialEq::eq, SOCKET_ERROR ) }) @@ -23,7 +22,6 @@ cfg_os_poll! { SocketAddr::new(|sockaddr, socklen| { wsa_syscall!( getpeername(socket.try_into().unwrap(), sockaddr, socklen), - PartialEq::eq, SOCKET_ERROR ) }) diff --git a/tests/unix_listener.rs b/tests/unix_listener.rs index 2874845c7..60b8bd4f7 100644 --- a/tests/unix_listener.rs +++ b/tests/unix_listener.rs @@ -2,7 +2,7 @@ use mio::net::UnixListener; #[cfg(windows)] -use mio::windows::std::net; +use mio::net; use mio::{Interest, Token}; use std::io::{self, Read}; #[cfg(unix)] @@ -33,6 +33,7 @@ fn unix_listener_smoke() { smoke_test(|path| UnixListener::bind(path), "unix_listener_smoke"); } +#[cfg(unix)] #[test] fn unix_listener_from_std() { smoke_test( diff --git a/tests/unix_stream.rs b/tests/unix_stream.rs index 3930d608a..408d9d064 100644 --- a/tests/unix_stream.rs +++ b/tests/unix_stream.rs @@ -1,8 +1,8 @@ -#![cfg(all(feature = "os-poll", feature = "net", any(unix, feature = "os-ext")))] +#![cfg(all(feature = "os-poll", feature = "net"))] -use mio::net::UnixStream; #[cfg(windows)] -use mio::windows::std::net; +use mio::net; +use mio::net::UnixStream; use mio::{Interest, Token}; use std::io::{self, IoSlice, IoSliceMut, Read, Write}; use std::net::Shutdown; @@ -12,6 +12,8 @@ use std::path::Path; use std::sync::mpsc::channel; use std::sync::{Arc, Barrier}; use std::thread; +#[cfg(windows)] +use std::time::Duration; #[macro_use] mod util; @@ -80,6 +82,7 @@ fn unix_stream_connect() { handle.join().unwrap(); } +#[cfg(unix)] #[test] fn unix_stream_from_std() { smoke_test( @@ -488,12 +491,24 @@ fn new_echo_listener( let (addr_sender, addr_receiver) = channel(); let handle = thread::spawn(move || { let path = temp_file(test_name); - let listener = net::UnixListener::bind(path).unwrap(); + // We use mio's non-blocking listener here for windows, since there is no listener in std + // yet. We must be sure to poll before listener I/O. + let mut listener = net::UnixListener::bind(path).unwrap(); + #[cfg(windows)] + let (mut poll, mut events) = init_with_poll(); + #[cfg(windows)] + poll.registry() + .register(&mut listener, TOKEN_1, Interest::READABLE) + .unwrap(); let local_addr = listener.local_addr().unwrap(); addr_sender.send(local_addr).unwrap(); for _ in 0..connections { + #[cfg(windows)] + poll.poll(&mut events, Some(Duration::from_millis(500))).unwrap(); let (mut stream, _) = listener.accept().unwrap(); + #[cfg(windows)] + assert_would_block(listener.accept()); // On Linux based system it will cause a connection reset // error when the reading side of the peer connection is @@ -534,12 +549,24 @@ fn new_noop_listener( let (sender, receiver) = channel(); let handle = thread::spawn(move || { let path = temp_file(test_name); - let listener = net::UnixListener::bind(path).unwrap(); + // We use mio's non-blocking listener here for windows, since there is no listener in std + // yet. We must be sure to poll before listener I/O. + let mut listener = net::UnixListener::bind(path).unwrap(); + #[cfg(windows)] + let (mut poll, mut events) = init_with_poll(); + #[cfg(windows)] + poll.registry() + .register(&mut listener, TOKEN_1, Interest::READABLE) + .unwrap(); let local_addr = listener.local_addr().unwrap(); sender.send(local_addr).unwrap(); for _ in 0..connections { + #[cfg(windows)] + poll.poll(&mut events, Some(Duration::from_millis(500))).unwrap(); let (stream, _) = listener.accept().unwrap(); + #[cfg(windows)] + assert_would_block(listener.accept()); barrier.wait(); stream.shutdown(Shutdown::Write).unwrap(); barrier.wait(); From b07b4f1c8f2c5afa11a3e5729292ca5c8a26fd8f Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Fri, 26 Aug 2022 00:57:40 -0700 Subject: [PATCH 24/34] fmt --- src/net/uds/stream.rs | 2 +- src/sys/windows/stdnet/socket.rs | 10 ++-------- tests/unix_listener.rs | 2 +- tests/unix_stream.rs | 6 ++++-- 4 files changed, 8 insertions(+), 12 deletions(-) diff --git a/src/net/uds/stream.rs b/src/net/uds/stream.rs index 375cb4d2d..fe2f16121 100644 --- a/src/net/uds/stream.rs +++ b/src/net/uds/stream.rs @@ -46,7 +46,7 @@ impl UnixStream { inner: IoSource::new(stream), } } - + #[cfg(windows)] pub(crate) fn from_std(stream: net::UnixStream) -> UnixStream { UnixStream { diff --git a/src/sys/windows/stdnet/socket.rs b/src/sys/windows/stdnet/socket.rs index deac3c031..de6501e4b 100644 --- a/src/sys/windows/stdnet/socket.rs +++ b/src/sys/windows/stdnet/socket.rs @@ -7,9 +7,7 @@ use std::os::raw::c_int; use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; use std::ptr; -use windows_sys::Win32::Networking::WinSock::{ - self, closesocket, SOCKET, SOCKET_ERROR, WSABUF, -}; +use windows_sys::Win32::Networking::WinSock::{self, closesocket, SOCKET, SOCKET_ERROR, WSABUF}; /// Maximum size of a buffer passed to system call like `recv` and `send`. const MAX_BUF_LEN: usize = c_int::MAX as usize; @@ -18,7 +16,6 @@ const MAX_BUF_LEN: usize = c_int::MAX as usize; pub(crate) struct Socket(SOCKET); impl Socket { - pub fn recv(&self, buf: &mut [u8]) -> io::Result { let ret = wsa_syscall!( recv(self.0, buf.as_mut_ptr() as *mut _, buf.len() as c_int, 0,), @@ -97,10 +94,7 @@ impl Socket { Shutdown::Read => WinSock::SD_RECEIVE, Shutdown::Both => WinSock::SD_BOTH, }; - wsa_syscall!( - shutdown(self.0, how.try_into().unwrap()), - SOCKET_ERROR - )?; + wsa_syscall!(shutdown(self.0, how.try_into().unwrap()), SOCKET_ERROR)?; Ok(()) } diff --git a/tests/unix_listener.rs b/tests/unix_listener.rs index 60b8bd4f7..30fa926b3 100644 --- a/tests/unix_listener.rs +++ b/tests/unix_listener.rs @@ -1,8 +1,8 @@ #![cfg(all(feature = "os-poll", feature = "net", any(unix, feature = "os-ext")))] -use mio::net::UnixListener; #[cfg(windows)] use mio::net; +use mio::net::UnixListener; use mio::{Interest, Token}; use std::io::{self, Read}; #[cfg(unix)] diff --git a/tests/unix_stream.rs b/tests/unix_stream.rs index 408d9d064..d48006a74 100644 --- a/tests/unix_stream.rs +++ b/tests/unix_stream.rs @@ -505,7 +505,8 @@ fn new_echo_listener( for _ in 0..connections { #[cfg(windows)] - poll.poll(&mut events, Some(Duration::from_millis(500))).unwrap(); + poll.poll(&mut events, Some(Duration::from_millis(500))) + .unwrap(); let (mut stream, _) = listener.accept().unwrap(); #[cfg(windows)] assert_would_block(listener.accept()); @@ -563,7 +564,8 @@ fn new_noop_listener( for _ in 0..connections { #[cfg(windows)] - poll.poll(&mut events, Some(Duration::from_millis(500))).unwrap(); + poll.poll(&mut events, Some(Duration::from_millis(500))) + .unwrap(); let (stream, _) = listener.accept().unwrap(); #[cfg(windows)] assert_would_block(listener.accept()); From a2831eab34f8265f0dcb3f4f42226b858052de32 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Fri, 26 Aug 2022 01:28:46 -0700 Subject: [PATCH 25/34] simplify windows mod --- src/net/uds/listener.rs | 2 +- src/net/uds/stream.rs | 2 +- src/sys/shell/uds.rs | 4 +- src/sys/windows/mod.rs | 260 +++++++++++++++----------------- src/sys/windows/stdnet/addr.rs | 1 - src/sys/windows/tcp.rs | 2 +- src/sys/windows/udp.rs | 2 +- src/sys/windows/uds/listener.rs | 2 +- src/sys/windows/uds/stream.rs | 2 +- 9 files changed, 133 insertions(+), 144 deletions(-) diff --git a/src/net/uds/listener.rs b/src/net/uds/listener.rs index 928a85dfc..0265048aa 100644 --- a/src/net/uds/listener.rs +++ b/src/net/uds/listener.rs @@ -3,7 +3,7 @@ use crate::net::{SocketAddr, UnixStream}; use crate::{event, sys, Interest, Registry, Token}; #[cfg(windows)] -use crate::sys::windows::std::net; +use crate::sys::windows::stdnet as net; #[cfg(unix)] use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; #[cfg(unix)] diff --git a/src/net/uds/stream.rs b/src/net/uds/stream.rs index fe2f16121..9c73dafa4 100644 --- a/src/net/uds/stream.rs +++ b/src/net/uds/stream.rs @@ -2,7 +2,7 @@ use crate::io_source::IoSource; use crate::{event, sys, Interest, Registry, Token}; #[cfg(windows)] -use crate::sys::windows::std::net; +use crate::sys::windows::stdnet as net; use std::fmt; use std::io::{self, IoSlice, IoSliceMut, Read, Write}; use std::net::Shutdown; diff --git a/src/sys/shell/uds.rs b/src/sys/shell/uds.rs index caa23b9cb..48e568b28 100644 --- a/src/sys/shell/uds.rs +++ b/src/sys/shell/uds.rs @@ -36,7 +36,7 @@ pub(crate) mod datagram { pub(crate) mod listener { use crate::net::{SocketAddr, UnixStream}; #[cfg(windows)] - use crate::sys::windows::std::net; + use crate::sys::windows::stdnet as net; use std::io; #[cfg(unix)] use std::os::unix::net; @@ -58,7 +58,7 @@ pub(crate) mod listener { pub(crate) mod stream { use crate::net::SocketAddr; #[cfg(windows)] - use crate::sys::windows::std::net; + use crate::sys::windows::stdnet as net; use std::io; #[cfg(unix)] use std::os::unix::net; diff --git a/src/sys/windows/mod.rs b/src/sys/windows/mod.rs index 09bb1c928..940522c7c 100644 --- a/src/sys/windows/mod.rs +++ b/src/sys/windows/mod.rs @@ -1,18 +1,3 @@ -/// Helper macro to execute a system call that returns an `io::Result`. -// -// Macro must be defined before any modules that uses them. -#[allow(unused_macros)] -macro_rules! syscall { - ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{ - let res = unsafe { $fn($($arg, )*) }; - if $err_test(&res, &$err_value) { - Err(io::Error::last_os_error()) - } else { - Ok(res) - } - }}; -} - /// Helper macro to execute a WinSock system call that returns an `io::Result`. #[allow(unused_macros)] macro_rules! wsa_syscall { @@ -28,163 +13,168 @@ macro_rules! wsa_syscall { }}; } -cfg_net! { - mod stdnet; - - pub(crate) mod std { - //! Windows only std lib modules that cannot be upstreamed. - pub(crate) mod net { - //! Internal Windows std net implementation. - pub(crate) use crate::sys::windows::stdnet::*; - } - } -} - cfg_os_poll! { - mod afd; +mod afd; - pub mod event; - pub use event::{Event, Events}; +pub mod event; +pub use event::{Event, Events}; - mod handle; - use handle::Handle; +mod handle; +use handle::Handle; - mod io_status_block; - mod iocp; +mod io_status_block; +mod iocp; - mod overlapped; - use overlapped::Overlapped; +mod overlapped; +use overlapped::Overlapped; - mod selector; - pub use selector::{Selector, SelectorInner, SockState}; +mod selector; +pub use selector::{Selector, SelectorInner, SockState}; - // Macros must be defined before the modules that use them - cfg_net! { - mod net; - pub(crate) mod tcp; - pub(crate) mod udp; - pub(crate) mod uds; - pub use self::uds::SocketAddr; +// Macros must be defined before the modules that use them +cfg_net! { + /// Helper macro to execute a system call that returns an `io::Result`. + // + // Macro must be defined before any modules that uses them. + macro_rules! syscall { + ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{ + let res = unsafe { $fn($($arg, )*) }; + if $err_test(&res, &$err_value) { + Err(io::Error::last_os_error()) + } else { + Ok(res) + } + }}; } - cfg_os_ext! { - pub(crate) mod named_pipe; - } + mod net; - mod waker; - pub(crate) use waker::Waker; + pub(crate) mod stdnet; + pub(crate) mod tcp; + pub(crate) mod udp; + pub(crate) mod uds; + pub use self::uds::SocketAddr; +} - cfg_io_source! { - use ::std::io; - use ::std::os::windows::io::RawSocket; - use ::std::pin::Pin; - use ::std::sync::{Arc, Mutex}; +cfg_os_ext! { + pub(crate) mod named_pipe; +} - use crate::{Interest, Registry, Token}; +mod waker; +pub(crate) use waker::Waker; - struct InternalState { - selector: Arc, - token: Token, - interests: Interest, - sock_state: Pin>>, - } +cfg_io_source! { + use std::io; + use std::os::windows::io::RawSocket; + use std::pin::Pin; + use std::sync::{Arc, Mutex}; - impl Drop for InternalState { - fn drop(&mut self) { - let mut sock_state = self.sock_state.lock().unwrap(); - sock_state.mark_delete(); - } + use crate::{Interest, Registry, Token}; + + struct InternalState { + selector: Arc, + token: Token, + interests: Interest, + sock_state: Pin>>, + } + + impl Drop for InternalState { + fn drop(&mut self) { + let mut sock_state = self.sock_state.lock().unwrap(); + sock_state.mark_delete(); } + } + + pub struct IoSourceState { + // This is `None` if the socket has not yet been registered. + // + // We box the internal state to not increase the size on the stack as the + // type might move around a lot. + inner: Option>, + } - pub struct IoSourceState { - // This is `None` if the socket has not yet been registered. - // - // We box the internal state to not increase the size on the stack as the - // type might move around a lot. - inner: Option>, + impl IoSourceState { + pub fn new() -> IoSourceState { + IoSourceState { inner: None } } - impl IoSourceState { - pub fn new() -> IoSourceState { - IoSourceState { inner: None } + pub fn do_io(&self, f: F, io: &T) -> io::Result + where + F: FnOnce(&T) -> io::Result, + { + let result = f(io); + if let Err(ref e) = result { + if e.kind() == io::ErrorKind::WouldBlock { + self.inner.as_ref().map_or(Ok(()), |state| { + state + .selector + .reregister(state.sock_state.clone(), state.token, state.interests) + })?; + } } + result + } - pub fn do_io(&self, f: F, io: &T) -> io::Result - where - F: FnOnce(&T) -> io::Result, - { - let result = f(io); - if let Err(ref e) = result { - if e.kind() == io::ErrorKind::WouldBlock { - self.inner.as_ref().map_or(Ok(()), |state| { - state - .selector - .reregister(state.sock_state.clone(), state.token, state.interests) - })?; - } - } - result + pub fn register( + &mut self, + registry: &Registry, + token: Token, + interests: Interest, + socket: RawSocket, + ) -> io::Result<()> { + if self.inner.is_some() { + Err(io::ErrorKind::AlreadyExists.into()) + } else { + registry + .selector() + .register(socket, token, interests) + .map(|state| { + self.inner = Some(Box::new(state)); + }) } + } - pub fn register( - &mut self, - registry: &Registry, - token: Token, - interests: Interest, - socket: RawSocket, - ) -> io::Result<()> { - if self.inner.is_some() { - Err(io::ErrorKind::AlreadyExists.into()) - } else { + pub fn reregister( + &mut self, + registry: &Registry, + token: Token, + interests: Interest, + ) -> io::Result<()> { + match self.inner.as_mut() { + Some(state) => { registry .selector() - .register(socket, token, interests) - .map(|state| { - self.inner = Some(Box::new(state)); + .reregister(state.sock_state.clone(), token, interests) + .map(|()| { + state.token = token; + state.interests = interests; }) } + None => Err(io::ErrorKind::NotFound.into()), } + } - pub fn reregister( - &mut self, - registry: &Registry, - token: Token, - interests: Interest, - ) -> io::Result<()> { - match self.inner.as_mut() { - Some(state) => { - registry - .selector() - .reregister(state.sock_state.clone(), token, interests) - .map(|()| { - state.token = token; - state.interests = interests; - }) - } - None => Err(io::ErrorKind::NotFound.into()), - } - } - - pub fn deregister(&mut self) -> io::Result<()> { - match self.inner.as_mut() { - Some(state) => { - { - let mut sock_state = state.sock_state.lock().unwrap(); - sock_state.mark_delete(); - } - self.inner = None; - Ok(()) + pub fn deregister(&mut self) -> io::Result<()> { + match self.inner.as_mut() { + Some(state) => { + { + let mut sock_state = state.sock_state.lock().unwrap(); + sock_state.mark_delete(); } - None => Err(io::ErrorKind::NotFound.into()), + self.inner = None; + Ok(()) } + None => Err(io::ErrorKind::NotFound.into()), } } } } +} cfg_not_os_poll! { cfg_net! { - mod uds; + pub(crate) mod stdnet; + pub(crate) mod uds; pub use self::uds::SocketAddr; } } diff --git a/src/sys/windows/stdnet/addr.rs b/src/sys/windows/stdnet/addr.rs index 8f52562f3..1ec24cb4c 100644 --- a/src/sys/windows/stdnet/addr.rs +++ b/src/sys/windows/stdnet/addr.rs @@ -16,7 +16,6 @@ fn path_offset(addr: &sockaddr_un) -> usize { cfg_os_poll! { use windows_sys::Win32::Networking::WinSock::AF_UNIX; - pub(super) fn socket_addr(path: &Path) -> io::Result<(sockaddr_un, c_int)> { let sockaddr = mem::MaybeUninit::::zeroed(); diff --git a/src/sys/windows/tcp.rs b/src/sys/windows/tcp.rs index a2f123e36..af0e25106 100644 --- a/src/sys/windows/tcp.rs +++ b/src/sys/windows/tcp.rs @@ -7,7 +7,7 @@ use windows_sys::Win32::Networking::WinSock::{ }; use crate::sys::windows::net::{new_socket, socket_addr}; -use crate::sys::windows::std::net::init; +use crate::sys::windows::stdnet::init; pub(crate) fn new_for_addr(address: SocketAddr) -> io::Result { init(); diff --git a/src/sys/windows/udp.rs b/src/sys/windows/udp.rs index e07b967c4..213f2d329 100644 --- a/src/sys/windows/udp.rs +++ b/src/sys/windows/udp.rs @@ -5,7 +5,7 @@ use std::os::windows::io::{AsRawSocket, FromRawSocket}; use std::os::windows::raw::SOCKET as StdSocket; // windows-sys uses usize, stdlib uses u32/u64. use crate::sys::windows::net::{new_ip_socket, socket_addr}; -use crate::sys::windows::std::net::init; +use crate::sys::windows::stdnet::init; use windows_sys::Win32::Networking::WinSock::{ bind as win_bind, closesocket, getsockopt, IPPROTO_IPV6, IPV6_V6ONLY, SOCKET_ERROR, SOCK_DGRAM, }; diff --git a/src/sys/windows/uds/listener.rs b/src/sys/windows/uds/listener.rs index 4cc393ffc..d6a7a25cb 100644 --- a/src/sys/windows/uds/listener.rs +++ b/src/sys/windows/uds/listener.rs @@ -3,7 +3,7 @@ use std::os::windows::io::AsRawSocket; use std::path::Path; use crate::net::{SocketAddr, UnixStream}; -use crate::sys::windows::std::net; +use crate::sys::windows::stdnet as net; pub(crate) fn bind(path: &Path) -> io::Result { let listener = net::UnixListener::bind(path)?; diff --git a/src/sys/windows/uds/stream.rs b/src/sys/windows/uds/stream.rs index c59a1f95c..45e4dd720 100644 --- a/src/sys/windows/uds/stream.rs +++ b/src/sys/windows/uds/stream.rs @@ -1,4 +1,4 @@ -use crate::sys::windows::std::net; +use crate::sys::windows::stdnet as net; use std::io; use std::os::windows::io::AsRawSocket; use std::path::Path; From 73d5faeb901b41735a1fbb09a69cc842ae76eadb Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Fri, 26 Aug 2022 02:10:17 -0700 Subject: [PATCH 26/34] clean up tests --- tests/unix_listener.rs | 54 +-------------------------- tests/unix_stream.rs | 85 ++++++++++++++++++++++++++---------------- 2 files changed, 53 insertions(+), 86 deletions(-) diff --git a/tests/unix_listener.rs b/tests/unix_listener.rs index 30fa926b3..90a316599 100644 --- a/tests/unix_listener.rs +++ b/tests/unix_listener.rs @@ -139,7 +139,7 @@ fn unix_listener_deregister() { #[cfg(target_os = "linux")] #[test] -fn unix_listener_abstract_namespace() { +fn unix_listener_abstract_namesapce() { use rand::Rng; let num: u64 = rand::thread_rng().gen(); let name = format!("\u{0000}-mio-abstract-uds-{}", num); @@ -191,58 +191,6 @@ where handle.join().unwrap(); } -#[test] -fn unix_listener_multiple_accepts() { - let (mut poll, mut events) = init_with_poll(); - let barrier = Arc::new(Barrier::new(2)); - let path = temp_file("unix_listener_multiple_accepts"); - let mut buf = [0; DEFAULT_BUF_SIZE]; - - let mut listener = UnixListener::bind(&path).unwrap(); - - assert_socket_non_blocking(&listener); - assert_socket_close_on_exec(&listener); - - poll.registry() - .register( - &mut listener, - TOKEN_1, - Interest::WRITABLE.add(Interest::READABLE), - ) - .unwrap(); - expect_no_events(&mut poll, &mut events); - - let handle = open_connections(path, 2, barrier.clone()); - - // First connection is opened, try to accept and read. - expect_events( - &mut poll, - &mut events, - vec![ExpectEvent::new(TOKEN_1, Interest::READABLE)], - ); - - let (mut stream1, _) = listener.accept().unwrap(); - assert_would_block(stream1.read(&mut buf)); - assert_would_block(listener.accept()); - barrier.wait(); - - // Second connection is opened, try to accept and read. - expect_events( - &mut poll, - &mut events, - vec![ExpectEvent::new(TOKEN_1, Interest::READABLE)], - ); - - let (mut stream1, _) = listener.accept().unwrap(); - assert_would_block(stream1.read(&mut buf)); - barrier.wait(); - - // We don't expect any more connections. - assert_would_block(listener.accept()); - assert!(listener.take_error().unwrap().is_none()); - handle.join().unwrap(); -} - fn open_connections( path: PathBuf, n_connections: usize, diff --git a/tests/unix_stream.rs b/tests/unix_stream.rs index d48006a74..46c6a935f 100644 --- a/tests/unix_stream.rs +++ b/tests/unix_stream.rs @@ -484,33 +484,72 @@ where handle.join().unwrap(); } -fn new_echo_listener( +#[cfg(windows)] +fn new_listener( connections: usize, test_name: &'static str, -) -> (thread::JoinHandle<()>, net::SocketAddr) { + handle_stream: F +) -> (thread::JoinHandle<()>, net::SocketAddr) +where + F: Fn(&net::UnixStream) + std::marker::Send + 'static +{ let (addr_sender, addr_receiver) = channel(); let handle = thread::spawn(move || { let path = temp_file(test_name); // We use mio's non-blocking listener here for windows, since there is no listener in std // yet. We must be sure to poll before listener I/O. let mut listener = net::UnixListener::bind(path).unwrap(); - #[cfg(windows)] let (mut poll, mut events) = init_with_poll(); - #[cfg(windows)] poll.registry() .register(&mut listener, TOKEN_1, Interest::READABLE) .unwrap(); + let local_addr = listener.local_addr().unwrap(); addr_sender.send(local_addr).unwrap(); for _ in 0..connections { - #[cfg(windows)] poll.poll(&mut events, Some(Duration::from_millis(500))) .unwrap(); - let (mut stream, _) = listener.accept().unwrap(); - #[cfg(windows)] + let (stream, _) = listener.accept().unwrap(); assert_would_block(listener.accept()); + handle_stream(&stream); + } + }); + (handle, addr_receiver.recv().unwrap()) +} + +#[cfg(unix)] +fn new_listener( + connections: usize, + test_name: &'static str, + handle_stream: F +) -> (thread::JoinHandle<()>, net::SocketAddr) +where + F: Fn(&net::UnixStream) + std::marker::Send + 'static +{ + let (addr_sender, addr_receiver) = channel(); + let handle = thread::spawn(move || { + let path = temp_file(test_name); + let listener = net::UnixListener::bind(path).unwrap(); + let local_addr = listener.local_addr().unwrap(); + addr_sender.send(local_addr).unwrap(); + + for _ in 0..connections { + let (stream, _) = listener.accept().unwrap(); + handle_stream(stream); + } + }); + (handle, addr_receiver.recv().unwrap()) +} +fn new_echo_listener( + connections: usize, + test_name: &'static str, +) -> (thread::JoinHandle<()>, net::SocketAddr) { + new_listener( + connections, + test_name, + |mut stream| { // On Linux based system it will cause a connection reset // error when the reading side of the peer connection is // shutdown, we don't consider it an actual here. @@ -538,8 +577,7 @@ fn new_echo_listener( } assert_eq!(read, written, "unequal reads and writes"); } - }); - (handle, addr_receiver.recv().unwrap()) + ) } fn new_noop_listener( @@ -547,33 +585,14 @@ fn new_noop_listener( barrier: Arc, test_name: &'static str, ) -> (thread::JoinHandle<()>, net::SocketAddr) { - let (sender, receiver) = channel(); - let handle = thread::spawn(move || { - let path = temp_file(test_name); - // We use mio's non-blocking listener here for windows, since there is no listener in std - // yet. We must be sure to poll before listener I/O. - let mut listener = net::UnixListener::bind(path).unwrap(); - #[cfg(windows)] - let (mut poll, mut events) = init_with_poll(); - #[cfg(windows)] - poll.registry() - .register(&mut listener, TOKEN_1, Interest::READABLE) - .unwrap(); - let local_addr = listener.local_addr().unwrap(); - sender.send(local_addr).unwrap(); - - for _ in 0..connections { - #[cfg(windows)] - poll.poll(&mut events, Some(Duration::from_millis(500))) - .unwrap(); - let (stream, _) = listener.accept().unwrap(); - #[cfg(windows)] - assert_would_block(listener.accept()); + new_listener( + connections, + test_name, + move |stream| { barrier.wait(); stream.shutdown(Shutdown::Write).unwrap(); barrier.wait(); drop(stream); } - }); - (handle, receiver.recv().unwrap()) + ) } From 0baf11238401e82b8e6935f4ca325daa7ab8c21d Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Fri, 26 Aug 2022 18:21:28 +0200 Subject: [PATCH 27/34] fix indentation, imports, address other comments --- src/sys/windows/mod.rs | 252 ++++++++++++++--------------- src/sys/windows/stdnet/addr.rs | 96 +++++------ src/sys/windows/stdnet/listener.rs | 40 ++--- src/sys/windows/stdnet/mod.rs | 25 +-- src/sys/windows/stdnet/socket.rs | 84 +++++----- src/sys/windows/stdnet/stream.rs | 190 +++++++++++----------- tests/close_on_drop.rs | 2 +- tests/unix_listener.rs | 2 +- tests/unix_stream.rs | 6 +- 9 files changed, 349 insertions(+), 348 deletions(-) diff --git a/src/sys/windows/mod.rs b/src/sys/windows/mod.rs index 940522c7c..0cb58bc64 100644 --- a/src/sys/windows/mod.rs +++ b/src/sys/windows/mod.rs @@ -4,7 +4,7 @@ macro_rules! wsa_syscall { ($fn: ident ( $($arg: expr),* $(,)* ), $err_value: expr) => {{ let res = unsafe { windows_sys::Win32::Networking::WinSock::$fn($($arg, )*) }; if PartialEq::eq(&res, &$err_value) { - Err(io::Error::from_raw_os_error(unsafe { + Err(std::io::Error::from_raw_os_error(unsafe { windows_sys::Win32::Networking::WinSock::WSAGetLastError() })) } else { @@ -14,162 +14,162 @@ macro_rules! wsa_syscall { } cfg_os_poll! { -mod afd; - -pub mod event; -pub use event::{Event, Events}; - -mod handle; -use handle::Handle; - -mod io_status_block; -mod iocp; - -mod overlapped; -use overlapped::Overlapped; - -mod selector; -pub use selector::{Selector, SelectorInner, SockState}; - -// Macros must be defined before the modules that use them -cfg_net! { - /// Helper macro to execute a system call that returns an `io::Result`. - // - // Macro must be defined before any modules that uses them. - macro_rules! syscall { - ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{ - let res = unsafe { $fn($($arg, )*) }; - if $err_test(&res, &$err_value) { - Err(io::Error::last_os_error()) - } else { - Ok(res) - } - }}; - } + mod afd; - mod net; + pub mod event; + pub use event::{Event, Events}; - pub(crate) mod stdnet; - pub(crate) mod tcp; - pub(crate) mod udp; - pub(crate) mod uds; - pub use self::uds::SocketAddr; -} + mod handle; + use handle::Handle; -cfg_os_ext! { - pub(crate) mod named_pipe; -} + mod io_status_block; + mod iocp; -mod waker; -pub(crate) use waker::Waker; + mod overlapped; + use overlapped::Overlapped; -cfg_io_source! { - use std::io; - use std::os::windows::io::RawSocket; - use std::pin::Pin; - use std::sync::{Arc, Mutex}; + mod selector; + pub use selector::{Selector, SelectorInner, SockState}; - use crate::{Interest, Registry, Token}; + // Macros must be defined before the modules that use them + cfg_net! { + /// Helper macro to execute a system call that returns an `io::Result`. + // + // Macro must be defined before any modules that uses them. + macro_rules! syscall { + ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{ + let res = unsafe { $fn($($arg, )*) }; + if $err_test(&res, &$err_value) { + Err(io::Error::last_os_error()) + } else { + Ok(res) + } + }}; + } - struct InternalState { - selector: Arc, - token: Token, - interests: Interest, - sock_state: Pin>>, - } + mod net; - impl Drop for InternalState { - fn drop(&mut self) { - let mut sock_state = self.sock_state.lock().unwrap(); - sock_state.mark_delete(); - } + pub(crate) mod stdnet; + pub(crate) mod tcp; + pub(crate) mod udp; + pub(crate) mod uds; + pub use self::uds::SocketAddr; } - pub struct IoSourceState { - // This is `None` if the socket has not yet been registered. - // - // We box the internal state to not increase the size on the stack as the - // type might move around a lot. - inner: Option>, + cfg_os_ext! { + pub(crate) mod named_pipe; } - impl IoSourceState { - pub fn new() -> IoSourceState { - IoSourceState { inner: None } - } + mod waker; + pub(crate) use waker::Waker; - pub fn do_io(&self, f: F, io: &T) -> io::Result - where - F: FnOnce(&T) -> io::Result, - { - let result = f(io); - if let Err(ref e) = result { - if e.kind() == io::ErrorKind::WouldBlock { - self.inner.as_ref().map_or(Ok(()), |state| { - state - .selector - .reregister(state.sock_state.clone(), state.token, state.interests) - })?; - } - } - result - } + cfg_io_source! { + use std::io; + use std::os::windows::io::RawSocket; + use std::pin::Pin; + use std::sync::{Arc, Mutex}; + + use crate::{Interest, Registry, Token}; - pub fn register( - &mut self, - registry: &Registry, + struct InternalState { + selector: Arc, token: Token, interests: Interest, - socket: RawSocket, - ) -> io::Result<()> { - if self.inner.is_some() { - Err(io::ErrorKind::AlreadyExists.into()) - } else { - registry - .selector() - .register(socket, token, interests) - .map(|state| { - self.inner = Some(Box::new(state)); - }) + sock_state: Pin>>, + } + + impl Drop for InternalState { + fn drop(&mut self) { + let mut sock_state = self.sock_state.lock().unwrap(); + sock_state.mark_delete(); } } - pub fn reregister( - &mut self, - registry: &Registry, - token: Token, - interests: Interest, - ) -> io::Result<()> { - match self.inner.as_mut() { - Some(state) => { + pub struct IoSourceState { + // This is `None` if the socket has not yet been registered. + // + // We box the internal state to not increase the size on the stack as the + // type might move around a lot. + inner: Option>, + } + + impl IoSourceState { + pub fn new() -> IoSourceState { + IoSourceState { inner: None } + } + + pub fn do_io(&self, f: F, io: &T) -> io::Result + where + F: FnOnce(&T) -> io::Result, + { + let result = f(io); + if let Err(ref e) = result { + if e.kind() == io::ErrorKind::WouldBlock { + self.inner.as_ref().map_or(Ok(()), |state| { + state + .selector + .reregister(state.sock_state.clone(), state.token, state.interests) + })?; + } + } + result + } + + pub fn register( + &mut self, + registry: &Registry, + token: Token, + interests: Interest, + socket: RawSocket, + ) -> io::Result<()> { + if self.inner.is_some() { + Err(io::ErrorKind::AlreadyExists.into()) + } else { registry .selector() - .reregister(state.sock_state.clone(), token, interests) - .map(|()| { - state.token = token; - state.interests = interests; + .register(socket, token, interests) + .map(|state| { + self.inner = Some(Box::new(state)); }) } - None => Err(io::ErrorKind::NotFound.into()), } - } - pub fn deregister(&mut self) -> io::Result<()> { - match self.inner.as_mut() { - Some(state) => { - { - let mut sock_state = state.sock_state.lock().unwrap(); - sock_state.mark_delete(); + pub fn reregister( + &mut self, + registry: &Registry, + token: Token, + interests: Interest, + ) -> io::Result<()> { + match self.inner.as_mut() { + Some(state) => { + registry + .selector() + .reregister(state.sock_state.clone(), token, interests) + .map(|()| { + state.token = token; + state.interests = interests; + }) } - self.inner = None; - Ok(()) + None => Err(io::ErrorKind::NotFound.into()), + } + } + + pub fn deregister(&mut self) -> io::Result<()> { + match self.inner.as_mut() { + Some(state) => { + { + let mut sock_state = state.sock_state.lock().unwrap(); + sock_state.mark_delete(); + } + self.inner = None; + Ok(()) + } + None => Err(io::ErrorKind::NotFound.into()), } - None => Err(io::ErrorKind::NotFound.into()), } } } } -} cfg_not_os_poll! { cfg_net! { diff --git a/src/sys/windows/stdnet/addr.rs b/src/sys/windows/stdnet/addr.rs index 1ec24cb4c..55433db26 100644 --- a/src/sys/windows/stdnet/addr.rs +++ b/src/sys/windows/stdnet/addr.rs @@ -15,58 +15,58 @@ fn path_offset(addr: &sockaddr_un) -> usize { } cfg_os_poll! { -use windows_sys::Win32::Networking::WinSock::AF_UNIX; -pub(super) fn socket_addr(path: &Path) -> io::Result<(sockaddr_un, c_int)> { - let sockaddr = mem::MaybeUninit::::zeroed(); - - // This is safe to assume because a `sockaddr_un` filled with `0` - // bytes is properly initialized. - // - // `0` is a valid value for `sockaddr_un::sun_family`; it is - // `WinSock::AF_UNSPEC`. - // - // `[0; 108]` is a valid value for `sockaddr_un::sun_path`; it begins an - // abstract path. - let mut sockaddr = unsafe { sockaddr.assume_init() }; - sockaddr.sun_family = AF_UNIX; - - // Winsock2 expects 'sun_path' to be a Win32 UTF-8 file system path - let bytes = path.to_str().map(|s| s.as_bytes()).ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidInput, - "path contains invalid characters", - ) - })?; - - if bytes.contains(&0) { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "paths may not contain interior null bytes", - )); - } + use windows_sys::Win32::Networking::WinSock::AF_UNIX; + pub(super) fn socket_addr(path: &Path) -> io::Result<(sockaddr_un, c_int)> { + let sockaddr = mem::MaybeUninit::::zeroed(); + + // This is safe to assume because a `sockaddr_un` filled with `0` + // bytes is properly initialized. + // + // `0` is a valid value for `sockaddr_un::sun_family`; it is + // `WinSock::AF_UNSPEC`. + // + // `[0; 108]` is a valid value for `sockaddr_un::sun_path`; it begins an + // abstract path. + let mut sockaddr = unsafe { sockaddr.assume_init() }; + sockaddr.sun_family = AF_UNIX; + + // Winsock2 expects 'sun_path' to be a Win32 UTF-8 file system path + let bytes = path.to_str().map(|s| s.as_bytes()).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "path contains invalid characters", + ) + })?; + + if bytes.contains(&0) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "paths may not contain interior null bytes", + )); + } - if bytes.len() >= sockaddr.sun_path.len() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "path must be shorter than SUN_LEN", - )); - } - for (dst, src) in sockaddr.sun_path.iter_mut().zip(bytes.iter()) { - *dst = *src as u8; - } + if bytes.len() >= sockaddr.sun_path.len() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "path must be shorter than SUN_LEN", + )); + } + for (dst, src) in sockaddr.sun_path.iter_mut().zip(bytes.iter()) { + *dst = *src as u8; + } - let offset = path_offset(&sockaddr); - let mut socklen = offset + bytes.len(); + let offset = path_offset(&sockaddr); + let mut socklen = offset + bytes.len(); - match bytes.get(0) { - // The struct has already been zeroes so the null byte for pathname - // addresses is already there. - Some(&0) | None => {} - Some(_) => socklen += 1, - } + match bytes.get(0) { + // The struct has already been zeroes so the null byte for pathname + // addresses is already there. + Some(&0) | None => {} + Some(_) => socklen += 1, + } - Ok((sockaddr, socklen as c_int)) -} + Ok((sockaddr, socklen as c_int)) + } } enum AddressKind<'a> { diff --git a/src/sys/windows/stdnet/listener.rs b/src/sys/windows/stdnet/listener.rs index 2d8965631..4d849d10d 100644 --- a/src/sys/windows/stdnet/listener.rs +++ b/src/sys/windows/stdnet/listener.rs @@ -54,30 +54,30 @@ impl IntoRawSocket for UnixListener { } cfg_os_poll! { -use std::path::Path; + use std::path::Path; -use super::{socket_addr, UnixStream}; + use super::{socket_addr, UnixStream}; -impl UnixListener { - pub fn bind>(path: P) -> io::Result { - let inner = Socket::new()?; - let (addr, len) = socket_addr(path.as_ref())?; + impl UnixListener { + pub fn bind>(path: P) -> io::Result { + let inner = Socket::new()?; + let (addr, len) = socket_addr(path.as_ref())?; - wsa_syscall!( - bind(inner.as_raw_socket() as _, &addr as *const _ as *const _, len as _), - SOCKET_ERROR - )?; - wsa_syscall!(listen(inner.as_raw_socket() as _, 128), SOCKET_ERROR)?; - Ok(UnixListener(inner)) - } + wsa_syscall!( + bind(inner.as_raw_socket() as _, &addr as *const _ as *const _, len as _), + SOCKET_ERROR + )?; + wsa_syscall!(listen(inner.as_raw_socket() as _, 128), SOCKET_ERROR)?; + Ok(UnixListener(inner)) + } - pub fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { - SocketAddr::init(|addr, len| self.0.accept(addr, len)) - .map(|(sock, addr)| (UnixStream(sock), addr)) - } + pub fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { + SocketAddr::init(|addr, len| self.0.accept(addr, len)) + .map(|(sock, addr)| (UnixStream(sock), addr)) + } - pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { - self.0.set_nonblocking(nonblocking) + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.0.set_nonblocking(nonblocking) + } } } -} diff --git a/src/sys/windows/stdnet/mod.rs b/src/sys/windows/stdnet/mod.rs index 62cf5cba1..4764c7456 100644 --- a/src/sys/windows/stdnet/mod.rs +++ b/src/sys/windows/stdnet/mod.rs @@ -1,3 +1,4 @@ +//! Implementation of blocking UDS types for windows, mirrors std::os::unix::net. mod addr; mod listener; mod socket; @@ -8,18 +9,18 @@ pub(crate) use self::listener::UnixListener; pub(crate) use self::stream::UnixStream; cfg_os_poll! { -pub(self) use self::addr::socket_addr; + pub(self) use self::addr::socket_addr; -use std::sync::Once; + use std::sync::Once; -/// Initialise the network stack for Windows. -pub(crate) fn init() { - static INIT: Once = Once::new(); - INIT.call_once(|| { - // Let standard library call `WSAStartup` for us, we can't do it - // ourselves because otherwise using any type in `std::net` would panic - // when it tries to call `WSAStartup` a second time. - drop(std::net::UdpSocket::bind("127.0.0.1:0")); - }); -} + /// Initialise the network stack for Windows. + pub(crate) fn init() { + static INIT: Once = Once::new(); + INIT.call_once(|| { + // Let standard library call `WSAStartup` for us, we can't do it + // ourselves because otherwise using any type in `std::net` would panic + // when it tries to call `WSAStartup` a second time. + drop(std::net::UdpSocket::bind("127.0.0.1:0")); + }); + } } diff --git a/src/sys/windows/stdnet/socket.rs b/src/sys/windows/stdnet/socket.rs index de6501e4b..869b5e60c 100644 --- a/src/sys/windows/stdnet/socket.rs +++ b/src/sys/windows/stdnet/socket.rs @@ -148,50 +148,50 @@ impl IntoRawSocket for Socket { } cfg_os_poll! { -use windows_sys::Win32::Networking::WinSock::{INVALID_SOCKET, SOCKADDR}; -use windows_sys::Win32::Foundation::{SetHandleInformation, HANDLE, HANDLE_FLAG_INHERIT}; -use super::init; - -impl Socket { - pub fn new() -> io::Result { - init(); - let socket = wsa_syscall!( - WSASocketW( - WinSock::AF_UNIX.into(), - WinSock::SOCK_STREAM.into(), - 0, - ptr::null_mut(), - 0, - WinSock::WSA_FLAG_OVERLAPPED | WinSock::WSA_FLAG_NO_HANDLE_INHERIT, - ), - INVALID_SOCKET - )?; - Ok(Socket(socket)) - } + use windows_sys::Win32::Networking::WinSock::{INVALID_SOCKET, SOCKADDR}; + use windows_sys::Win32::Foundation::{SetHandleInformation, HANDLE, HANDLE_FLAG_INHERIT}; + use super::init; + + impl Socket { + pub fn new() -> io::Result { + init(); + let socket = wsa_syscall!( + WSASocketW( + WinSock::AF_UNIX.into(), + WinSock::SOCK_STREAM.into(), + 0, + ptr::null_mut(), + 0, + WinSock::WSA_FLAG_OVERLAPPED | WinSock::WSA_FLAG_NO_HANDLE_INHERIT, + ), + INVALID_SOCKET + )?; + Ok(Socket(socket)) + } - pub fn accept(&self, storage: *mut SOCKADDR, len: *mut c_int) -> io::Result { - let socket = wsa_syscall!(accept(self.0, storage, len), INVALID_SOCKET)?; - let socket = Socket(socket); - socket.set_no_inherit()?; - Ok(socket) - } + pub fn accept(&self, storage: *mut SOCKADDR, len: *mut c_int) -> io::Result { + let socket = wsa_syscall!(accept(self.0, storage, len), INVALID_SOCKET)?; + let socket = Socket(socket); + socket.set_no_inherit()?; + Ok(socket) + } - fn set_no_inherit(&self) -> io::Result<()> { - syscall!( - SetHandleInformation(self.0 as HANDLE, HANDLE_FLAG_INHERIT, 0), - PartialEq::eq, - 0 - )?; - Ok(()) - } + fn set_no_inherit(&self) -> io::Result<()> { + syscall!( + SetHandleInformation(self.0 as HANDLE, HANDLE_FLAG_INHERIT, 0), + PartialEq::eq, + 0 + )?; + Ok(()) + } - pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { - let mut nonblocking = if nonblocking { 1 } else { 0 }; - wsa_syscall!( - ioctlsocket(self.0, WinSock::FIONBIO, &mut nonblocking), - SOCKET_ERROR - )?; - Ok(()) + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + let mut nonblocking = if nonblocking { 1 } else { 0 }; + wsa_syscall!( + ioctlsocket(self.0, WinSock::FIONBIO, &mut nonblocking), + SOCKET_ERROR + )?; + Ok(()) + } } } -} diff --git a/src/sys/windows/stdnet/stream.rs b/src/sys/windows/stdnet/stream.rs index ed8b3314d..9f9548f8b 100644 --- a/src/sys/windows/stdnet/stream.rs +++ b/src/sys/windows/stdnet/stream.rs @@ -120,107 +120,107 @@ impl IntoRawSocket for UnixStream { } cfg_os_poll! { -use std::path::{Path, PathBuf}; -use windows_sys::Win32::Foundation::STATUS_SUCCESS; -use windows_sys::Win32::Networking::WinSock::WSAEINPROGRESS; -use windows_sys::Win32::Security::Cryptography::{ - BCryptGenRandom, BCRYPT_USE_SYSTEM_PREFERRED_RNG, -}; - -use super::{socket_addr, UnixListener}; + use std::path::{Path, PathBuf}; + use windows_sys::Win32::Foundation::STATUS_SUCCESS; + use windows_sys::Win32::Networking::WinSock::WSAEINPROGRESS; + use windows_sys::Win32::Security::Cryptography::{ + BCryptGenRandom, BCRYPT_USE_SYSTEM_PREFERRED_RNG, + }; + + use super::{socket_addr, UnixListener}; + + impl UnixStream { + pub fn connect>(path: P) -> io::Result { + let inner = Socket::new()?; + let (addr, len) = socket_addr(path.as_ref())?; + + match wsa_syscall!( + connect( + inner.as_raw_socket() as _, + &addr as *const _ as *const _, + len as i32, + ), + SOCKET_ERROR + ) { + Ok(_) => {} + Err(ref err) if err.raw_os_error() == Some(WSAEINPROGRESS) => {} + Err(e) => return Err(e), + } + Ok(UnixStream(inner)) + } -impl UnixStream { - pub fn connect>(path: P) -> io::Result { - let inner = Socket::new()?; - let (addr, len) = socket_addr(path.as_ref())?; - - match wsa_syscall!( - connect( - inner.as_raw_socket() as _, - &addr as *const _ as *const _, - len as i32, - ), - SOCKET_ERROR - ) { - Ok(_) => {} - Err(ref err) if err.raw_os_error() == Some(WSAEINPROGRESS) => {} - Err(e) => return Err(e), + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.0.set_nonblocking(nonblocking) } - Ok(UnixStream(inner)) - } - pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { - self.0.set_nonblocking(nonblocking) + pub fn pair() -> io::Result<(Self, Self)> { + use std::sync::{Arc, RwLock}; + use std::thread::spawn; + + let file_path = temp_path(10)?; + let a: Arc>>> = Arc::new(RwLock::new(None)); + let ul = UnixListener::bind(&file_path).unwrap(); + let server = { + let a = a.clone(); + spawn(move || { + let mut store = a.write().unwrap(); + let stream0 = ul.accept().map(|s| s.0); + *store = Some(stream0); + }) + }; + let stream1 = UnixStream::connect(&file_path)?; + server + .join() + .map_err(|_| io::Error::from(io::ErrorKind::ConnectionRefused))?; + let stream0 = (*(a.write().unwrap())).take().unwrap()?; + let _ = std::fs::remove_file(&file_path); + Ok((stream0, stream1)) + } } - pub fn pair() -> io::Result<(Self, Self)> { - use std::sync::{Arc, RwLock}; - use std::thread::spawn; - - let file_path = temp_path(10)?; - let a: Arc>>> = Arc::new(RwLock::new(None)); - let ul = UnixListener::bind(&file_path).unwrap(); - let server = { - let a = a.clone(); - spawn(move || { - let mut store = a.write().unwrap(); - let stream0 = ul.accept().map(|s| s.0); - *store = Some(stream0); + fn sample_ascii_string(len: usize) -> io::Result { + const GEN_ASCII_STR_CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\ + abcdefghijklmnopqrstuvwxyz\ + 0123456789-_"; + let mut buf: Vec = vec![0; len]; + for chunk in buf.chunks_mut(u32::max_value() as usize) { + syscall!( + BCryptGenRandom( + 0, + chunk.as_mut_ptr(), + chunk.len() as u32, + BCRYPT_USE_SYSTEM_PREFERRED_RNG, + ), + PartialEq::ne, + STATUS_SUCCESS + )?; + } + let result: String = buf + .into_iter() + .map(|r| { + // We pick from 64=2^6 characters so we can use a simple bitshift. + let idx = r >> (8 - 6); + char::from(GEN_ASCII_STR_CHARSET[idx as usize]) }) - }; - let stream1 = UnixStream::connect(&file_path)?; - server - .join() - .map_err(|_| io::Error::from(io::ErrorKind::ConnectionRefused))?; - let stream0 = (*(a.write().unwrap())).take().unwrap()?; - let _ = std::fs::remove_file(&file_path); - Ok((stream0, stream1)) - } -} - -fn sample_ascii_string(len: usize) -> io::Result { - const GEN_ASCII_STR_CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\ - abcdefghijklmnopqrstuvwxyz\ - 0123456789-_"; - let mut buf: Vec = vec![0; len]; - for chunk in buf.chunks_mut(u32::max_value() as usize) { - syscall!( - BCryptGenRandom( - 0, - chunk.as_mut_ptr(), - chunk.len() as u32, - BCRYPT_USE_SYSTEM_PREFERRED_RNG, - ), - PartialEq::ne, - STATUS_SUCCESS - )?; - } - let result: String = buf - .into_iter() - .map(|r| { - // We pick from 64=2^6 characters so we can use a simple bitshift. - let idx = r >> (8 - 6); - char::from(GEN_ASCII_STR_CHARSET[idx as usize]) - }) - .collect(); - Ok(result) -} - -fn temp_path(len: usize) -> io::Result { - let dir = std::env::temp_dir(); - // Retry a few times in case of collisions - for _ in 0..10 { - let rand_str = sample_ascii_string(len)?; - let filename = format!(".tmp-{rand_str}.socket"); - let path = dir.join(filename); - if !path.exists() { - return Ok(path); + .collect(); + Ok(result) + } + + fn temp_path(len: usize) -> io::Result { + let dir = std::env::temp_dir(); + // Retry a few times in case of collisions + for _ in 0..10 { + let rand_str = sample_ascii_string(len)?; + let filename = format!(".tmp-{rand_str}.socket"); + let path = dir.join(filename); + if !path.exists() { + return Ok(path); + } } - } - Err(io::Error::new( - io::ErrorKind::AlreadyExists, - "too many temporary files exist", - )) -} + Err(io::Error::new( + io::ErrorKind::AlreadyExists, + "too many temporary files exist", + )) + } } diff --git a/tests/close_on_drop.rs b/tests/close_on_drop.rs index a2e88d9de..8d9eefcca 100644 --- a/tests/close_on_drop.rs +++ b/tests/close_on_drop.rs @@ -58,7 +58,7 @@ impl TestHandler { AfterRead => {} } - let mut buf = vec![0; 1024]; + let mut buf = Vec::with_capacity(1024); match self.cli.read(&mut buf) { Ok(0) => self.shutdown = true, diff --git a/tests/unix_listener.rs b/tests/unix_listener.rs index 90a316599..c131497cc 100644 --- a/tests/unix_listener.rs +++ b/tests/unix_listener.rs @@ -1,4 +1,4 @@ -#![cfg(all(feature = "os-poll", feature = "net", any(unix, feature = "os-ext")))] +#![cfg(all(feature = "os-poll", feature = "net"))] #[cfg(windows)] use mio::net; diff --git a/tests/unix_stream.rs b/tests/unix_stream.rs index 46c6a935f..4eebc9c19 100644 --- a/tests/unix_stream.rs +++ b/tests/unix_stream.rs @@ -491,7 +491,7 @@ fn new_listener( handle_stream: F ) -> (thread::JoinHandle<()>, net::SocketAddr) where - F: Fn(&net::UnixStream) + std::marker::Send + 'static + F: Fn(net::UnixStream) + std::marker::Send + 'static { let (addr_sender, addr_receiver) = channel(); let handle = thread::spawn(move || { @@ -512,7 +512,7 @@ where .unwrap(); let (stream, _) = listener.accept().unwrap(); assert_would_block(listener.accept()); - handle_stream(&stream); + handle_stream(stream); } }); (handle, addr_receiver.recv().unwrap()) @@ -525,7 +525,7 @@ fn new_listener( handle_stream: F ) -> (thread::JoinHandle<()>, net::SocketAddr) where - F: Fn(&net::UnixStream) + std::marker::Send + 'static + F: Fn(net::UnixStream) + std::marker::Send + 'static { let (addr_sender, addr_receiver) = channel(); let handle = thread::spawn(move || { From d9d4bb3424b94b1da94164299fcc975231bfd91e Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Fri, 26 Aug 2022 18:26:23 +0200 Subject: [PATCH 28/34] fmt --- tests/unix_stream.rs | 78 ++++++++++++++++++++------------------------ 1 file changed, 35 insertions(+), 43 deletions(-) diff --git a/tests/unix_stream.rs b/tests/unix_stream.rs index 4eebc9c19..a5ae8fa1e 100644 --- a/tests/unix_stream.rs +++ b/tests/unix_stream.rs @@ -488,10 +488,10 @@ where fn new_listener( connections: usize, test_name: &'static str, - handle_stream: F + handle_stream: F, ) -> (thread::JoinHandle<()>, net::SocketAddr) where - F: Fn(net::UnixStream) + std::marker::Send + 'static + F: Fn(net::UnixStream) + std::marker::Send + 'static, { let (addr_sender, addr_receiver) = channel(); let handle = thread::spawn(move || { @@ -522,10 +522,10 @@ where fn new_listener( connections: usize, test_name: &'static str, - handle_stream: F + handle_stream: F, ) -> (thread::JoinHandle<()>, net::SocketAddr) where - F: Fn(net::UnixStream) + std::marker::Send + 'static + F: Fn(net::UnixStream) + std::marker::Send + 'static, { let (addr_sender, addr_receiver) = channel(); let handle = thread::spawn(move || { @@ -546,38 +546,34 @@ fn new_echo_listener( connections: usize, test_name: &'static str, ) -> (thread::JoinHandle<()>, net::SocketAddr) { - new_listener( - connections, - test_name, - |mut stream| { - // On Linux based system it will cause a connection reset - // error when the reading side of the peer connection is - // shutdown, we don't consider it an actual here. - let (mut read, mut written) = (0, 0); - let mut buf = [0; DEFAULT_BUF_SIZE]; - loop { - let n = match stream.read(&mut buf) { - Ok(amount) => { - read += amount; - amount - } - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => continue, - Err(ref err) if err.kind() == io::ErrorKind::ConnectionReset => break, - Err(err) => panic!("{}", err), - }; - if n == 0 { - break; + new_listener(connections, test_name, |mut stream| { + // On Linux based system it will cause a connection reset + // error when the reading side of the peer connection is + // shutdown, we don't consider it an actual here. + let (mut read, mut written) = (0, 0); + let mut buf = [0; DEFAULT_BUF_SIZE]; + loop { + let n = match stream.read(&mut buf) { + Ok(amount) => { + read += amount; + amount } - match stream.write(&buf[..n]) { - Ok(amount) => written += amount, - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => continue, - Err(ref err) if err.kind() == io::ErrorKind::BrokenPipe => break, - Err(err) => panic!("{}", err), - }; + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => continue, + Err(ref err) if err.kind() == io::ErrorKind::ConnectionReset => break, + Err(err) => panic!("{}", err), + }; + if n == 0 { + break; } - assert_eq!(read, written, "unequal reads and writes"); + match stream.write(&buf[..n]) { + Ok(amount) => written += amount, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => continue, + Err(ref err) if err.kind() == io::ErrorKind::BrokenPipe => break, + Err(err) => panic!("{}", err), + }; } - ) + assert_eq!(read, written, "unequal reads and writes"); + }) } fn new_noop_listener( @@ -585,14 +581,10 @@ fn new_noop_listener( barrier: Arc, test_name: &'static str, ) -> (thread::JoinHandle<()>, net::SocketAddr) { - new_listener( - connections, - test_name, - move |stream| { - barrier.wait(); - stream.shutdown(Shutdown::Write).unwrap(); - barrier.wait(); - drop(stream); - } - ) + new_listener(connections, test_name, move |stream| { + barrier.wait(); + stream.shutdown(Shutdown::Write).unwrap(); + barrier.wait(); + drop(stream); + }) } From 82b17e83e870260007d244ad0147a42f8c7dbef8 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Mon, 12 Sep 2022 14:58:10 -0700 Subject: [PATCH 29/34] remove unrelated code changes --- Cargo.toml | 1 - src/net/tcp/stream.rs | 20 ++-- src/net/uds/listener.rs | 7 ++ src/net/uds/mod.rs | 1 + src/net/uds/stream.rs | 89 +++++++++------ src/sys/unix/pipe.rs | 20 ++-- src/sys/unix/uds/mod.rs | 4 +- src/sys/windows/iocp.rs | 2 +- src/sys/windows/mod.rs | 45 ++++---- src/sys/windows/stdnet/addr.rs | 25 ++--- src/sys/windows/stdnet/listener.rs | 74 ++++++------ src/sys/windows/stdnet/socket.rs | 52 ++++----- src/sys/windows/stdnet/stream.rs | 173 ++++++++--------------------- src/sys/windows/uds/listener.rs | 1 - src/sys/windows/uds/stream.rs | 7 -- tests/unix_pipe.rs | 4 +- tests/unix_stream.rs | 2 + 17 files changed, 231 insertions(+), 296 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9f62ac3b4..8433f91ca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,7 +54,6 @@ features = [ "Win32_Storage_FileSystem", # Enables NtCreateFile "Win32_Foundation", # Basic types eg HANDLE "Win32_Networking_WinSock", # winsock2 types/functions - "Win32_Security_Cryptography", # Random number generation "Win32_System_IO", # IO types like OVERLAPPED etc "Win32_System_WindowsProgramming", # General future used for various types/funcs ] diff --git a/src/net/tcp/stream.rs b/src/net/tcp/stream.rs index a7a9aa1ba..532e7d9b6 100644 --- a/src/net/tcp/stream.rs +++ b/src/net/tcp/stream.rs @@ -269,49 +269,49 @@ impl TcpStream { impl Read for TcpStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|mut inner| inner.read(buf)) + self.inner.do_io(|inner| (&*inner).read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|mut inner| inner.read_vectored(bufs)) + self.inner.do_io(|inner| (&*inner).read_vectored(bufs)) } } impl<'a> Read for &'a TcpStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|mut inner| inner.read(buf)) + self.inner.do_io(|inner| (&*inner).read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|mut inner| inner.read_vectored(bufs)) + self.inner.do_io(|inner| (&*inner).read_vectored(bufs)) } } impl Write for TcpStream { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|mut inner| inner.write(buf)) + self.inner.do_io(|inner| (&*inner).write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|mut inner| inner.write_vectored(bufs)) + self.inner.do_io(|inner| (&*inner).write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|mut inner| inner.flush()) + self.inner.do_io(|inner| (&*inner).flush()) } } impl<'a> Write for &'a TcpStream { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|mut inner| inner.write(buf)) + self.inner.do_io(|inner| (&*inner).write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|mut inner| inner.write_vectored(bufs)) + self.inner.do_io(|inner| (&*inner).write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|mut inner| inner.flush()) + self.inner.do_io(|inner| (&*inner).flush()) } } diff --git a/src/net/uds/listener.rs b/src/net/uds/listener.rs index 0265048aa..365b39c54 100644 --- a/src/net/uds/listener.rs +++ b/src/net/uds/listener.rs @@ -31,6 +31,7 @@ impl UnixListener { /// about the underlying listener; it is left up to the user to set it in /// non-blocking mode. #[cfg(unix)] + #[cfg_attr(docsrs, doc(cfg(unix)))] pub fn from_std(listener: net::UnixListener) -> UnixListener { UnixListener { inner: IoSource::new(listener), @@ -94,6 +95,7 @@ impl fmt::Debug for UnixListener { } #[cfg(unix)] +#[cfg_attr(docsrs, doc(cfg(unix)))] impl IntoRawFd for UnixListener { fn into_raw_fd(self) -> RawFd { self.inner.into_inner().into_raw_fd() @@ -101,6 +103,7 @@ impl IntoRawFd for UnixListener { } #[cfg(unix)] +#[cfg_attr(docsrs, doc(cfg(unix)))] impl AsRawFd for UnixListener { fn as_raw_fd(&self) -> RawFd { self.inner.as_raw_fd() @@ -108,6 +111,7 @@ impl AsRawFd for UnixListener { } #[cfg(unix)] +#[cfg_attr(docsrs, doc(cfg(unix)))] impl FromRawFd for UnixListener { /// Converts a `RawFd` to a `UnixListener`. /// @@ -121,6 +125,7 @@ impl FromRawFd for UnixListener { } #[cfg(windows)] +#[cfg_attr(docsrs, doc(cfg(windows)))] impl IntoRawSocket for UnixListener { fn into_raw_socket(self) -> RawSocket { self.inner.into_inner().into_raw_socket() @@ -128,6 +133,7 @@ impl IntoRawSocket for UnixListener { } #[cfg(windows)] +#[cfg_attr(docsrs, doc(cfg(windows)))] impl AsRawSocket for UnixListener { fn as_raw_socket(&self) -> RawSocket { self.inner.as_raw_socket() @@ -135,6 +141,7 @@ impl AsRawSocket for UnixListener { } #[cfg(windows)] +#[cfg_attr(docsrs, doc(cfg(windows)))] impl FromRawSocket for UnixListener { unsafe fn from_raw_socket(sock: RawSocket) -> Self { UnixListener::from_std(FromRawSocket::from_raw_socket(sock)) diff --git a/src/net/uds/mod.rs b/src/net/uds/mod.rs index c0a77bbf2..a48a713fe 100644 --- a/src/net/uds/mod.rs +++ b/src/net/uds/mod.rs @@ -1,6 +1,7 @@ #[cfg(unix)] mod datagram; #[cfg(unix)] +#[cfg_attr(docsrs, doc(cfg(unix)))] pub use self::datagram::UnixDatagram; mod listener; diff --git a/src/net/uds/stream.rs b/src/net/uds/stream.rs index 9c73dafa4..2fd95a8c8 100644 --- a/src/net/uds/stream.rs +++ b/src/net/uds/stream.rs @@ -41,6 +41,7 @@ impl UnixStream { /// should already be connected via some other means (be it manually, or /// the standard library). #[cfg(unix)] + #[cfg_attr(docsrs, doc(cfg(unix)))] pub fn from_std(stream: net::UnixStream) -> UnixStream { UnixStream { inner: IoSource::new(stream), @@ -57,6 +58,8 @@ impl UnixStream { /// Creates an unnamed pair of connected sockets. /// /// Returns two `UnixStream`s which are connected to each other. + #[cfg(unix)] + #[cfg_attr(docsrs, doc(cfg(unix)))] pub fn pair() -> io::Result<(UnixStream, UnixStream)> { sys::uds::stream::pair().map(|(stream1, stream2)| { (UnixStream::from_std(stream1), UnixStream::from_std(stream2)) @@ -158,11 +161,43 @@ impl UnixStream { /// use std::io; /// use std::os::windows::io::AsRawSocket; /// use std::os::raw::c_int; - /// use mio::net::UnixStream; + /// use mio::net::{UnixStream, UnixListener}; /// use windows_sys::Win32::Networking::WinSock; /// use std::convert::TryInto; /// - /// let (stream1, stream2) = UnixStream::pair()?; + /// let file_path = std::env::temp_dir().join("server.sock"); + /// # let _ = std::fs::remove_file(&file_path); + /// let server = UnixListener::bind(&file_path).unwrap(); + /// + /// let handle = std::thread::spawn(move || { + /// if let Ok((stream2, _)) = server.accept() { + /// // Wait until the stream is readable... + /// + /// // Read from the stream using a direct WinSock call, of course the + /// // `io::Read` implementation would be easier to use. + /// let mut buf = [0; 512]; + /// let n = stream2.try_io(|| { + /// let res = unsafe { + /// WinSock::recv( + /// stream2.as_raw_socket().try_into().unwrap(), + /// &mut buf as *mut _ as *mut _, + /// buf.len() as c_int, + /// 0 + /// ) + /// }; + /// if res != WinSock::SOCKET_ERROR { + /// Ok(res as usize) + /// } else { + /// // If EAGAIN or EWOULDBLOCK is set by WinSock::recv, the closure + /// // should return `WouldBlock` error. + /// Err(io::Error::last_os_error()) + /// } + /// }).unwrap(); + /// eprintln!("read {} bytes", n); + /// } + /// }); + /// + /// let stream1 = UnixStream::connect(&file_path).unwrap(); /// /// // Wait until the stream is writable... /// @@ -190,29 +225,7 @@ impl UnixStream { /// })?; /// eprintln!("write {} bytes", n); /// - /// // Wait until the stream is readable... - /// - /// // Read from the stream using a direct WinSock call, of course the - /// // `io::Read` implementation would be easier to use. - /// let mut buf = [0; 512]; - /// let n = stream2.try_io(|| { - /// let res = unsafe { - /// WinSock::recv( - /// stream2.as_raw_socket().try_into().unwrap(), - /// &mut buf as *mut _ as *mut _, - /// buf.len() as c_int, - /// 0 - /// ) - /// }; - /// if res != WinSock::SOCKET_ERROR { - /// Ok(res as usize) - /// } else { - /// // If EAGAIN or EWOULDBLOCK is set by WinSock::recv, the closure - /// // should return `WouldBlock` error. - /// Err(io::Error::last_os_error()) - /// } - /// })?; - /// eprintln!("read {} bytes", n); + /// # handle.join().unwrap(); /// # Ok(()) /// # } /// ``` @@ -226,49 +239,49 @@ impl UnixStream { impl Read for UnixStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|mut inner| inner.read(buf)) + self.inner.do_io(|inner| (&*inner).read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|mut inner| inner.read_vectored(bufs)) + self.inner.do_io(|inner| (&*inner).read_vectored(bufs)) } } impl<'a> Read for &'a UnixStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|mut inner| inner.read(buf)) + self.inner.do_io(|inner| (&*inner).read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|mut inner| inner.read_vectored(bufs)) + self.inner.do_io(|inner| (&*inner).read_vectored(bufs)) } } impl Write for UnixStream { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|mut inner| inner.write(buf)) + self.inner.do_io(|inner| (&*inner).write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|mut inner| inner.write_vectored(bufs)) + self.inner.do_io(|inner| (&*inner).write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|mut inner| inner.flush()) + self.inner.do_io(|inner| (&*inner).flush()) } } impl<'a> Write for &'a UnixStream { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|mut inner| inner.write(buf)) + self.inner.do_io(|inner| (&*inner).write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|mut inner| inner.write_vectored(bufs)) + self.inner.do_io(|inner| (&*inner).write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|mut inner| inner.flush()) + self.inner.do_io(|inner| (&*inner).flush()) } } @@ -303,6 +316,7 @@ impl fmt::Debug for UnixStream { } #[cfg(unix)] +#[cfg_attr(docsrs, doc(cfg(unix)))] impl IntoRawFd for UnixStream { fn into_raw_fd(self) -> RawFd { self.inner.into_inner().into_raw_fd() @@ -310,6 +324,7 @@ impl IntoRawFd for UnixStream { } #[cfg(unix)] +#[cfg_attr(docsrs, doc(cfg(unix)))] impl AsRawFd for UnixStream { fn as_raw_fd(&self) -> RawFd { self.inner.as_raw_fd() @@ -317,6 +332,7 @@ impl AsRawFd for UnixStream { } #[cfg(unix)] +#[cfg_attr(docsrs, doc(cfg(unix)))] impl FromRawFd for UnixStream { /// Converts a `RawFd` to a `UnixStream`. /// @@ -330,6 +346,7 @@ impl FromRawFd for UnixStream { } #[cfg(windows)] +#[cfg_attr(docsrs, doc(cfg(windows)))] impl IntoRawSocket for UnixStream { fn into_raw_socket(self) -> RawSocket { self.inner.into_inner().into_raw_socket() @@ -337,6 +354,7 @@ impl IntoRawSocket for UnixStream { } #[cfg(windows)] +#[cfg_attr(docsrs, doc(cfg(windows)))] impl AsRawSocket for UnixStream { fn as_raw_socket(&self) -> RawSocket { self.inner.as_raw_socket() @@ -344,6 +362,7 @@ impl AsRawSocket for UnixStream { } #[cfg(windows)] +#[cfg_attr(docsrs, doc(cfg(windows)))] impl FromRawSocket for UnixStream { unsafe fn from_raw_socket(sock: RawSocket) -> Self { UnixStream::from_std(FromRawSocket::from_raw_socket(sock)) diff --git a/src/sys/unix/pipe.rs b/src/sys/unix/pipe.rs index 7a95b9697..b2865cda7 100644 --- a/src/sys/unix/pipe.rs +++ b/src/sys/unix/pipe.rs @@ -313,29 +313,29 @@ impl event::Source for Sender { impl Write for Sender { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|mut sender| sender.write(buf)) + self.inner.do_io(|sender| (&*sender).write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|mut sender| sender.write_vectored(bufs)) + self.inner.do_io(|sender| (&*sender).write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|mut sender| sender.flush()) + self.inner.do_io(|sender| (&*sender).flush()) } } impl Write for &Sender { fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.do_io(|mut sender| sender.write(buf)) + self.inner.do_io(|sender| (&*sender).write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.inner.do_io(|mut sender| sender.write_vectored(bufs)) + self.inner.do_io(|sender| (&*sender).write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|mut sender| sender.flush()) + self.inner.do_io(|sender| (&*sender).flush()) } } @@ -478,21 +478,21 @@ impl event::Source for Receiver { impl Read for Receiver { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|mut sender| sender.read(buf)) + self.inner.do_io(|sender| (&*sender).read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|mut sender| sender.read_vectored(bufs)) + self.inner.do_io(|sender| (&*sender).read_vectored(bufs)) } } impl Read for &Receiver { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.do_io(|mut sender| sender.read(buf)) + self.inner.do_io(|sender| (&*sender).read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { - self.inner.do_io(|mut sender| sender.read_vectored(bufs)) + self.inner.do_io(|sender| (&*sender).read_vectored(bufs)) } } diff --git a/src/sys/unix/uds/mod.rs b/src/sys/unix/uds/mod.rs index 526bbdfd0..8e28a9573 100644 --- a/src/sys/unix/uds/mod.rs +++ b/src/sys/unix/uds/mod.rs @@ -40,7 +40,7 @@ cfg_os_poll! { sockaddr.sun_family = libc::AF_UNIX as libc::sa_family_t; let bytes = path.as_os_str().as_bytes(); - match (bytes.first(), bytes.len().cmp(&sockaddr.sun_path.len())) { + match (bytes.get(0), bytes.len().cmp(&sockaddr.sun_path.len())) { // Abstract paths don't need a null terminator (Some(&0), Ordering::Greater) => { return Err(io::Error::new( @@ -64,7 +64,7 @@ cfg_os_poll! { let offset = path_offset(&sockaddr); let mut socklen = offset + bytes.len(); - match bytes.first() { + match bytes.get(0) { // The struct has already been zeroes so the null byte for pathname // addresses is already there. Some(&0) | None => {} diff --git a/src/sys/windows/iocp.rs b/src/sys/windows/iocp.rs index f7651daa8..d75f3826e 100644 --- a/src/sys/windows/iocp.rs +++ b/src/sys/windows/iocp.rs @@ -260,6 +260,6 @@ mod tests { } assert_eq!(s[2].bytes_transferred(), 0); assert_eq!(s[2].token(), 0); - assert_eq!(s[2].overlapped(), std::ptr::null_mut()); + assert_eq!(s[2].overlapped(), 0 as *mut _); } } diff --git a/src/sys/windows/mod.rs b/src/sys/windows/mod.rs index 0cb58bc64..fb596b967 100644 --- a/src/sys/windows/mod.rs +++ b/src/sys/windows/mod.rs @@ -1,3 +1,17 @@ +// Macro must be defined before any modules that uses them. +/// Helper macro to execute a system call that returns an `io::Result`. +#[allow(unused_macros)] +macro_rules! syscall { + ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{ + let res = unsafe { $fn($($arg, )*) }; + if $err_test(&res, &$err_value) { + Err(io::Error::last_os_error()) + } else { + Ok(res) + } + }}; +} + /// Helper macro to execute a WinSock system call that returns an `io::Result`. #[allow(unused_macros)] macro_rules! wsa_syscall { @@ -13,6 +27,12 @@ macro_rules! wsa_syscall { }}; } +cfg_net! { + pub(crate) mod stdnet; + pub(crate) mod uds; + pub use self::uds::SocketAddr; +} + cfg_os_poll! { mod afd; @@ -33,27 +53,10 @@ cfg_os_poll! { // Macros must be defined before the modules that use them cfg_net! { - /// Helper macro to execute a system call that returns an `io::Result`. - // - // Macro must be defined before any modules that uses them. - macro_rules! syscall { - ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{ - let res = unsafe { $fn($($arg, )*) }; - if $err_test(&res, &$err_value) { - Err(io::Error::last_os_error()) - } else { - Ok(res) - } - }}; - } - mod net; - pub(crate) mod stdnet; pub(crate) mod tcp; pub(crate) mod udp; - pub(crate) mod uds; - pub use self::uds::SocketAddr; } cfg_os_ext! { @@ -170,11 +173,3 @@ cfg_os_poll! { } } } - -cfg_not_os_poll! { - cfg_net! { - pub(crate) mod stdnet; - pub(crate) mod uds; - pub use self::uds::SocketAddr; - } -} diff --git a/src/sys/windows/stdnet/addr.rs b/src/sys/windows/stdnet/addr.rs index 55433db26..3c5d2b772 100644 --- a/src/sys/windows/stdnet/addr.rs +++ b/src/sys/windows/stdnet/addr.rs @@ -51,14 +51,13 @@ cfg_os_poll! { "path must be shorter than SUN_LEN", )); } - for (dst, src) in sockaddr.sun_path.iter_mut().zip(bytes.iter()) { - *dst = *src as u8; - } + + sockaddr.sun_path[..bytes.len()].copy_from_slice(bytes); let offset = path_offset(&sockaddr); let mut socklen = offset + bytes.len(); - match bytes.get(0) { + match bytes.first() { // The struct has already been zeroes so the null byte for pathname // addresses is already there. Some(&0) | None => {} @@ -90,7 +89,6 @@ impl<'a> fmt::Display for AsciiEscaped<'a> { } /// An address associated with a Unix socket -#[derive(Copy, Clone)] pub struct SocketAddr { addr: sockaddr_un, len: c_int, @@ -108,7 +106,13 @@ impl SocketAddr { let mut len = mem::size_of::() as c_int; let result = f(&mut sockaddr as *mut _ as *mut _, &mut len)?; - Ok((result, SocketAddr::from_parts(sockaddr, len))) + Ok(( + result, + SocketAddr { + addr: sockaddr, + len, + }, + )) } pub(crate) fn new(f: F) -> io::Result @@ -118,15 +122,6 @@ impl SocketAddr { SocketAddr::init(f).map(|(_, addr)| addr) } - pub(crate) fn from_parts(addr: sockaddr_un, mut len: c_int) -> SocketAddr { - if len == 0 { - // When there is a datagram from unnamed unix socket - // linux returns zero bytes of address - len = path_offset(&addr) as c_int; // i.e. zero-length address - } - SocketAddr { addr, len } - } - /// Returns true if and only if the address is unnamed. pub fn is_unnamed(&self) -> bool { matches!(self.address(), AddressKind::Unnamed) diff --git a/src/sys/windows/stdnet/listener.rs b/src/sys/windows/stdnet/listener.rs index 4d849d10d..214167276 100644 --- a/src/sys/windows/stdnet/listener.rs +++ b/src/sys/windows/stdnet/listener.rs @@ -7,19 +7,8 @@ use super::{socket::Socket, SocketAddr}; pub(crate) struct UnixListener(Socket); -impl fmt::Debug for UnixListener { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut builder = fmt.debug_struct("UnixListener"); - builder.field("socket", &self.0.as_raw_socket()); - if let Ok(addr) = self.local_addr() { - builder.field("local", &addr); - } - builder.finish() - } -} - impl UnixListener { - pub fn local_addr(&self) -> io::Result { + pub(crate) fn local_addr(&self) -> io::Result { SocketAddr::new(|addr, len| { wsa_syscall!( getsockname(self.0.as_raw_socket() as _, addr, len), @@ -28,38 +17,18 @@ impl UnixListener { }) } - pub fn take_error(&self) -> io::Result> { + pub(crate) fn take_error(&self) -> io::Result> { self.0.take_error() } } -impl AsRawSocket for UnixListener { - fn as_raw_socket(&self) -> RawSocket { - self.0.as_raw_socket() - } -} - -impl FromRawSocket for UnixListener { - unsafe fn from_raw_socket(sock: RawSocket) -> Self { - UnixListener(Socket::from_raw_socket(sock)) - } -} - -impl IntoRawSocket for UnixListener { - fn into_raw_socket(self) -> RawSocket { - let ret = self.0.as_raw_socket(); - mem::forget(self); - ret - } -} - cfg_os_poll! { use std::path::Path; use super::{socket_addr, UnixStream}; impl UnixListener { - pub fn bind>(path: P) -> io::Result { + pub(crate) fn bind>(path: P) -> io::Result { let inner = Socket::new()?; let (addr, len) = socket_addr(path.as_ref())?; @@ -67,17 +36,48 @@ cfg_os_poll! { bind(inner.as_raw_socket() as _, &addr as *const _ as *const _, len as _), SOCKET_ERROR )?; - wsa_syscall!(listen(inner.as_raw_socket() as _, 128), SOCKET_ERROR)?; + wsa_syscall!(listen(inner.as_raw_socket() as _, 1024), SOCKET_ERROR)?; Ok(UnixListener(inner)) } - pub fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { + pub(crate) fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { SocketAddr::init(|addr, len| self.0.accept(addr, len)) .map(|(sock, addr)| (UnixStream(sock), addr)) } - pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + pub(crate) fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { self.0.set_nonblocking(nonblocking) } } } + +impl fmt::Debug for UnixListener { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = fmt.debug_struct("UnixListener"); + builder.field("socket", &self.0.as_raw_socket()); + if let Ok(addr) = self.local_addr() { + builder.field("local", &addr); + } + builder.finish() + } +} + +impl AsRawSocket for UnixListener { + fn as_raw_socket(&self) -> RawSocket { + self.0.as_raw_socket() + } +} + +impl FromRawSocket for UnixListener { + unsafe fn from_raw_socket(sock: RawSocket) -> Self { + UnixListener(Socket::from_raw_socket(sock)) + } +} + +impl IntoRawSocket for UnixListener { + fn into_raw_socket(self) -> RawSocket { + let ret = self.0.as_raw_socket(); + mem::forget(self); + ret + } +} diff --git a/src/sys/windows/stdnet/socket.rs b/src/sys/windows/stdnet/socket.rs index 869b5e60c..12312672c 100644 --- a/src/sys/windows/stdnet/socket.rs +++ b/src/sys/windows/stdnet/socket.rs @@ -121,32 +121,6 @@ impl Socket { } } -impl Drop for Socket { - fn drop(&mut self) { - let _ = unsafe { closesocket(self.0) }; - } -} - -impl AsRawSocket for Socket { - fn as_raw_socket(&self) -> RawSocket { - self.0 as RawSocket - } -} - -impl FromRawSocket for Socket { - unsafe fn from_raw_socket(sock: RawSocket) -> Self { - Socket(sock as SOCKET) - } -} - -impl IntoRawSocket for Socket { - fn into_raw_socket(self) -> RawSocket { - let ret = self.0 as RawSocket; - mem::forget(self); - ret - } -} - cfg_os_poll! { use windows_sys::Win32::Networking::WinSock::{INVALID_SOCKET, SOCKADDR}; use windows_sys::Win32::Foundation::{SetHandleInformation, HANDLE, HANDLE_FLAG_INHERIT}; @@ -195,3 +169,29 @@ cfg_os_poll! { } } } + +impl Drop for Socket { + fn drop(&mut self) { + let _ = unsafe { closesocket(self.0) }; + } +} + +impl AsRawSocket for Socket { + fn as_raw_socket(&self) -> RawSocket { + self.0 as RawSocket + } +} + +impl FromRawSocket for Socket { + unsafe fn from_raw_socket(sock: RawSocket) -> Self { + Socket(sock as SOCKET) + } +} + +impl IntoRawSocket for Socket { + fn into_raw_socket(self) -> RawSocket { + let ret = self.0 as RawSocket; + mem::forget(self); + ret + } +} diff --git a/src/sys/windows/stdnet/stream.rs b/src/sys/windows/stdnet/stream.rs index 9f9548f8b..ce1da2f54 100644 --- a/src/sys/windows/stdnet/stream.rs +++ b/src/sys/windows/stdnet/stream.rs @@ -9,22 +9,8 @@ use super::{socket::Socket, SocketAddr}; pub(crate) struct UnixStream(pub(super) Socket); -impl fmt::Debug for UnixStream { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut builder = fmt.debug_struct("UnixStream"); - builder.field("socket", &self.0.as_raw_socket()); - if let Ok(addr) = self.local_addr() { - builder.field("local", &addr); - } - if let Ok(addr) = self.peer_addr() { - builder.field("peer", &addr); - } - builder.finish() - } -} - impl UnixStream { - pub fn local_addr(&self) -> io::Result { + pub(crate) fn local_addr(&self) -> io::Result { SocketAddr::new(|addr, len| { wsa_syscall!( getsockname(self.0.as_raw_socket() as _, addr, len), @@ -33,7 +19,7 @@ impl UnixStream { }) } - pub fn peer_addr(&self) -> io::Result { + pub(crate) fn peer_addr(&self) -> io::Result { SocketAddr::new(|addr, len| { wsa_syscall!( getpeername(self.0.as_raw_socket() as _, addr, len), @@ -42,15 +28,60 @@ impl UnixStream { }) } - pub fn take_error(&self) -> io::Result> { + pub(crate) fn take_error(&self) -> io::Result> { self.0.take_error() } - pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { + pub(crate) fn shutdown(&self, how: Shutdown) -> io::Result<()> { self.0.shutdown(how) } } +cfg_os_poll! { + use std::path::Path; + use windows_sys::Win32::Networking::WinSock::WSAEINPROGRESS; + use super::socket_addr; + + impl UnixStream { + pub(crate) fn connect>(path: P) -> io::Result { + let inner = Socket::new()?; + let (addr, len) = socket_addr(path.as_ref())?; + + match wsa_syscall!( + connect( + inner.as_raw_socket() as _, + &addr as *const _ as *const _, + len as i32, + ), + SOCKET_ERROR + ) { + Ok(_) => {} + Err(ref err) if err.raw_os_error() == Some(WSAEINPROGRESS) => {} + Err(e) => return Err(e), + } + Ok(UnixStream(inner)) + } + + pub(crate) fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.0.set_nonblocking(nonblocking) + } + } +} + +impl fmt::Debug for UnixStream { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = fmt.debug_struct("UnixStream"); + builder.field("socket", &self.0.as_raw_socket()); + if let Ok(addr) = self.local_addr() { + builder.field("local", &addr); + } + if let Ok(addr) = self.peer_addr() { + builder.field("peer", &addr); + } + builder.finish() + } +} + impl io::Read for UnixStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { io::Read::read(&mut &*self, buf) @@ -118,109 +149,3 @@ impl IntoRawSocket for UnixStream { ret } } - -cfg_os_poll! { - use std::path::{Path, PathBuf}; - use windows_sys::Win32::Foundation::STATUS_SUCCESS; - use windows_sys::Win32::Networking::WinSock::WSAEINPROGRESS; - use windows_sys::Win32::Security::Cryptography::{ - BCryptGenRandom, BCRYPT_USE_SYSTEM_PREFERRED_RNG, - }; - - use super::{socket_addr, UnixListener}; - - impl UnixStream { - pub fn connect>(path: P) -> io::Result { - let inner = Socket::new()?; - let (addr, len) = socket_addr(path.as_ref())?; - - match wsa_syscall!( - connect( - inner.as_raw_socket() as _, - &addr as *const _ as *const _, - len as i32, - ), - SOCKET_ERROR - ) { - Ok(_) => {} - Err(ref err) if err.raw_os_error() == Some(WSAEINPROGRESS) => {} - Err(e) => return Err(e), - } - Ok(UnixStream(inner)) - } - - pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { - self.0.set_nonblocking(nonblocking) - } - - pub fn pair() -> io::Result<(Self, Self)> { - use std::sync::{Arc, RwLock}; - use std::thread::spawn; - - let file_path = temp_path(10)?; - let a: Arc>>> = Arc::new(RwLock::new(None)); - let ul = UnixListener::bind(&file_path).unwrap(); - let server = { - let a = a.clone(); - spawn(move || { - let mut store = a.write().unwrap(); - let stream0 = ul.accept().map(|s| s.0); - *store = Some(stream0); - }) - }; - let stream1 = UnixStream::connect(&file_path)?; - server - .join() - .map_err(|_| io::Error::from(io::ErrorKind::ConnectionRefused))?; - let stream0 = (*(a.write().unwrap())).take().unwrap()?; - let _ = std::fs::remove_file(&file_path); - Ok((stream0, stream1)) - } - } - - fn sample_ascii_string(len: usize) -> io::Result { - const GEN_ASCII_STR_CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\ - abcdefghijklmnopqrstuvwxyz\ - 0123456789-_"; - let mut buf: Vec = vec![0; len]; - for chunk in buf.chunks_mut(u32::max_value() as usize) { - syscall!( - BCryptGenRandom( - 0, - chunk.as_mut_ptr(), - chunk.len() as u32, - BCRYPT_USE_SYSTEM_PREFERRED_RNG, - ), - PartialEq::ne, - STATUS_SUCCESS - )?; - } - let result: String = buf - .into_iter() - .map(|r| { - // We pick from 64=2^6 characters so we can use a simple bitshift. - let idx = r >> (8 - 6); - char::from(GEN_ASCII_STR_CHARSET[idx as usize]) - }) - .collect(); - Ok(result) - } - - fn temp_path(len: usize) -> io::Result { - let dir = std::env::temp_dir(); - // Retry a few times in case of collisions - for _ in 0..10 { - let rand_str = sample_ascii_string(len)?; - let filename = format!(".tmp-{rand_str}.socket"); - let path = dir.join(filename); - if !path.exists() { - return Ok(path); - } - } - - Err(io::Error::new( - io::ErrorKind::AlreadyExists, - "too many temporary files exist", - )) - } -} diff --git a/src/sys/windows/uds/listener.rs b/src/sys/windows/uds/listener.rs index d6a7a25cb..df16542c9 100644 --- a/src/sys/windows/uds/listener.rs +++ b/src/sys/windows/uds/listener.rs @@ -12,7 +12,6 @@ pub(crate) fn bind(path: &Path) -> io::Result { } pub(crate) fn accept(listener: &net::UnixListener) -> io::Result<(UnixStream, SocketAddr)> { - listener.set_nonblocking(true)?; listener .accept() .map(|(stream, addr)| (UnixStream::from_std(stream), addr)) diff --git a/src/sys/windows/uds/stream.rs b/src/sys/windows/uds/stream.rs index 45e4dd720..ef2a66bfb 100644 --- a/src/sys/windows/uds/stream.rs +++ b/src/sys/windows/uds/stream.rs @@ -9,13 +9,6 @@ pub(crate) fn connect(path: &Path) -> io::Result { Ok(socket) } -pub(crate) fn pair() -> io::Result<(net::UnixStream, net::UnixStream)> { - let (stream0, stream1) = net::UnixStream::pair()?; - stream0.set_nonblocking(true)?; - stream1.set_nonblocking(true)?; - Ok((stream0, stream1)) -} - pub(crate) fn local_addr(socket: &net::UnixStream) -> io::Result { super::local_addr(socket.as_raw_socket()) } diff --git a/tests/unix_pipe.rs b/tests/unix_pipe.rs index f8e6464c9..a83e3833b 100644 --- a/tests/unix_pipe.rs +++ b/tests/unix_pipe.rs @@ -49,7 +49,7 @@ fn smoke() { ); let n = receiver.read(&mut buf).unwrap(); assert_eq!(n, DATA1.len()); - assert_eq!(&buf[..n], DATA1); + assert_eq!(&buf[..n], &*DATA1); } #[test] @@ -162,7 +162,7 @@ fn from_child_process_io() { let mut buf = [0; 20]; let n = receiver.read(&mut buf).unwrap(); assert_eq!(n, DATA1.len()); - assert_eq!(&buf[..n], DATA1); + assert_eq!(&buf[..n], &*DATA1); drop(sender); diff --git a/tests/unix_stream.rs b/tests/unix_stream.rs index a5ae8fa1e..42eef6c53 100644 --- a/tests/unix_stream.rs +++ b/tests/unix_stream.rs @@ -29,6 +29,7 @@ const DATA1_LEN: usize = 16; const DATA2_LEN: usize = 14; const DEFAULT_BUF_SIZE: usize = 64; const TOKEN_1: Token = Token(0); +#[cfg(unix)] const TOKEN_2: Token = Token(1); #[test] @@ -97,6 +98,7 @@ fn unix_stream_from_std() { ) } +#[cfg(unix)] #[test] fn unix_stream_pair() { let (mut poll, mut events) = init_with_poll(); From 96e07358997d7bc055408e4131ef6e1658a6d2a7 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Mon, 12 Sep 2022 16:27:49 -0700 Subject: [PATCH 30/34] fix lint --- src/sys/shell/uds.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sys/shell/uds.rs b/src/sys/shell/uds.rs index 48e568b28..5914e8944 100644 --- a/src/sys/shell/uds.rs +++ b/src/sys/shell/uds.rs @@ -68,6 +68,7 @@ pub(crate) mod stream { os_required!() } + #[cfg(unix)] pub(crate) fn pair() -> io::Result<(net::UnixStream, net::UnixStream)> { os_required!() } From 0e1b6dfe2f23cacca6985ebdb64153f6154617c0 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Mon, 12 Sep 2022 17:37:27 -0700 Subject: [PATCH 31/34] remove explicit SetHandleInformation calls --- src/sys/windows/stdnet/socket.rs | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/src/sys/windows/stdnet/socket.rs b/src/sys/windows/stdnet/socket.rs index 12312672c..55de184d9 100644 --- a/src/sys/windows/stdnet/socket.rs +++ b/src/sys/windows/stdnet/socket.rs @@ -123,13 +123,12 @@ impl Socket { cfg_os_poll! { use windows_sys::Win32::Networking::WinSock::{INVALID_SOCKET, SOCKADDR}; - use windows_sys::Win32::Foundation::{SetHandleInformation, HANDLE, HANDLE_FLAG_INHERIT}; use super::init; impl Socket { pub fn new() -> io::Result { init(); - let socket = wsa_syscall!( + wsa_syscall!( WSASocketW( WinSock::AF_UNIX.into(), WinSock::SOCK_STREAM.into(), @@ -139,24 +138,11 @@ cfg_os_poll! { WinSock::WSA_FLAG_OVERLAPPED | WinSock::WSA_FLAG_NO_HANDLE_INHERIT, ), INVALID_SOCKET - )?; - Ok(Socket(socket)) + ).map(Socket) } pub fn accept(&self, storage: *mut SOCKADDR, len: *mut c_int) -> io::Result { - let socket = wsa_syscall!(accept(self.0, storage, len), INVALID_SOCKET)?; - let socket = Socket(socket); - socket.set_no_inherit()?; - Ok(socket) - } - - fn set_no_inherit(&self) -> io::Result<()> { - syscall!( - SetHandleInformation(self.0 as HANDLE, HANDLE_FLAG_INHERIT, 0), - PartialEq::eq, - 0 - )?; - Ok(()) + wsa_syscall!(accept(self.0, storage, len), INVALID_SOCKET).map(Socket) } pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { From 77155cc9061d848791c3e0afec1bb83fa6e43f18 Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Tue, 13 Sep 2022 14:09:05 -0700 Subject: [PATCH 32/34] abstract socketaddr behind common API in net --- src/net/mod.rs | 3 + src/net/uds/addr.rs | 97 +++++++++++++++++++++++++++++++++ src/net/uds/datagram.rs | 10 +++- src/net/uds/listener.rs | 5 +- src/net/uds/mod.rs | 4 +- src/net/uds/stream.rs | 9 +-- src/sys/mod.rs | 4 +- src/sys/shell/uds.rs | 7 ++- src/sys/unix/mod.rs | 4 +- src/sys/unix/uds/listener.rs | 3 +- src/sys/unix/uds/mod.rs | 2 +- src/sys/unix/uds/socketaddr.rs | 77 +------------------------- src/sys/windows/mod.rs | 2 +- src/sys/windows/stdnet/addr.rs | 51 ++--------------- src/sys/windows/stdnet/mod.rs | 2 +- src/sys/windows/uds/listener.rs | 3 +- src/sys/windows/uds/mod.rs | 2 +- src/sys/windows/uds/stream.rs | 5 +- 18 files changed, 145 insertions(+), 145 deletions(-) create mode 100644 src/net/uds/addr.rs diff --git a/src/net/mod.rs b/src/net/mod.rs index dc5d4b388..41d81a2d4 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -37,5 +37,8 @@ mod uds; #[cfg(not(target_os = "wasi"))] pub use self::uds::{SocketAddr, UnixListener, UnixStream}; +#[cfg(not(target_os = "wasi"))] +pub(crate) use self::uds::AddressKind; + #[cfg(unix)] pub use self::uds::UnixDatagram; diff --git a/src/net/uds/addr.rs b/src/net/uds/addr.rs new file mode 100644 index 000000000..81742a004 --- /dev/null +++ b/src/net/uds/addr.rs @@ -0,0 +1,97 @@ +use std::path::Path; +use std::{ascii, fmt}; +use crate::sys; + +/// An address associated with a `mio` specific Unix socket. +/// +/// This is implemented instead of imported from [`net::SocketAddr`] because +/// there is no way to create a [`net::SocketAddr`]. One must be returned by +/// [`accept`], so this is returned instead. +/// +/// [`net::SocketAddr`]: std::os::unix::net::SocketAddr +/// [`accept`]: #method.accept +pub struct SocketAddr { + inner: sys::SocketAddr +} + +struct AsciiEscaped<'a>(&'a [u8]); + +pub(crate) enum AddressKind<'a> { + Unnamed, + Pathname(&'a Path), + Abstract(&'a [u8]), +} + +impl SocketAddr { + pub(crate) fn new(inner: sys::SocketAddr) -> Self { + SocketAddr { inner } + } + + fn address(&self) -> AddressKind<'_> { + self.inner.address() + } +} + +cfg_os_poll! { + impl SocketAddr { + /// Returns `true` if the address is unnamed. + /// + /// Documentation reflected in [`SocketAddr`] + /// + /// [`SocketAddr`]: std::os::unix::net::SocketAddr + pub fn is_unnamed(&self) -> bool { + matches!(self.address(), AddressKind::Unnamed) + } + + /// Returns the contents of this address if it is a `pathname` address. + /// + /// Documentation reflected in [`SocketAddr`] + /// + /// [`SocketAddr`]: std::os::unix::net::SocketAddr + pub fn as_pathname(&self) -> Option<&Path> { + if let AddressKind::Pathname(path) = self.address() { + Some(path) + } else { + None + } + } + + /// Returns the contents of this address if it is an abstract namespace + /// without the leading null byte. + // Link to std::os::unix::net::SocketAddr pending + // https://github.com/rust-lang/rust/issues/85410. + pub fn as_abstract_namespace(&self) -> Option<&[u8]> { + if let AddressKind::Abstract(path) = self.address() { + Some(path) + } else { + None + } + } + } +} + +impl fmt::Debug for SocketAddr { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "{:?}", self.address()) + } +} + +impl fmt::Debug for AddressKind<'_> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AddressKind::Unnamed => write!(fmt, "(unnamed)"), + AddressKind::Abstract(name) => write!(fmt, "{} (abstract)", AsciiEscaped(name)), + AddressKind::Pathname(path) => write!(fmt, "{:?} (pathname)", path), + } + } +} + +impl<'a> fmt::Display for AsciiEscaped<'a> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "\"")?; + for byte in self.0.iter().cloned().flat_map(ascii::escape_default) { + write!(fmt, "{}", byte as char)?; + } + write!(fmt, "\"") + } +} diff --git a/src/net/uds/datagram.rs b/src/net/uds/datagram.rs index e963d6e2f..57114c28b 100644 --- a/src/net/uds/datagram.rs +++ b/src/net/uds/datagram.rs @@ -1,5 +1,6 @@ use crate::io_source::IoSource; use crate::{event, sys, Interest, Registry, Token}; +use crate::net::SocketAddr; use std::net::Shutdown; use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; @@ -54,24 +55,27 @@ impl UnixDatagram { } /// Returns the address of this socket. - pub fn local_addr(&self) -> io::Result { + pub fn local_addr(&self) -> io::Result { sys::uds::datagram::local_addr(&self.inner) + .map(|addr| SocketAddr::new(addr)) } /// Returns the address of this socket's peer. /// /// The `connect` method will connect the socket to a peer. - pub fn peer_addr(&self) -> io::Result { + pub fn peer_addr(&self) -> io::Result { sys::uds::datagram::peer_addr(&self.inner) + .map(|addr| SocketAddr::new(addr)) } /// Receives data from the socket. /// /// On success, returns the number of bytes read and the address from /// whence the data came. - pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, sys::SocketAddr)> { + pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { self.inner .do_io(|inner| sys::uds::datagram::recv_from(inner, buf)) + .map(|(nread, addr)| (nread, SocketAddr::new(addr))) } /// Receives data from the socket. diff --git a/src/net/uds/listener.rs b/src/net/uds/listener.rs index 365b39c54..03b02821f 100644 --- a/src/net/uds/listener.rs +++ b/src/net/uds/listener.rs @@ -51,11 +51,12 @@ impl UnixListener { /// non-blocking mode. pub fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { self.inner.do_io(sys::uds::listener::accept) + .map(|(stream, addr)| (stream, SocketAddr::new(addr))) } /// Returns the local socket address of this listener. - pub fn local_addr(&self) -> io::Result { - sys::uds::listener::local_addr(&self.inner) + pub fn local_addr(&self) -> io::Result { + sys::uds::listener::local_addr(&self.inner).map(|addr| SocketAddr::new(addr)) } /// Returns the value of the `SO_ERROR` option. diff --git a/src/net/uds/mod.rs b/src/net/uds/mod.rs index a48a713fe..fe8a02ff3 100644 --- a/src/net/uds/mod.rs +++ b/src/net/uds/mod.rs @@ -10,4 +10,6 @@ pub use self::listener::UnixListener; mod stream; pub use self::stream::UnixStream; -pub use crate::sys::SocketAddr; +mod addr; +pub use self::addr::SocketAddr; +pub(crate) use self::addr::AddressKind; diff --git a/src/net/uds/stream.rs b/src/net/uds/stream.rs index 2fd95a8c8..d541867c6 100644 --- a/src/net/uds/stream.rs +++ b/src/net/uds/stream.rs @@ -1,4 +1,5 @@ use crate::io_source::IoSource; +use crate::net::SocketAddr; use crate::{event, sys, Interest, Registry, Token}; #[cfg(windows)] @@ -67,13 +68,13 @@ impl UnixStream { } /// Returns the socket address of the local half of this connection. - pub fn local_addr(&self) -> io::Result { - sys::uds::stream::local_addr(&self.inner) + pub fn local_addr(&self) -> io::Result { + sys::uds::stream::local_addr(&self.inner).map(|addr| SocketAddr::new(addr)) } /// Returns the socket address of the remote half of this connection. - pub fn peer_addr(&self) -> io::Result { - sys::uds::stream::peer_addr(&self.inner) + pub fn peer_addr(&self) -> io::Result { + sys::uds::stream::peer_addr(&self.inner).map(|addr| SocketAddr::new(addr)) } /// Returns the value of the `SO_ERROR` option. diff --git a/src/sys/mod.rs b/src/sys/mod.rs index ac9365263..13b180c4c 100644 --- a/src/sys/mod.rs +++ b/src/sys/mod.rs @@ -81,7 +81,7 @@ cfg_not_os_poll! { #[cfg(unix)] cfg_net! { - pub use self::unix::SocketAddr; + pub(crate) use self::unix::SocketAddr; } #[cfg(windows)] @@ -91,6 +91,6 @@ cfg_not_os_poll! { #[cfg(windows)] cfg_net! { - pub use self::windows::SocketAddr; + pub(crate) use self::windows::SocketAddr; } } diff --git a/src/sys/shell/uds.rs b/src/sys/shell/uds.rs index 5914e8944..4ff01790a 100644 --- a/src/sys/shell/uds.rs +++ b/src/sys/shell/uds.rs @@ -1,6 +1,6 @@ #[cfg(unix)] pub(crate) mod datagram { - use crate::net::SocketAddr; + use crate::sys::SocketAddr; use std::io; use std::os::unix::net; use std::path::Path; @@ -34,7 +34,8 @@ pub(crate) mod datagram { } pub(crate) mod listener { - use crate::net::{SocketAddr, UnixStream}; + use crate::net::UnixStream; + use crate::sys::SocketAddr; #[cfg(windows)] use crate::sys::windows::stdnet as net; use std::io; @@ -56,7 +57,7 @@ pub(crate) mod listener { } pub(crate) mod stream { - use crate::net::SocketAddr; + use crate::sys::SocketAddr; #[cfg(windows)] use crate::sys::windows::stdnet as net; use std::io; diff --git a/src/sys/unix/mod.rs b/src/sys/unix/mod.rs index 231480a5d..b80bfa7d2 100644 --- a/src/sys/unix/mod.rs +++ b/src/sys/unix/mod.rs @@ -29,7 +29,7 @@ cfg_os_poll! { pub(crate) mod tcp; pub(crate) mod udp; pub(crate) mod uds; - pub use self::uds::SocketAddr; + pub(crate) use self::uds::SocketAddr; } cfg_io_source! { @@ -62,7 +62,7 @@ cfg_os_poll! { cfg_not_os_poll! { cfg_net! { mod uds; - pub use self::uds::SocketAddr; + pub(crate) use self::uds::SocketAddr; } cfg_any_os_ext! { diff --git a/src/sys/unix/uds/listener.rs b/src/sys/unix/uds/listener.rs index 79bd14ee0..46e9a83e3 100644 --- a/src/sys/unix/uds/listener.rs +++ b/src/sys/unix/uds/listener.rs @@ -1,5 +1,6 @@ use super::socket_addr; -use crate::net::{SocketAddr, UnixStream}; +use crate::net::UnixStream; +use super::SocketAddr; use crate::sys::unix::net::new_socket; use std::os::unix::io::{AsRawFd, FromRawFd}; use std::os::unix::net; diff --git a/src/sys/unix/uds/mod.rs b/src/sys/unix/uds/mod.rs index 8e28a9573..d715e611e 100644 --- a/src/sys/unix/uds/mod.rs +++ b/src/sys/unix/uds/mod.rs @@ -1,5 +1,5 @@ mod socketaddr; -pub use self::socketaddr::SocketAddr; +pub(crate) use self::socketaddr::SocketAddr; /// Get the `sun_path` field offset of `sockaddr_un` for the target OS. /// diff --git a/src/sys/unix/uds/socketaddr.rs b/src/sys/unix/uds/socketaddr.rs index 4c7c41161..acdc8a662 100644 --- a/src/sys/unix/uds/socketaddr.rs +++ b/src/sys/unix/uds/socketaddr.rs @@ -1,32 +1,15 @@ use super::path_offset; +use crate::net::AddressKind; use std::ffi::OsStr; use std::os::unix::ffi::OsStrExt; -use std::path::Path; -use std::{ascii, fmt}; -/// An address associated with a `mio` specific Unix socket. -/// -/// This is implemented instead of imported from [`net::SocketAddr`] because -/// there is no way to create a [`net::SocketAddr`]. One must be returned by -/// [`accept`], so this is returned instead. -/// -/// [`net::SocketAddr`]: std::os::unix::net::SocketAddr -/// [`accept`]: #method.accept -pub struct SocketAddr { +pub(crate) struct SocketAddr { sockaddr: libc::sockaddr_un, socklen: libc::socklen_t, } -struct AsciiEscaped<'a>(&'a [u8]); - -enum AddressKind<'a> { - Unnamed, - Pathname(&'a Path), - Abstract(&'a [u8]), -} - impl SocketAddr { - fn address(&self) -> AddressKind<'_> { + pub(crate) fn address(&self) -> AddressKind<'_> { let offset = path_offset(&self.sockaddr); // Don't underflow in `len` below. if (self.socklen as usize) < offset { @@ -72,59 +55,5 @@ cfg_os_poll! { pub(crate) fn from_parts(sockaddr: libc::sockaddr_un, socklen: libc::socklen_t) -> SocketAddr { SocketAddr { sockaddr, socklen } } - - /// Returns `true` if the address is unnamed. - /// - /// Documentation reflected in [`SocketAddr`] - /// - /// [`SocketAddr`]: std::os::unix::net::SocketAddr - pub fn is_unnamed(&self) -> bool { - matches!(self.address(), AddressKind::Unnamed) - } - - /// Returns the contents of this address if it is a `pathname` address. - /// - /// Documentation reflected in [`SocketAddr`] - /// - /// [`SocketAddr`]: std::os::unix::net::SocketAddr - pub fn as_pathname(&self) -> Option<&Path> { - if let AddressKind::Pathname(path) = self.address() { - Some(path) - } else { - None - } - } - - /// Returns the contents of this address if it is an abstract namespace - /// without the leading null byte. - // Link to std::os::unix::net::SocketAddr pending - // https://github.com/rust-lang/rust/issues/85410. - pub fn as_abstract_namespace(&self) -> Option<&[u8]> { - if let AddressKind::Abstract(path) = self.address() { - Some(path) - } else { - None - } - } - } -} - -impl fmt::Debug for SocketAddr { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - match self.address() { - AddressKind::Unnamed => write!(fmt, "(unnamed)"), - AddressKind::Abstract(name) => write!(fmt, "{} (abstract)", AsciiEscaped(name)), - AddressKind::Pathname(path) => write!(fmt, "{:?} (pathname)", path), - } - } -} - -impl<'a> fmt::Display for AsciiEscaped<'a> { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "\"")?; - for byte in self.0.iter().cloned().flat_map(ascii::escape_default) { - write!(fmt, "{}", byte as char)?; - } - write!(fmt, "\"") } } diff --git a/src/sys/windows/mod.rs b/src/sys/windows/mod.rs index fb596b967..07f7dda6c 100644 --- a/src/sys/windows/mod.rs +++ b/src/sys/windows/mod.rs @@ -30,7 +30,7 @@ macro_rules! wsa_syscall { cfg_net! { pub(crate) mod stdnet; pub(crate) mod uds; - pub use self::uds::SocketAddr; + pub(crate) use self::uds::SocketAddr; } cfg_os_poll! { diff --git a/src/sys/windows/stdnet/addr.rs b/src/sys/windows/stdnet/addr.rs index 3c5d2b772..c864c057e 100644 --- a/src/sys/windows/stdnet/addr.rs +++ b/src/sys/windows/stdnet/addr.rs @@ -1,9 +1,7 @@ -use std::ascii; -use std::fmt; -use std::io; -use std::mem; +use std::{fmt, io, mem}; use std::os::raw::c_int; use std::path::Path; +use crate::net::AddressKind; use windows_sys::Win32::Networking::WinSock::{sockaddr_un, SOCKADDR}; @@ -68,28 +66,7 @@ cfg_os_poll! { } } -enum AddressKind<'a> { - Unnamed, - Pathname(&'a Path), - // Note: Windows does not support Abstract addresses - // https://github.com/microsoft/WSL/issues/4240#issuecomment-620805115/ - Abstract(&'a [u8]), -} - -struct AsciiEscaped<'a>(&'a [u8]); - -impl<'a> fmt::Display for AsciiEscaped<'a> { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "\"")?; - for byte in self.0.iter().cloned().flat_map(ascii::escape_default) { - write!(fmt, "{}", byte as char)?; - } - write!(fmt, "\"") - } -} - -/// An address associated with a Unix socket -pub struct SocketAddr { +pub(crate) struct SocketAddr { addr: sockaddr_un, len: c_int, } @@ -122,21 +99,7 @@ impl SocketAddr { SocketAddr::init(f).map(|(_, addr)| addr) } - /// Returns true if and only if the address is unnamed. - pub fn is_unnamed(&self) -> bool { - matches!(self.address(), AddressKind::Unnamed) - } - - /// Returns the contents of this address if it is a `pathname` address. - pub fn as_pathname(&self) -> Option<&Path> { - if let AddressKind::Pathname(path) = self.address() { - Some(path) - } else { - None - } - } - - fn address(&self) -> AddressKind<'_> { + pub(crate) fn address(&self) -> AddressKind<'_> { let len = self.len as usize - path_offset(&self.addr); // sockaddr_un::sun_path on Windows is a Win32 UTF-8 file system path @@ -156,10 +119,6 @@ impl SocketAddr { impl fmt::Debug for SocketAddr { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - match self.address() { - AddressKind::Unnamed => write!(fmt, "(unnamed)"), - AddressKind::Abstract(name) => write!(fmt, "{} (abstract)", AsciiEscaped(name)), - AddressKind::Pathname(path) => write!(fmt, "{:?} (pathname)", path), - } + write!(fmt, "{:?}", self.address()) } } diff --git a/src/sys/windows/stdnet/mod.rs b/src/sys/windows/stdnet/mod.rs index 4764c7456..0eb5130d4 100644 --- a/src/sys/windows/stdnet/mod.rs +++ b/src/sys/windows/stdnet/mod.rs @@ -4,7 +4,7 @@ mod listener; mod socket; mod stream; -pub use self::addr::SocketAddr; +pub(crate) use self::addr::SocketAddr; pub(crate) use self::listener::UnixListener; pub(crate) use self::stream::UnixStream; diff --git a/src/sys/windows/uds/listener.rs b/src/sys/windows/uds/listener.rs index df16542c9..4ba4395e5 100644 --- a/src/sys/windows/uds/listener.rs +++ b/src/sys/windows/uds/listener.rs @@ -2,7 +2,8 @@ use std::io; use std::os::windows::io::AsRawSocket; use std::path::Path; -use crate::net::{SocketAddr, UnixStream}; +use super::SocketAddr; +use crate::net::UnixStream; use crate::sys::windows::stdnet as net; pub(crate) fn bind(path: &Path) -> io::Result { diff --git a/src/sys/windows/uds/mod.rs b/src/sys/windows/uds/mod.rs index 13569e104..b99c01e42 100644 --- a/src/sys/windows/uds/mod.rs +++ b/src/sys/windows/uds/mod.rs @@ -1,4 +1,4 @@ -pub use super::stdnet::SocketAddr; +pub(crate) use super::stdnet::SocketAddr; cfg_os_poll! { use std::convert::TryInto; diff --git a/src/sys/windows/uds/stream.rs b/src/sys/windows/uds/stream.rs index ef2a66bfb..c002a075a 100644 --- a/src/sys/windows/uds/stream.rs +++ b/src/sys/windows/uds/stream.rs @@ -1,4 +1,5 @@ use crate::sys::windows::stdnet as net; +use super::SocketAddr; use std::io; use std::os::windows::io::AsRawSocket; use std::path::Path; @@ -9,10 +10,10 @@ pub(crate) fn connect(path: &Path) -> io::Result { Ok(socket) } -pub(crate) fn local_addr(socket: &net::UnixStream) -> io::Result { +pub(crate) fn local_addr(socket: &net::UnixStream) -> io::Result { super::local_addr(socket.as_raw_socket()) } -pub(crate) fn peer_addr(socket: &net::UnixStream) -> io::Result { +pub(crate) fn peer_addr(socket: &net::UnixStream) -> io::Result { super::peer_addr(socket.as_raw_socket()) } From 26a060b235a083336661f4fa2e79be055a05da4b Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Tue, 13 Sep 2022 14:52:15 -0700 Subject: [PATCH 33/34] fix lint --- src/net/uds/addr.rs | 4 ++-- src/net/uds/datagram.rs | 8 +++----- src/net/uds/listener.rs | 5 +++-- src/net/uds/mod.rs | 2 +- src/net/uds/stream.rs | 4 ++-- src/sys/shell/uds.rs | 4 ++-- src/sys/unix/uds/listener.rs | 2 +- src/sys/windows/stdnet/addr.rs | 4 ++-- src/sys/windows/uds/stream.rs | 2 +- 9 files changed, 17 insertions(+), 18 deletions(-) diff --git a/src/net/uds/addr.rs b/src/net/uds/addr.rs index 81742a004..9fb4c9c88 100644 --- a/src/net/uds/addr.rs +++ b/src/net/uds/addr.rs @@ -1,6 +1,6 @@ +use crate::sys; use std::path::Path; use std::{ascii, fmt}; -use crate::sys; /// An address associated with a `mio` specific Unix socket. /// @@ -11,7 +11,7 @@ use crate::sys; /// [`net::SocketAddr`]: std::os::unix::net::SocketAddr /// [`accept`]: #method.accept pub struct SocketAddr { - inner: sys::SocketAddr + inner: sys::SocketAddr, } struct AsciiEscaped<'a>(&'a [u8]); diff --git a/src/net/uds/datagram.rs b/src/net/uds/datagram.rs index 57114c28b..7bc1b7b1f 100644 --- a/src/net/uds/datagram.rs +++ b/src/net/uds/datagram.rs @@ -1,6 +1,6 @@ use crate::io_source::IoSource; -use crate::{event, sys, Interest, Registry, Token}; use crate::net::SocketAddr; +use crate::{event, sys, Interest, Registry, Token}; use std::net::Shutdown; use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; @@ -56,16 +56,14 @@ impl UnixDatagram { /// Returns the address of this socket. pub fn local_addr(&self) -> io::Result { - sys::uds::datagram::local_addr(&self.inner) - .map(|addr| SocketAddr::new(addr)) + sys::uds::datagram::local_addr(&self.inner).map(SocketAddr::new) } /// Returns the address of this socket's peer. /// /// The `connect` method will connect the socket to a peer. pub fn peer_addr(&self) -> io::Result { - sys::uds::datagram::peer_addr(&self.inner) - .map(|addr| SocketAddr::new(addr)) + sys::uds::datagram::peer_addr(&self.inner).map(SocketAddr::new) } /// Receives data from the socket. diff --git a/src/net/uds/listener.rs b/src/net/uds/listener.rs index 03b02821f..181806202 100644 --- a/src/net/uds/listener.rs +++ b/src/net/uds/listener.rs @@ -50,13 +50,14 @@ impl UnixListener { /// The call is responsible for ensuring that the listening socket is in /// non-blocking mode. pub fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { - self.inner.do_io(sys::uds::listener::accept) + self.inner + .do_io(sys::uds::listener::accept) .map(|(stream, addr)| (stream, SocketAddr::new(addr))) } /// Returns the local socket address of this listener. pub fn local_addr(&self) -> io::Result { - sys::uds::listener::local_addr(&self.inner).map(|addr| SocketAddr::new(addr)) + sys::uds::listener::local_addr(&self.inner).map(SocketAddr::new) } /// Returns the value of the `SO_ERROR` option. diff --git a/src/net/uds/mod.rs b/src/net/uds/mod.rs index fe8a02ff3..2a12f965e 100644 --- a/src/net/uds/mod.rs +++ b/src/net/uds/mod.rs @@ -11,5 +11,5 @@ mod stream; pub use self::stream::UnixStream; mod addr; -pub use self::addr::SocketAddr; pub(crate) use self::addr::AddressKind; +pub use self::addr::SocketAddr; diff --git a/src/net/uds/stream.rs b/src/net/uds/stream.rs index d541867c6..0a04f035a 100644 --- a/src/net/uds/stream.rs +++ b/src/net/uds/stream.rs @@ -69,12 +69,12 @@ impl UnixStream { /// Returns the socket address of the local half of this connection. pub fn local_addr(&self) -> io::Result { - sys::uds::stream::local_addr(&self.inner).map(|addr| SocketAddr::new(addr)) + sys::uds::stream::local_addr(&self.inner).map(SocketAddr::new) } /// Returns the socket address of the remote half of this connection. pub fn peer_addr(&self) -> io::Result { - sys::uds::stream::peer_addr(&self.inner).map(|addr| SocketAddr::new(addr)) + sys::uds::stream::peer_addr(&self.inner).map(SocketAddr::new) } /// Returns the value of the `SO_ERROR` option. diff --git a/src/sys/shell/uds.rs b/src/sys/shell/uds.rs index 4ff01790a..3aac1bd7a 100644 --- a/src/sys/shell/uds.rs +++ b/src/sys/shell/uds.rs @@ -35,9 +35,9 @@ pub(crate) mod datagram { pub(crate) mod listener { use crate::net::UnixStream; - use crate::sys::SocketAddr; #[cfg(windows)] use crate::sys::windows::stdnet as net; + use crate::sys::SocketAddr; use std::io; #[cfg(unix)] use std::os::unix::net; @@ -57,9 +57,9 @@ pub(crate) mod listener { } pub(crate) mod stream { - use crate::sys::SocketAddr; #[cfg(windows)] use crate::sys::windows::stdnet as net; + use crate::sys::SocketAddr; use std::io; #[cfg(unix)] use std::os::unix::net; diff --git a/src/sys/unix/uds/listener.rs b/src/sys/unix/uds/listener.rs index 46e9a83e3..0b13ab817 100644 --- a/src/sys/unix/uds/listener.rs +++ b/src/sys/unix/uds/listener.rs @@ -1,6 +1,6 @@ use super::socket_addr; -use crate::net::UnixStream; use super::SocketAddr; +use crate::net::UnixStream; use crate::sys::unix::net::new_socket; use std::os::unix::io::{AsRawFd, FromRawFd}; use std::os::unix::net; diff --git a/src/sys/windows/stdnet/addr.rs b/src/sys/windows/stdnet/addr.rs index c864c057e..26b1fddde 100644 --- a/src/sys/windows/stdnet/addr.rs +++ b/src/sys/windows/stdnet/addr.rs @@ -1,7 +1,7 @@ -use std::{fmt, io, mem}; +use crate::net::AddressKind; use std::os::raw::c_int; use std::path::Path; -use crate::net::AddressKind; +use std::{fmt, io, mem}; use windows_sys::Win32::Networking::WinSock::{sockaddr_un, SOCKADDR}; diff --git a/src/sys/windows/uds/stream.rs b/src/sys/windows/uds/stream.rs index c002a075a..b02f32e8f 100644 --- a/src/sys/windows/uds/stream.rs +++ b/src/sys/windows/uds/stream.rs @@ -1,5 +1,5 @@ -use crate::sys::windows::stdnet as net; use super::SocketAddr; +use crate::sys::windows::stdnet as net; use std::io; use std::os::windows::io::AsRawSocket; use std::path::Path; From fd8ddc181c7cd0b83d7ea1174a63a3b474005e9e Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Mon, 3 Oct 2022 17:27:22 -0700 Subject: [PATCH 34/34] add comment clarifying inheritance during calls to accept --- src/sys/windows/stdnet/socket.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/sys/windows/stdnet/socket.rs b/src/sys/windows/stdnet/socket.rs index 55de184d9..9212c1e04 100644 --- a/src/sys/windows/stdnet/socket.rs +++ b/src/sys/windows/stdnet/socket.rs @@ -142,6 +142,9 @@ cfg_os_poll! { } pub fn accept(&self, storage: *mut SOCKADDR, len: *mut c_int) -> io::Result { + // WinSock's accept returns a socket with the same properties as the listener. it is + // called on. In particular, the WSA_FLAG_NO_HANDLE_INHERIT will be inherited from the + // listener. wsa_syscall!(accept(self.0, storage, len), INVALID_SOCKET).map(Socket) }