diff --git a/proxy/src/connection.rs b/proxy/src/connection.rs index 01d461b23..9d6ebfbd3 100644 --- a/proxy/src/connection.rs +++ b/proxy/src/connection.rs @@ -3,15 +3,14 @@ use futures::*; use std; use std::io; use std::net::SocketAddr; -use tokio_core; -use tokio_core::net::{TcpListener, TcpStreamNew}; +use tokio_core::net::{TcpListener, TcpStreamNew, TcpStream}; use tokio_core::reactor::Handle; use tokio_io::{AsyncRead, AsyncWrite}; use config::Addr; use transport::GetOriginalDst; -pub type PlaintextSocket = tokio_core::net::TcpStream; +pub type PlaintextSocket = TcpStream; pub struct BoundPort { inner: std::net::TcpListener, @@ -165,10 +164,20 @@ impl io::Write for Connection { impl AsyncWrite for Connection { fn shutdown(&mut self) -> Poll<(), io::Error> { + use std::net::Shutdown; use self::Connection::*; match *self { - Plain(ref mut t) => t.shutdown(), + Plain(ref mut t) => { + try_ready!(AsyncWrite::shutdown(t)); + // TCP shutdown the write side. + // + // If we're shutting down, then we definitely won't write + // anymore. So, we should tell the remote about this. This + // is relied upon in our TCP proxy, to start shutting down + // the pipe if one side closes. + TcpStream::shutdown(t, Shutdown::Write).map(Async::Ready) + }, } } diff --git a/proxy/src/transparency/tcp.rs b/proxy/src/transparency/tcp.rs index 016c83cbd..3ab74b871 100644 --- a/proxy/src/transparency/tcp.rs +++ b/proxy/src/transparency/tcp.rs @@ -1,11 +1,12 @@ +use std::io; use std::sync::Arc; use std::time::Duration; -use futures::{future, Future}; +use bytes::{Buf, BufMut}; +use futures::{future, Async, Future, Poll}; use tokio_connect::Connect; use tokio_core::reactor::Handle; use tokio_io::{AsyncRead, AsyncWrite}; -use tokio_io::io::copy; use conduit_proxy_controller_grpc::common; use ctx::transport::{Client as ClientCtx, Server as ServerCtx}; @@ -71,14 +72,188 @@ impl Proxy { let fut = connect.connect() .map_err(|e| debug!("tcp connect error: {:?}", e)) .and_then(move |tcp_out| { - let (in_r, in_w) = tcp_in.split(); - let (out_r, out_w) = tcp_out.split(); - - copy(in_r, out_w) - .join(copy(out_r, in_w)) - .map(|_| ()) + Duplex::new(tcp_in, tcp_out) .map_err(|e| debug!("tcp error: {}", e)) }); Box::new(fut) } } + +/// A future piping data bi-directionally to In and Out. +struct Duplex { + half_in: HalfDuplex, + half_out: HalfDuplex, +} + +struct HalfDuplex { + // None means socket met eof, and bytes have been drained into other half. + buf: Option, + is_shutdown: bool, + io: T, +} + +/// A buffer used to copy bytes from one IO to another. +/// +/// Keeps read and write positions. +struct CopyBuf { + // TODO: + // In linkerd-tcp, a shared buffer is used to start, and an allocation is + // only made if NotReady is found trying to flush the buffer. We could + // consider making the same optimization here. + buf: Box<[u8]>, + read_pos: usize, + write_pos: usize, +} + +impl Duplex +where + In: AsyncRead + AsyncWrite, + Out: AsyncRead + AsyncWrite, +{ + fn new(in_io: In, out_io: Out) -> Self { + Duplex { + half_in: HalfDuplex::new(in_io), + half_out: HalfDuplex::new(out_io), + } + } +} + +impl Future for Duplex +where + In: AsyncRead + AsyncWrite, + Out: AsyncRead + AsyncWrite, +{ + type Item = (); + type Error = io::Error; + + fn poll(&mut self) -> Poll { + // This purposefully ignores the Async part, since we don't want to + // return early if the first half isn't ready, but the other half + // could make progress. + self.half_in.copy_into(&mut self.half_out)?; + self.half_out.copy_into(&mut self.half_in)?; + + if self.half_in.is_done() && self.half_out.is_done() { + Ok(Async::Ready(())) + } else { + Ok(Async::NotReady) + } + } +} + +impl HalfDuplex +where + T: AsyncRead, +{ + fn new(io: T) -> Self { + Self { + buf: Some(CopyBuf::new()), + is_shutdown: false, + io, + } + } + + fn copy_into(&mut self, dst: &mut HalfDuplex) -> Poll<(), io::Error> + where + U: AsyncWrite, + { + loop { + try_ready!(self.read()); + try_ready!(self.write_into(dst)); + + if self.buf.is_none() && !dst.is_shutdown { + try_ready!(dst.io.shutdown()); + dst.is_shutdown = true; + + return Ok(Async::Ready(())); + } + } + } + + fn read(&mut self) -> Poll<(), io::Error> { + let mut is_eof = false; + if let Some(ref mut buf) = self.buf { + if !buf.has_remaining() { + buf.reset(); + let n = try_ready!(self.io.read_buf(buf)); + is_eof = n == 0; + } + } + + if is_eof { + self.buf.take(); + } + + Ok(Async::Ready(())) + } + + fn write_into(&mut self, dst: &mut HalfDuplex) -> Poll<(), io::Error> + where + U: AsyncWrite, + { + if let Some(ref mut buf) = self.buf { + while buf.has_remaining() { + let n = try_ready!(dst.io.write_buf(buf)); + if n == 0 { + return Err(write_zero()); + } + } + } + + Ok(Async::Ready(())) + } + + fn is_done(&self) -> bool { + self.is_shutdown + } +} + +fn write_zero() -> io::Error { + io::Error::new(io::ErrorKind::WriteZero, "write zero bytes") +} + +impl CopyBuf { + fn new() -> Self { + CopyBuf { + buf: Box::new([0; 4096]), + read_pos: 0, + write_pos: 0, + } + } + + fn reset(&mut self) { + debug_assert_eq!(self.read_pos, self.write_pos); + self.read_pos = 0; + self.write_pos = 0; + } +} + +impl Buf for CopyBuf { + fn remaining(&self) -> usize { + self.write_pos - self.read_pos + } + + fn bytes(&self) -> &[u8] { + &self.buf[self.read_pos..self.write_pos] + } + + fn advance(&mut self, cnt: usize) { + assert!(self.write_pos >= self.read_pos + cnt); + self.read_pos += cnt; + } +} + +impl BufMut for CopyBuf { + fn remaining_mut(&self) -> usize { + self.buf.len() - self.write_pos + } + + unsafe fn bytes_mut(&mut self) -> &mut [u8] { + &mut self.buf[self.write_pos..] + } + + unsafe fn advance_mut(&mut self, cnt: usize) { + assert!(self.buf.len() >= self.write_pos + cnt); + self.write_pos += cnt; + } +} diff --git a/proxy/tests/support/tcp.rs b/proxy/tests/support/tcp.rs index 87256b4b0..bb4388d72 100644 --- a/proxy/tests/support/tcp.rs +++ b/proxy/tests/support/tcp.rs @@ -26,8 +26,20 @@ pub struct TcpClient { tx: TcpSender, } +type Handler = Box; + +trait CallBox: 'static { + fn call_box(self: Box, sock: TcpStream) -> Box>; +} + +impl Box> + Send + 'static> CallBox for F { + fn call_box(self: Box, sock: TcpStream) -> Box> { + (*self)(sock) + } +} + pub struct TcpServer { - accepts: VecDeque) -> Vec + Send>>, + accepts: VecDeque, } pub struct TcpConn { @@ -50,10 +62,29 @@ impl TcpClient { impl TcpServer { pub fn accept(mut self, cb: F) -> Self where - F: Fn(Vec) -> U + Send + 'static, + F: FnOnce(Vec) -> U + Send + 'static, U: Into>, { - self.accepts.push_back(Box::new(move |v| cb(v).into())); + self.accept_fut(move |sock| { + tokio_io::io::read(sock, vec![0; 1024]) + .and_then(move |(sock, mut vec, n)| { + vec.truncate(n); + let write = cb(vec).into(); + tokio_io::io::write_all(sock, write) + }) + .map(|_| ()) + .map_err(|e| panic!("tcp server error: {}", e)) + }) + } + + pub fn accept_fut(mut self, cb: F) -> Self + where + F: FnOnce(TcpStream) -> U + Send + 'static, + U: IntoFuture + 'static, + { + self.accepts.push_back(Box::new(move |tcp| -> Box> { + Box::new(cb(tcp).into_future()) + })); self } @@ -166,15 +197,7 @@ fn run_server(tcp: TcpServer) -> server::Listening { let work = bind.incoming().for_each(move |(sock, _)| { let cb = accepts.pop_front().expect("no more accepts"); - let fut = tokio_io::io::read(sock, vec![0; 1024]) - .and_then(move |(sock, mut vec, n)| { - vec.truncate(n); - let write = cb(vec); - tokio_io::io::write_all(sock, write) - }) - .map(|_| ()) - .map_err(|e| panic!("tcp server error: {}", e)); - + let fut = cb.call_box(sock); reactor.spawn(fut); Ok(()) }); diff --git a/proxy/tests/transparency.rs b/proxy/tests/transparency.rs index 17e287142..373a5d335 100644 --- a/proxy/tests/transparency.rs +++ b/proxy/tests/transparency.rs @@ -249,6 +249,55 @@ fn tcp_with_no_orig_dst() { assert_eq!(read, b""); } +#[test] +fn tcp_connections_close_if_client_closes() { + use std::sync::mpsc; + + let _ = env_logger::try_init(); + + let msg1 = "custom tcp hello"; + let msg2 = "custom tcp bye"; + + let (tx, rx) = mpsc::channel(); + + let srv = server::tcp() + .accept_fut(move |sock| { + tokio_io::io::read(sock, vec![0; 1024]) + .and_then(move |(sock, vec, n)| { + assert_eq!(&vec[..n], msg1.as_bytes()); + + tokio_io::io::write_all(sock, msg2.as_bytes()) + }).and_then(|(sock, _)| { + // lets read again, but we should get eof + tokio_io::io::read(sock, [0; 16]) + }) + .map(move |(_sock, _vec, n)| { + assert_eq!(n, 0); + tx.send(()).unwrap(); + }) + .map_err(|e| panic!("tcp server error: {}", e)) + }) + .run(); + let ctrl = controller::new().run(); + let proxy = proxy::new() + .controller(ctrl) + .inbound(srv) + .run(); + + let client = client::tcp(proxy.inbound); + + let tcp_client = client.connect(); + tcp_client.write(msg1); + assert_eq!(tcp_client.read(), msg2.as_bytes()); + + drop(tcp_client); + + // rx will be fulfilled when our tcp accept_fut sees + // a socket disconnect, which is what we are testing for. + // the timeout here is just to prevent this test from hanging + rx.recv_timeout(Duration::from_secs(5)).unwrap(); +} + #[test] fn http11_upgrade_not_supported() { let _ = env_logger::try_init();