diff --git a/src/libstd/net/tcp.rs b/src/libstd/net/tcp.rs index 7be1fc9cd8c..fdeca8bc5ca 100644 --- a/src/libstd/net/tcp.rs +++ b/src/libstd/net/tcp.rs @@ -134,6 +134,24 @@ impl TcpStream { super::each_addr(addr, net_imp::TcpStream::connect).map(TcpStream) } + /// Opens a TCP connection to a remote host with a timeout. + /// + /// Unlike `connect`, `connect_timeout` takes a single [`SocketAddr`] since + /// timeout must be applied to individual addresses. + /// + /// It is an error to pass a zero `Duration` to this function. + /// + /// Unlike other methods on `TcpStream`, this does not correspond to a + /// single system call. It instead calls `connect` in nonblocking mode and + /// then uses an OS-specific mechanism to await the completion of the + /// connection request. + /// + /// [`SocketAddr`]: ../../std/net/enum.SocketAddr.html + #[unstable(feature = "tcpstream_connect_timeout", issue = "43709")] + pub fn connect_timeout(addr: &SocketAddr, timeout: Duration) -> io::Result { + net_imp::TcpStream::connect_timeout(addr, timeout).map(TcpStream) + } + /// Returns the socket address of the remote peer of this TCP connection. /// /// # Examples @@ -1509,4 +1527,19 @@ mod tests { t!(txdone.send(())); }) } + + #[test] + fn connect_timeout_unroutable() { + // this IP is unroutable, so connections should always time out. + let addr = "10.255.255.1:80".parse().unwrap(); + let e = TcpStream::connect_timeout(&addr, Duration::from_millis(250)).unwrap_err(); + assert_eq!(e.kind(), io::ErrorKind::TimedOut); + } + + #[test] + fn connect_timeout_valid() { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + TcpStream::connect_timeout(&addr, Duration::from_secs(2)).unwrap(); + } } diff --git a/src/libstd/sys/redox/net/tcp.rs b/src/libstd/sys/redox/net/tcp.rs index 17673f0bd60..5d1067e4039 100644 --- a/src/libstd/sys/redox/net/tcp.rs +++ b/src/libstd/sys/redox/net/tcp.rs @@ -32,6 +32,10 @@ impl TcpStream { Ok(TcpStream(File::open(&Path::new(path.as_str()), &options)?)) } + pub fn connect_timeout(_addr: &SocketAddr, _timeout: Duration) -> Result<()> { + Err(Error::new(ErrorKind::Other, "TcpStream::connect_timeout not implemented")) + } + pub fn duplicate(&self) -> Result { Ok(TcpStream(self.0.dup(&[])?)) } diff --git a/src/libstd/sys/unix/net.rs b/src/libstd/sys/unix/net.rs index 8fb361a78e2..668b2f92aba 100644 --- a/src/libstd/sys/unix/net.rs +++ b/src/libstd/sys/unix/net.rs @@ -17,7 +17,8 @@ use str; use sys::fd::FileDesc; use sys_common::{AsInner, FromInner, IntoInner}; use sys_common::net::{getsockopt, setsockopt, sockaddr_to_addr}; -use time::Duration; +use time::{Duration, Instant}; +use cmp; pub use sys::{cvt, cvt_r}; pub extern crate libc as netc; @@ -122,6 +123,70 @@ impl Socket { } } + pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> { + self.set_nonblocking(true)?; + let r = unsafe { + let (addrp, len) = addr.into_inner(); + cvt(libc::connect(self.0.raw(), addrp, len)) + }; + self.set_nonblocking(false)?; + + match r { + Ok(_) => return Ok(()), + // there's no ErrorKind for EINPROGRESS :( + Err(ref e) if e.raw_os_error() == Some(libc::EINPROGRESS) => {} + Err(e) => return Err(e), + } + + let mut pollfd = libc::pollfd { + fd: self.0.raw(), + events: libc::POLLOUT, + revents: 0, + }; + + if timeout.as_secs() == 0 && timeout.subsec_nanos() == 0 { + return Err(io::Error::new(io::ErrorKind::InvalidInput, + "cannot set a 0 duration timeout")); + } + + let start = Instant::now(); + + loop { + let elapsed = start.elapsed(); + if elapsed >= timeout { + return Err(io::Error::new(io::ErrorKind::TimedOut, "connection timed out")); + } + + let timeout = timeout - elapsed; + let mut timeout = timeout.as_secs() + .saturating_mul(1_000) + .saturating_add(timeout.subsec_nanos() as u64 / 1_000_000); + if timeout == 0 { + timeout = 1; + } + + let timeout = cmp::min(timeout, c_int::max_value() as u64) as c_int; + + match unsafe { libc::poll(&mut pollfd, 1, timeout) } { + -1 => { + let err = io::Error::last_os_error(); + if err.kind() != io::ErrorKind::Interrupted { + return Err(err); + } + } + 0 => {} + _ => { + if pollfd.revents & libc::POLLOUT == 0 { + if let Some(e) = self.take_error()? { + return Err(e); + } + } + return Ok(()); + } + } + } + } + pub fn accept(&self, storage: *mut sockaddr, len: *mut socklen_t) -> io::Result { // Unfortunately the only known way right now to accept a socket and diff --git a/src/libstd/sys/windows/c.rs b/src/libstd/sys/windows/c.rs index 1646f8cce72..4785cefd6b4 100644 --- a/src/libstd/sys/windows/c.rs +++ b/src/libstd/sys/windows/c.rs @@ -298,6 +298,8 @@ pub const PIPE_TYPE_BYTE: DWORD = 0x00000000; pub const PIPE_REJECT_REMOTE_CLIENTS: DWORD = 0x00000008; pub const PIPE_READMODE_BYTE: DWORD = 0x00000000; +pub const FD_SETSIZE: usize = 64; + #[repr(C)] #[cfg(target_arch = "x86")] pub struct WSADATA { @@ -837,6 +839,26 @@ pub struct CONSOLE_READCONSOLE_CONTROL { } pub type PCONSOLE_READCONSOLE_CONTROL = *mut CONSOLE_READCONSOLE_CONTROL; +#[repr(C)] +#[derive(Copy)] +pub struct fd_set { + pub fd_count: c_uint, + pub fd_array: [SOCKET; FD_SETSIZE], +} + +impl Clone for fd_set { + fn clone(&self) -> fd_set { + *self + } +} + +#[repr(C)] +#[derive(Copy, Clone)] +pub struct timeval { + pub tv_sec: c_long, + pub tv_usec: c_long, +} + extern "system" { pub fn WSAStartup(wVersionRequested: WORD, lpWSAData: LPWSADATA) -> c_int; @@ -1125,6 +1147,11 @@ extern "system" { lpOverlapped: LPOVERLAPPED, lpNumberOfBytesTransferred: LPDWORD, bWait: BOOL) -> BOOL; + pub fn select(nfds: c_int, + readfds: *mut fd_set, + writefds: *mut fd_set, + exceptfds: *mut fd_set, + timeout: *const timeval) -> c_int; } // Functions that aren't available on Windows XP, but we still use them and just diff --git a/src/libstd/sys/windows/net.rs b/src/libstd/sys/windows/net.rs index f2a2793425d..cd8acff6b0c 100644 --- a/src/libstd/sys/windows/net.rs +++ b/src/libstd/sys/windows/net.rs @@ -12,7 +12,7 @@ use cmp; use io::{self, Read}; -use libc::{c_int, c_void, c_ulong}; +use libc::{c_int, c_void, c_ulong, c_long}; use mem; use net::{SocketAddr, Shutdown}; use ptr; @@ -115,6 +115,60 @@ impl Socket { Ok(socket) } + pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> { + self.set_nonblocking(true)?; + let r = unsafe { + let (addrp, len) = addr.into_inner(); + cvt(c::connect(self.0, addrp, len)) + }; + self.set_nonblocking(false)?; + + match r { + Ok(_) => return Ok(()), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} + Err(e) => return Err(e), + } + + if timeout.as_secs() == 0 && timeout.subsec_nanos() == 0 { + return Err(io::Error::new(io::ErrorKind::InvalidInput, + "cannot set a 0 duration timeout")); + } + + let mut timeout = c::timeval { + tv_sec: timeout.as_secs() as c_long, + tv_usec: (timeout.subsec_nanos() / 1000) as c_long, + }; + if timeout.tv_sec == 0 && timeout.tv_usec == 0 { + timeout.tv_usec = 1; + } + + let fds = unsafe { + let mut fds = mem::zeroed::(); + fds.fd_count = 1; + fds.fd_array[0] = self.0; + fds + }; + + let mut writefds = fds; + let mut errorfds = fds; + + let n = unsafe { + cvt(c::select(1, ptr::null_mut(), &mut writefds, &mut errorfds, &timeout))? + }; + + match n { + 0 => Err(io::Error::new(io::ErrorKind::TimedOut, "connection timed out")), + _ => { + if writefds.fd_count != 1 { + if let Some(e) = self.take_error()? { + return Err(e); + } + } + Ok(()) + } + } + } + pub fn accept(&self, storage: *mut c::SOCKADDR, len: *mut c_int) -> io::Result { let socket = unsafe { diff --git a/src/libstd/sys_common/net.rs b/src/libstd/sys_common/net.rs index 809b728379d..5775dd4f1fc 100644 --- a/src/libstd/sys_common/net.rs +++ b/src/libstd/sys_common/net.rs @@ -215,6 +215,14 @@ impl TcpStream { Ok(TcpStream { inner: sock }) } + pub fn connect_timeout(addr: &SocketAddr, timeout: Duration) -> io::Result { + init(); + + let sock = Socket::new(addr, c::SOCK_STREAM)?; + sock.connect_timeout(addr, timeout)?; + Ok(TcpStream { inner: sock }) + } + pub fn socket(&self) -> &Socket { &self.inner } pub fn into_socket(self) -> Socket { self.inner }