Changing gRPC Java inbound flow control model

The goal is to mirror the token-based approach used by the Reactive
Streams API.
This commit is contained in:
nmittler 2015-01-16 11:54:24 -08:00
parent 52f4220395
commit de3a13164f
31 changed files with 379 additions and 648 deletions

View File

@ -31,9 +31,6 @@
package com.google.net.stubby; package com.google.net.stubby;
import com.google.common.util.concurrent.ListenableFuture;
import javax.annotation.Nullable;
/** /**
* Low-level methods for communicating with a remote server during a single RPC. Unlike normal RPCs, * Low-level methods for communicating with a remote server during a single RPC. Unlike normal RPCs,
@ -69,14 +66,13 @@ public abstract class Call<RequestT, ResponseT> {
* This method is always called, if no headers were received then an empty {@link Metadata} * This method is always called, if no headers were received then an empty {@link Metadata}
* is passed. * is passed.
*/ */
public abstract ListenableFuture<Void> onHeaders(Metadata.Headers headers); public abstract void onHeaders(Metadata.Headers headers);
/** /**
* A response payload has been received. For streaming calls, there may be zero payload * A response payload has been received. For streaming calls, there may be zero payload
* messages. * messages.
*/ */
@Nullable public abstract void onPayload(T payload);
public abstract ListenableFuture<Void> onPayload(T payload);
/** /**
* The Call has been closed. No further sending or receiving can occur. If {@code status} is * The Call has been closed. No further sending or receiving can occur. If {@code status} is
@ -97,6 +93,22 @@ public abstract class Call<RequestT, ResponseT> {
// TODO(lryan): Might be better to put into Channel#newCall, might reduce decoration burden // TODO(lryan): Might be better to put into Channel#newCall, might reduce decoration burden
public abstract void start(Listener<ResponseT> responseListener, Metadata.Headers headers); public abstract void start(Listener<ResponseT> responseListener, Metadata.Headers headers);
/**
* Requests up to the given number of messages from the call to be delivered to
* {@link Listener#onPayload(Object)}. No additional messages will be delivered.
*
* <p>Message delivery is guaranteed to be sequential in the order received. In addition, the
* listener methods will not be accessed concurrently. While it is not guaranteed that the same
* thread will always be used, it is guaranteed that only a single thread will access the listener
* at a time.
*
* <p>If it is desired to bypass inbound flow control, a very large number of messages can be
* specified (e.g. {@link Integer#MAX_VALUE}).
*
* @param numMessages the requested number of messages to be delivered to the listener.
*/
public abstract void request(int numMessages);
/** /**
* Prevent any further processing for this Call. No further messages may be sent or will be * Prevent any further processing for this Call. No further messages may be sent or will be
* received. The server is informed of cancellations, but may not stop processing the call. * received. The server is informed of cancellations, but may not stop processing the call.

View File

@ -32,13 +32,10 @@
package com.google.net.stubby; package com.google.net.stubby;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.FutureCallback; import com.google.common.base.Throwables;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.Service.Listener; import com.google.common.util.concurrent.Service.Listener;
import com.google.common.util.concurrent.Service.State; import com.google.common.util.concurrent.Service.State;
import com.google.common.util.concurrent.SettableFuture;
import com.google.net.stubby.transport.ClientStream; import com.google.net.stubby.transport.ClientStream;
import com.google.net.stubby.transport.ClientStreamListener; import com.google.net.stubby.transport.ClientStreamListener;
import com.google.net.stubby.transport.ClientTransport; import com.google.net.stubby.transport.ClientTransport;
@ -48,7 +45,6 @@ import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.logging.Level; import java.util.logging.Level;
@ -68,6 +64,7 @@ public final class ChannelImpl implements Channel {
@Override public void flush() {} @Override public void flush() {}
@Override public void cancel() {} @Override public void cancel() {}
@Override public void halfClose() {} @Override public void halfClose() {}
@Override public void request(int numMessages) {}
} }
private final ClientTransportFactory transportFactory; private final ClientTransportFactory transportFactory;
@ -270,6 +267,11 @@ public final class ChannelImpl implements Channel {
} }
} }
@Override
public void request(int numMessages) {
stream.request(numMessages);
}
@Override @Override
public void cancel() { public void cancel() {
// Cancel is called in exception handling cases, so it may be the case that the // Cancel is called in exception handling cases, so it may be the case that the
@ -311,60 +313,53 @@ public final class ChannelImpl implements Channel {
private class ClientStreamListenerImpl implements ClientStreamListener { private class ClientStreamListenerImpl implements ClientStreamListener {
private final Listener<RespT> observer; private final Listener<RespT> observer;
private boolean closed;
public ClientStreamListenerImpl(Listener<RespT> observer) { public ClientStreamListenerImpl(Listener<RespT> observer) {
Preconditions.checkNotNull(observer); Preconditions.checkNotNull(observer);
this.observer = observer; this.observer = observer;
} }
private ListenableFuture<Void> dispatchCallable( @Override
final Callable<ListenableFuture<Void>> callable) { public void headersRead(final Metadata.Headers headers) {
final SettableFuture<Void> ours = SettableFuture.create();
callExecutor.execute(new Runnable() { callExecutor.execute(new Runnable() {
@Override @Override
public void run() { public void run() {
try { try {
ListenableFuture<Void> theirs = callable.call(); if (closed) {
if (theirs == null) { return;
ours.set(null);
} else {
Futures.addCallback(theirs, new FutureCallback<Void>() {
@Override
public void onSuccess(Void result) {
ours.set(null);
} }
@Override
public void onFailure(Throwable t) { observer.onHeaders(headers);
ours.setException(t); } catch (Throwable t) {
cancel();
throw Throwables.propagate(t);
} }
}, MoreExecutors.directExecutor()); }
});
}
@Override
public void messageRead(final InputStream message, final int length) {
callExecutor.execute(new Runnable() {
@Override
public void run() {
try {
if (closed) {
return;
}
try {
observer.onPayload(method.parseResponse(message));
} finally {
message.close();
} }
} catch (Throwable t) { } catch (Throwable t) {
ours.setException(t); cancel();
throw Throwables.propagate(t);
} }
} }
}); });
return ours;
}
@Override
public ListenableFuture<Void> headersRead(final Metadata.Headers headers) {
return dispatchCallable(new Callable<ListenableFuture<Void>>() {
@Override
public ListenableFuture<Void> call() throws Exception {
return observer.onHeaders(headers);
}
});
}
@Override
public ListenableFuture<Void> messageRead(final InputStream message, final int length) {
return dispatchCallable(new Callable<ListenableFuture<Void>>() {
@Override
public ListenableFuture<Void> call() {
return observer.onPayload(method.parseResponse(message));
}
});
} }
@Override @Override
@ -372,6 +367,7 @@ public final class ChannelImpl implements Channel {
callExecutor.execute(new Runnable() { callExecutor.execute(new Runnable() {
@Override @Override
public void run() { public void run() {
closed = true;
observer.onClose(status, trailers); observer.onClose(status, trailers);
} }
}); });

View File

@ -33,7 +33,6 @@ package com.google.net.stubby;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ListenableFuture;
import java.util.Arrays; import java.util.Arrays;
import java.util.Iterator; import java.util.Iterator;
@ -121,6 +120,11 @@ public class ClientInterceptors {
this.delegate.start(responseListener, headers); this.delegate.start(responseListener, headers);
} }
@Override
public void request(int numMessages) {
this.delegate.request(numMessages);
}
@Override @Override
public void cancel() { public void cancel() {
this.delegate.cancel(); this.delegate.cancel();
@ -150,13 +154,13 @@ public class ClientInterceptors {
} }
@Override @Override
public ListenableFuture<Void> onHeaders(Metadata.Headers headers) { public void onHeaders(Metadata.Headers headers) {
return delegate.onHeaders(headers); delegate.onHeaders(headers);
} }
@Override @Override
public ListenableFuture<Void> onPayload(T payload) { public void onPayload(T payload) {
return delegate.onPayload(payload); delegate.onPayload(payload);
} }
@Override @Override

View File

@ -31,9 +31,6 @@
package com.google.net.stubby; package com.google.net.stubby;
import com.google.common.util.concurrent.ListenableFuture;
import javax.annotation.Nullable;
/** /**
* Low-level method for communicating with a remote client during a single RPC. Unlike normal RPCs, * Low-level method for communicating with a remote client during a single RPC. Unlike normal RPCs,
@ -67,8 +64,7 @@ public abstract class ServerCall<ResponseT> {
* A request payload has been received. For streaming calls, there may be zero payload * A request payload has been received. For streaming calls, there may be zero payload
* messages. * messages.
*/ */
@Nullable public abstract void onPayload(RequestT payload);
public abstract ListenableFuture<Void> onPayload(RequestT payload);
/** /**
* The client completed all message sending. However, the call may still be cancelled. * The client completed all message sending. However, the call may still be cancelled.
@ -93,6 +89,14 @@ public abstract class ServerCall<ResponseT> {
public abstract void onComplete(); public abstract void onComplete();
} }
/**
* Requests up to the given number of messages from the call to be delivered to
* {@link Listener#onPayload(Object)}. No additional messages will be delivered.
*
* @param numMessages the requested number of messages to be delivered to the listener.
*/
public abstract void request(int numMessages);
/** /**
* Send response header metadata prior to sending a response payload. This method may * Send response header metadata prior to sending a response payload. This method may
* only be called once and cannot be called after calls to {@code Stream#sendPayload} * only be called once and cannot be called after calls to {@code Stream#sendPayload}

View File

@ -34,26 +34,20 @@ package com.google.net.stubby;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.base.Throwables; import com.google.common.base.Throwables;
import com.google.common.util.concurrent.AbstractService; import com.google.common.util.concurrent.AbstractService;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.Service; import com.google.common.util.concurrent.Service;
import com.google.common.util.concurrent.SettableFuture;
import com.google.net.stubby.transport.ServerListener; import com.google.net.stubby.transport.ServerListener;
import com.google.net.stubby.transport.ServerStream; import com.google.net.stubby.transport.ServerStream;
import com.google.net.stubby.transport.ServerStreamListener; import com.google.net.stubby.transport.ServerStreamListener;
import com.google.net.stubby.transport.ServerTransportListener; import com.google.net.stubby.transport.ServerTransportListener;
import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.concurrent.Callable;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import javax.annotation.Nullable;
/** /**
* Default implementation of {@link Server}, for creation by transports. * Default implementation of {@link Server}, for creation by transports.
* *
@ -299,9 +293,12 @@ public class ServerImpl extends AbstractService implements Server {
private static class NoopListener implements ServerStreamListener { private static class NoopListener implements ServerStreamListener {
@Override @Override
@Nullable public void messageRead(InputStream value, int length) {
public ListenableFuture<Void> messageRead(InputStream value, int length) { try {
return null; value.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
} }
@Override @Override
@ -349,12 +346,16 @@ public class ServerImpl extends AbstractService implements Server {
} }
@Override @Override
@Nullable public void messageRead(final InputStream message, final int length) {
public ListenableFuture<Void> messageRead(final InputStream message, final int length) { callExecutor.execute(new Runnable() {
return dispatchCallable(new Callable<ListenableFuture<Void>>() {
@Override @Override
public ListenableFuture<Void> call() { public void run() {
return getListener().messageRead(message, length); try {
getListener().messageRead(message, length);
} catch (Throwable t) {
internalClose(Status.fromThrowable(t), new Metadata.Trailers());
throw Throwables.propagate(t);
}
} }
}); });
} }
@ -383,36 +384,6 @@ public class ServerImpl extends AbstractService implements Server {
} }
}); });
} }
private ListenableFuture<Void> dispatchCallable(
final Callable<ListenableFuture<Void>> callable) {
final SettableFuture<Void> ours = SettableFuture.create();
callExecutor.execute(new Runnable() {
@Override
public void run() {
try {
ListenableFuture<Void> theirs = callable.call();
if (theirs == null) {
ours.set(null);
} else {
Futures.addCallback(theirs, new FutureCallback<Void>() {
@Override
public void onSuccess(Void result) {
ours.set(null);
}
@Override
public void onFailure(Throwable t) {
ours.setException(t);
}
}, MoreExecutors.directExecutor());
}
} catch (Throwable t) {
ours.setException(t);
}
}
});
return ours;
}
} }
private class ServerCallImpl<ReqT, RespT> extends ServerCall<RespT> { private class ServerCallImpl<ReqT, RespT> extends ServerCall<RespT> {
@ -425,6 +396,11 @@ public class ServerImpl extends AbstractService implements Server {
this.methodDef = methodDef; this.methodDef = methodDef;
} }
@Override
public void request(int numMessages) {
stream.request(numMessages);
}
@Override @Override
public void sendHeaders(Metadata.Headers headers) { public void sendHeaders(Metadata.Headers headers) {
stream.writeHeaders(headers); stream.writeHeaders(headers);
@ -468,13 +444,28 @@ public class ServerImpl extends AbstractService implements Server {
} }
@Override @Override
@Nullable public void messageRead(final InputStream message, int length) {
public ListenableFuture<Void> messageRead(final InputStream message, int length) { if (cancelled) {
return listener.onPayload(methodDef.parseRequest(message)); return;
}
try {
listener.onPayload(methodDef.parseRequest(message));
} finally {
try {
message.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
} }
@Override @Override
public void halfClosed() { public void halfClosed() {
if (cancelled) {
return;
}
listener.onHalfClose(); listener.onHalfClose();
} }

View File

@ -144,6 +144,11 @@ public class ServerInterceptors {
this.delegate = delegate; this.delegate = delegate;
} }
@Override
public void request(int numMessages) {
delegate.request(numMessages);
}
@Override @Override
public void sendHeaders(Metadata.Headers headers) { public void sendHeaders(Metadata.Headers headers) {
delegate.sendHeaders(headers); delegate.sendHeaders(headers);

View File

@ -33,14 +33,11 @@ package com.google.net.stubby.transport;
import com.google.common.base.MoreObjects; import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.net.stubby.Metadata; import com.google.net.stubby.Metadata;
import com.google.net.stubby.Status; import com.google.net.stubby.Status;
import java.io.InputStream; import java.io.InputStream;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.concurrent.Executor;
import java.util.logging.Level; import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
@ -53,7 +50,6 @@ public abstract class AbstractClientStream<IdT> extends AbstractStream<IdT>
implements ClientStream { implements ClientStream {
private static final Logger log = Logger.getLogger(AbstractClientStream.class.getName()); private static final Logger log = Logger.getLogger(AbstractClientStream.class.getName());
private static final ListenableFuture<Void> COMPLETED_FUTURE = Futures.immediateFuture(null);
private final ClientStreamListener listener; private final ClientStreamListener listener;
private boolean listenerClosed; private boolean listenerClosed;
@ -65,17 +61,15 @@ public abstract class AbstractClientStream<IdT> extends AbstractStream<IdT>
private Runnable closeListenerTask; private Runnable closeListenerTask;
protected AbstractClientStream(ClientStreamListener listener, Executor deframerExecutor) { protected AbstractClientStream(ClientStreamListener listener) {
super(deframerExecutor);
this.listener = Preconditions.checkNotNull(listener); this.listener = Preconditions.checkNotNull(listener);
} }
@Override @Override
protected ListenableFuture<Void> receiveMessage(InputStream is, int length) { protected void receiveMessage(InputStream is, int length) {
if (listenerClosed) { if (!listenerClosed) {
return COMPLETED_FUTURE; listener.messageRead(is, length);
} }
return listener.messageRead(is, length);
} }
@Override @Override
@ -114,7 +108,7 @@ public abstract class AbstractClientStream<IdT> extends AbstractStream<IdT>
new Object[]{id(), headers}); new Object[]{id(), headers});
} }
inboundPhase(Phase.MESSAGE); inboundPhase(Phase.MESSAGE);
delayDeframer(listener.headersRead(headers)); listener.headersRead(headers);
} }
/** /**
@ -208,7 +202,7 @@ public abstract class AbstractClientStream<IdT> extends AbstractStream<IdT>
closeListenerTask = null; closeListenerTask = null;
// Determine if the deframer is stalled (i.e. currently has no complete messages to deliver). // Determine if the deframer is stalled (i.e. currently has no complete messages to deliver).
boolean deliveryStalled = !deframer.isDeliveryOutstanding(); boolean deliveryStalled = deframer.isStalled();
if (stopDelivery || deliveryStalled) { if (stopDelivery || deliveryStalled) {
// Close the listener immediately. // Close the listener immediately.

View File

@ -32,13 +32,11 @@
package com.google.net.stubby.transport; package com.google.net.stubby.transport;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.net.stubby.Metadata; import com.google.net.stubby.Metadata;
import com.google.net.stubby.Status; import com.google.net.stubby.Status;
import java.io.InputStream; import java.io.InputStream;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.concurrent.Executor;
import java.util.logging.Level; import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
@ -63,8 +61,7 @@ public abstract class AbstractServerStream<IdT> extends AbstractStream<IdT>
/** Saved trailers from close() that need to be sent once the framer has sent all messages. */ /** Saved trailers from close() that need to be sent once the framer has sent all messages. */
private Metadata.Trailers stashedTrailers; private Metadata.Trailers stashedTrailers;
protected AbstractServerStream(IdT id, Executor deframerExecutor) { protected AbstractServerStream(IdT id) {
super(deframerExecutor);
id(id); id(id);
} }
@ -73,9 +70,9 @@ public abstract class AbstractServerStream<IdT> extends AbstractStream<IdT>
} }
@Override @Override
protected ListenableFuture<Void> receiveMessage(InputStream is, int length) { protected void receiveMessage(InputStream is, int length) {
inboundPhase(Phase.MESSAGE); inboundPhase(Phase.MESSAGE);
return listener.messageRead(is, length); listener.messageRead(is, length);
} }
@Override @Override

