proxy: detect TCP socket hang ups from client or server (#463)
We previously `join`ed on piping data from both sides, meaning that the future didn't complete until **both** sides had disconnected. Even if the client disconnected, it was possible the server never knew, and we "leaked" this future. To fix this, the `join` is replaced with a `Duplex` future, which pipes from both ends into the other, while also detecting when one side shuts down. When a side does shutdown, a write shutdown is forwarded to the other side, to allow draining to occur for deployments that half-close sockets. Closes #434
This commit is contained in:
parent
95057b067a
commit
0effefa5d7
|
@ -3,15 +3,14 @@ use futures::*;
|
|||
use std;
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
use tokio_core;
|
||||
use tokio_core::net::{TcpListener, TcpStreamNew};
|
||||
use tokio_core::net::{TcpListener, TcpStreamNew, TcpStream};
|
||||
use tokio_core::reactor::Handle;
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use config::Addr;
|
||||
use transport::GetOriginalDst;
|
||||
|
||||
pub type PlaintextSocket = tokio_core::net::TcpStream;
|
||||
pub type PlaintextSocket = TcpStream;
|
||||
|
||||
pub struct BoundPort {
|
||||
inner: std::net::TcpListener,
|
||||
|
@ -165,10 +164,20 @@ impl io::Write for Connection {
|
|||
|
||||
impl AsyncWrite for Connection {
|
||||
fn shutdown(&mut self) -> Poll<(), io::Error> {
|
||||
use std::net::Shutdown;
|
||||
use self::Connection::*;
|
||||
|
||||
match *self {
|
||||
Plain(ref mut t) => t.shutdown(),
|
||||
Plain(ref mut t) => {
|
||||
try_ready!(AsyncWrite::shutdown(t));
|
||||
// TCP shutdown the write side.
|
||||
//
|
||||
// If we're shutting down, then we definitely won't write
|
||||
// anymore. So, we should tell the remote about this. This
|
||||
// is relied upon in our TCP proxy, to start shutting down
|
||||
// the pipe if one side closes.
|
||||
TcpStream::shutdown(t, Shutdown::Write).map(Async::Ready)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
use std::io;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use futures::{future, Future};
|
||||
use bytes::{Buf, BufMut};
|
||||
use futures::{future, Async, Future, Poll};
|
||||
use tokio_connect::Connect;
|
||||
use tokio_core::reactor::Handle;
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
use tokio_io::io::copy;
|
||||
|
||||
use conduit_proxy_controller_grpc::common;
|
||||
use ctx::transport::{Client as ClientCtx, Server as ServerCtx};
|
||||
|
@ -71,14 +72,188 @@ impl Proxy {
|
|||
let fut = connect.connect()
|
||||
.map_err(|e| debug!("tcp connect error: {:?}", e))
|
||||
.and_then(move |tcp_out| {
|
||||
let (in_r, in_w) = tcp_in.split();
|
||||
let (out_r, out_w) = tcp_out.split();
|
||||
|
||||
copy(in_r, out_w)
|
||||
.join(copy(out_r, in_w))
|
||||
.map(|_| ())
|
||||
Duplex::new(tcp_in, tcp_out)
|
||||
.map_err(|e| debug!("tcp error: {}", e))
|
||||
});
|
||||
Box::new(fut)
|
||||
}
|
||||
}
|
||||
|
||||
/// A future piping data bi-directionally to In and Out.
|
||||
struct Duplex<In, Out> {
|
||||
half_in: HalfDuplex<In>,
|
||||
half_out: HalfDuplex<Out>,
|
||||
}
|
||||
|
||||
struct HalfDuplex<T> {
|
||||
// None means socket met eof, and bytes have been drained into other half.
|
||||
buf: Option<CopyBuf>,
|
||||
is_shutdown: bool,
|
||||
io: T,
|
||||
}
|
||||
|
||||
/// A buffer used to copy bytes from one IO to another.
|
||||
///
|
||||
/// Keeps read and write positions.
|
||||
struct CopyBuf {
|
||||
// TODO:
|
||||
// In linkerd-tcp, a shared buffer is used to start, and an allocation is
|
||||
// only made if NotReady is found trying to flush the buffer. We could
|
||||
// consider making the same optimization here.
|
||||
buf: Box<[u8]>,
|
||||
read_pos: usize,
|
||||
write_pos: usize,
|
||||
}
|
||||
|
||||
impl<In, Out> Duplex<In, Out>
|
||||
where
|
||||
In: AsyncRead + AsyncWrite,
|
||||
Out: AsyncRead + AsyncWrite,
|
||||
{
|
||||
fn new(in_io: In, out_io: Out) -> Self {
|
||||
Duplex {
|
||||
half_in: HalfDuplex::new(in_io),
|
||||
half_out: HalfDuplex::new(out_io),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<In, Out> Future for Duplex<In, Out>
|
||||
where
|
||||
In: AsyncRead + AsyncWrite,
|
||||
Out: AsyncRead + AsyncWrite,
|
||||
{
|
||||
type Item = ();
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
// This purposefully ignores the Async part, since we don't want to
|
||||
// return early if the first half isn't ready, but the other half
|
||||
// could make progress.
|
||||
self.half_in.copy_into(&mut self.half_out)?;
|
||||
self.half_out.copy_into(&mut self.half_in)?;
|
||||
|
||||
if self.half_in.is_done() && self.half_out.is_done() {
|
||||
Ok(Async::Ready(()))
|
||||
} else {
|
||||
Ok(Async::NotReady)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> HalfDuplex<T>
|
||||
where
|
||||
T: AsyncRead,
|
||||
{
|
||||
fn new(io: T) -> Self {
|
||||
Self {
|
||||
buf: Some(CopyBuf::new()),
|
||||
is_shutdown: false,
|
||||
io,
|
||||
}
|
||||
}
|
||||
|
||||
fn copy_into<U>(&mut self, dst: &mut HalfDuplex<U>) -> Poll<(), io::Error>
|
||||
where
|
||||
U: AsyncWrite,
|
||||
{
|
||||
loop {
|
||||
try_ready!(self.read());
|
||||
try_ready!(self.write_into(dst));
|
||||
|
||||
if self.buf.is_none() && !dst.is_shutdown {
|
||||
try_ready!(dst.io.shutdown());
|
||||
dst.is_shutdown = true;
|
||||
|
||||
return Ok(Async::Ready(()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn read(&mut self) -> Poll<(), io::Error> {
|
||||
let mut is_eof = false;
|
||||
if let Some(ref mut buf) = self.buf {
|
||||
if !buf.has_remaining() {
|
||||
buf.reset();
|
||||
let n = try_ready!(self.io.read_buf(buf));
|
||||
is_eof = n == 0;
|
||||
}
|
||||
}
|
||||
|
||||
if is_eof {
|
||||
self.buf.take();
|
||||
}
|
||||
|
||||
Ok(Async::Ready(()))
|
||||
}
|
||||
|
||||
fn write_into<U>(&mut self, dst: &mut HalfDuplex<U>) -> Poll<(), io::Error>
|
||||
where
|
||||
U: AsyncWrite,
|
||||
{
|
||||
if let Some(ref mut buf) = self.buf {
|
||||
while buf.has_remaining() {
|
||||
let n = try_ready!(dst.io.write_buf(buf));
|
||||
if n == 0 {
|
||||
return Err(write_zero());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Async::Ready(()))
|
||||
}
|
||||
|
||||
fn is_done(&self) -> bool {
|
||||
self.is_shutdown
|
||||
}
|
||||
}
|
||||
|
||||
fn write_zero() -> io::Error {
|
||||
io::Error::new(io::ErrorKind::WriteZero, "write zero bytes")
|
||||
}
|
||||
|
||||
impl CopyBuf {
|
||||
fn new() -> Self {
|
||||
CopyBuf {
|
||||
buf: Box::new([0; 4096]),
|
||||
read_pos: 0,
|
||||
write_pos: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
debug_assert_eq!(self.read_pos, self.write_pos);
|
||||
self.read_pos = 0;
|
||||
self.write_pos = 0;
|
||||
}
|
||||
}
|
||||
|
||||
impl Buf for CopyBuf {
|
||||
fn remaining(&self) -> usize {
|
||||
self.write_pos - self.read_pos
|
||||
}
|
||||
|
||||
fn bytes(&self) -> &[u8] {
|
||||
&self.buf[self.read_pos..self.write_pos]
|
||||
}
|
||||
|
||||
fn advance(&mut self, cnt: usize) {
|
||||
assert!(self.write_pos >= self.read_pos + cnt);
|
||||
self.read_pos += cnt;
|
||||
}
|
||||
}
|
||||
|
||||
impl BufMut for CopyBuf {
|
||||
fn remaining_mut(&self) -> usize {
|
||||
self.buf.len() - self.write_pos
|
||||
}
|
||||
|
||||
unsafe fn bytes_mut(&mut self) -> &mut [u8] {
|
||||
&mut self.buf[self.write_pos..]
|
||||
}
|
||||
|
||||
unsafe fn advance_mut(&mut self, cnt: usize) {
|
||||
assert!(self.buf.len() >= self.write_pos + cnt);
|
||||
self.write_pos += cnt;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,8 +26,20 @@ pub struct TcpClient {
|
|||
tx: TcpSender,
|
||||
}
|
||||
|
||||
type Handler = Box<CallBox + Send>;
|
||||
|
||||
trait CallBox: 'static {
|
||||
fn call_box(self: Box<Self>, sock: TcpStream) -> Box<Future<Item=(), Error=()>>;
|
||||
}
|
||||
|
||||
impl<F: FnOnce(TcpStream) -> Box<Future<Item=(), Error=()>> + Send + 'static> CallBox for F {
|
||||
fn call_box(self: Box<Self>, sock: TcpStream) -> Box<Future<Item=(), Error=()>> {
|
||||
(*self)(sock)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TcpServer {
|
||||
accepts: VecDeque<Box<Fn(Vec<u8>) -> Vec<u8> + Send>>,
|
||||
accepts: VecDeque<Handler>,
|
||||
}
|
||||
|
||||
pub struct TcpConn {
|
||||
|
@ -50,10 +62,29 @@ impl TcpClient {
|
|||
impl TcpServer {
|
||||
pub fn accept<F, U>(mut self, cb: F) -> Self
|
||||
where
|
||||
F: Fn(Vec<u8>) -> U + Send + 'static,
|
||||
F: FnOnce(Vec<u8>) -> U + Send + 'static,
|
||||
U: Into<Vec<u8>>,
|
||||
{
|
||||
self.accepts.push_back(Box::new(move |v| cb(v).into()));
|
||||
self.accept_fut(move |sock| {
|
||||
tokio_io::io::read(sock, vec![0; 1024])
|
||||
.and_then(move |(sock, mut vec, n)| {
|
||||
vec.truncate(n);
|
||||
let write = cb(vec).into();
|
||||
tokio_io::io::write_all(sock, write)
|
||||
})
|
||||
.map(|_| ())
|
||||
.map_err(|e| panic!("tcp server error: {}", e))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn accept_fut<F, U>(mut self, cb: F) -> Self
|
||||
where
|
||||
F: FnOnce(TcpStream) -> U + Send + 'static,
|
||||
U: IntoFuture<Item=(), Error=()> + 'static,
|
||||
{
|
||||
self.accepts.push_back(Box::new(move |tcp| -> Box<Future<Item=(), Error=()>> {
|
||||
Box::new(cb(tcp).into_future())
|
||||
}));
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -166,15 +197,7 @@ fn run_server(tcp: TcpServer) -> server::Listening {
|
|||
let work = bind.incoming().for_each(move |(sock, _)| {
|
||||
let cb = accepts.pop_front().expect("no more accepts");
|
||||
|
||||
let fut = tokio_io::io::read(sock, vec![0; 1024])
|
||||
.and_then(move |(sock, mut vec, n)| {
|
||||
vec.truncate(n);
|
||||
let write = cb(vec);
|
||||
tokio_io::io::write_all(sock, write)
|
||||
})
|
||||
.map(|_| ())
|
||||
.map_err(|e| panic!("tcp server error: {}", e));
|
||||
|
||||
let fut = cb.call_box(sock);
|
||||
reactor.spawn(fut);
|
||||
Ok(())
|
||||
});
|
||||
|
|
|
@ -249,6 +249,55 @@ fn tcp_with_no_orig_dst() {
|
|||
assert_eq!(read, b"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tcp_connections_close_if_client_closes() {
|
||||
use std::sync::mpsc;
|
||||
|
||||
let _ = env_logger::try_init();
|
||||
|
||||
let msg1 = "custom tcp hello";
|
||||
let msg2 = "custom tcp bye";
|
||||
|
||||
let (tx, rx) = mpsc::channel();
|
||||
|
||||
let srv = server::tcp()
|
||||
.accept_fut(move |sock| {
|
||||
tokio_io::io::read(sock, vec![0; 1024])
|
||||
.and_then(move |(sock, vec, n)| {
|
||||
assert_eq!(&vec[..n], msg1.as_bytes());
|
||||
|
||||
tokio_io::io::write_all(sock, msg2.as_bytes())
|
||||
}).and_then(|(sock, _)| {
|
||||
// lets read again, but we should get eof
|
||||
tokio_io::io::read(sock, [0; 16])
|
||||
})
|
||||
.map(move |(_sock, _vec, n)| {
|
||||
assert_eq!(n, 0);
|
||||
tx.send(()).unwrap();
|
||||
})
|
||||
.map_err(|e| panic!("tcp server error: {}", e))
|
||||
})
|
||||
.run();
|
||||
let ctrl = controller::new().run();
|
||||
let proxy = proxy::new()
|
||||
.controller(ctrl)
|
||||
.inbound(srv)
|
||||
.run();
|
||||
|
||||
let client = client::tcp(proxy.inbound);
|
||||
|
||||
let tcp_client = client.connect();
|
||||
tcp_client.write(msg1);
|
||||
assert_eq!(tcp_client.read(), msg2.as_bytes());
|
||||
|
||||
drop(tcp_client);
|
||||
|
||||
// rx will be fulfilled when our tcp accept_fut sees
|
||||
// a socket disconnect, which is what we are testing for.
|
||||
// the timeout here is just to prevent this test from hanging
|
||||
rx.recv_timeout(Duration::from_secs(5)).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn http11_upgrade_not_supported() {
|
||||
let _ = env_logger::try_init();
|
||||
|
|
Loading…
Reference in New Issue