proxy: detect TCP socket hang ups from client or server (#463)

We previously `join`ed on piping data from both sides, meaning
that the future didn't complete until **both** sides had disconnected.
Even if the client disconnected, it was possible the server never knew,
and we "leaked" this future.

To fix this, the `join` is replaced with a `Duplex` future, which pipes
from both ends into the other, while also detecting when one side shuts
down. When a side does shutdown, a write shutdown is forwarded to the
other side, to allow draining to occur for deployments that half-close
sockets.

Closes #434
This commit is contained in:
Sean McArthur 2018-03-02 10:14:54 -08:00 committed by GitHub
parent 95057b067a
commit 0effefa5d7
4 changed files with 280 additions and 24 deletions

View File

@ -3,15 +3,14 @@ use futures::*;
use std;
use std::io;
use std::net::SocketAddr;
use tokio_core;
use tokio_core::net::{TcpListener, TcpStreamNew};
use tokio_core::net::{TcpListener, TcpStreamNew, TcpStream};
use tokio_core::reactor::Handle;
use tokio_io::{AsyncRead, AsyncWrite};
use config::Addr;
use transport::GetOriginalDst;
pub type PlaintextSocket = tokio_core::net::TcpStream;
pub type PlaintextSocket = TcpStream;
pub struct BoundPort {
inner: std::net::TcpListener,
@ -165,10 +164,20 @@ impl io::Write for Connection {
impl AsyncWrite for Connection {
fn shutdown(&mut self) -> Poll<(), io::Error> {
use std::net::Shutdown;
use self::Connection::*;
match *self {
Plain(ref mut t) => t.shutdown(),
Plain(ref mut t) => {
try_ready!(AsyncWrite::shutdown(t));
// TCP shutdown the write side.
//
// If we're shutting down, then we definitely won't write
// anymore. So, we should tell the remote about this. This
// is relied upon in our TCP proxy, to start shutting down
// the pipe if one side closes.
TcpStream::shutdown(t, Shutdown::Write).map(Async::Ready)
},
}
}

View File

@ -1,11 +1,12 @@
use std::io;
use std::sync::Arc;
use std::time::Duration;
use futures::{future, Future};
use bytes::{Buf, BufMut};
use futures::{future, Async, Future, Poll};
use tokio_connect::Connect;
use tokio_core::reactor::Handle;
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_io::io::copy;
use conduit_proxy_controller_grpc::common;
use ctx::transport::{Client as ClientCtx, Server as ServerCtx};
@ -71,14 +72,188 @@ impl Proxy {
let fut = connect.connect()
.map_err(|e| debug!("tcp connect error: {:?}", e))
.and_then(move |tcp_out| {
let (in_r, in_w) = tcp_in.split();
let (out_r, out_w) = tcp_out.split();
copy(in_r, out_w)
.join(copy(out_r, in_w))
.map(|_| ())
Duplex::new(tcp_in, tcp_out)
.map_err(|e| debug!("tcp error: {}", e))
});
Box::new(fut)
}
}
/// A future piping data bi-directionally to In and Out.
struct Duplex<In, Out> {
half_in: HalfDuplex<In>,
half_out: HalfDuplex<Out>,
}
struct HalfDuplex<T> {
// None means socket met eof, and bytes have been drained into other half.
buf: Option<CopyBuf>,
is_shutdown: bool,
io: T,
}
/// A buffer used to copy bytes from one IO to another.
///
/// Keeps read and write positions.
struct CopyBuf {
// TODO:
// In linkerd-tcp, a shared buffer is used to start, and an allocation is
// only made if NotReady is found trying to flush the buffer. We could
// consider making the same optimization here.
buf: Box<[u8]>,
read_pos: usize,
write_pos: usize,
}
impl<In, Out> Duplex<In, Out>
where
In: AsyncRead + AsyncWrite,
Out: AsyncRead + AsyncWrite,
{
fn new(in_io: In, out_io: Out) -> Self {
Duplex {
half_in: HalfDuplex::new(in_io),
half_out: HalfDuplex::new(out_io),
}
}
}
impl<In, Out> Future for Duplex<In, Out>
where
In: AsyncRead + AsyncWrite,
Out: AsyncRead + AsyncWrite,
{
type Item = ();
type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
// This purposefully ignores the Async part, since we don't want to
// return early if the first half isn't ready, but the other half
// could make progress.
self.half_in.copy_into(&mut self.half_out)?;
self.half_out.copy_into(&mut self.half_in)?;
if self.half_in.is_done() && self.half_out.is_done() {
Ok(Async::Ready(()))
} else {
Ok(Async::NotReady)
}
}
}
impl<T> HalfDuplex<T>
where
T: AsyncRead,
{
fn new(io: T) -> Self {
Self {
buf: Some(CopyBuf::new()),
is_shutdown: false,
io,
}
}
fn copy_into<U>(&mut self, dst: &mut HalfDuplex<U>) -> Poll<(), io::Error>
where
U: AsyncWrite,
{
loop {
try_ready!(self.read());
try_ready!(self.write_into(dst));
if self.buf.is_none() && !dst.is_shutdown {
try_ready!(dst.io.shutdown());
dst.is_shutdown = true;
return Ok(Async::Ready(()));
}
}
}
fn read(&mut self) -> Poll<(), io::Error> {
let mut is_eof = false;
if let Some(ref mut buf) = self.buf {
if !buf.has_remaining() {
buf.reset();
let n = try_ready!(self.io.read_buf(buf));
is_eof = n == 0;
}
}
if is_eof {
self.buf.take();
}
Ok(Async::Ready(()))
}
fn write_into<U>(&mut self, dst: &mut HalfDuplex<U>) -> Poll<(), io::Error>
where
U: AsyncWrite,
{
if let Some(ref mut buf) = self.buf {
while buf.has_remaining() {
let n = try_ready!(dst.io.write_buf(buf));
if n == 0 {
return Err(write_zero());
}
}
}
Ok(Async::Ready(()))
}
fn is_done(&self) -> bool {
self.is_shutdown
}
}
fn write_zero() -> io::Error {
io::Error::new(io::ErrorKind::WriteZero, "write zero bytes")
}
impl CopyBuf {
fn new() -> Self {
CopyBuf {
buf: Box::new([0; 4096]),
read_pos: 0,
write_pos: 0,
}
}
fn reset(&mut self) {
debug_assert_eq!(self.read_pos, self.write_pos);
self.read_pos = 0;
self.write_pos = 0;
}
}
impl Buf for CopyBuf {
fn remaining(&self) -> usize {
self.write_pos - self.read_pos
}
fn bytes(&self) -> &[u8] {
&self.buf[self.read_pos..self.write_pos]
}
fn advance(&mut self, cnt: usize) {
assert!(self.write_pos >= self.read_pos + cnt);
self.read_pos += cnt;
}
}
impl BufMut for CopyBuf {
fn remaining_mut(&self) -> usize {
self.buf.len() - self.write_pos
}
unsafe fn bytes_mut(&mut self) -> &mut [u8] {
&mut self.buf[self.write_pos..]
}
unsafe fn advance_mut(&mut self, cnt: usize) {
assert!(self.buf.len() >= self.write_pos + cnt);
self.write_pos += cnt;
}
}

View File

@ -26,8 +26,20 @@ pub struct TcpClient {
tx: TcpSender,
}
type Handler = Box<CallBox + Send>;
trait CallBox: 'static {
fn call_box(self: Box<Self>, sock: TcpStream) -> Box<Future<Item=(), Error=()>>;
}
impl<F: FnOnce(TcpStream) -> Box<Future<Item=(), Error=()>> + Send + 'static> CallBox for F {
fn call_box(self: Box<Self>, sock: TcpStream) -> Box<Future<Item=(), Error=()>> {
(*self)(sock)
}
}
pub struct TcpServer {
accepts: VecDeque<Box<Fn(Vec<u8>) -> Vec<u8> + Send>>,
accepts: VecDeque<Handler>,
}
pub struct TcpConn {
@ -50,10 +62,29 @@ impl TcpClient {
impl TcpServer {
pub fn accept<F, U>(mut self, cb: F) -> Self
where
F: Fn(Vec<u8>) -> U + Send + 'static,
F: FnOnce(Vec<u8>) -> U + Send + 'static,
U: Into<Vec<u8>>,
{
self.accepts.push_back(Box::new(move |v| cb(v).into()));
self.accept_fut(move |sock| {
tokio_io::io::read(sock, vec![0; 1024])
.and_then(move |(sock, mut vec, n)| {
vec.truncate(n);
let write = cb(vec).into();
tokio_io::io::write_all(sock, write)
})
.map(|_| ())
.map_err(|e| panic!("tcp server error: {}", e))
})
}
pub fn accept_fut<F, U>(mut self, cb: F) -> Self
where
F: FnOnce(TcpStream) -> U + Send + 'static,
U: IntoFuture<Item=(), Error=()> + 'static,
{
self.accepts.push_back(Box::new(move |tcp| -> Box<Future<Item=(), Error=()>> {
Box::new(cb(tcp).into_future())
}));
self
}
@ -166,15 +197,7 @@ fn run_server(tcp: TcpServer) -> server::Listening {
let work = bind.incoming().for_each(move |(sock, _)| {
let cb = accepts.pop_front().expect("no more accepts");
let fut = tokio_io::io::read(sock, vec![0; 1024])
.and_then(move |(sock, mut vec, n)| {
vec.truncate(n);
let write = cb(vec);
tokio_io::io::write_all(sock, write)
})
.map(|_| ())
.map_err(|e| panic!("tcp server error: {}", e));
let fut = cb.call_box(sock);
reactor.spawn(fut);
Ok(())
});

View File

@ -249,6 +249,55 @@ fn tcp_with_no_orig_dst() {
assert_eq!(read, b"");
}
#[test]
fn tcp_connections_close_if_client_closes() {
use std::sync::mpsc;
let _ = env_logger::try_init();
let msg1 = "custom tcp hello";
let msg2 = "custom tcp bye";
let (tx, rx) = mpsc::channel();
let srv = server::tcp()
.accept_fut(move |sock| {
tokio_io::io::read(sock, vec![0; 1024])
.and_then(move |(sock, vec, n)| {
assert_eq!(&vec[..n], msg1.as_bytes());
tokio_io::io::write_all(sock, msg2.as_bytes())
}).and_then(|(sock, _)| {
// lets read again, but we should get eof
tokio_io::io::read(sock, [0; 16])
})
.map(move |(_sock, _vec, n)| {
assert_eq!(n, 0);
tx.send(()).unwrap();
})
.map_err(|e| panic!("tcp server error: {}", e))
})
.run();
let ctrl = controller::new().run();
let proxy = proxy::new()
.controller(ctrl)
.inbound(srv)
.run();
let client = client::tcp(proxy.inbound);
let tcp_client = client.connect();
tcp_client.write(msg1);
assert_eq!(tcp_client.read(), msg2.as_bytes());
drop(tcp_client);
// rx will be fulfilled when our tcp accept_fut sees
// a socket disconnect, which is what we are testing for.
// the timeout here is just to prevent this test from hanging
rx.recv_timeout(Duration::from_secs(5)).unwrap();
}
#[test]
fn http11_upgrade_not_supported() {
let _ = env_logger::try_init();