View File

@ -34,15 +34,9 @@ package com.google.net.stubby.transport;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects; import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.io.Closeables;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
import java.io.InputStream; import java.io.InputStream;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.concurrent.Executor;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -59,15 +53,6 @@ public abstract class AbstractStream<IdT> implements Stream {
private volatile IdT id; private volatile IdT id;
private final MessageFramer framer; private final MessageFramer framer;
private final FutureCallback<Object> deframerErrorCallback = new FutureCallback<Object>() {
@Override
public void onSuccess(Object result) {}
@Override
public void onFailure(Throwable t) {
deframeFailed(t);
}
};
final MessageDeframer deframer; final MessageDeframer deframer;
@ -81,7 +66,7 @@ public abstract class AbstractStream<IdT> implements Stream {
*/ */
private Phase outboundPhase = Phase.HEADERS; private Phase outboundPhase = Phase.HEADERS;
AbstractStream(Executor deframerExecutor) { AbstractStream() {
MessageDeframer.Listener inboundMessageHandler = new MessageDeframer.Listener() { MessageDeframer.Listener inboundMessageHandler = new MessageDeframer.Listener() {
@Override @Override
public void bytesRead(int numBytes) { public void bytesRead(int numBytes) {
@ -89,14 +74,8 @@ public abstract class AbstractStream<IdT> implements Stream {
} }
@Override @Override
public ListenableFuture<Void> messageRead(InputStream input, final int length) { public void messageRead(InputStream input, final int length) {
ListenableFuture<Void> future = null; receiveMessage(input, length);
try {
future = receiveMessage(input, length);
return future;
} finally {
closeWhenDone(future, input);
}
} }
@Override @Override
@ -117,7 +96,7 @@ public abstract class AbstractStream<IdT> implements Stream {
}; };
framer = new MessageFramer(outboundFrameHandler, 4096); framer = new MessageFramer(outboundFrameHandler, 4096);
this.deframer = new MessageDeframer(inboundMessageHandler, deframerExecutor); this.deframer = new MessageDeframer(inboundMessageHandler);
} }
/** /**
@ -194,7 +173,7 @@ public abstract class AbstractStream<IdT> implements Stream {
protected abstract void internalSendFrame(ByteBuffer frame, boolean endOfStream); protected abstract void internalSendFrame(ByteBuffer frame, boolean endOfStream);
/** A message was deframed. */ /** A message was deframed. */
protected abstract ListenableFuture<Void> receiveMessage(InputStream is, int length); protected abstract void receiveMessage(InputStream is, int length);
/** Deframer has no pending deliveries. */ /** Deframer has no pending deliveries. */
protected abstract void inboundDeliveryPaused(); protected abstract void inboundDeliveryPaused();
@ -215,23 +194,25 @@ public abstract class AbstractStream<IdT> implements Stream {
/** /**
* Called to parse a received frame and attempt delivery of any completed * Called to parse a received frame and attempt delivery of any completed
* messages. * messages. Must be called from the transport thread.
*/ */
protected final void deframe(Buffer frame, boolean endOfStream) { protected final void deframe(Buffer frame, boolean endOfStream) {
ListenableFuture<?> future; try {
future = deframer.deframe(frame, endOfStream); deframer.deframe(frame, endOfStream);
if (future != null) { } catch (Throwable t) {
Futures.addCallback(future, deframerErrorCallback); deframeFailed(t);
} }
} }
/** /**
* Delays delivery from the deframer until the given future completes. * Called to request the given number of messages from the deframer. Must be called
* from the transport thread.
*/ */
protected final void delayDeframer(ListenableFuture<?> future) { protected final void requestMessagesFromDeframer(int numMessages) {
ListenableFuture<?> deliveryFuture = deframer.delayProcessing(future); try {
if (deliveryFuture != null) { deframer.request(numMessages);
Futures.addCallback(deliveryFuture, deframerErrorCallback); } catch (Throwable t) {
deframeFailed(t);
} }
} }
@ -271,26 +252,6 @@ public abstract class AbstractStream<IdT> implements Stream {
return nextPhase; return nextPhase;
} }
/**
* If the given future is provided, closes the {@link InputStream} when it completes. Otherwise
* the {@link InputStream} is closed immediately.
*/
private static void closeWhenDone(@Nullable ListenableFuture<Void> future,
final InputStream input) {
if (future == null) {
Closeables.closeQuietly(input);
return;
}
// Close the buffer when the future completes.
future.addListener(new Runnable() {
@Override
public void run() {
Closeables.closeQuietly(input);
}
}, MoreExecutors.directExecutor());
}
/** /**
* Can the stream receive data from its remote peer. * Can the stream receive data from its remote peer.
*/ */

View File

@ -49,5 +49,4 @@ public interface ClientStream extends Stream {
* the remote end-point is closed. * the remote end-point is closed.
*/ */
void halfClose(); void halfClose();
} }

View File

@ -31,12 +31,9 @@
package com.google.net.stubby.transport; package com.google.net.stubby.transport;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.net.stubby.Metadata; import com.google.net.stubby.Metadata;
import com.google.net.stubby.Status; import com.google.net.stubby.Status;
import javax.annotation.Nullable;
/** An observer of client-side stream events. */ /** An observer of client-side stream events. */
public interface ClientStreamListener extends StreamListener { public interface ClientStreamListener extends StreamListener {
/** /**
@ -48,11 +45,8 @@ public interface ClientStreamListener extends StreamListener {
* <p>This method should return quickly, as the same thread may be used to process other streams. * <p>This method should return quickly, as the same thread may be used to process other streams.
* *
* @param headers the fully buffered received headers. * @param headers the fully buffered received headers.
* @return a processing completion future, or {@code null} to indicate that processing of the
* headers is immediately complete.
*/ */
@Nullable void headersRead(Metadata.Headers headers);
ListenableFuture<Void> headersRead(Metadata.Headers headers);
/** /**
* Called when the stream is fully closed. {@link * Called when the stream is fully closed. {@link

View File

@ -37,7 +37,6 @@ import com.google.net.stubby.Metadata;
import com.google.net.stubby.Status; import com.google.net.stubby.Status;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import java.util.concurrent.Executor;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -70,8 +69,8 @@ public abstract class Http2ClientStream extends AbstractClientStream<Integer> {
private Charset errorCharset = Charsets.UTF_8; private Charset errorCharset = Charsets.UTF_8;
private boolean contentTypeChecked; private boolean contentTypeChecked;
protected Http2ClientStream(ClientStreamListener listener, Executor deframerExecutor) { protected Http2ClientStream(ClientStreamListener listener) {
super(listener, deframerExecutor); super(listener);
} }
protected void transportHeadersReceived(Metadata.Headers headers) { protected void transportHeadersReceived(Metadata.Headers headers) {

View File

@ -32,19 +32,13 @@
package com.google.net.stubby.transport; package com.google.net.stubby.transport;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.io.ByteStreams; import com.google.common.io.ByteStreams;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import com.google.net.stubby.Status; import com.google.net.stubby.Status;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.Closeable; import java.io.Closeable;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.util.concurrent.Executor;
import java.util.zip.GZIPInputStream; import java.util.zip.GZIPInputStream;
import javax.annotation.concurrent.NotThreadSafe; import javax.annotation.concurrent.NotThreadSafe;
@ -52,8 +46,8 @@ import javax.annotation.concurrent.NotThreadSafe;
/** /**
* Deframer for GRPC frames. * Deframer for GRPC frames.
* *
* <p>This class is not thread-safe. All calls to this class must be made in the context of the * <p>This class is not thread-safe. All calls to public methods should be made in the transport
* executor provided during creation. That executor must not allow concurrent execution of tasks. * thread.
*/ */
@NotThreadSafe @NotThreadSafe
public class MessageDeframer implements Closeable { public class MessageDeframer implements Closeable {
@ -82,11 +76,8 @@ public class MessageDeframer implements Closeable {
* *
* @param is stream containing the message. * @param is stream containing the message.
* @param length the length in bytes of the message. * @param length the length in bytes of the message.
* @return a future indicating when the application has completed processing the message. The
* next delivery will not occur until this future completes. If {@code null}, it is assumed that
* the application has completed processing the message upon returning from the method call.
*/ */
ListenableFuture<Void> messageRead(InputStream is, int length); void messageRead(InputStream is, int length);
/** /**
* Called when end-of-stream has not yet been reached but there are no complete messages * Called when end-of-stream has not yet been reached but there are no complete messages
@ -105,65 +96,67 @@ public class MessageDeframer implements Closeable {
} }
private final Listener listener; private final Listener listener;
private final Executor executor;
private final Compression compression; private final Compression compression;
private State state = State.HEADER; private State state = State.HEADER;
private int requiredLength = HEADER_LENGTH; private int requiredLength = HEADER_LENGTH;
private boolean compressedFlag; private boolean compressedFlag;
private boolean endOfStream; private boolean endOfStream;
private SettableFuture<?> deliveryOutstanding;
private CompositeBuffer nextFrame; private CompositeBuffer nextFrame;
private CompositeBuffer unprocessed = new CompositeBuffer(); private CompositeBuffer unprocessed = new CompositeBuffer();
private long pendingDeliveries;
private boolean deliveryStalled = true;
/** /**
* Create a deframer. All calls to this class must be made in the context of the provided * Create a deframer. Compression will not be supported.
* executor, which also must not allow concurrent processing of Runnables. Compression will not be
* supported.
* *
* @param listener listener for deframer events. * @param listener listener for deframer events.
* @param executor used for internal event processing
*/ */
public MessageDeframer(Listener listener, Executor executor) { public MessageDeframer(Listener listener) {
this(listener, executor, Compression.NONE); this(listener, Compression.NONE);
} }
/** /**
* Create a deframer. All calls to this class must be made in the context of the provided * Create a deframer.
* executor, which also must not allow concurrent processing of Runnables.
* *
* @param listener listener for deframer events. * @param listener listener for deframer events.
* @param executor used for internal event processing
* @param compression the compression used if a compressed frame is encountered, with NONE meaning * @param compression the compression used if a compressed frame is encountered, with NONE meaning
* unsupported * unsupported
*/ */
public MessageDeframer(Listener listener, Executor executor, Compression compression) { public MessageDeframer(Listener listener, Compression compression) {
this.listener = Preconditions.checkNotNull(listener, "sink"); this.listener = Preconditions.checkNotNull(listener, "sink");
this.executor = Preconditions.checkNotNull(executor, "executor");
this.compression = Preconditions.checkNotNull(compression, "compression"); this.compression = Preconditions.checkNotNull(compression, "compression");
} }
/** /**
* Adds the given data to this deframer and attempts delivery to the sink. * Requests up to the given number of messages from the call to be delivered to
* {@link Listener#messageRead(InputStream, int)}. No additional messages will be delivered.
* *
* <p>If returned future is not {@code null}, then it completes when no more deliveries are * @param numMessages the requested number of messages to be delivered to the listener.
* occuring. Delivering completes if all available deframing input is consumed or if delivery
* resulted in an exception, in which case this method may throw the exception or the returned
* future will fail with the throwable. The future is guaranteed to complete within the executor
* provided during construction.
*/ */
public ListenableFuture<?> deframe(Buffer data, boolean endOfStream) { public void request(int numMessages) {
Preconditions.checkArgument(numMessages > 0, "numMessages must be > 0");
pendingDeliveries += numMessages;
deliver();
}
/**
* Adds the given data to this deframer and attempts delivery to the sink.
*/
public void deframe(Buffer data, boolean endOfStream) {
Preconditions.checkNotNull(data, "data"); Preconditions.checkNotNull(data, "data");
Preconditions.checkState(!this.endOfStream, "Past end of stream"); Preconditions.checkState(!this.endOfStream, "Past end of stream");
unprocessed.addBuffer(data); unprocessed.addBuffer(data);
// Indicate that all of the data for this stream has been received. // Indicate that all of the data for this stream has been received.
this.endOfStream = endOfStream; this.endOfStream = endOfStream;
deliver();
if (isDeliveryOutstanding()) {
// Only allow one outstanding delivery at a time.
return null;
} }
return deliver();
/**
* Indicates whether delivery is currently stalled, pending receipt of more data.
*/
public boolean isStalled() {
return deliveryStalled;
} }
@Override @Override
@ -175,83 +168,23 @@ public class MessageDeframer implements Closeable {
} }
/** /**
* Indicates whether or not there is currently a delivery outstanding to the application. * Reads and delivers as many messages to the sink as possible.
*/ */
public final boolean isDeliveryOutstanding() { private void deliver() {
return deliveryOutstanding != null;
}
/**
* Consider {@code future} to be a message currently being processed. Messages will not be
* delivered until the future completes. The returned future behaves as if it was returned by
* {@link #deframe(Buffer, boolean)}.
*
* @throws IllegalStateException if a message is already being processed
*/
public ListenableFuture<?> delayProcessing(ListenableFuture<?> future) {
Preconditions.checkState(!isDeliveryOutstanding(), "Only one delay allowed concurrently");
if (future == null) {
return null;
}
return delayProcessingInternal(future);
}
/**
* May only be called when a delivery is known not to be outstanding. If deliveryOutstanding is
* non-null, then it will be re-used and this method will return {@code null}.
*/
private ListenableFuture<?> delayProcessingInternal(ListenableFuture<?> future) {
Preconditions.checkNotNull(future, "future");
// Return a separate future so that our callback is guaranteed to complete before any
// listeners on the returned future.
ListenableFuture<?> returnFuture = null;
if (!isDeliveryOutstanding()) {
returnFuture = deliveryOutstanding = SettableFuture.create();
}
Futures.addCallback(future, new FutureCallback<Object>() {
@Override
public void onFailure(Throwable t) {
SettableFuture<?> previousOutstanding = deliveryOutstanding;
deliveryOutstanding = null;
previousOutstanding.setException(t);
}
@Override
public void onSuccess(Object result) {
try {
deliver();
} catch (Throwable t) {
if (!isDeliveryOutstanding()) {
throw Throwables.propagate(t);
} else {
onFailure(t);
}
}
}
}, executor);
return returnFuture;
}
/**
* Reads and delivers as many messages to the sink as possible. May only be called when a delivery
* is known not to be outstanding.
*/
private ListenableFuture<?> deliver() {
// Process the uncompressed bytes. // Process the uncompressed bytes.
while (readRequiredBytes()) { boolean stalled = false;
while (pendingDeliveries > 0 && !(stalled = !readRequiredBytes())) {
switch (state) { switch (state) {
case HEADER: case HEADER:
processHeader(); processHeader();
break; break;
case BODY: case BODY:
// Read the body and deliver the message to the sink. // Read the body and deliver the message.
ListenableFuture<?> processingFuture = processBody(); processBody();
if (processingFuture != null) {
// A future was returned for the completion of processing the delivered
// message. Once it's done, try to deliver the next message.
return delayProcessingInternal(processingFuture);
}
// Since we've delivered a message, decrement the number of pending
// deliveries remaining.
pendingDeliveries--;
break; break;
default: default:
throw new AssertionError("Invalid state: " + state); throw new AssertionError("Invalid state: " + state);
@ -259,25 +192,29 @@ public class MessageDeframer implements Closeable {
} }
if (endOfStream) { if (endOfStream) {
if (nextFrame.readableBytes() != 0) { if (!isDataAvailable()) {
throw Status.INTERNAL listener.endOfStream();
.withDescription("Encountered end-of-stream mid-frame") } else if (stalled) {
// We've received the entire stream and have data available but we don't have
// enough to read the next frame ... this is bad.
throw Status.INTERNAL.withDescription("Encountered end-of-stream mid-frame")
.asRuntimeException(); .asRuntimeException();
} }
listener.endOfStream();
} }
// All available messages have processed. // Never indicate that we're stalled if we've received all the data for the stream.
if (isDeliveryOutstanding()) { stalled &= !endOfStream;
SettableFuture<?> previousOutstanding = deliveryOutstanding;
deliveryOutstanding = null; // If we're transitioning to the stalled state, notify the listener.
previousOutstanding.set(null); boolean previouslyStalled = deliveryStalled;
if (!endOfStream) { deliveryStalled = stalled;
// Notify that delivery is currently paused. if (stalled && !previouslyStalled) {
listener.deliveryStalled(); listener.deliveryStalled();
} }
} }
return null;
private boolean isDataAvailable() {
return unprocessed.readableBytes() > 0 || (nextFrame != null && nextFrame.readableBytes() > 0);
} }
/** /**
@ -335,35 +272,32 @@ public class MessageDeframer implements Closeable {
* Processes the body of the GRPC compression frame. A single compression frame may contain * Processes the body of the GRPC compression frame. A single compression frame may contain
* several GRPC messages within it. * several GRPC messages within it.
*/ */
private ListenableFuture<?> processBody() { private void processBody() {
ListenableFuture<?> future;
if (compressedFlag) { if (compressedFlag) {
if (compression == Compression.NONE) { if (compression == Compression.NONE) {
throw Status.INTERNAL throw Status.INTERNAL.withDescription(
.withDescription("Can't decode compressed frame as compression not configured.") "Can't decode compressed frame as compression not configured.").asRuntimeException();
.asRuntimeException();
} else if (compression == Compression.GZIP) { } else if (compression == Compression.GZIP) {
// Fully drain frame. // Fully drain frame.
byte[] bytes; byte[] bytes;
try { try {
bytes = ByteStreams.toByteArray( bytes =
new GZIPInputStream(Buffers.openStream(nextFrame, false))); ByteStreams.toByteArray(new GZIPInputStream(Buffers.openStream(nextFrame, false)));
} catch (IOException ex) { } catch (IOException ex) {
throw new RuntimeException(ex); throw new RuntimeException(ex);
} }
future = listener.messageRead(new ByteArrayInputStream(bytes), bytes.length); listener.messageRead(new ByteArrayInputStream(bytes), bytes.length);
} else { } else {
throw new AssertionError("Unknown compression type"); throw new AssertionError("Unknown compression type");
} }
} else { } else {
// Don't close the frame, since the sink is now responsible for the life-cycle. // Don't close the frame, since the sink is now responsible for the life-cycle.
future = listener.messageRead(Buffers.openStream(nextFrame, true), nextFrame.readableBytes()); listener.messageRead(Buffers.openStream(nextFrame, true), nextFrame.readableBytes());
nextFrame = null; nextFrame = null;
} }
// Done with this frame, begin processing the next header. // Done with this frame, begin processing the next header.
state = State.HEADER; state = State.HEADER;
requiredLength = HEADER_LENGTH; requiredLength = HEADER_LENGTH;
return future;
} }
} }

View File

@ -41,6 +41,15 @@ import javax.annotation.Nullable;
* <p>An implementation doesn't need to be thread-safe. * <p>An implementation doesn't need to be thread-safe.
*/ */
public interface Stream { public interface Stream {
/**
* Requests up to the given number of messages from the call to be delivered to
* {@link StreamListener#messageRead(java.io.InputStream, int)}. No additional messages will be
* delivered.
*
* @param numMessages the requested number of messages to be delivered to the listener.
*/
void request(int numMessages);
/** /**
* Writes a message payload to the remote end-point. The bytes from the stream are immediate read * Writes a message payload to the remote end-point. The bytes from the stream are immediate read
* by the Transport. This method will always return immediately and will not wait for the write to * by the Transport. This method will always return immediately and will not wait for the write to

View File

@ -31,12 +31,8 @@
package com.google.net.stubby.transport; package com.google.net.stubby.transport;
import com.google.common.util.concurrent.ListenableFuture;
import java.io.InputStream; import java.io.InputStream;
import javax.annotation.Nullable;
/** /**
* An observer of {@link Stream} events. It is guaranteed to only have one concurrent callback at a * An observer of {@link Stream} events. It is guaranteed to only have one concurrent callback at a
* time. * time.
@ -46,21 +42,12 @@ public interface StreamListener {
* Called upon receiving a message from the remote end-point. The {@link InputStream} is * Called upon receiving a message from the remote end-point. The {@link InputStream} is
* non-blocking and contains the entire message. * non-blocking and contains the entire message.
* *
* <p>The method optionally returns a future that can be observed by flow control to determine * <p>The provided {@code message} {@link InputStream} must be closed by the listener.
* when the message has been processed by the application. If {@code null} is returned, processing
* of this message is assumed to be complete upon returning from this method.
*
* <p>The {@code message} {@link InputStream} will be closed when the returned future completes.
* If no future is returned, the stream will be closed immediately after returning from this
* method.
* *
* <p>This method should return quickly, as the same thread may be used to process other streams. * <p>This method should return quickly, as the same thread may be used to process other streams.
* *
* @param message the bytes of the message. * @param message the bytes of the message.
* @param length the length of the message {@link InputStream}. * @param length the length of the message {@link InputStream}.
* @return a processing completion future, or {@code null} to indicate that processing of the
* message is immediately complete.
*/ */
@Nullable void messageRead(InputStream message, int length);
ListenableFuture<Void> messageRead(InputStream message, int length);
} }

View File

@ -130,6 +130,7 @@ public class ClientInterceptorsTest {
public void ordered() { public void ordered() {
final List<String> order = new ArrayList<String>(); final List<String> order = new ArrayList<String>();
channel = new Channel() { channel = new Channel() {
@SuppressWarnings("unchecked")
@Override @Override
public <ReqT, RespT> Call<ReqT, RespT> newCall(MethodDescriptor<ReqT, RespT> method) { public <ReqT, RespT> Call<ReqT, RespT> newCall(MethodDescriptor<ReqT, RespT> method) {
order.add("channel"); order.add("channel");
@ -199,9 +200,9 @@ public class ClientInterceptorsTest {
public void start(Call.Listener<RespT> responseListener, Metadata.Headers headers) { public void start(Call.Listener<RespT> responseListener, Metadata.Headers headers) {
super.start(new ForwardingListener<RespT>(responseListener) { super.start(new ForwardingListener<RespT>(responseListener) {
@Override @Override
public ListenableFuture<Void> onHeaders(Metadata.Headers headers) { public void onHeaders(Metadata.Headers headers) {
examinedHeaders.add(headers); examinedHeaders.add(headers);
return super.onHeaders(headers); super.onHeaders(headers);
} }
}, headers); }, headers);
} }

View File

@ -34,21 +34,18 @@ package com.google.net.stubby;
import static com.google.common.base.Charsets.UTF_8; import static com.google.common.base.Charsets.UTF_8;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.isNull; import static org.mockito.Matchers.isNull;
import static org.mockito.Matchers.notNull; import static org.mockito.Matchers.notNull;
import static org.mockito.Matchers.same; import static org.mockito.Matchers.same;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.google.common.io.ByteStreams; import com.google.common.io.ByteStreams;
import com.google.common.util.concurrent.AbstractService; import com.google.common.util.concurrent.AbstractService;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.Service; import com.google.common.util.concurrent.Service;
import com.google.common.util.concurrent.SettableFuture;
import com.google.net.stubby.transport.ServerStream; import com.google.net.stubby.transport.ServerStream;
import com.google.net.stubby.transport.ServerStreamListener; import com.google.net.stubby.transport.ServerStreamListener;
import com.google.net.stubby.transport.ServerTransportListener; import com.google.net.stubby.transport.ServerTransportListener;
@ -71,8 +68,6 @@ import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
/** Unit tests for {@link ServerImpl}. */ /** Unit tests for {@link ServerImpl}. */
@ -86,7 +81,9 @@ public class ServerImplTest {
private Service transportServer = new NoopService(); private Service transportServer = new NoopService();
private ServerImpl server = new ServerImpl(executor, registry) private ServerImpl server = new ServerImpl(executor, registry)
.setTransportServer(transportServer); .setTransportServer(transportServer);
private ServerStream stream = Mockito.mock(ServerStream.class);
@Mock
private ServerStream stream;
@Mock @Mock
private ServerCall.Listener<String> callListener; private ServerCall.Listener<String> callListener;
@ -238,9 +235,8 @@ public class ServerImplTest {
assertNotNull(call); assertNotNull(call);
String order = "Lots of pizza, please"; String order = "Lots of pizza, please";
ListenableFuture<Void> future = streamListener.messageRead(STRING_MARSHALLER.stream(order), 1); streamListener.messageRead(STRING_MARSHALLER.stream(order), 1);
future.get(); verify(callListener, timeout(2000)).onPayload(order);
verify(callListener).onPayload(order);
call.sendPayload(314); call.sendPayload(314);
ArgumentCaptor<InputStream> inputCaptor = ArgumentCaptor.forClass(InputStream.class); ArgumentCaptor<InputStream> inputCaptor = ArgumentCaptor.forClass(InputStream.class);
@ -297,48 +293,6 @@ public class ServerImplTest {
verifyNoMoreInteractions(stream); verifyNoMoreInteractions(stream);
} }
@Test
public void futureStatusIsPropagatedToTransport() throws Exception {
final AtomicReference<ServerCall<Integer>> callReference
= new AtomicReference<ServerCall<Integer>>();
registry.addService(ServerServiceDefinition.builder("Waiter")
.addMethod("serve", STRING_MARSHALLER, INTEGER_MARSHALLER,
new ServerCallHandler<String, Integer>() {
@Override
public ServerCall.Listener<String> startCall(String fullMethodName,
ServerCall<Integer> call, Metadata.Headers headers) {
callReference.set(call);
return callListener;
}
}).build());
ServerTransportListener transportListener = newTransport(server);
ServerStreamListener streamListener
= transportListener.streamCreated(stream, "/Waiter/serve", new Metadata.Headers());
assertNotNull(streamListener);
executeBarrier(executor).await();
ServerCall<Integer> call = callReference.get();
assertNotNull(call);
String delay = "No, I've not looked over the menu yet";
SettableFuture<Void> appFuture = SettableFuture.create();
when(callListener.onPayload(delay)).thenReturn(appFuture);
ListenableFuture<Void> future = streamListener.messageRead(STRING_MARSHALLER.stream(delay), 1);
executeBarrier(executor).await();
verify(callListener).onPayload(delay);
try {
future.get(0, TimeUnit.SECONDS);
fail();
} catch (TimeoutException ex) {
// Expected.
}
appFuture.set(null);
// Shouldn't throw.
future.get(0, TimeUnit.SECONDS);
}
private static ServerTransportListener newTransport(ServerImpl server) { private static ServerTransportListener newTransport(ServerImpl server) {
Service transport = new NoopService(); Service transport = new NoopService();
transport.startAsync(); transport.startAsync();

View File

@ -32,25 +32,16 @@
package com.google.net.stubby.transport; package com.google.net.stubby.transport;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.google.common.io.ByteStreams; import com.google.common.io.ByteStreams;
import com.google.common.primitives.Bytes; import com.google.common.primitives.Bytes;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.SettableFuture;
import com.google.net.stubby.transport.MessageDeframer.Listener; import com.google.net.stubby.transport.MessageDeframer.Listener;
import org.junit.Test; import org.junit.Test;
@ -62,7 +53,6 @@ import java.io.ByteArrayOutputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.util.List; import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.zip.GZIPOutputStream; import java.util.zip.GZIPOutputStream;
/** /**
@ -71,13 +61,13 @@ import java.util.zip.GZIPOutputStream;
@RunWith(JUnit4.class) @RunWith(JUnit4.class)
public class MessageDeframerTest { public class MessageDeframerTest {
private Listener listener = mock(Listener.class); private Listener listener = mock(Listener.class);
private MessageDeframer deframer = private MessageDeframer deframer = new MessageDeframer(listener);
new MessageDeframer(listener, MoreExecutors.directExecutor());
private ArgumentCaptor<InputStream> messages = ArgumentCaptor.forClass(InputStream.class); private ArgumentCaptor<InputStream> messages = ArgumentCaptor.forClass(InputStream.class);
@Test @Test
public void simplePayload() { public void simplePayload() {
assertNull(deframer.deframe(buffer(new byte[]{0, 0, 0, 0, 2, 3, 14}), false)); deframer.request(1);
deframer.deframe(buffer(new byte[]{0, 0, 0, 0, 2, 3, 14}), false);
verify(listener).messageRead(messages.capture(), eq(2)); verify(listener).messageRead(messages.capture(), eq(2));
assertEquals(Bytes.asList(new byte[]{3, 14}), bytes(messages)); assertEquals(Bytes.asList(new byte[]{3, 14}), bytes(messages));
verify(listener, atLeastOnce()).bytesRead(anyInt()); verify(listener, atLeastOnce()).bytesRead(anyInt());
@ -86,8 +76,8 @@ public class MessageDeframerTest {
@Test @Test
public void smallCombinedPayloads() { public void smallCombinedPayloads() {
assertNull( deframer.request(2);
deframer.deframe(buffer(new byte[]{0, 0, 0, 0, 1, 3, 0, 0, 0, 0, 2, 14, 15}), false)); deframer.deframe(buffer(new byte[]{0, 0, 0, 0, 1, 3, 0, 0, 0, 0, 2, 14, 15}), false);
verify(listener).messageRead(messages.capture(), eq(1)); verify(listener).messageRead(messages.capture(), eq(1));
assertEquals(Bytes.asList(new byte[] {3}), bytes(messages)); assertEquals(Bytes.asList(new byte[] {3}), bytes(messages));
verify(listener).messageRead(messages.capture(), eq(2)); verify(listener).messageRead(messages.capture(), eq(2));
@ -98,7 +88,8 @@ public class MessageDeframerTest {
@Test @Test
public void endOfStreamWithPayloadShouldNotifyEndOfStream() { public void endOfStreamWithPayloadShouldNotifyEndOfStream() {
assertNull(deframer.deframe(buffer(new byte[] {0, 0, 0, 0, 1, 3}), true)); deframer.request(1);
deframer.deframe(buffer(new byte[] {0, 0, 0, 0, 1, 3}), true);
verify(listener).messageRead(messages.capture(), eq(1)); verify(listener).messageRead(messages.capture(), eq(1));
assertEquals(Bytes.asList(new byte[] {3}), bytes(messages)); assertEquals(Bytes.asList(new byte[] {3}), bytes(messages));
verify(listener).endOfStream(); verify(listener).endOfStream();
@ -108,17 +99,18 @@ public class MessageDeframerTest {
@Test @Test
public void endOfStreamShouldNotifyEndOfStream() { public void endOfStreamShouldNotifyEndOfStream() {
assertNull(deframer.deframe(buffer(new byte[0]), true)); deframer.deframe(buffer(new byte[0]), true);
verify(listener).endOfStream(); verify(listener).endOfStream();
verifyNoMoreInteractions(listener); verifyNoMoreInteractions(listener);
} }
@Test @Test
public void payloadSplitBetweenBuffers() { public void payloadSplitBetweenBuffers() {
assertNull(deframer.deframe(buffer(new byte[] {0, 0, 0, 0, 7, 3, 14, 1, 5, 9}), false)); deframer.request(1);
deframer.deframe(buffer(new byte[] {0, 0, 0, 0, 7, 3, 14, 1, 5, 9}), false);
verify(listener, atLeastOnce()).bytesRead(anyInt()); verify(listener, atLeastOnce()).bytesRead(anyInt());
verifyNoMoreInteractions(listener); verifyNoMoreInteractions(listener);
assertNull(deframer.deframe(buffer(new byte[] {2, 6}), false)); deframer.deframe(buffer(new byte[] {2, 6}), false);
verify(listener).messageRead(messages.capture(), eq(7)); verify(listener).messageRead(messages.capture(), eq(7));
assertEquals(Bytes.asList(new byte[] {3, 14, 1, 5, 9, 2, 6}), bytes(messages)); assertEquals(Bytes.asList(new byte[] {3, 14, 1, 5, 9, 2, 6}), bytes(messages));
verify(listener, atLeastOnce()).bytesRead(anyInt()); verify(listener, atLeastOnce()).bytesRead(anyInt());
@ -127,10 +119,12 @@ public class MessageDeframerTest {
@Test @Test
public void frameHeaderSplitBetweenBuffers() { public void frameHeaderSplitBetweenBuffers() {
assertNull(deframer.deframe(buffer(new byte[] {0, 0}), false)); deframer.request(1);
deframer.deframe(buffer(new byte[] {0, 0}), false);
verify(listener, atLeastOnce()).bytesRead(anyInt()); verify(listener, atLeastOnce()).bytesRead(anyInt());
verifyNoMoreInteractions(listener); verifyNoMoreInteractions(listener);
assertNull(deframer.deframe(buffer(new byte[] {0, 0, 1, 3}), false)); deframer.deframe(buffer(new byte[] {0, 0, 1, 3}), false);
verify(listener).messageRead(messages.capture(), eq(1)); verify(listener).messageRead(messages.capture(), eq(1));
assertEquals(Bytes.asList(new byte[] {3}), bytes(messages)); assertEquals(Bytes.asList(new byte[] {3}), bytes(messages));
verify(listener, atLeastOnce()).bytesRead(anyInt()); verify(listener, atLeastOnce()).bytesRead(anyInt());
@ -139,7 +133,8 @@ public class MessageDeframerTest {
@Test @Test
public void emptyPayload() { public void emptyPayload() {
assertNull(deframer.deframe(buffer(new byte[] {0, 0, 0, 0, 0}), false)); deframer.request(1);
deframer.deframe(buffer(new byte[] {0, 0, 0, 0, 0}), false);
verify(listener).messageRead(messages.capture(), eq(0)); verify(listener).messageRead(messages.capture(), eq(0));
assertEquals(Bytes.asList(), bytes(messages)); assertEquals(Bytes.asList(), bytes(messages));
verify(listener, atLeastOnce()).bytesRead(anyInt()); verify(listener, atLeastOnce()).bytesRead(anyInt());
@ -148,8 +143,9 @@ public class MessageDeframerTest {
@Test @Test
public void largerFrameSize() { public void largerFrameSize() {
assertNull(deframer.deframe( deframer.request(1);
Buffers.wrap(Bytes.concat(new byte[] {0, 0, 0, 3, (byte) 232}, new byte[1000])), false)); deframer.deframe(
Buffers.wrap(Bytes.concat(new byte[] {0, 0, 0, 3, (byte) 232}, new byte[1000])), false);
verify(listener).messageRead(messages.capture(), eq(1000)); verify(listener).messageRead(messages.capture(), eq(1000));
assertEquals(Bytes.asList(new byte[1000]), bytes(messages)); assertEquals(Bytes.asList(new byte[1000]), bytes(messages));
verify(listener, atLeastOnce()).bytesRead(anyInt()); verify(listener, atLeastOnce()).bytesRead(anyInt());
@ -157,110 +153,23 @@ public class MessageDeframerTest {
} }
@Test @Test
public void payloadCallbackShouldWaitForFutureCompletion() { public void endOfStreamCallbackShouldWaitForMessageDelivery() {
SettableFuture<Void> messageFuture = SettableFuture.create(); deframer.deframe(buffer(new byte[] {0, 0, 0, 0, 1, 3}), true);
when(listener.messageRead(any(InputStream.class), eq(1))).thenReturn(messageFuture); verifyNoMoreInteractions(listener);
// Deframe a block with 2 messages.
ListenableFuture<?> deframeFuture deframer.request(1);
= deframer.deframe(buffer(new byte[] {0, 0, 0, 0, 1, 3, 0, 0, 0, 0, 2, 14, 15}), false);
assertNotNull(deframeFuture);
verify(listener).messageRead(messages.capture(), eq(1)); verify(listener).messageRead(messages.capture(), eq(1));
assertEquals(Bytes.asList(new byte[] {3}), bytes(messages)); assertEquals(Bytes.asList(new byte[] {3}), bytes(messages));
verify(listener, atLeastOnce()).bytesRead(anyInt());
verifyNoMoreInteractions(listener);
SettableFuture<Void> messageFuture2 = SettableFuture.create();
when(listener.messageRead(any(InputStream.class), eq(2))).thenReturn(messageFuture2);
messageFuture.set(null);
assertFalse(deframeFuture.isDone());
verify(listener).messageRead(messages.capture(), eq(2));
assertEquals(Bytes.asList(new byte[] {14, 15}), bytes(messages));
verify(listener, atLeastOnce()).bytesRead(anyInt());
verifyNoMoreInteractions(listener);
messageFuture2.set(null);
assertTrue(deframeFuture.isDone());
verify(listener, atLeastOnce()).bytesRead(anyInt());
verify(listener).deliveryStalled();
verifyNoMoreInteractions(listener);
}
@Test
public void endOfStreamCallbackShouldWaitForFutureCompletion() {
SettableFuture<Void> messageFuture = SettableFuture.create();
when(listener.messageRead(any(InputStream.class), eq(1))).thenReturn(messageFuture);
ListenableFuture<?> deframeFuture
= deframer.deframe(buffer(new byte[] {0, 0, 0, 0, 1, 3}), true);
assertNotNull(deframeFuture);
verify(listener).messageRead(messages.capture(), eq(1));
assertEquals(Bytes.asList(new byte[] {3}), bytes(messages));
verify(listener, atLeastOnce()).bytesRead(anyInt());
verifyNoMoreInteractions(listener);
messageFuture.set(null);
assertTrue(deframeFuture.isDone());
verify(listener).endOfStream(); verify(listener).endOfStream();
verify(listener, atLeastOnce()).bytesRead(anyInt()); verify(listener, atLeastOnce()).bytesRead(anyInt());
verifyNoMoreInteractions(listener); verifyNoMoreInteractions(listener);
} }
@Test
public void futureShouldPropagateThrownException() throws InterruptedException {
SettableFuture<Void> messageFuture = SettableFuture.create();
when(listener.messageRead(any(InputStream.class), eq(1))).thenReturn(messageFuture);
ListenableFuture<?> deframeFuture
= deframer.deframe(buffer(new byte[] {0, 0, 0, 0, 1, 3, 0, 0, 0, 0, 2, 14, 15}), false);
assertNotNull(deframeFuture);
verify(listener).messageRead(messages.capture(), eq(1));
assertEquals(Bytes.asList(new byte[]{3}), bytes(messages));
verify(listener, atLeastOnce()).bytesRead(anyInt());
verifyNoMoreInteractions(listener);
RuntimeException thrownEx = new RuntimeException();
when(listener.messageRead(any(InputStream.class), eq(2))).thenThrow(thrownEx);
messageFuture.set(null);
verify(listener).messageRead(messages.capture(), eq(2));
assertTrue(deframeFuture.isDone());
try {
deframeFuture.get();
fail("Should have throws ExecutionException");
} catch (ExecutionException ex) {
assertEquals(thrownEx, ex.getCause());
}
verify(listener, atLeastOnce()).bytesRead(anyInt());
verifyNoMoreInteractions(listener);
}
@Test
public void futureFailureShouldStopAndPropagateFailure() throws InterruptedException {
SettableFuture<Void> messageFuture = SettableFuture.create();
when(listener.messageRead(any(InputStream.class), eq(1))).thenReturn(messageFuture);
ListenableFuture<?> deframeFuture
= deframer.deframe(buffer(new byte[] {0, 0, 0, 0, 1, 3, 0, 0, 0, 0, 2, 14, 15}), false);
assertNotNull(deframeFuture);
verify(listener).messageRead(messages.capture(), eq(1));
assertEquals(Bytes.asList(new byte[] {3}), bytes(messages));
verify(listener, atLeastOnce()).bytesRead(anyInt());
verifyNoMoreInteractions(listener);
RuntimeException thrownEx = new RuntimeException();
messageFuture.setException(thrownEx);
assertTrue(deframeFuture.isDone());
try {
deframeFuture.get();
fail("Should have throws ExecutionException");
} catch (ExecutionException ex) {
assertEquals(thrownEx, ex.getCause());
}
verify(listener, atLeastOnce()).bytesRead(anyInt());
verifyNoMoreInteractions(listener);
}
@Test @Test
public void compressed() { public void compressed() {
deframer = new MessageDeframer( deframer = new MessageDeframer(listener, MessageDeframer.Compression.GZIP);
listener, MoreExecutors.directExecutor(), MessageDeframer.Compression.GZIP); deframer.request(1);
byte[] payload = compress(new byte[1000]); byte[] payload = compress(new byte[1000]);
assertTrue(payload.length < 100); assertTrue(payload.length < 100);
byte[] header = new byte[] {1, 0, 0, 0, (byte) payload.length}; byte[] header = new byte[] {1, 0, 0, 0, (byte) payload.length};

View File

@ -39,7 +39,6 @@ import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture; import com.google.common.util.concurrent.SettableFuture;
import com.google.common.util.concurrent.Uninterruptibles; import com.google.common.util.concurrent.Uninterruptibles;
import com.google.net.stubby.AbstractServerBuilder; import com.google.net.stubby.AbstractServerBuilder;
@ -75,12 +74,12 @@ import org.mockito.ArgumentCaptor;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
/** /**
@ -426,22 +425,18 @@ public abstract class AbstractTransportTest {
// Start the call and prepare capture of results. // Start the call and prepare capture of results.
final List<StreamingOutputCallResponse> results = final List<StreamingOutputCallResponse> results =
Collections.synchronizedList(new ArrayList<StreamingOutputCallResponse>()); Collections.synchronizedList(new ArrayList<StreamingOutputCallResponse>());
final List<SettableFuture<Void>> processedFutures =
Collections.synchronizedList(new LinkedList<SettableFuture<Void>>());
final SettableFuture<Void> completionFuture = SettableFuture.create(); final SettableFuture<Void> completionFuture = SettableFuture.create();
final AtomicInteger count = new AtomicInteger();
call.start(new Call.Listener<StreamingOutputCallResponse>() { call.start(new Call.Listener<StreamingOutputCallResponse>() {
@Override @Override
public ListenableFuture<Void> onHeaders(Metadata.Headers headers) { public void onHeaders(Metadata.Headers headers) {
return null;
} }
@Override @Override
public ListenableFuture<Void> onPayload(final StreamingOutputCallResponse payload) { public void onPayload(final StreamingOutputCallResponse payload) {
SettableFuture<Void> processedFuture = SettableFuture.create();
results.add(payload); results.add(payload);
processedFutures.add(processedFuture); count.incrementAndGet();
return processedFuture;
} }
@Override @Override
@ -460,17 +455,9 @@ public abstract class AbstractTransportTest {
// Slowly set completion on all of the futures. // Slowly set completion on all of the futures.
int expectedResults = responseSizes.size(); int expectedResults = responseSizes.size();
int count = 0; while (count.get() < expectedResults) {
while (count < expectedResults) { // Allow one more inbound message to be delivered to the application.
if (!processedFutures.isEmpty()) { call.request(1);
assertEquals(1, processedFutures.size());
assertEquals(count + 1, results.size());
count++;
// Remove and set the first future to allow receipt of additional messages
// from flow control.
processedFutures.remove(0).set(null);
}
// Sleep a bit. // Sleep a bit.
Uninterruptibles.sleepUninterruptibly(1, TimeUnit.SECONDS); Uninterruptibles.sleepUninterruptibly(1, TimeUnit.SECONDS);

View File

@ -51,11 +51,21 @@ class NettyClientStream extends Http2ClientStream {
private final NettyClientHandler handler; private final NettyClientHandler handler;
NettyClientStream(ClientStreamListener listener, Channel channel, NettyClientHandler handler) { NettyClientStream(ClientStreamListener listener, Channel channel, NettyClientHandler handler) {
super(listener, channel.eventLoop()); super(listener);
this.channel = checkNotNull(channel, "channel"); this.channel = checkNotNull(channel, "channel");
this.handler = checkNotNull(handler, "handler"); this.handler = checkNotNull(handler, "handler");
} }
@Override
public void request(final int numMessages) {
channel.eventLoop().execute(new Runnable() {
@Override
public void run() {
requestMessagesFromDeframer(numMessages);
}
});
}
void transportHeadersReceived(Http2Headers headers, boolean endOfStream) { void transportHeadersReceived(Http2Headers headers, boolean endOfStream) {
if (endOfStream) { if (endOfStream) {
transportTrailersReceived(Utils.convertTrailers(headers)); transportTrailersReceived(Utils.convertTrailers(headers));

View File

@ -51,7 +51,7 @@ class NettyServerStream extends AbstractServerStream<Integer> {
private final NettyServerHandler handler; private final NettyServerHandler handler;
NettyServerStream(Channel channel, int id, NettyServerHandler handler) { NettyServerStream(Channel channel, int id, NettyServerHandler handler) {
super(id, channel.eventLoop()); super(id);
this.channel = checkNotNull(channel, "channel"); this.channel = checkNotNull(channel, "channel");
this.handler = checkNotNull(handler, "handler"); this.handler = checkNotNull(handler, "handler");
} }
@ -60,6 +60,16 @@ class NettyServerStream extends AbstractServerStream<Integer> {
super.inboundDataReceived(new NettyBuffer(frame.retain()), endOfStream); super.inboundDataReceived(new NettyBuffer(frame.retain()), endOfStream);
} }
@Override
public void request(final int numMessages) {
channel.eventLoop().execute(new Runnable() {
@Override
public void run() {
requestMessagesFromDeframer(numMessages);
}
});
}
@Override @Override
protected void inboundDeliveryPaused() { protected void inboundDeliveryPaused() {
// Do nothing. // Do nothing.

View File

@ -45,9 +45,7 @@ import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.same; import static org.mockito.Matchers.same;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.common.util.concurrent.SettableFuture;
import com.google.net.stubby.Metadata; import com.google.net.stubby.Metadata;
import com.google.net.stubby.Status; import com.google.net.stubby.Status;
import com.google.net.stubby.transport.AbstractStream; import com.google.net.stubby.transport.AbstractStream;
@ -64,7 +62,6 @@ import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.Mockito;
import java.io.InputStream; import java.io.InputStream;
@ -202,9 +199,9 @@ public class NettyClientStreamTest extends NettyStreamTestBase {
verify(listener, never()).closed(any(Status.class), any(Metadata.Trailers.class)); verify(listener, never()).closed(any(Status.class), any(Metadata.Trailers.class));
// We are now waiting for 100 bytes of error context on the stream, cancel has not yet been sent // We are now waiting for 100 bytes of error context on the stream, cancel has not yet been sent
Mockito.verify(channel, never()).writeAndFlush(any(CancelStreamCommand.class)); verify(channel, never()).writeAndFlush(any(CancelStreamCommand.class));
stream().transportDataReceived(Unpooled.buffer(100).writeZero(100), false); stream().transportDataReceived(Unpooled.buffer(100).writeZero(100), false);
Mockito.verify(channel, never()).writeAndFlush(any(CancelStreamCommand.class)); verify(channel, never()).writeAndFlush(any(CancelStreamCommand.class));
stream().transportDataReceived(Unpooled.buffer(1000).writeZero(1000), false); stream().transportDataReceived(Unpooled.buffer(1000).writeZero(1000), false);
// Now verify that cancel is sent and an error is reported to the listener // Now verify that cancel is sent and an error is reported to the listener
@ -226,10 +223,6 @@ public class NettyClientStreamTest extends NettyStreamTestBase {
@Test @Test
public void deframedDataAfterCancelShouldBeIgnored() throws Exception { public void deframedDataAfterCancelShouldBeIgnored() throws Exception {
// Mock the listener to return this future when a message is read.
final SettableFuture<Void> future = SettableFuture.create();
when(listener.messageRead(any(InputStream.class), anyInt())).thenReturn(future);
stream().id(1); stream().id(1);
// Receive headers first so that it's a valid GRPC response. // Receive headers first so that it's a valid GRPC response.
stream().transportHeadersReceived(grpcResponseHeaders(), false); stream().transportHeadersReceived(grpcResponseHeaders(), false);
@ -238,6 +231,9 @@ public class NettyClientStreamTest extends NettyStreamTestBase {
stream().transportDataReceived(simpleGrpcFrame(), false); stream().transportDataReceived(simpleGrpcFrame(), false);
stream().transportDataReceived(simpleGrpcFrame(), false); stream().transportDataReceived(simpleGrpcFrame(), false);
// Only allow the first to be delivered.
stream().request(1);
// Receive error trailers. The server status will not be processed until after all of the // Receive error trailers. The server status will not be processed until after all of the
// data frames have been processed. Since cancellation will interrupt message delivery, // data frames have been processed. Since cancellation will interrupt message delivery,
// this status will never be processed and the listener will instead only see the // this status will never be processed and the listener will instead only see the
@ -251,9 +247,8 @@ public class NettyClientStreamTest extends NettyStreamTestBase {
Metadata.Trailers trailers = Utils.convertTrailers(grpcResponseTrailers(Status.CANCELLED)); Metadata.Trailers trailers = Utils.convertTrailers(grpcResponseTrailers(Status.CANCELLED));
stream().transportReportStatus(Status.CANCELLED, true, trailers); stream().transportReportStatus(Status.CANCELLED, true, trailers);
// Now complete the future to trigger the deframer to fire the next message to the // Now allow the delivery of the second.
// stream. stream().request(1);
future.set(null);
// Verify that the listener was only notified of the first message, not the second. // Verify that the listener was only notified of the first message, not the second.
verify(listener).messageRead(any(InputStream.class), anyInt()); verify(listener).messageRead(any(InputStream.class), anyInt());

View File

@ -44,6 +44,7 @@ import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
@ -163,6 +164,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase {
private void inboundDataShouldForwardToStreamListener(boolean endStream) throws Exception { private void inboundDataShouldForwardToStreamListener(boolean endStream) throws Exception {
createStream(); createStream();
stream.request(1);
// Create a data frame and then trigger the handler to read it. // Create a data frame and then trigger the handler to read it.
ByteBuf frame = dataFrame(STREAM_ID, endStream); ByteBuf frame = dataFrame(STREAM_ID, endStream);
@ -180,6 +182,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase {
@Test @Test
public void clientHalfCloseShouldForwardToStreamListener() throws Exception { public void clientHalfCloseShouldForwardToStreamListener() throws Exception {
createStream(); createStream();
stream.request(1);
handler.channelRead(ctx, emptyGrpcFrame(STREAM_ID, true)); handler.channelRead(ctx, emptyGrpcFrame(STREAM_ID, true));
ArgumentCaptor<InputStream> captor = ArgumentCaptor.forClass(InputStream.class); ArgumentCaptor<InputStream> captor = ArgumentCaptor.forClass(InputStream.class);
@ -202,11 +205,12 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase {
@Test @Test
public void streamErrorShouldNotCloseChannel() throws Exception { public void streamErrorShouldNotCloseChannel() throws Exception {
createStream(); createStream();
stream.request(1);
// When a DATA frame is read, throw an exception. It will be converted into an // When a DATA frame is read, throw an exception. It will be converted into an
// Http2StreamException. // Http2StreamException.
RuntimeException e = new RuntimeException("Fake Exception"); RuntimeException e = new RuntimeException("Fake Exception");
when(streamListener.messageRead(any(InputStream.class), anyInt())).thenThrow(e); doThrow(e).when(streamListener).messageRead(any(InputStream.class), anyInt());
// Read a DATA frame to trigger the exception. // Read a DATA frame to trigger the exception.
handler.channelRead(ctx, emptyGrpcFrame(STREAM_ID, true)); handler.channelRead(ctx, emptyGrpcFrame(STREAM_ID, true));
@ -217,7 +221,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase {
// Verify the stream was closed. // Verify the stream was closed.
ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class); ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
verify(streamListener).closed(captor.capture()); verify(streamListener).closed(captor.capture());
assertEquals(e, captor.getValue().asException().getCause().getCause()); assertEquals(e, captor.getValue().asException().getCause());
assertEquals(Code.INTERNAL, captor.getValue().getCode()); assertEquals(Code.INTERNAL, captor.getValue().getCode());
} }
@ -225,7 +229,7 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase {
public void connectionErrorShouldCloseChannel() throws Exception { public void connectionErrorShouldCloseChannel() throws Exception {
createStream(); createStream();
// Read a DATA frame to trigger the exception. // Read a bad frame to trigger the exception.
handler.channelRead(ctx, badFrame()); handler.channelRead(ctx, badFrame());
// Verify the expected GO_AWAY frame was written. // Verify the expected GO_AWAY frame was written.

View File

@ -32,7 +32,6 @@
package com.google.net.stubby.transport.netty; package com.google.net.stubby.transport.netty;
import static com.google.net.stubby.transport.netty.NettyTestUtil.messageFrame; import static com.google.net.stubby.transport.netty.NettyTestUtil.messageFrame;
import static com.google.net.stubby.transport.netty.NettyTestUtil.statusFrame;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;

View File

@ -35,14 +35,12 @@ import static com.google.net.stubby.transport.netty.NettyTestUtil.messageFrame;
import static io.netty.util.CharsetUtil.UTF_8; import static io.netty.util.CharsetUtil.UTF_8;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.anyLong; import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.common.util.concurrent.SettableFuture;
import com.google.net.stubby.transport.AbstractStream; import com.google.net.stubby.transport.AbstractStream;
import com.google.net.stubby.transport.StreamListener; import com.google.net.stubby.transport.StreamListener;
@ -95,8 +93,6 @@ public abstract class NettyStreamTestBase {
@Mock @Mock
protected ChannelPromise promise; protected ChannelPromise promise;
protected SettableFuture<Void> processingFuture;
protected InputStream input; protected InputStream input;
protected AbstractStream<Integer> stream; protected AbstractStream<Integer> stream;
@ -114,9 +110,6 @@ public abstract class NettyStreamTestBase {
when(pipeline.firstContext()).thenReturn(ctx); when(pipeline.firstContext()).thenReturn(ctx);
when(eventLoop.inEventLoop()).thenReturn(true); when(eventLoop.inEventLoop()).thenReturn(true);
processingFuture = SettableFuture.create();
when(listener().messageRead(any(InputStream.class), anyInt())).thenReturn(processingFuture);
doAnswer(new Answer<Void>() { doAnswer(new Answer<Void>() {
@Override @Override
public Void answer(InvocationOnMock invocation) throws Throwable { public Void answer(InvocationOnMock invocation) throws Throwable {
@ -132,6 +125,8 @@ public abstract class NettyStreamTestBase {
@Test @Test
public void inboundMessageShouldCallListener() throws Exception { public void inboundMessageShouldCallListener() throws Exception {
stream.request(1);
if (stream instanceof NettyServerStream) { if (stream instanceof NettyServerStream) {
((NettyServerStream) stream).inboundDataReceived(messageFrame(MESSAGE), false); ((NettyServerStream) stream).inboundDataReceived(messageFrame(MESSAGE), false);
} else { } else {
@ -142,10 +137,6 @@ public abstract class NettyStreamTestBase {
// Verify that inbound flow control window update has been disabled for the stream. // Verify that inbound flow control window update has been disabled for the stream.
assertEquals(MESSAGE, NettyTestUtil.toString(captor.getValue())); assertEquals(MESSAGE, NettyTestUtil.toString(captor.getValue()));
// Verify that inbound flow control window update has been re-enabled for the stream after
// the future completes.
processingFuture.set(null);
} }
protected abstract AbstractStream<Integer> createStream(); protected abstract AbstractStream<Integer> createStream();

View File

@ -42,7 +42,6 @@ import com.squareup.okhttp.internal.spdy.Header;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.List; import java.util.List;
import java.util.concurrent.Executor;
import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.GuardedBy;
@ -57,28 +56,11 @@ class OkHttpClientStream extends Http2ClientStream {
/** /**
* Construct a new client stream. * Construct a new client stream.
*/ */
static OkHttpClientStream newStream(final Executor executor, ClientStreamListener listener, static OkHttpClientStream newStream(ClientStreamListener listener,
AsyncFrameWriter frameWriter, AsyncFrameWriter frameWriter,
OkHttpClientTransport transport, OkHttpClientTransport transport,
OutboundFlowController outboundFlow) { OutboundFlowController outboundFlow) {
// Create a lock object that can be used by both the executor and methods in the stream return new OkHttpClientStream(listener, frameWriter, transport, outboundFlow);
// to ensure consistent locking behavior.
final Object executorLock = new Object();
Executor synchronizingExecutor = new Executor() {
@Override
public void execute(final Runnable command) {
executor.execute(new Runnable() {
@Override
public void run() {
synchronized (executorLock) {
command.run();
}
}
});
}
};
return new OkHttpClientStream(synchronizingExecutor, listener, frameWriter, transport,
executorLock, outboundFlow);
} }
@GuardedBy("executorLock") @GuardedBy("executorLock")
@ -88,25 +70,28 @@ class OkHttpClientStream extends Http2ClientStream {
private final AsyncFrameWriter frameWriter; private final AsyncFrameWriter frameWriter;
private final OutboundFlowController outboundFlow; private final OutboundFlowController outboundFlow;
private final OkHttpClientTransport transport; private final OkHttpClientTransport transport;
// Lock used to synchronize with work done on the executor. private final Object lock = new Object();
private final Object executorLock;
private Object outboundFlowState; private Object outboundFlowState;
private OkHttpClientStream(final Executor executor, private OkHttpClientStream(ClientStreamListener listener,
final ClientStreamListener listener,
AsyncFrameWriter frameWriter, AsyncFrameWriter frameWriter,
OkHttpClientTransport transport, OkHttpClientTransport transport,
Object executorLock,
OutboundFlowController outboundFlow) { OutboundFlowController outboundFlow) {
super(listener, executor); super(listener);
this.frameWriter = frameWriter; this.frameWriter = frameWriter;
this.transport = transport; this.transport = transport;
this.executorLock = executorLock;
this.outboundFlow = outboundFlow; this.outboundFlow = outboundFlow;
} }
@Override
public void request(final int numMessages) {
synchronized (lock) {
requestMessagesFromDeframer(numMessages);
}
}
public void transportHeadersReceived(List<Header> headers, boolean endOfStream) { public void transportHeadersReceived(List<Header> headers, boolean endOfStream) {
synchronized (executorLock) { synchronized (lock) {
if (endOfStream) { if (endOfStream) {
transportTrailersReceived(Utils.convertTrailers(headers)); transportTrailersReceived(Utils.convertTrailers(headers));
} else { } else {
@ -120,7 +105,7 @@ class OkHttpClientStream extends Http2ClientStream {
* the future listeners (executed by synchronizedExecutor) will not be executed in the same time. * the future listeners (executed by synchronizedExecutor) will not be executed in the same time.
*/ */
public void transportDataReceived(okio.Buffer frame, boolean endOfStream) { public void transportDataReceived(okio.Buffer frame, boolean endOfStream) {
synchronized (executorLock) { synchronized (lock) {
long length = frame.size(); long length = frame.size();
window -= length; window -= length;
super.transportDataReceived(new OkHttpBuffer(frame), endOfStream); super.transportDataReceived(new OkHttpBuffer(frame), endOfStream);
@ -143,7 +128,7 @@ class OkHttpClientStream extends Http2ClientStream {
@Override @Override
protected void returnProcessedBytes(int processedBytes) { protected void returnProcessedBytes(int processedBytes) {
synchronized (executorLock) { synchronized (lock) {
processedWindow -= processedBytes; processedWindow -= processedBytes;
if (processedWindow <= WINDOW_UPDATE_THRESHOLD) { if (processedWindow <= WINDOW_UPDATE_THRESHOLD) {
int delta = OkHttpClientTransport.DEFAULT_INITIAL_WINDOW_SIZE - processedWindow; int delta = OkHttpClientTransport.DEFAULT_INITIAL_WINDOW_SIZE - processedWindow;
@ -157,7 +142,7 @@ class OkHttpClientStream extends Http2ClientStream {
@Override @Override
public void transportReportStatus(Status newStatus, boolean stopDelivery, public void transportReportStatus(Status newStatus, boolean stopDelivery,
Metadata.Trailers trailers) { Metadata.Trailers trailers) {
synchronized (executorLock) { synchronized (lock) {
super.transportReportStatus(newStatus, stopDelivery, trailers); super.transportReportStatus(newStatus, stopDelivery, trailers);
} }
} }

View File

@ -166,7 +166,7 @@ public class OkHttpClientTransport extends AbstractClientTransport {
protected ClientStream newStreamInternal(MethodDescriptor<?, ?> method, protected ClientStream newStreamInternal(MethodDescriptor<?, ?> method,
Metadata.Headers headers, Metadata.Headers headers,
ClientStreamListener listener) { ClientStreamListener listener) {
OkHttpClientStream clientStream = OkHttpClientStream.newStream(executor, listener, OkHttpClientStream clientStream = OkHttpClientStream.newStream(listener,
frameWriter, this, outboundFlow); frameWriter, this, outboundFlow);
if (goAway) { if (goAway) {
clientStream.transportReportStatus(goAwayStatus, false, new Metadata.Trailers()); clientStream.transportReportStatus(goAwayStatus, false, new Metadata.Trailers());

View File

@ -45,7 +45,6 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.Service; import com.google.common.util.concurrent.Service;
import com.google.common.util.concurrent.Service.State; import com.google.common.util.concurrent.Service.State;
import com.google.net.stubby.Metadata; import com.google.net.stubby.Metadata;
@ -137,8 +136,8 @@ public class OkHttpClientTransportTest {
public void nextFrameThrowIOException() throws Exception { public void nextFrameThrowIOException() throws Exception {
MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener1 = new MockStreamListener();
MockStreamListener listener2 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener();
clientTransport.newStream(method, new Metadata.Headers(), listener1); clientTransport.newStream(method, new Metadata.Headers(), listener1).request(1);
clientTransport.newStream(method, new Metadata.Headers(), listener2); clientTransport.newStream(method, new Metadata.Headers(), listener2).request(1);
assertEquals(2, streams.size()); assertEquals(2, streams.size());
assertTrue(streams.containsKey(3)); assertTrue(streams.containsKey(3));
assertTrue(streams.containsKey(5)); assertTrue(streams.containsKey(5));
@ -158,7 +157,7 @@ public class OkHttpClientTransportTest {
final int numMessages = 10; final int numMessages = 10;
final String message = "Hello Client"; final String message = "Hello Client";
MockStreamListener listener = new MockStreamListener(); MockStreamListener listener = new MockStreamListener();
clientTransport.newStream(method, new Metadata.Headers(), listener); clientTransport.newStream(method, new Metadata.Headers(), listener).request(numMessages);
assertTrue(streams.containsKey(3)); assertTrue(streams.containsKey(3));
frameHandler.headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS); frameHandler.headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS);
assertNotNull(listener.headers); assertNotNull(listener.headers);
@ -179,7 +178,7 @@ public class OkHttpClientTransportTest {
@Test @Test
public void invalidInboundHeadersCancelStream() throws Exception { public void invalidInboundHeadersCancelStream() throws Exception {
MockStreamListener listener = new MockStreamListener(); MockStreamListener listener = new MockStreamListener();
clientTransport.newStream(method, new Metadata.Headers(), listener); clientTransport.newStream(method, new Metadata.Headers(), listener).request(1);
assertTrue(streams.containsKey(3)); assertTrue(streams.containsKey(3));
// Empty headers block without correct content type or status // Empty headers block without correct content type or status
frameHandler.headers(false, false, 3, 0, new ArrayList<Header>(), frameHandler.headers(false, false, 3, 0, new ArrayList<Header>(),
@ -246,8 +245,8 @@ public class OkHttpClientTransportTest {
public void windowUpdate() throws Exception { public void windowUpdate() throws Exception {
MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener1 = new MockStreamListener();
MockStreamListener listener2 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener();
clientTransport.newStream(method,new Metadata.Headers(), listener1); clientTransport.newStream(method,new Metadata.Headers(), listener1).request(2);
clientTransport.newStream(method,new Metadata.Headers(), listener2); clientTransport.newStream(method,new Metadata.Headers(), listener2).request(2);
assertEquals(2, streams.size()); assertEquals(2, streams.size());
OkHttpClientStream stream1 = streams.get(3); OkHttpClientStream stream1 = streams.get(3);
OkHttpClientStream stream2 = streams.get(5); OkHttpClientStream stream2 = streams.get(5);
@ -299,7 +298,7 @@ public class OkHttpClientTransportTest {
@Test @Test
public void windowUpdateWithInboundFlowControl() throws Exception { public void windowUpdateWithInboundFlowControl() throws Exception {
MockStreamListener listener = new MockStreamListener(); MockStreamListener listener = new MockStreamListener();
clientTransport.newStream(method, new Metadata.Headers(), listener); clientTransport.newStream(method, new Metadata.Headers(), listener).request(1);
OkHttpClientStream stream = streams.get(3); OkHttpClientStream stream = streams.get(3);
int messageLength = OkHttpClientTransport.DEFAULT_INITIAL_WINDOW_SIZE / 2 + 1; int messageLength = OkHttpClientTransport.DEFAULT_INITIAL_WINDOW_SIZE / 2 + 1;
@ -342,8 +341,8 @@ public class OkHttpClientTransportTest {
// start 2 streams. // start 2 streams.
MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener1 = new MockStreamListener();
MockStreamListener listener2 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener();
clientTransport.newStream(method,new Metadata.Headers(), listener1); clientTransport.newStream(method,new Metadata.Headers(), listener1).request(1);
clientTransport.newStream(method,new Metadata.Headers(), listener2); clientTransport.newStream(method,new Metadata.Headers(), listener2).request(1);
assertEquals(2, streams.size()); assertEquals(2, streams.size());
// Receive goAway, max good id is 3. // Receive goAway, max good id is 3.
@ -494,18 +493,16 @@ public class OkHttpClientTransportTest {
} }
@Override @Override
public ListenableFuture<Void> headersRead(Metadata.Headers headers) { public void headersRead(Metadata.Headers headers) {
this.headers = headers; this.headers = headers;
return null;
} }
@Override @Override
public ListenableFuture<Void> messageRead(InputStream message, int length) { public void messageRead(InputStream message, int length) {
String msg = getContent(message); String msg = getContent(message);
if (msg != null) { if (msg != null) {
messages.add(msg); messages.add(msg);
} }
return null;
} }
@Override @Override
@ -522,13 +519,18 @@ public class OkHttpClientTransportTest {
} }
static String getContent(InputStream message) { static String getContent(InputStream message) {
BufferedReader br = BufferedReader br = new BufferedReader(new InputStreamReader(message, UTF_8));
new BufferedReader(new InputStreamReader(message, UTF_8));
try { try {
// Only one line message is used in this test. // Only one line message is used in this test.
return br.readLine(); return br.readLine();
} catch (IOException e) { } catch (IOException e) {
return null; return null;
} finally {
try {
message.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
} }
} }
} }

View File

@ -160,7 +160,7 @@ public class Calls {
ReqT param, ReqT param,
StreamObserver<RespT> responseObserver) { StreamObserver<RespT> responseObserver) {
asyncServerStreamingCall(call, param, asyncServerStreamingCall(call, param,
new StreamObserverToCallListenerAdapter<RespT>(responseObserver)); new StreamObserverToCallListenerAdapter<RespT>(call, responseObserver));
} }
private static <ReqT, RespT> void asyncServerStreamingCall( private static <ReqT, RespT> void asyncServerStreamingCall(
@ -168,6 +168,7 @@ public class Calls {
ReqT param, ReqT param,
Call.Listener<RespT> responseListener) { Call.Listener<RespT> responseListener) {
call.start(responseListener, new Metadata.Headers()); call.start(responseListener, new Metadata.Headers());
call.request(1);
try { try {
call.sendPayload(param); call.sendPayload(param);
call.halfClose(); call.halfClose();
@ -217,10 +218,11 @@ public class Calls {
* Execute a duplex-streaming call. * Execute a duplex-streaming call.
* @return request stream observer. * @return request stream observer.
*/ */
public static <ReqT, RespT> StreamObserver<ReqT> duplexStreamingCall( public static <ReqT, RespT> StreamObserver<ReqT> duplexStreamingCall(Call<ReqT, RespT> call,
Call<ReqT, RespT> call, StreamObserver<RespT> responseObserver) { StreamObserver<RespT> responseObserver) {
call.start(new StreamObserverToCallListenerAdapter<RespT>(responseObserver), call.start(new StreamObserverToCallListenerAdapter<RespT>(call, responseObserver),
new Metadata.Headers()); new Metadata.Headers());
call.request(1);
return new CallToStreamObserverAdapter<ReqT>(call); return new CallToStreamObserverAdapter<ReqT>(call);
} }
@ -248,22 +250,25 @@ public class Calls {
} }
} }
private static class StreamObserverToCallListenerAdapter<T> extends Call.Listener<T> { private static class StreamObserverToCallListenerAdapter<RespT> extends Call.Listener<RespT> {
private final StreamObserver<T> observer; private final Call<?, RespT> call;
private final StreamObserver<RespT> observer;
public StreamObserverToCallListenerAdapter(StreamObserver<T> observer) { public StreamObserverToCallListenerAdapter(Call<?, RespT> call, StreamObserver<RespT> observer) {
this.call = call;
this.observer = observer; this.observer = observer;
} }
@Override @Override
public ListenableFuture<Void> onHeaders(Metadata.Headers headers) { public void onHeaders(Metadata.Headers headers) {
return null;
} }
@Override @Override
public ListenableFuture<Void> onPayload(T payload) { public void onPayload(RespT payload) {
observer.onValue(payload); observer.onValue(payload);
return null;
// Request delivery of the next inbound message.
call.request(1);
} }
@Override @Override
@ -288,18 +293,16 @@ public class Calls {
} }
@Override @Override
public ListenableFuture<Void> onHeaders(Metadata.Headers headers) { public void onHeaders(Metadata.Headers headers) {
return null;
} }
@Override @Override
public ListenableFuture<Void> onPayload(RespT value) { public void onPayload(RespT value) {
if (this.value != null) { if (this.value != null) {
throw Status.INTERNAL.withDescription("More than one value received for unary call") throw Status.INTERNAL.withDescription("More than one value received for unary call")
.asRuntimeException(); .asRuntimeException();
} }
this.value = value; this.value = value;
return null;
} }
@Override @Override
@ -357,11 +360,13 @@ public class Calls {
if (!hasNext()) { if (!hasNext()) {
throw new NoSuchElementException(); throw new NoSuchElementException();
} }
try {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
Payload<T> tmp = (Payload<T>) last; T tmp = (T) last;
return tmp;
} finally {
last = null; last = null;
tmp.processed.set(null); }
return tmp.value;
} }
@Override @Override
@ -373,16 +378,13 @@ public class Calls {
private boolean done = false; private boolean done = false;
@Override @Override
public ListenableFuture<Void> onHeaders(Metadata.Headers headers) { public void onHeaders(Metadata.Headers headers) {
return null;
} }
@Override @Override
public ListenableFuture<Void> onPayload(T value) { public void onPayload(T value) {
Preconditions.checkState(!done, "Call already closed"); Preconditions.checkState(!done, "Call already closed");
SettableFuture<Void> future = SettableFuture.create(); buffer.add(value);
buffer.add(new Payload<T>(value, future));
return future;
} }
@Override @Override
@ -397,14 +399,4 @@ public class Calls {
} }
} }
} }
private static class Payload<T> {
public final T value;
public final SettableFuture<Void> processed;
public Payload(T value, SettableFuture<Void> processed) {
this.value = value;
this.processed = processed;
}
}
} }

View File

@ -31,7 +31,6 @@
package com.google.net.stubby.stub; package com.google.net.stubby.stub;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.net.stubby.Call; import com.google.net.stubby.Call;
import com.google.net.stubby.Channel; import com.google.net.stubby.Channel;
import com.google.net.stubby.ClientInterceptor; import com.google.net.stubby.ClientInterceptor;
@ -122,9 +121,9 @@ public class MetadataUtils {
trailersCapture.set(null); trailersCapture.set(null);
super.start(new ForwardingListener<RespT>(responseListener) { super.start(new ForwardingListener<RespT>(responseListener) {
@Override @Override
public ListenableFuture<Void> onHeaders(Metadata.Headers headers) { public void onHeaders(Metadata.Headers headers) {
headersCapture.set(headers); headersCapture.set(headers);
return super.onHeaders(headers); super.onHeaders(headers);
} }
@Override @Override

View File

@ -31,7 +31,6 @@
package com.google.net.stubby.stub; package com.google.net.stubby.stub;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.net.stubby.Metadata; import com.google.net.stubby.Metadata;
import com.google.net.stubby.ServerCall; import com.google.net.stubby.ServerCall;
import com.google.net.stubby.ServerCallHandler; import com.google.net.stubby.ServerCallHandler;
@ -59,21 +58,24 @@ public class ServerCalls {
public ServerCall.Listener<ReqT> startCall( public ServerCall.Listener<ReqT> startCall(
String fullMethodName, final ServerCall<RespT> call, Metadata.Headers headers) { String fullMethodName, final ServerCall<RespT> call, Metadata.Headers headers) {
final ResponseObserver<RespT> responseObserver = new ResponseObserver<RespT>(call); final ResponseObserver<RespT> responseObserver = new ResponseObserver<RespT>(call);
call.request(1);
return new EmptyServerCallListener<ReqT>() { return new EmptyServerCallListener<ReqT>() {
ReqT request; ReqT request;
@Override @Override
public ListenableFuture<Void> onPayload(ReqT request) { public void onPayload(ReqT request) {
if (this.request == null) { if (this.request == null) {
// We delay calling method.invoke() until onHalfClose(), because application may call // We delay calling method.invoke() until onHalfClose(), because application may call
// close(OK) inside invoke(), while close(OK) is not allowed before onHalfClose(). // close(OK) inside invoke(), while close(OK) is not allowed before onHalfClose().
this.request = request; this.request = request;
// Request delivery of the next inbound message.
call.request(1);
} else { } else {
call.close( call.close(
Status.INVALID_ARGUMENT.withDescription( Status.INVALID_ARGUMENT.withDescription(
"More than one request payloads for unary call or server streaming call"), "More than one request payloads for unary call or server streaming call"),
new Metadata.Trailers()); new Metadata.Trailers());
} }
return null;
} }
@Override @Override
@ -99,17 +101,20 @@ public class ServerCalls {
final StreamingRequestMethod<ReqT, RespT> method) { final StreamingRequestMethod<ReqT, RespT> method) {
return new ServerCallHandler<ReqT, RespT>() { return new ServerCallHandler<ReqT, RespT>() {
@Override @Override
public ServerCall.Listener<ReqT> startCall(String fullMethodName, ServerCall<RespT> call, public ServerCall.Listener<ReqT> startCall(String fullMethodName,
Metadata.Headers headers) { final ServerCall<RespT> call, Metadata.Headers headers) {
call.request(1);
final ResponseObserver<RespT> responseObserver = new ResponseObserver<RespT>(call); final ResponseObserver<RespT> responseObserver = new ResponseObserver<RespT>(call);
final StreamObserver<ReqT> requestObserver = method.invoke(responseObserver); final StreamObserver<ReqT> requestObserver = method.invoke(responseObserver);
return new EmptyServerCallListener<ReqT>() { return new EmptyServerCallListener<ReqT>() {
boolean halfClosed = false; boolean halfClosed = false;
@Override @Override
public ListenableFuture<Void> onPayload(ReqT request) { public void onPayload(ReqT request) {
requestObserver.onValue(request); requestObserver.onValue(request);
return null;
// Request delivery of the next inbound message.
call.request(1);
} }
@Override @Override
@ -158,6 +163,9 @@ public class ServerCalls {
throw Status.CANCELLED.asRuntimeException(); throw Status.CANCELLED.asRuntimeException();
} }
call.sendPayload(response); call.sendPayload(response);
// Request delivery of the next inbound message.
call.request(1);
} }
@Override @Override
@ -177,8 +185,7 @@ public class ServerCalls {
private static class EmptyServerCallListener<ReqT> extends ServerCall.Listener<ReqT> { private static class EmptyServerCallListener<ReqT> extends ServerCall.Listener<ReqT> {
@Override @Override
public ListenableFuture<Void> onPayload(ReqT request) { public void onPayload(ReqT request) {
return null;
} }
@Override @Override