Implement TcpStream::connect_timeout
This breaks the "single syscall rule", but it's really annoying to hand write and is pretty foundational.
This commit is contained in:
parent
ccf401f8f7
commit
8c92da3c51
6 changed files with 193 additions and 2 deletions
|
@ -134,6 +134,24 @@ impl TcpStream {
|
|||
super::each_addr(addr, net_imp::TcpStream::connect).map(TcpStream)
|
||||
}
|
||||
|
||||
/// Opens a TCP connection to a remote host with a timeout.
|
||||
///
|
||||
/// Unlike `connect`, `connect_timeout` takes a single [`SocketAddr`] since
|
||||
/// timeout must be applied to individual addresses.
|
||||
///
|
||||
/// It is an error to pass a zero `Duration` to this function.
|
||||
///
|
||||
/// Unlike other methods on `TcpStream`, this does not correspond to a
|
||||
/// single system call. It instead calls `connect` in nonblocking mode and
|
||||
/// then uses an OS-specific mechanism to await the completion of the
|
||||
/// connection request.
|
||||
///
|
||||
/// [`SocketAddr`]: ../../std/net/enum.SocketAddr.html
|
||||
#[unstable(feature = "tcpstream_connect_timeout", issue = "43709")]
|
||||
pub fn connect_timeout(addr: &SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
|
||||
net_imp::TcpStream::connect_timeout(addr, timeout).map(TcpStream)
|
||||
}
|
||||
|
||||
/// Returns the socket address of the remote peer of this TCP connection.
|
||||
///
|
||||
/// # Examples
|
||||
|
@ -1509,4 +1527,19 @@ mod tests {
|
|||
t!(txdone.send(()));
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_timeout_unroutable() {
|
||||
// this IP is unroutable, so connections should always time out.
|
||||
let addr = "10.255.255.1:80".parse().unwrap();
|
||||
let e = TcpStream::connect_timeout(&addr, Duration::from_millis(250)).unwrap_err();
|
||||
assert_eq!(e.kind(), io::ErrorKind::TimedOut);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_timeout_valid() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
TcpStream::connect_timeout(&addr, Duration::from_secs(2)).unwrap();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -32,6 +32,10 @@ impl TcpStream {
|
|||
Ok(TcpStream(File::open(&Path::new(path.as_str()), &options)?))
|
||||
}
|
||||
|
||||
pub fn connect_timeout(_addr: &SocketAddr, _timeout: Duration) -> Result<()> {
|
||||
Err(Error::new(ErrorKind::Other, "TcpStream::connect_timeout not implemented"))
|
||||
}
|
||||
|
||||
pub fn duplicate(&self) -> Result<TcpStream> {
|
||||
Ok(TcpStream(self.0.dup(&[])?))
|
||||
}
|
||||
|
|
|
@ -17,7 +17,8 @@ use str;
|
|||
use sys::fd::FileDesc;
|
||||
use sys_common::{AsInner, FromInner, IntoInner};
|
||||
use sys_common::net::{getsockopt, setsockopt, sockaddr_to_addr};
|
||||
use time::Duration;
|
||||
use time::{Duration, Instant};
|
||||
use cmp;
|
||||
|
||||
pub use sys::{cvt, cvt_r};
|
||||
pub extern crate libc as netc;
|
||||
|
@ -122,6 +123,70 @@ impl Socket {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
|
||||
self.set_nonblocking(true)?;
|
||||
let r = unsafe {
|
||||
let (addrp, len) = addr.into_inner();
|
||||
cvt(libc::connect(self.0.raw(), addrp, len))
|
||||
};
|
||||
self.set_nonblocking(false)?;
|
||||
|
||||
match r {
|
||||
Ok(_) => return Ok(()),
|
||||
// there's no ErrorKind for EINPROGRESS :(
|
||||
Err(ref e) if e.raw_os_error() == Some(libc::EINPROGRESS) => {}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
|
||||
let mut pollfd = libc::pollfd {
|
||||
fd: self.0.raw(),
|
||||
events: libc::POLLOUT,
|
||||
revents: 0,
|
||||
};
|
||||
|
||||
if timeout.as_secs() == 0 && timeout.subsec_nanos() == 0 {
|
||||
return Err(io::Error::new(io::ErrorKind::InvalidInput,
|
||||
"cannot set a 0 duration timeout"));
|
||||
}
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
loop {
|
||||
let elapsed = start.elapsed();
|
||||
if elapsed >= timeout {
|
||||
return Err(io::Error::new(io::ErrorKind::TimedOut, "connection timed out"));
|
||||
}
|
||||
|
||||
let timeout = timeout - elapsed;
|
||||
let mut timeout = timeout.as_secs()
|
||||
.saturating_mul(1_000)
|
||||
.saturating_add(timeout.subsec_nanos() as u64 / 1_000_000);
|
||||
if timeout == 0 {
|
||||
timeout = 1;
|
||||
}
|
||||
|
||||
let timeout = cmp::min(timeout, c_int::max_value() as u64) as c_int;
|
||||
|
||||
match unsafe { libc::poll(&mut pollfd, 1, timeout) } {
|
||||
-1 => {
|
||||
let err = io::Error::last_os_error();
|
||||
if err.kind() != io::ErrorKind::Interrupted {
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
0 => {}
|
||||
_ => {
|
||||
if pollfd.revents & libc::POLLOUT == 0 {
|
||||
if let Some(e) = self.take_error()? {
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn accept(&self, storage: *mut sockaddr, len: *mut socklen_t)
|
||||
-> io::Result<Socket> {
|
||||
// Unfortunately the only known way right now to accept a socket and
|
||||
|
|
|
@ -298,6 +298,8 @@ pub const PIPE_TYPE_BYTE: DWORD = 0x00000000;
|
|||
pub const PIPE_REJECT_REMOTE_CLIENTS: DWORD = 0x00000008;
|
||||
pub const PIPE_READMODE_BYTE: DWORD = 0x00000000;
|
||||
|
||||
pub const FD_SETSIZE: usize = 64;
|
||||
|
||||
#[repr(C)]
|
||||
#[cfg(target_arch = "x86")]
|
||||
pub struct WSADATA {
|
||||
|
@ -837,6 +839,26 @@ pub struct CONSOLE_READCONSOLE_CONTROL {
|
|||
}
|
||||
pub type PCONSOLE_READCONSOLE_CONTROL = *mut CONSOLE_READCONSOLE_CONTROL;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Copy)]
|
||||
pub struct fd_set {
|
||||
pub fd_count: c_uint,
|
||||
pub fd_array: [SOCKET; FD_SETSIZE],
|
||||
}
|
||||
|
||||
impl Clone for fd_set {
|
||||
fn clone(&self) -> fd_set {
|
||||
*self
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct timeval {
|
||||
pub tv_sec: c_long,
|
||||
pub tv_usec: c_long,
|
||||
}
|
||||
|
||||
extern "system" {
|
||||
pub fn WSAStartup(wVersionRequested: WORD,
|
||||
lpWSAData: LPWSADATA) -> c_int;
|
||||
|
@ -1125,6 +1147,11 @@ extern "system" {
|
|||
lpOverlapped: LPOVERLAPPED,
|
||||
lpNumberOfBytesTransferred: LPDWORD,
|
||||
bWait: BOOL) -> BOOL;
|
||||
pub fn select(nfds: c_int,
|
||||
readfds: *mut fd_set,
|
||||
writefds: *mut fd_set,
|
||||
exceptfds: *mut fd_set,
|
||||
timeout: *const timeval) -> c_int;
|
||||
}
|
||||
|
||||
// Functions that aren't available on Windows XP, but we still use them and just
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
|
||||
use cmp;
|
||||
use io::{self, Read};
|
||||
use libc::{c_int, c_void, c_ulong};
|
||||
use libc::{c_int, c_void, c_ulong, c_long};
|
||||
use mem;
|
||||
use net::{SocketAddr, Shutdown};
|
||||
use ptr;
|
||||
|
@ -115,6 +115,60 @@ impl Socket {
|
|||
Ok(socket)
|
||||
}
|
||||
|
||||
pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
|
||||
self.set_nonblocking(true)?;
|
||||
let r = unsafe {
|
||||
let (addrp, len) = addr.into_inner();
|
||||
cvt(c::connect(self.0, addrp, len))
|
||||
};
|
||||
self.set_nonblocking(false)?;
|
||||
|
||||
match r {
|
||||
Ok(_) => return Ok(()),
|
||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
|
||||
if timeout.as_secs() == 0 && timeout.subsec_nanos() == 0 {
|
||||
return Err(io::Error::new(io::ErrorKind::InvalidInput,
|
||||
"cannot set a 0 duration timeout"));
|
||||
}
|
||||
|
||||
let mut timeout = c::timeval {
|
||||
tv_sec: timeout.as_secs() as c_long,
|
||||
tv_usec: (timeout.subsec_nanos() / 1000) as c_long,
|
||||
};
|
||||
if timeout.tv_sec == 0 && timeout.tv_usec == 0 {
|
||||
timeout.tv_usec = 1;
|
||||
}
|
||||
|
||||
let fds = unsafe {
|
||||
let mut fds = mem::zeroed::<c::fd_set>();
|
||||
fds.fd_count = 1;
|
||||
fds.fd_array[0] = self.0;
|
||||
fds
|
||||
};
|
||||
|
||||
let mut writefds = fds;
|
||||
let mut errorfds = fds;
|
||||
|
||||
let n = unsafe {
|
||||
cvt(c::select(1, ptr::null_mut(), &mut writefds, &mut errorfds, &timeout))?
|
||||
};
|
||||
|
||||
match n {
|
||||
0 => Err(io::Error::new(io::ErrorKind::TimedOut, "connection timed out")),
|
||||
_ => {
|
||||
if writefds.fd_count != 1 {
|
||||
if let Some(e) = self.take_error()? {
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn accept(&self, storage: *mut c::SOCKADDR,
|
||||
len: *mut c_int) -> io::Result<Socket> {
|
||||
let socket = unsafe {
|
||||
|
|
|
@ -215,6 +215,14 @@ impl TcpStream {
|
|||
Ok(TcpStream { inner: sock })
|
||||
}
|
||||
|
||||
pub fn connect_timeout(addr: &SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
|
||||
init();
|
||||
|
||||
let sock = Socket::new(addr, c::SOCK_STREAM)?;
|
||||
sock.connect_timeout(addr, timeout)?;
|
||||
Ok(TcpStream { inner: sock })
|
||||
}
|
||||
|
||||
pub fn socket(&self) -> &Socket { &self.inner }
|
||||
|
||||
pub fn into_socket(self) -> Socket { self.inner }
|
||||
|
|
Loading…
Reference in a new issue