diff --git a/api/src/main/java/io/grpc/ServerBuilder.java b/api/src/main/java/io/grpc/ServerBuilder.java index 731227e501..e5f0ae6270 100644 --- a/api/src/main/java/io/grpc/ServerBuilder.java +++ b/api/src/main/java/io/grpc/ServerBuilder.java @@ -246,7 +246,7 @@ public abstract class ServerBuilder> { /** * Sets the time without read activity before sending a keepalive ping. An unreasonably small * value might be increased, and {@code Long.MAX_VALUE} nano seconds or an unreasonably large - * value will disable keepalive. The typical default is infinite when supported. + * value will disable keepalive. The typical default is two hours when supported. * * @throws IllegalArgumentException if time is not positive * @throws UnsupportedOperationException if unsupported diff --git a/core/src/main/java/io/grpc/internal/AbstractServerStream.java b/core/src/main/java/io/grpc/internal/AbstractServerStream.java index 3513ec9346..f94c3e54e2 100644 --- a/core/src/main/java/io/grpc/internal/AbstractServerStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractServerStream.java @@ -50,7 +50,7 @@ public abstract class AbstractServerStream extends AbstractStream * @param flush {@code true} if more data may not be arriving soon * @param numMessages the number of messages this frame represents */ - void writeFrame(@Nullable WritableBuffer frame, boolean flush, int numMessages); + void writeFrame(WritableBuffer frame, boolean flush, int numMessages); /** * Sends trailers to the remote end point. This call implies end of stream. @@ -108,7 +108,14 @@ public abstract class AbstractServerStream extends AbstractStream WritableBuffer frame, boolean endOfStream, boolean flush, int numMessages) { // Since endOfStream is triggered by the sending of trailers, avoid flush here and just flush // after the trailers. - abstractServerStreamSink().writeFrame(frame, endOfStream ? false : flush, numMessages); + if (frame == null) { + assert endOfStream; + return; + } + if (endOfStream) { + flush = false; + } + abstractServerStreamSink().writeFrame(frame, flush, numMessages); } @Override diff --git a/core/src/main/java/io/grpc/internal/GrpcUtil.java b/core/src/main/java/io/grpc/internal/GrpcUtil.java index 91fceb4d9f..605fabc0cc 100644 --- a/core/src/main/java/io/grpc/internal/GrpcUtil.java +++ b/core/src/main/java/io/grpc/internal/GrpcUtil.java @@ -799,6 +799,12 @@ public final class GrpcUtil { } } + /** Reads {@code in} until end of stream. */ + public static void exhaust(InputStream in) throws IOException { + byte[] buf = new byte[256]; + while (in.read(buf) != -1) {} + } + /** * Checks whether the given item exists in the iterable. This is copied from Guava Collect's * {@code Iterables.contains()} because Guava Collect is not Android-friendly thus core can't diff --git a/core/src/test/java/io/grpc/internal/AbstractTransportTest.java b/core/src/test/java/io/grpc/internal/AbstractTransportTest.java index d70d8bc6b8..a1c00d7dca 100644 --- a/core/src/test/java/io/grpc/internal/AbstractTransportTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractTransportTest.java @@ -408,12 +408,13 @@ public abstract class AbstractTransportTest { } assumeTrue("transport is not using InetSocketAddress", port != -1); server.shutdown(); + assertTrue(serverListener.waitForShutdown(TIMEOUT_MS, TimeUnit.MILLISECONDS)); server = newServer(port, Arrays.asList(serverStreamTracerFactory)); boolean success; Thread.currentThread().interrupt(); try { - server.start(serverListener); + server.start(serverListener = new MockServerListener()); success = true; } catch (Exception ex) { success = false; diff --git a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java index f552b937a0..902b54f376 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java @@ -52,6 +52,7 @@ import io.grpc.internal.TransportTracer; import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2ServerHeadersDecoder; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; @@ -854,6 +855,9 @@ class NettyServerHandler extends AbstractNettyHandler { keepAliveManager.onDataReceived(); } NettyServerHandler.this.onHeadersRead(ctx, streamId, headers); + if (endStream) { + NettyServerHandler.this.onDataRead(streamId, Unpooled.EMPTY_BUFFER, 0, endStream); + } } @Override diff --git a/netty/src/main/java/io/grpc/netty/NettyServerStream.java b/netty/src/main/java/io/grpc/netty/NettyServerStream.java index 3850a6a291..6ab391b260 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerStream.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerStream.java @@ -108,10 +108,6 @@ class NettyServerStream extends AbstractServerStream { private void writeFrameInternal(WritableBuffer frame, boolean flush, final int numMessages) { Preconditions.checkArgument(numMessages >= 0); - if (frame == null) { - writeQueue.scheduleFlush(); - return; - } ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf().touch(); final int numBytes = bytebuf.readableBytes(); // Add the bytes to outbound flow control. diff --git a/okhttp/build.gradle b/okhttp/build.gradle index fffdf8d76d..1634410dae 100644 --- a/okhttp/build.gradle +++ b/okhttp/build.gradle @@ -21,7 +21,7 @@ dependencies { testImplementation project(':grpc-core').sourceSets.test.output, project(':grpc-api').sourceSets.test.output, project(':grpc-testing'), - project(':grpc-netty'), + libraries.netty.codec.http2, libraries.okhttp signature "org.codehaus.mojo.signature:java17:1.0@signature" signature "net.sf.androidscents.signature:android-api-level-14:4.0_r4@signature" diff --git a/okhttp/src/main/java/io/grpc/okhttp/AsyncSink.java b/okhttp/src/main/java/io/grpc/okhttp/AsyncSink.java index bf90843efe..eb1e1b4b0a 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/AsyncSink.java +++ b/okhttp/src/main/java/io/grpc/okhttp/AsyncSink.java @@ -164,6 +164,13 @@ final class AsyncSink implements Sink { serializingExecutor.execute(new Runnable() { @Override public void run() { + try { + if (buffer.size() > 0) { + sink.write(buffer, buffer.size()); + } + } catch (IOException e) { + transportExceptionHandler.onException(e); + } buffer.close(); try { if (sink != null) { diff --git a/okhttp/src/main/java/io/grpc/okhttp/HandshakerSocketFactory.java b/okhttp/src/main/java/io/grpc/okhttp/HandshakerSocketFactory.java new file mode 100644 index 0000000000..a6cf8db9b4 --- /dev/null +++ b/okhttp/src/main/java/io/grpc/okhttp/HandshakerSocketFactory.java @@ -0,0 +1,42 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import com.google.common.base.Preconditions; +import io.grpc.Attributes; +import io.grpc.InternalChannelz; +import java.io.IOException; +import java.net.Socket; + +/** Handshakes new connections. */ +interface HandshakerSocketFactory { + /** When the returned socket is closed, {@code socket} must be closed. */ + HandshakeResult handshake(Socket socket, Attributes attributes) throws IOException; + + static final class HandshakeResult { + public final Socket socket; + public final Attributes attributes; + public final InternalChannelz.Security securityInfo; + + public HandshakeResult( + Socket socket, Attributes attributes, InternalChannelz.Security securityInfo) { + this.socket = Preconditions.checkNotNull(socket, "socket"); + this.attributes = Preconditions.checkNotNull(attributes, "attributes"); + this.securityInfo = securityInfo; + } + } +} diff --git a/okhttp/src/main/java/io/grpc/okhttp/Headers.java b/okhttp/src/main/java/io/grpc/okhttp/Headers.java index 15008f8040..ff2033b35b 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/Headers.java +++ b/okhttp/src/main/java/io/grpc/okhttp/Headers.java @@ -16,9 +16,6 @@ package io.grpc.okhttp; -import static io.grpc.internal.GrpcUtil.CONTENT_TYPE_KEY; -import static io.grpc.internal.GrpcUtil.USER_AGENT_KEY; - import com.google.common.base.Preconditions; import io.grpc.InternalMetadata; import io.grpc.Metadata; @@ -39,7 +36,7 @@ class Headers { public static final Header METHOD_HEADER = new Header(Header.TARGET_METHOD, GrpcUtil.HTTP_METHOD); public static final Header METHOD_GET_HEADER = new Header(Header.TARGET_METHOD, "GET"); public static final Header CONTENT_TYPE_HEADER = - new Header(CONTENT_TYPE_KEY.name(), GrpcUtil.CONTENT_TYPE_GRPC); + new Header(GrpcUtil.CONTENT_TYPE_KEY.name(), GrpcUtil.CONTENT_TYPE_GRPC); public static final Header TE_HEADER = new Header("te", GrpcUtil.TE_TRAILERS); /** @@ -58,10 +55,7 @@ class Headers { Preconditions.checkNotNull(defaultPath, "defaultPath"); Preconditions.checkNotNull(authority, "authority"); - // Discard any application supplied duplicates of the reserved headers - headers.discardAll(GrpcUtil.CONTENT_TYPE_KEY); - headers.discardAll(GrpcUtil.TE_HEADER); - headers.discardAll(GrpcUtil.USER_AGENT_KEY); + stripNonApplicationHeaders(headers); // 7 is the number of explicit add calls below. List
okhttpHeaders = new ArrayList<>(7 + InternalMetadata.headerCount(headers)); @@ -89,27 +83,72 @@ class Headers { okhttpHeaders.add(TE_HEADER); // Now add any application-provided headers. - byte[][] serializedHeaders = TransportFrameUtil.toHttp2Headers(headers); - for (int i = 0; i < serializedHeaders.length; i += 2) { - ByteString key = ByteString.of(serializedHeaders[i]); - String keyString = key.utf8(); - if (isApplicationHeader(keyString)) { - ByteString value = ByteString.of(serializedHeaders[i + 1]); - okhttpHeaders.add(new Header(key, value)); - } - } - - return okhttpHeaders; + return addMetadata(okhttpHeaders, headers); } /** - * Returns {@code true} if the given header is an application-provided header. Otherwise, returns - * {@code false} if the header is reserved by GRPC. + * Serializes the given headers and creates a list of OkHttp {@link Header}s to be used when + * starting a response. Since this serializes the headers, this method should be called in the + * application thread context. */ - private static boolean isApplicationHeader(String key) { - // Don't allow HTTP/2 pseudo headers or content-type to be added by the application. - return (!key.startsWith(":") - && !CONTENT_TYPE_KEY.name().equalsIgnoreCase(key)) - && !USER_AGENT_KEY.name().equalsIgnoreCase(key); + public static List
createResponseHeaders(Metadata headers) { + stripNonApplicationHeaders(headers); + + // 2 is the number of explicit add calls below. + List
okhttpHeaders = new ArrayList<>(2 + InternalMetadata.headerCount(headers)); + okhttpHeaders.add(new Header(Header.RESPONSE_STATUS, "200")); + // All non-pseudo headers must come after pseudo headers. + okhttpHeaders.add(CONTENT_TYPE_HEADER); + return addMetadata(okhttpHeaders, headers); + } + + /** + * Serializes the given headers and creates a list of OkHttp {@link Header}s to be used when + * finishing a response. Since this serializes the headers, this method should be called in the + * application thread context. + */ + public static List
createResponseTrailers(Metadata trailers, boolean headersSent) { + if (!headersSent) { + return createResponseHeaders(trailers); + } + stripNonApplicationHeaders(trailers); + + List
okhttpTrailers = new ArrayList<>(InternalMetadata.headerCount(trailers)); + return addMetadata(okhttpTrailers, trailers); + } + + /** + * Serializes the given headers and creates a list of OkHttp {@link Header}s to be used when + * failing with an HTTP response. + */ + public static List
createHttpResponseHeaders( + int httpCode, String contentType, Metadata headers) { + // 2 is the number of explicit add calls below. + List
okhttpHeaders = new ArrayList<>(2 + InternalMetadata.headerCount(headers)); + okhttpHeaders.add(new Header(Header.RESPONSE_STATUS, "" + httpCode)); + // All non-pseudo headers must come after pseudo headers. + okhttpHeaders.add(new Header(GrpcUtil.CONTENT_TYPE_KEY.name(), contentType)); + return addMetadata(okhttpHeaders, headers); + } + + private static List
addMetadata(List
okhttpHeaders, Metadata toAdd) { + byte[][] serializedHeaders = TransportFrameUtil.toHttp2Headers(toAdd); + for (int i = 0; i < serializedHeaders.length; i += 2) { + ByteString key = ByteString.of(serializedHeaders[i]); + // Don't allow HTTP/2 pseudo headers to be added by the application. + if (key.size() == 0 || key.getByte(0) == ':') { + continue; + } + ByteString value = ByteString.of(serializedHeaders[i + 1]); + okhttpHeaders.add(new Header(key, value)); + } + return okhttpHeaders; + } + + /** Strips all non-pseudo headers reserved by gRPC, to avoid duplicates and misinterpretation. */ + private static void stripNonApplicationHeaders(Metadata headers) { + headers.discardAll(GrpcUtil.CONTENT_TYPE_KEY); + headers.discardAll(GrpcUtil.TE_HEADER); + headers.discardAll(GrpcUtil.USER_AGENT_KEY); } } diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java index e4df0b50e2..a752885766 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java @@ -139,7 +139,7 @@ public final class OkHttpChannelBuilder extends ((ExecutorService) executor).shutdown(); } }; - private static final ObjectPool DEFAULT_TRANSPORT_EXECUTOR_POOL = + static final ObjectPool DEFAULT_TRANSPORT_EXECUTOR_POOL = SharedResourcePool.forResource(SHARED_EXECUTOR); /** Creates a new builder for the given server host and port. */ diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java index baf659a627..46396b2a41 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientStream.java @@ -53,8 +53,6 @@ class OkHttpClientStream extends AbstractClientStream { private final String userAgent; private final StatsTraceContext statsTraceCtx; private String authority; - private Object outboundFlowState; - private volatile int id = ABSENT_ID; private final TransportState state; private final Sink sink = new Sink(); private final Attributes attributes; @@ -120,10 +118,6 @@ class OkHttpClientStream extends AbstractClientStream { return method.getType(); } - public int id() { - return id; - } - /** * Returns whether the stream uses GET. This is not known until after {@link Sink#writeHeaders} is * invoked. @@ -198,7 +192,8 @@ class OkHttpClientStream extends AbstractClientStream { } } - class TransportState extends Http2ClientStreamTransportState { + class TransportState extends Http2ClientStreamTransportState + implements OutboundFlowController.Stream { private final int initialWindowSize; private final Object lock; @GuardedBy("lock") @@ -223,6 +218,9 @@ class OkHttpClientStream extends AbstractClientStream { @GuardedBy("lock") private boolean canStart = true; private final Tag tag; + @GuardedBy("lock") + private OutboundFlowController.StreamState outboundFlowState; + private int id = ABSENT_ID; public TransportState( int maxMessageSize, @@ -249,6 +247,7 @@ class OkHttpClientStream extends AbstractClientStream { public void start(int streamId) { checkState(id == ABSENT_ID, "the stream has been started with id %s", streamId); id = streamId; + outboundFlowState = outboundFlow.createState(this, streamId); // TODO(b/145386688): This access should be guarded by 'OkHttpClientStream.this.state.lock'; // instead found: 'this.lock' state.onStreamAllocated(); @@ -260,7 +259,9 @@ class OkHttpClientStream extends AbstractClientStream { requestHeaders = null; if (pendingData.size() > 0) { - outboundFlow.data(pendingDataHasEndOfStream, id, pendingData, flushPendingData); + outboundFlow.data( + pendingDataHasEndOfStream, outboundFlowState, pendingData, flushPendingData); + } canStart = false; } @@ -396,7 +397,7 @@ class OkHttpClientStream extends AbstractClientStream { checkState(id() != ABSENT_ID, "streamId should be set"); // If buffer > frameWriter.maxDataLength() the flow-controller will ensure that it is // properly chunked. - outboundFlow.data(endOfStream, id(), buffer, flush); + outboundFlow.data(endOfStream, outboundFlowState, buffer, flush); } } @@ -419,13 +420,15 @@ class OkHttpClientStream extends AbstractClientStream { Tag tag() { return tag; } - } - void setOutboundFlowState(Object outboundFlowState) { - this.outboundFlowState = outboundFlowState; - } + int id() { + return id; + } - Object getOutboundFlowState() { - return outboundFlowState; + OutboundFlowController.StreamState getOutboundFlowState() { + synchronized (lock) { + return outboundFlowState; + } + } } } diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java index ee43f6f9e7..a8fc4d53c9 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java @@ -105,10 +105,10 @@ import okio.Timeout; /** * A okhttp-based {@link ConnectionClientTransport} implementation. */ -class OkHttpClientTransport implements ConnectionClientTransport, TransportExceptionHandler { +class OkHttpClientTransport implements ConnectionClientTransport, TransportExceptionHandler, + OutboundFlowController.Transport { private static final Map ERROR_CODE_TO_STATUS = buildErrorCodeToStatusMap(); private static final Logger log = Logger.getLogger(OkHttpClientTransport.class.getName()); - private static final OkHttpClientStream[] EMPTY_STREAM_ARRAY = new OkHttpClientStream[0]; private static Map buildErrorCodeToStatusMap() { Map errorToStatus = new EnumMap<>(ErrorCode.class); @@ -424,7 +424,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep @GuardedBy("lock") private void startStream(OkHttpClientStream stream) { Preconditions.checkState( - stream.id() == OkHttpClientStream.ABSENT_ID, "StreamId already assigned"); + stream.transportState().id() == OkHttpClientStream.ABSENT_ID, "StreamId already assigned"); streams.put(nextStreamId, stream); setInUse(stream); // TODO(b/145386688): This access should be guarded by 'stream.transportState().lock'; instead @@ -808,9 +808,16 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep /** * Gets all active streams as an array. */ - OkHttpClientStream[] getActiveStreams() { + @Override + public OutboundFlowController.StreamState[] getActiveStreams() { synchronized (lock) { - return streams.values().toArray(EMPTY_STREAM_ARRAY); + OutboundFlowController.StreamState[] flowStreams = + new OutboundFlowController.StreamState[streams.size()]; + int i = 0; + for (OkHttpClientStream stream : streams.values()) { + flowStreams[i++] = stream.transportState().getOutboundFlowState(); + } + return flowStreams; } } @@ -1125,7 +1132,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep if (stream == null) { if (mayHaveCreatedStream(streamId)) { synchronized (lock) { - frameWriter.rstStream(streamId, ErrorCode.INVALID_STREAM); + frameWriter.rstStream(streamId, ErrorCode.STREAM_CLOSED); } in.skip(length); } else { @@ -1186,7 +1193,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep OkHttpClientStream stream = streams.get(streamId); if (stream == null) { if (mayHaveCreatedStream(streamId)) { - frameWriter.rstStream(streamId, ErrorCode.INVALID_STREAM); + frameWriter.rstStream(streamId, ErrorCode.STREAM_CLOSED); } else { unknownStream = true; } @@ -1365,7 +1372,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep OkHttpClientStream stream = streams.get(streamId); if (stream != null) { - outboundFlow.windowUpdate(stream, (int) delta); + outboundFlow.windowUpdate(stream.transportState().getOutboundFlowState(), (int) delta); } else if (!mayHaveCreatedStream(streamId)) { unknownStream = true; } diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServer.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServer.java new file mode 100644 index 0000000000..f63950e4a0 --- /dev/null +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServer.java @@ -0,0 +1,189 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import com.google.common.base.MoreObjects; +import com.google.common.base.Preconditions; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import io.grpc.InternalChannelz; +import io.grpc.InternalInstrumented; +import io.grpc.InternalLogId; +import io.grpc.ServerStreamTracer; +import io.grpc.internal.InternalServer; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.ServerListener; +import java.io.IOException; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.SocketAddress; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.Executor; +import java.util.concurrent.ScheduledExecutorService; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.net.ServerSocketFactory; + +final class OkHttpServer implements InternalServer { + private static final Logger log = Logger.getLogger(OkHttpServer.class.getName()); + + private final SocketAddress originalListenAddress; + private final ServerSocketFactory socketFactory; + private final ObjectPool transportExecutorPool; + private final ObjectPool scheduledExecutorServicePool; + private final OkHttpServerTransport.Config transportConfig; + private final InternalChannelz channelz; + private ServerSocket serverSocket; + private SocketAddress actualListenAddress; + private InternalInstrumented listenInstrumented; + private Executor transportExecutor; + private ScheduledExecutorService scheduledExecutorService; + private ServerListener listener; + private boolean shutdown; + + public OkHttpServer( + OkHttpServerBuilder builder, + List streamTracerFactories, + InternalChannelz channelz) { + this.originalListenAddress = Preconditions.checkNotNull(builder.listenAddress, "listenAddress"); + this.socketFactory = Preconditions.checkNotNull(builder.socketFactory, "socketFactory"); + this.transportExecutorPool = + Preconditions.checkNotNull(builder.transportExecutorPool, "transportExecutorPool"); + this.scheduledExecutorServicePool = + Preconditions.checkNotNull( + builder.scheduledExecutorServicePool, "scheduledExecutorServicePool"); + this.transportConfig = new OkHttpServerTransport.Config(builder, streamTracerFactories); + this.channelz = Preconditions.checkNotNull(channelz, "channelz"); + } + + @Override + public void start(ServerListener listener) throws IOException { + this.listener = Preconditions.checkNotNull(listener, "listener"); + ServerSocket serverSocket = socketFactory.createServerSocket(); + try { + serverSocket.bind(originalListenAddress); + } catch (IOException t) { + serverSocket.close(); + throw t; + } + + this.serverSocket = serverSocket; + this.actualListenAddress = serverSocket.getLocalSocketAddress(); + this.listenInstrumented = new ListenSocket(serverSocket); + this.transportExecutor = transportExecutorPool.getObject(); + // Keep reference alive to avoid frequent re-creation by server transports + this.scheduledExecutorService = scheduledExecutorServicePool.getObject(); + channelz.addListenSocket(this.listenInstrumented); + transportExecutor.execute(this::acceptConnections); + } + + private void acceptConnections() { + try { + while (true) { + Socket socket; + try { + socket = serverSocket.accept(); + } catch (IOException ex) { + if (shutdown) { + break; + } + throw ex; + } + OkHttpServerTransport transport = new OkHttpServerTransport(transportConfig, socket); + transport.start(listener.transportCreated(transport)); + } + } catch (Throwable t) { + log.log(Level.SEVERE, "Accept loop failed", t); + } + listener.serverShutdown(); + } + + @Override + public void shutdown() { + if (shutdown) { + return; + } + shutdown = true; + + if (serverSocket == null) { + return; + } + channelz.removeListenSocket(this.listenInstrumented); + try { + serverSocket.close(); + } catch (IOException ex) { + log.log(Level.WARNING, "Failed closing server socket", serverSocket); + } + transportExecutor = transportExecutorPool.returnObject(transportExecutor); + scheduledExecutorService = scheduledExecutorServicePool.returnObject(scheduledExecutorService); + } + + @Override + public SocketAddress getListenSocketAddress() { + return actualListenAddress; + } + + @Override + public InternalInstrumented getListenSocketStats() { + return listenInstrumented; + } + + @Override + public List getListenSocketAddresses() { + return Collections.singletonList(getListenSocketAddress()); + } + + @Override + public List> getListenSocketStatsList() { + return Collections.singletonList(getListenSocketStats()); + } + + private static final class ListenSocket + implements InternalInstrumented { + private final InternalLogId id; + private final ServerSocket socket; + + public ListenSocket(ServerSocket socket) { + this.socket = socket; + this.id = InternalLogId.allocate(getClass(), String.valueOf(socket.getLocalSocketAddress())); + } + + @Override + public ListenableFuture getStats() { + return Futures.immediateFuture(new InternalChannelz.SocketStats( + /*data=*/ null, + socket.getLocalSocketAddress(), + /*remote=*/ null, + new InternalChannelz.SocketOptions.Builder().build(), + /*security=*/ null)); + } + + @Override + public InternalLogId getLogId() { + return id; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("logId", id.getId()) + .add("socket", socket) + .toString(); + } + } +} diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java new file mode 100644 index 0000000000..b72c506957 --- /dev/null +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java @@ -0,0 +1,387 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.DoNotCall; +import io.grpc.ChoiceServerCredentials; +import io.grpc.ExperimentalApi; +import io.grpc.ForwardingServerBuilder; +import io.grpc.InsecureServerCredentials; +import io.grpc.Internal; +import io.grpc.ServerBuilder; +import io.grpc.ServerCredentials; +import io.grpc.ServerStreamTracer; +import io.grpc.TlsServerCredentials; +import io.grpc.internal.FixedObjectPool; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.InternalServer; +import io.grpc.internal.KeepAliveManager; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.ServerImplBuilder; +import io.grpc.internal.SharedResourcePool; +import io.grpc.internal.TransportTracer; +import io.grpc.okhttp.internal.Platform; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.EnumSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.Executor; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.net.ServerSocketFactory; +import javax.net.ssl.KeyManager; +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.security.auth.x500.X500Principal; + +/** + * Build servers with the OkHttp transport. + * + * @since 1.49.0 + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/1785") +public final class OkHttpServerBuilder extends ForwardingServerBuilder { + private static final Logger log = Logger.getLogger(OkHttpServerBuilder.class.getName()); + private static final int DEFAULT_FLOW_CONTROL_WINDOW = 65535; + private static final long AS_LARGE_AS_INFINITE = TimeUnit.DAYS.toNanos(1000L); + private static final ObjectPool DEFAULT_TRANSPORT_EXECUTOR_POOL = + OkHttpChannelBuilder.DEFAULT_TRANSPORT_EXECUTOR_POOL; + + /** + * Always throws, to shadow {@code ServerBuilder.forPort()}. + * + * @deprecated Use {@link #forPort(int, ServerCredentials)} instead + */ + @DoNotCall("Always throws. Use forPort(int, ServerCredentials) instead") + @Deprecated + public static OkHttpServerBuilder forPort(int port) { + throw new UnsupportedOperationException(); + } + + /** + * Creates a builder for a server listening on {@code port}. + */ + public static OkHttpServerBuilder forPort(int port, ServerCredentials creds) { + return forPort(new InetSocketAddress(port), creds); + } + + /** + * Creates a builder for a server listening on {@code address}. + */ + public static OkHttpServerBuilder forPort(SocketAddress address, ServerCredentials creds) { + HandshakerSocketFactoryResult result = handshakerSocketFactoryFrom(creds); + if (result.error != null) { + throw new IllegalArgumentException(result.error); + } + return new OkHttpServerBuilder(address, result.factory); + } + + final ServerImplBuilder serverImplBuilder = new ServerImplBuilder(this::buildTransportServers); + final SocketAddress listenAddress; + final HandshakerSocketFactory handshakerSocketFactory; + TransportTracer.Factory transportTracerFactory = TransportTracer.getDefaultFactory(); + + ObjectPool transportExecutorPool = DEFAULT_TRANSPORT_EXECUTOR_POOL; + ObjectPool scheduledExecutorServicePool = + SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE); + + ServerSocketFactory socketFactory = ServerSocketFactory.getDefault(); + long keepAliveTimeNanos = GrpcUtil.DEFAULT_SERVER_KEEPALIVE_TIME_NANOS; + long keepAliveTimeoutNanos = GrpcUtil.DEFAULT_SERVER_KEEPALIVE_TIMEOUT_NANOS; + int flowControlWindow = DEFAULT_FLOW_CONTROL_WINDOW; + int maxInboundMetadataSize = GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE; + int maxInboundMessageSize = GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; + + @VisibleForTesting + OkHttpServerBuilder( + SocketAddress address, HandshakerSocketFactory handshakerSocketFactory) { + this.listenAddress = Preconditions.checkNotNull(address, "address"); + this.handshakerSocketFactory = + Preconditions.checkNotNull(handshakerSocketFactory, "handshakerSocketFactory"); + } + + @Internal + @Override + protected ServerBuilder delegate() { + return serverImplBuilder; + } + + @VisibleForTesting + OkHttpServerBuilder setTransportTracerFactory(TransportTracer.Factory transportTracerFactory) { + this.transportTracerFactory = transportTracerFactory; + return this; + } + + /** + * Override the default executor necessary for internal transport use. + * + *

The channel does not take ownership of the given executor. It is the caller' responsibility + * to shutdown the executor when appropriate. + */ + public OkHttpServerBuilder transportExecutor(Executor transportExecutor) { + if (transportExecutor == null) { + this.transportExecutorPool = DEFAULT_TRANSPORT_EXECUTOR_POOL; + } else { + this.transportExecutorPool = new FixedObjectPool<>(transportExecutor); + } + return this; + } + + /** + * Override the default {@link ServerSocketFactory} used to listen. If the socket factory is not + * set or set to null, a default one will be used. + */ + public OkHttpServerBuilder socketFactory(ServerSocketFactory socketFactory) { + if (socketFactory == null) { + this.socketFactory = ServerSocketFactory.getDefault(); + } else { + this.socketFactory = socketFactory; + } + return this; + } + + /** + * Sets the time without read activity before sending a keepalive ping. An unreasonably small + * value might be increased, and {@code Long.MAX_VALUE} nano seconds or an unreasonably large + * value will disable keepalive. Defaults to two hours. + * + * @throws IllegalArgumentException if time is not positive + */ + @Override + public OkHttpServerBuilder keepAliveTime(long keepAliveTime, TimeUnit timeUnit) { + Preconditions.checkArgument(keepAliveTime > 0L, "keepalive time must be positive"); + keepAliveTimeNanos = timeUnit.toNanos(keepAliveTime); + keepAliveTimeNanos = KeepAliveManager.clampKeepAliveTimeInNanos(keepAliveTimeNanos); + if (keepAliveTimeNanos >= AS_LARGE_AS_INFINITE) { + // Bump keepalive time to infinite. This disables keepalive. + keepAliveTimeNanos = GrpcUtil.KEEPALIVE_TIME_NANOS_DISABLED; + } + return this; + } + + /** + * Sets a time waiting for read activity after sending a keepalive ping. If the time expires + * without any read activity on the connection, the connection is considered dead. An unreasonably + * small value might be increased. Defaults to 20 seconds. + * + *

This value should be at least multiple times the RTT to allow for lost packets. + * + * @throws IllegalArgumentException if timeout is not positive + */ + @Override + public OkHttpServerBuilder keepAliveTimeout(long keepAliveTimeout, TimeUnit timeUnit) { + Preconditions.checkArgument(keepAliveTimeout > 0L, "keepalive timeout must be positive"); + keepAliveTimeoutNanos = timeUnit.toNanos(keepAliveTimeout); + keepAliveTimeoutNanos = KeepAliveManager.clampKeepAliveTimeoutInNanos(keepAliveTimeoutNanos); + return this; + } + + /** + * Sets the flow control window in bytes. If not called, the default value is 64 KiB. + */ + public OkHttpServerBuilder flowControlWindow(int flowControlWindow) { + Preconditions.checkState(flowControlWindow > 0, "flowControlWindow must be positive"); + this.flowControlWindow = flowControlWindow; + return this; + } + + /** + * Provides a custom scheduled executor service. + * + *

It's an optional parameter. If the user has not provided a scheduled executor service when + * the channel is built, the builder will use a static thread pool. + * + * @return this + */ + public OkHttpServerBuilder scheduledExecutorService( + ScheduledExecutorService scheduledExecutorService) { + this.scheduledExecutorServicePool = new FixedObjectPool<>( + Preconditions.checkNotNull(scheduledExecutorService, "scheduledExecutorService")); + return this; + } + + /** + * Sets the maximum size of metadata allowed to be received. Defaults to 8 KiB. + * + *

The implementation does not currently limit memory usage; this value is checked only after + * the metadata is decoded from the wire. It does prevent large metadata from being passed to the + * application. + * + * @param bytes the maximum size of received metadata + * @return this + * @throws IllegalArgumentException if bytes is non-positive + */ + @Override + public OkHttpServerBuilder maxInboundMetadataSize(int bytes) { + Preconditions.checkArgument(bytes > 0, "maxInboundMetadataSize must be > 0"); + this.maxInboundMetadataSize = bytes; + return this; + } + + /** + * Sets the maximum message size allowed to be received on the server. If not called, defaults to + * defaults to 4 MiB. The default provides protection to servers who haven't considered the + * possibility of receiving large messages while trying to be large enough to not be hit in normal + * usage. + * + * @param bytes the maximum number of bytes a single message can be. + * @return this + * @throws IllegalArgumentException if bytes is negative. + */ + @Override + public OkHttpServerBuilder maxInboundMessageSize(int bytes) { + Preconditions.checkArgument(bytes >= 0, "negative max bytes"); + maxInboundMessageSize = bytes; + return this; + } + + void setStatsEnabled(boolean value) { + this.serverImplBuilder.setStatsEnabled(value); + } + + InternalServer buildTransportServers( + List streamTracerFactories) { + return new OkHttpServer(this, streamTracerFactories, serverImplBuilder.getChannelz()); + } + + private static final EnumSet understoodTlsFeatures = + EnumSet.of( + TlsServerCredentials.Feature.MTLS, TlsServerCredentials.Feature.CUSTOM_MANAGERS); + + static HandshakerSocketFactoryResult handshakerSocketFactoryFrom(ServerCredentials creds) { + if (creds instanceof TlsServerCredentials) { + TlsServerCredentials tlsCreds = (TlsServerCredentials) creds; + Set incomprehensible = + tlsCreds.incomprehensible(understoodTlsFeatures); + if (!incomprehensible.isEmpty()) { + return HandshakerSocketFactoryResult.error( + "TLS features not understood: " + incomprehensible); + } + KeyManager[] km = null; + if (tlsCreds.getKeyManagers() != null) { + km = tlsCreds.getKeyManagers().toArray(new KeyManager[0]); + } else if (tlsCreds.getPrivateKey() != null) { + return HandshakerSocketFactoryResult.error( + "byte[]-based private key unsupported. Use KeyManager"); + } // else don't have a client cert + TrustManager[] tm = null; + if (tlsCreds.getTrustManagers() != null) { + tm = tlsCreds.getTrustManagers().toArray(new TrustManager[0]); + } else if (tlsCreds.getRootCertificates() != null) { + try { + tm = createTrustManager(tlsCreds.getRootCertificates()); + } catch (GeneralSecurityException gse) { + log.log(Level.FINE, "Exception loading root certificates from credential", gse); + return HandshakerSocketFactoryResult.error( + "Unable to load root certificates: " + gse.getMessage()); + } + } // else use system default + SSLContext sslContext; + try { + sslContext = SSLContext.getInstance("TLS", Platform.get().getProvider()); + sslContext.init(km, tm, null); + } catch (GeneralSecurityException gse) { + throw new RuntimeException("TLS Provider failure", gse); + } + return HandshakerSocketFactoryResult.factory(new TlsServerHandshakerSocketFactory( + new SslSocketFactoryServerCredentials.ServerCredentials( + sslContext.getSocketFactory()))); + + } else if (creds instanceof InsecureServerCredentials) { + return HandshakerSocketFactoryResult.factory(new PlaintextHandshakerSocketFactory()); + + } else if (creds instanceof SslSocketFactoryServerCredentials.ServerCredentials) { + SslSocketFactoryServerCredentials.ServerCredentials factoryCreds = + (SslSocketFactoryServerCredentials.ServerCredentials) creds; + return HandshakerSocketFactoryResult.factory( + new TlsServerHandshakerSocketFactory(factoryCreds)); + + } else if (creds instanceof ChoiceServerCredentials) { + ChoiceServerCredentials choiceCreds = (ChoiceServerCredentials) creds; + StringBuilder error = new StringBuilder(); + for (ServerCredentials innerCreds : choiceCreds.getCredentialsList()) { + HandshakerSocketFactoryResult result = handshakerSocketFactoryFrom(innerCreds); + if (result.error == null) { + return result; + } + error.append(", "); + error.append(result.error); + } + return HandshakerSocketFactoryResult.error(error.substring(2)); + + } else { + return HandshakerSocketFactoryResult.error( + "Unsupported credential type: " + creds.getClass().getName()); + } + } + + static TrustManager[] createTrustManager(byte[] rootCerts) throws GeneralSecurityException { + KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); + try { + ks.load(null, null); + } catch (IOException ex) { + // Shouldn't really happen, as we're not loading any data. + throw new GeneralSecurityException(ex); + } + CertificateFactory cf = CertificateFactory.getInstance("X.509"); + ByteArrayInputStream in = new ByteArrayInputStream(rootCerts); + try { + X509Certificate cert = (X509Certificate) cf.generateCertificate(in); + X500Principal principal = cert.getSubjectX500Principal(); + ks.setCertificateEntry(principal.getName("RFC2253"), cert); + } finally { + GrpcUtil.closeQuietly(in); + } + + TrustManagerFactory trustManagerFactory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init(ks); + return trustManagerFactory.getTrustManagers(); + } + + static final class HandshakerSocketFactoryResult { + public final HandshakerSocketFactory factory; + public final String error; + + private HandshakerSocketFactoryResult(HandshakerSocketFactory factory, String error) { + this.factory = factory; + this.error = error; + } + + public static HandshakerSocketFactoryResult error(String error) { + return new HandshakerSocketFactoryResult( + null, Preconditions.checkNotNull(error, "error")); + } + + public static HandshakerSocketFactoryResult factory(HandshakerSocketFactory factory) { + return new HandshakerSocketFactoryResult( + Preconditions.checkNotNull(factory, "factory"), null); + } + } +} diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerStream.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerStream.java new file mode 100644 index 0000000000..1def5c17e0 --- /dev/null +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerStream.java @@ -0,0 +1,302 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import com.google.common.base.Preconditions; +import io.grpc.Attributes; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.internal.AbstractServerStream; +import io.grpc.internal.StatsTraceContext; +import io.grpc.internal.TransportTracer; +import io.grpc.internal.WritableBuffer; +import io.grpc.okhttp.internal.framed.ErrorCode; +import io.grpc.okhttp.internal.framed.Header; +import io.perfmark.PerfMark; +import io.perfmark.Tag; +import java.util.List; +import javax.annotation.concurrent.GuardedBy; +import okio.Buffer; + +/** + * Server stream for the okhttp transport. + */ +class OkHttpServerStream extends AbstractServerStream { + private final String authority; + private final TransportState state; + private final Sink sink = new Sink(); + private final TransportTracer transportTracer; + private final Attributes attributes; + + public OkHttpServerStream( + TransportState state, + Attributes transportAttrs, + String authority, + StatsTraceContext statsTraceCtx, + TransportTracer transportTracer) { + super(new OkHttpWritableBufferAllocator(), statsTraceCtx); + this.state = Preconditions.checkNotNull(state, "state"); + this.attributes = Preconditions.checkNotNull(transportAttrs, "transportAttrs"); + this.authority = authority; + this.transportTracer = Preconditions.checkNotNull(transportTracer, "transportTracer"); + } + + @Override + protected TransportState transportState() { + return state; + } + + @Override + protected Sink abstractServerStreamSink() { + return sink; + } + + @Override + public int streamId() { + return state.streamId; + } + + @Override + public String getAuthority() { + return authority; + } + + @Override + public Attributes getAttributes() { + return attributes; + } + + class Sink implements AbstractServerStream.Sink { + @Override + public void writeHeaders(Metadata metadata) { + PerfMark.startTask("OkHttpServerStream$Sink.writeHeaders"); + try { + List

responseHeaders = Headers.createResponseHeaders(metadata); + synchronized (state.lock) { + state.sendHeaders(responseHeaders); + } + } finally { + PerfMark.stopTask("OkHttpServerStream$Sink.writeHeaders"); + } + } + + @Override + public void writeFrame(WritableBuffer frame, boolean flush, int numMessages) { + PerfMark.startTask("OkHttpServerStream$Sink.writeFrame"); + Buffer buffer = ((OkHttpWritableBuffer) frame).buffer(); + int size = (int) buffer.size(); + if (size > 0) { + onSendingBytes(size); + } + + try { + synchronized (state.lock) { + state.sendBuffer(buffer, flush); + transportTracer.reportMessageSent(numMessages); + } + } finally { + PerfMark.stopTask("OkHttpServerStream$Sink.writeFrame"); + } + } + + @Override + public void writeTrailers(Metadata trailers, boolean headersSent, Status status) { + PerfMark.startTask("OkHttpServerStream$Sink.writeTrailers"); + try { + List
responseTrailers = Headers.createResponseTrailers(trailers, headersSent); + synchronized (state.lock) { + state.sendTrailers(responseTrailers); + } + } finally { + PerfMark.stopTask("OkHttpServerStream$Sink.writeTrailers"); + } + } + + @Override + public void cancel(Status reason) { + PerfMark.startTask("OkHttpServerStream$Sink.cancel"); + try { + synchronized (state.lock) { + state.cancel(ErrorCode.CANCEL, reason); + } + } finally { + PerfMark.stopTask("OkHttpServerStream$Sink.cancel"); + } + } + } + + static class TransportState extends AbstractServerStream.TransportState + implements OutboundFlowController.Stream, OkHttpServerTransport.StreamState { + @GuardedBy("lock") + private final OkHttpServerTransport transport; + private final int streamId; + private final int initialWindowSize; + private final Object lock; + @GuardedBy("lock") + private boolean cancelSent = false; + @GuardedBy("lock") + private int window; + @GuardedBy("lock") + private int processedWindow; + @GuardedBy("lock") + private final ExceptionHandlingFrameWriter frameWriter; + @GuardedBy("lock") + private final OutboundFlowController outboundFlow; + @GuardedBy("lock") + private boolean receivedEndOfStream; + private final Tag tag; + private final OutboundFlowController.StreamState outboundFlowState; + + public TransportState( + OkHttpServerTransport transport, + int streamId, + int maxMessageSize, + StatsTraceContext statsTraceCtx, + Object lock, + ExceptionHandlingFrameWriter frameWriter, + OutboundFlowController outboundFlow, + int initialWindowSize, + TransportTracer transportTracer, + String methodName) { + super(maxMessageSize, statsTraceCtx, transportTracer); + this.transport = Preconditions.checkNotNull(transport, "transport"); + this.streamId = streamId; + this.lock = Preconditions.checkNotNull(lock, "lock"); + this.frameWriter = frameWriter; + this.outboundFlow = outboundFlow; + this.window = initialWindowSize; + this.processedWindow = initialWindowSize; + this.initialWindowSize = initialWindowSize; + tag = PerfMark.createTag(methodName); + outboundFlowState = outboundFlow.createState(this, streamId); + } + + @Override + @GuardedBy("lock") + public void deframeFailed(Throwable cause) { + cancel(ErrorCode.INTERNAL_ERROR, Status.fromThrowable(cause)); + } + + @Override + @GuardedBy("lock") + public void bytesRead(int processedBytes) { + processedWindow -= processedBytes; + if (processedWindow <= initialWindowSize * Utils.DEFAULT_WINDOW_UPDATE_RATIO) { + int delta = initialWindowSize - processedWindow; + window += delta; + processedWindow += delta; + frameWriter.windowUpdate(streamId, delta); + frameWriter.flush(); + } + } + + @Override + @GuardedBy("lock") + public void runOnTransportThread(final Runnable r) { + synchronized (lock) { + r.run(); + } + } + + /** + * Must be called with holding the transport lock. + */ + @Override + public void inboundDataReceived(okio.Buffer frame, int windowConsumed, boolean endOfStream) { + synchronized (lock) { + PerfMark.event("OkHttpServerTransport$FrameHandler.data", tag); + if (endOfStream) { + this.receivedEndOfStream = true; + } + window -= windowConsumed; + super.inboundDataReceived(new OkHttpReadableBuffer(frame), endOfStream); + } + } + + /** Must be called with holding the transport lock. */ + @Override + public void inboundRstReceived(Status status) { + PerfMark.event("OkHttpServerTransport$FrameHandler.rstStream", tag); + transportReportStatus(status); + } + + /** Must be called with holding the transport lock. */ + @Override + public boolean hasReceivedEndOfStream() { + synchronized (lock) { + return receivedEndOfStream; + } + } + + /** Must be called with holding the transport lock. */ + @Override + public int inboundWindowAvailable() { + synchronized (lock) { + return window; + } + } + + @GuardedBy("lock") + private void sendBuffer(Buffer buffer, boolean flush) { + if (cancelSent) { + return; + } + // If buffer > frameWriter.maxDataLength() the flow-controller will ensure that it is + // properly chunked. + outboundFlow.data(false, outboundFlowState, buffer, flush); + } + + @GuardedBy("lock") + private void sendHeaders(List
responseHeaders) { + frameWriter.synReply(false, streamId, responseHeaders); + frameWriter.flush(); + } + + @GuardedBy("lock") + private void sendTrailers(List
responseTrailers) { + outboundFlow.notifyWhenNoPendingData( + outboundFlowState, () -> sendTrailersAfterFlowControlled(responseTrailers)); + } + + private void sendTrailersAfterFlowControlled(List
responseTrailers) { + synchronized (lock) { + frameWriter.synReply(true, streamId, responseTrailers); + if (!receivedEndOfStream) { + frameWriter.rstStream(streamId, ErrorCode.NO_ERROR); + } + transport.streamClosed(streamId, /*flush=*/ true); + complete(); + } + } + + @GuardedBy("lock") + private void cancel(ErrorCode http2Error, Status reason) { + if (cancelSent) { + return; + } + cancelSent = true; + frameWriter.rstStream(streamId, http2Error); + transportReportStatus(reason); + transport.streamClosed(streamId, /*flush=*/ true); + } + + @Override + public OutboundFlowController.StreamState getOutboundFlowState() { + return outboundFlowState; + } + } +} diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java new file mode 100644 index 0000000000..296f6a3736 --- /dev/null +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java @@ -0,0 +1,1088 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import com.google.common.base.Preconditions; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import io.grpc.Attributes; +import io.grpc.InternalChannelz; +import io.grpc.InternalLogId; +import io.grpc.InternalStatus; +import io.grpc.Metadata; +import io.grpc.ServerStreamTracer; +import io.grpc.Status; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.KeepAliveManager; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SerializingExecutor; +import io.grpc.internal.ServerTransport; +import io.grpc.internal.ServerTransportListener; +import io.grpc.internal.StatsTraceContext; +import io.grpc.internal.TransportTracer; +import io.grpc.okhttp.internal.framed.ErrorCode; +import io.grpc.okhttp.internal.framed.FrameReader; +import io.grpc.okhttp.internal.framed.FrameWriter; +import io.grpc.okhttp.internal.framed.Header; +import io.grpc.okhttp.internal.framed.HeadersMode; +import io.grpc.okhttp.internal.framed.Http2; +import io.grpc.okhttp.internal.framed.Settings; +import io.grpc.okhttp.internal.framed.Variant; +import java.io.IOException; +import java.net.Socket; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; +import java.util.concurrent.Executor; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.concurrent.GuardedBy; +import okio.Buffer; +import okio.BufferedSource; +import okio.ByteString; +import okio.Okio; + +/** + * OkHttp-based server transport. + */ +final class OkHttpServerTransport implements ServerTransport, + ExceptionHandlingFrameWriter.TransportExceptionHandler, OutboundFlowController.Transport { + private static final Logger log = Logger.getLogger(OkHttpServerTransport.class.getName()); + private static final int GRACEFUL_SHUTDOWN_PING = 0x1111; + private static final int KEEPALIVE_PING = 0xDEAD; + private static final ByteString HTTP_METHOD = ByteString.encodeUtf8(":method"); + private static final ByteString CONNECT_METHOD = ByteString.encodeUtf8("CONNECT"); + private static final ByteString POST_METHOD = ByteString.encodeUtf8("POST"); + private static final ByteString SCHEME = ByteString.encodeUtf8(":scheme"); + private static final ByteString PATH = ByteString.encodeUtf8(":path"); + private static final ByteString AUTHORITY = ByteString.encodeUtf8(":authority"); + private static final ByteString CONNECTION = ByteString.encodeUtf8("connection"); + private static final ByteString HOST = ByteString.encodeUtf8("host"); + private static final ByteString TE = ByteString.encodeUtf8("te"); + private static final ByteString TE_TRAILERS = ByteString.encodeUtf8("trailers"); + private static final ByteString CONTENT_TYPE = ByteString.encodeUtf8("content-type"); + private static final ByteString CONTENT_LENGTH = ByteString.encodeUtf8("content-length"); + + private final Config config; + private final Socket bareSocket; + private final Variant variant = new Http2(); + private final TransportTracer tracer; + private final InternalLogId logId; + private ServerTransportListener listener; + private Executor transportExecutor; + private ScheduledExecutorService scheduledExecutorService; + private Attributes attributes; + private KeepAliveManager keepAliveManager; + + private final Object lock = new Object(); + @GuardedBy("lock") + private boolean abruptShutdown; + @GuardedBy("lock") + private boolean gracefulShutdown; + @GuardedBy("lock") + private boolean handshakeShutdown; + @GuardedBy("lock") + private InternalChannelz.Security securityInfo; + @GuardedBy("lock") + private ExceptionHandlingFrameWriter frameWriter; + @GuardedBy("lock") + private OutboundFlowController outboundFlow; + @GuardedBy("lock") + private final Map streams = new TreeMap<>(); + @GuardedBy("lock") + private int lastStreamId; + @GuardedBy("lock") + private int goAwayStreamId = Integer.MAX_VALUE; + /** + * Indicates the transport is in go-away state: no new streams will be processed, but existing + * streams may continue. + */ + @GuardedBy("lock") + private Status goAwayStatus; + /** Non-{@code null} when gracefully shutting down and have not yet sent second GOAWAY. */ + @GuardedBy("lock") + private ScheduledFuture secondGoawayTimer; + /** Non-{@code null} when waiting for forceful close GOAWAY to be sent. */ + @GuardedBy("lock") + private ScheduledFuture forcefulCloseTimer; + + public OkHttpServerTransport(Config config, Socket bareSocket) { + this.config = Preconditions.checkNotNull(config, "config"); + this.bareSocket = Preconditions.checkNotNull(bareSocket, "bareSocket"); + + tracer = config.transportTracerFactory.create(); + tracer.setFlowControlWindowReader(this::readFlowControlWindow); + logId = InternalLogId.allocate(getClass(), bareSocket.getRemoteSocketAddress().toString()); + transportExecutor = config.transportExecutorPool.getObject(); + scheduledExecutorService = config.scheduledExecutorServicePool.getObject(); + } + + public void start(ServerTransportListener listener) { + this.listener = Preconditions.checkNotNull(listener, "listener"); + + SerializingExecutor serializingExecutor = new SerializingExecutor(transportExecutor); + serializingExecutor.execute(() -> startIo(serializingExecutor)); + } + + private void startIo(SerializingExecutor serializingExecutor) { + try { + bareSocket.setTcpNoDelay(true); + HandshakerSocketFactory.HandshakeResult result = + config.handshakerSocketFactory.handshake(bareSocket, Attributes.EMPTY); + Socket socket = result.socket; + this.attributes = result.attributes; + + AsyncSink asyncSink = AsyncSink.sink(serializingExecutor, this); + asyncSink.becomeConnected(Okio.sink(socket), socket); + FrameWriter rawFrameWriter = variant.newWriter(Okio.buffer(asyncSink), false); + synchronized (lock) { + this.securityInfo = result.securityInfo; + + // Handle FrameWriter exceptions centrally, since there are many callers. Note that + // errors coming from rawFrameWriter are generally broken invariants/bugs, as AsyncSink + // does not propagate syscall errors through the FrameWriter. But we handle the + // AsyncSink failures with the same TransportExceptionHandler instance so it is all + // mixed back together. + frameWriter = new ExceptionHandlingFrameWriter(this, rawFrameWriter); + outboundFlow = new OutboundFlowController(this, frameWriter); + + // These writes will be queued in the serializingExecutor waiting for this function to + // return. + frameWriter.connectionPreface(); + Settings settings = new Settings(); + OkHttpSettingsUtil.set(settings, + OkHttpSettingsUtil.INITIAL_WINDOW_SIZE, config.flowControlWindow); + OkHttpSettingsUtil.set(settings, + OkHttpSettingsUtil.MAX_HEADER_LIST_SIZE, config.maxInboundMetadataSize); + frameWriter.settings(settings); + if (config.flowControlWindow > Utils.DEFAULT_WINDOW_SIZE) { + frameWriter.windowUpdate( + Utils.CONNECTION_STREAM_ID, config.flowControlWindow - Utils.DEFAULT_WINDOW_SIZE); + } + frameWriter.flush(); + } + + if (config.keepAliveTimeNanos != GrpcUtil.KEEPALIVE_TIME_NANOS_DISABLED) { + keepAliveManager = new KeepAliveManager( + new KeepAlivePinger(), scheduledExecutorService, config.keepAliveTimeNanos, + config.keepAliveTimeoutNanos, true); + keepAliveManager.onTransportStarted(); + } + + transportExecutor.execute( + new FrameHandler(variant.newReader(Okio.buffer(Okio.source(socket)), false))); + } catch (Error | IOException | RuntimeException ex) { + synchronized (lock) { + if (!handshakeShutdown) { + log.log(Level.INFO, "Socket failed to handshake", ex); + } + } + GrpcUtil.closeQuietly(bareSocket); + terminated(); + } + } + + @Override + public void shutdown() { + synchronized (lock) { + if (gracefulShutdown || abruptShutdown) { + return; + } + gracefulShutdown = true; + if (frameWriter == null) { + handshakeShutdown = true; + GrpcUtil.closeQuietly(bareSocket); + } else { + // RFC7540 ยง6.8. Begin double-GOAWAY graceful shutdown. To wait one RTT we use a PING, but + // we also set a timer to limit the upper bound in case the PING is excessively stalled or + // the client is malicious. + secondGoawayTimer = scheduledExecutorService.schedule( + this::triggerGracefulSecondGoaway, 1, TimeUnit.SECONDS); + frameWriter.goAway(Integer.MAX_VALUE, ErrorCode.NO_ERROR, new byte[0]); + frameWriter.ping(false, 0, GRACEFUL_SHUTDOWN_PING); + frameWriter.flush(); + } + } + } + + private void triggerGracefulSecondGoaway() { + synchronized (lock) { + if (secondGoawayTimer == null) { + return; + } + secondGoawayTimer.cancel(false); + secondGoawayTimer = null; + frameWriter.goAway(lastStreamId, ErrorCode.NO_ERROR, new byte[0]); + goAwayStreamId = lastStreamId; + if (streams.isEmpty()) { + frameWriter.close(); + } else { + frameWriter.flush(); + } + } + } + + @Override + public void shutdownNow(Status reason) { + synchronized (lock) { + if (frameWriter == null) { + handshakeShutdown = true; + GrpcUtil.closeQuietly(bareSocket); + return; + } + } + abruptShutdown(ErrorCode.NO_ERROR, "", reason, true); + } + + /** + * Finish all active streams due to an IOException, then close the transport. + */ + @Override + public void onException(Throwable failureCause) { + Preconditions.checkNotNull(failureCause, "failureCause"); + Status status = Status.UNAVAILABLE.withCause(failureCause); + abruptShutdown(ErrorCode.INTERNAL_ERROR, "I/O failure", status, false); + } + + private void abruptShutdown( + ErrorCode errorCode, String moreDetail, Status reason, boolean rstStreams) { + synchronized (lock) { + if (abruptShutdown) { + return; + } + abruptShutdown = true; + goAwayStatus = reason; + + if (secondGoawayTimer != null) { + secondGoawayTimer.cancel(false); + secondGoawayTimer = null; + } + for (Map.Entry entry : streams.entrySet()) { + if (rstStreams) { + frameWriter.rstStream(entry.getKey(), ErrorCode.CANCEL); + } + entry.getValue().transportReportStatus(reason); + } + streams.clear(); + + // RFC7540 ยง5.4.1. Attempt to inform the client what went wrong. We try to write the GOAWAY + // _and then_ close our side of the connection. But place an upper-bound for how long we wait + // for I/O with a timer, which forcefully closes the socket. + frameWriter.goAway(lastStreamId, errorCode, moreDetail.getBytes(GrpcUtil.US_ASCII)); + goAwayStreamId = lastStreamId; + frameWriter.close(); + forcefulCloseTimer = scheduledExecutorService.schedule( + this::triggerForcefulClose, 1, TimeUnit.SECONDS); + } + } + + private void triggerForcefulClose() { + // Safe to do unconditionally; no need to check if timer cancellation raced + GrpcUtil.closeQuietly(bareSocket); + } + + private void terminated() { + synchronized (lock) { + if (forcefulCloseTimer != null) { + forcefulCloseTimer.cancel(false); + forcefulCloseTimer = null; + } + } + if (keepAliveManager != null) { + keepAliveManager.onTransportTermination(); + } + transportExecutor = config.transportExecutorPool.returnObject(transportExecutor); + scheduledExecutorService = + config.scheduledExecutorServicePool.returnObject(scheduledExecutorService); + listener.transportTerminated(); + } + + @Override + public ScheduledExecutorService getScheduledExecutorService() { + return scheduledExecutorService; + } + + @Override + public ListenableFuture getStats() { + synchronized (lock) { + return Futures.immediateFuture(new InternalChannelz.SocketStats( + tracer.getStats(), + bareSocket.getLocalSocketAddress(), + bareSocket.getRemoteSocketAddress(), + Utils.getSocketOptions(bareSocket), + securityInfo)); + } + } + + private TransportTracer.FlowControlWindows readFlowControlWindow() { + synchronized (lock) { + long local = outboundFlow == null ? -1 : outboundFlow.windowUpdate(null, 0); + // connectionUnacknowledgedBytesRead is only readable by FrameHandler, so we provide a lower + // bound. + long remote = (long) (config.flowControlWindow * Utils.DEFAULT_WINDOW_UPDATE_RATIO); + return new TransportTracer.FlowControlWindows(local, remote); + } + } + + @Override + public InternalLogId getLogId() { + return logId; + } + + @Override + public OutboundFlowController.StreamState[] getActiveStreams() { + synchronized (lock) { + OutboundFlowController.StreamState[] flowStreams = + new OutboundFlowController.StreamState[streams.size()]; + int i = 0; + for (StreamState stream : streams.values()) { + flowStreams[i++] = stream.getOutboundFlowState(); + } + return flowStreams; + } + } + + /** + * Notify the transport that the stream was closed. Any frames for the stream must be enqueued + * before calling. + */ + void streamClosed(int streamId, boolean flush) { + synchronized (lock) { + streams.remove(streamId); + if (gracefulShutdown && streams.isEmpty()) { + frameWriter.close(); + } else { + if (flush) { + frameWriter.flush(); + } + } + } + } + + private static String asciiString(ByteString value) { + // utf8() string is cached in ByteString, so we prefer it when the contents are ASCII. This + // provides benefit if the header was reused via HPACK. + for (int i = 0; i < value.size(); i++) { + if (value.getByte(i) >= 0x80) { + return value.string(GrpcUtil.US_ASCII); + } + } + return value.utf8(); + } + + private static int headerFind(List
header, ByteString key, int startIndex) { + for (int i = startIndex; i < header.size(); i++) { + if (header.get(i).name.equals(key)) { + return i; + } + } + return -1; + } + + private static boolean headerContains(List
header, ByteString key) { + return headerFind(header, key, 0) != -1; + } + + private static void headerRemove(List
header, ByteString key) { + int i = 0; + while ((i = headerFind(header, key, i)) != -1) { + header.remove(i); + } + } + + /** Assumes that caller requires this field, so duplicates are treated as missing. */ + private static ByteString headerGetRequiredSingle(List
header, ByteString key) { + int i = headerFind(header, key, 0); + if (i == -1) { + return null; + } + if (headerFind(header, key, i + 1) != -1) { + return null; + } + return header.get(i).value; + } + + static final class Config { + final List streamTracerFactories; + final ObjectPool transportExecutorPool; + final ObjectPool scheduledExecutorServicePool; + final TransportTracer.Factory transportTracerFactory; + final HandshakerSocketFactory handshakerSocketFactory; + final long keepAliveTimeNanos; + final long keepAliveTimeoutNanos; + final int flowControlWindow; + final int maxInboundMessageSize; + final int maxInboundMetadataSize; + + public Config( + OkHttpServerBuilder builder, + List streamTracerFactories) { + this.streamTracerFactories = Preconditions.checkNotNull( + streamTracerFactories, "streamTracerFactories"); + transportExecutorPool = Preconditions.checkNotNull( + builder.transportExecutorPool, "transportExecutorPool"); + scheduledExecutorServicePool = Preconditions.checkNotNull( + builder.scheduledExecutorServicePool, "scheduledExecutorServicePool"); + transportTracerFactory = Preconditions.checkNotNull( + builder.transportTracerFactory, "transportTracerFactory"); + handshakerSocketFactory = Preconditions.checkNotNull( + builder.handshakerSocketFactory, "handshakerSocketFactory"); + keepAliveTimeNanos = builder.keepAliveTimeNanos; + keepAliveTimeoutNanos = builder.keepAliveTimeoutNanos; + flowControlWindow = builder.flowControlWindow; + maxInboundMessageSize = builder.maxInboundMessageSize; + maxInboundMetadataSize = builder.maxInboundMetadataSize; + } + } + + /** + * Runnable which reads frames and dispatches them to in flight calls. + */ + class FrameHandler implements FrameReader.Handler, Runnable { + private final OkHttpFrameLogger frameLogger = + new OkHttpFrameLogger(Level.FINE, OkHttpServerTransport.class); + private final FrameReader frameReader; + private boolean receivedSettings; + private int connectionUnacknowledgedBytesRead; + + public FrameHandler(FrameReader frameReader) { + this.frameReader = frameReader; + } + + @Override + public void run() { + String threadName = Thread.currentThread().getName(); + Thread.currentThread().setName("OkHttpServerTransport"); + try { + frameReader.readConnectionPreface(); + if (!frameReader.nextFrame(this)) { + connectionError(ErrorCode.INTERNAL_ERROR, "Failed to read initial SETTINGS"); + return; + } + if (!receivedSettings) { + connectionError(ErrorCode.PROTOCOL_ERROR, + "First HTTP/2 frame must be SETTINGS. RFC7540 section 3.5"); + return; + } + // Read until the underlying socket closes. + while (frameReader.nextFrame(this)) { + if (keepAliveManager != null) { + keepAliveManager.onDataReceived(); + } + } + // frameReader.nextFrame() returns false when the underlying read encounters an IOException, + // it may be triggered by the socket closing, in such case, the startGoAway() will do + // nothing, otherwise, we finish all streams since it's a real IO issue. + Status status; + synchronized (lock) { + status = goAwayStatus; + } + if (status == null) { + status = Status.UNAVAILABLE.withDescription("TCP connection closed or IOException"); + } + abruptShutdown(ErrorCode.INTERNAL_ERROR, "I/O failure", status, false); + } catch (Throwable t) { + log.log(Level.WARNING, "Error decoding HTTP/2 frames", t); + abruptShutdown(ErrorCode.INTERNAL_ERROR, "Error in frame decoder", + Status.INTERNAL.withDescription("Error decoding HTTP/2 frames").withCause(t), false); + } finally { + // Wait for the abrupt shutdown to be processed by AsyncSink and close the socket + try { + GrpcUtil.exhaust(bareSocket.getInputStream()); + } catch (IOException ex) { + // Unable to wait, so just proceed to tear-down. The socket is probably already closed so + // the GOAWAY can't be sent anyway. + } + GrpcUtil.closeQuietly(bareSocket); + terminated(); + Thread.currentThread().setName(threadName); + } + } + + /** + * Handle HTTP2 HEADER and CONTINUATION frames. + */ + @Override + public void headers(boolean outFinished, + boolean inFinished, + int streamId, + int associatedStreamId, + List
headerBlock, + HeadersMode headersMode) { + frameLogger.logHeaders( + OkHttpFrameLogger.Direction.INBOUND, streamId, headerBlock, inFinished); + // streamId == 0 checking is in HTTP/2 decoder + if ((streamId & 1) == 0) { + // The server doesn't use PUSH_PROMISE, so all even streams are IDLE + connectionError(ErrorCode.PROTOCOL_ERROR, + "Clients cannot open even numbered streams. RFC7540 section 5.1.1"); + return; + } + boolean newStream; + synchronized (lock) { + if (streamId > goAwayStreamId) { + return; + } + newStream = streamId > lastStreamId; + if (newStream) { + lastStreamId = streamId; + } + } + + int metadataSize = headerBlockSize(headerBlock); + if (metadataSize > config.maxInboundMetadataSize) { + respondWithHttpError(streamId, inFinished, 431, Status.Code.RESOURCE_EXHAUSTED, + String.format( + "Request metadata larger than %d: %d", + config.maxInboundMetadataSize, + metadataSize)); + return; + } + + headerRemove(headerBlock, ByteString.EMPTY); + + ByteString httpMethod = null; + ByteString scheme = null; + ByteString path = null; + ByteString authority = null; + while (headerBlock.size() > 0 && headerBlock.get(0).name.getByte(0) == ':') { + Header header = headerBlock.remove(0); + if (HTTP_METHOD.equals(header.name) && httpMethod == null) { + httpMethod = header.value; + } else if (SCHEME.equals(header.name) && scheme == null) { + scheme = header.value; + } else if (PATH.equals(header.name) && path == null) { + path = header.value; + } else if (AUTHORITY.equals(header.name) && authority == null) { + authority = header.value; + } else { + streamError(streamId, ErrorCode.PROTOCOL_ERROR, + "Unexpected pseudo header. RFC7540 section 8.1.2.1"); + return; + } + } + for (int i = 0; i < headerBlock.size(); i++) { + if (headerBlock.get(i).name.getByte(0) == ':') { + streamError(streamId, ErrorCode.PROTOCOL_ERROR, + "Pseudo header not before regular headers. RFC7540 section 8.1.2.1"); + return; + } + } + if (!CONNECT_METHOD.equals(httpMethod) + && newStream + && (httpMethod == null || scheme == null || path == null)) { + streamError(streamId, ErrorCode.PROTOCOL_ERROR, + "Missing required pseudo header. RFC7540 section 8.1.2.3"); + return; + } + if (headerContains(headerBlock, CONNECTION)) { + streamError(streamId, ErrorCode.PROTOCOL_ERROR, + "Connection-specific headers not permitted. RFC7540 section 8.1.2.2"); + return; + } + + if (!newStream) { + if (inFinished) { + synchronized (lock) { + StreamState stream = streams.get(streamId); + if (stream == null) { + streamError(streamId, ErrorCode.STREAM_CLOSED, "Received headers for closed stream"); + return; + } + if (stream.hasReceivedEndOfStream()) { + streamError(streamId, ErrorCode.STREAM_CLOSED, + "Received HEADERS for half-closed (remote) stream. RFC7540 section 5.1"); + return; + } + // Ignore the trailers, but still half-close the stream + stream.inboundDataReceived(new Buffer(), 0, true); + return; + } + } else { + streamError(streamId, ErrorCode.PROTOCOL_ERROR, + "Headers disallowed in the middle of the stream. RFC7540 section 8.1"); + return; + } + } + + if (authority == null) { + int i = headerFind(headerBlock, HOST, 0); + if (i != -1) { + if (headerFind(headerBlock, HOST, i + 1) != -1) { + respondWithHttpError(streamId, inFinished, 400, Status.Code.INTERNAL, + "Multiple host headers disallowed. RFC7230 section 5.4"); + return; + } + authority = headerBlock.get(i).value; + } + } + headerRemove(headerBlock, HOST); + + // Remove the leading slash of the path and get the fully qualified method name + if (path.size() == 0 || path.getByte(0) != '/') { + respondWithHttpError(streamId, inFinished, 404, Status.Code.UNIMPLEMENTED, + "Expected path to start with /: " + asciiString(path)); + return; + } + String method = asciiString(path).substring(1); + + ByteString contentType = headerGetRequiredSingle(headerBlock, CONTENT_TYPE); + if (contentType == null) { + respondWithHttpError(streamId, inFinished, 415, Status.Code.INTERNAL, + "Content-Type is missing or duplicated"); + return; + } + String contentTypeString = asciiString(contentType); + if (!GrpcUtil.isGrpcContentType(contentTypeString)) { + respondWithHttpError(streamId, inFinished, 415, Status.Code.INTERNAL, + "Content-Type is not supported: " + contentTypeString); + return; + } + + if (!POST_METHOD.equals(httpMethod)) { + respondWithHttpError(streamId, inFinished, 405, Status.Code.INTERNAL, + "HTTP Method is not supported: " + asciiString(httpMethod)); + return; + } + + ByteString te = headerGetRequiredSingle(headerBlock, TE); + if (!TE_TRAILERS.equals(te)) { + respondWithGrpcError(streamId, inFinished, Status.Code.INTERNAL, + String.format("Expected header TE: %s, but %s is received. " + + "Some intermediate proxy may not support trailers", + asciiString(TE_TRAILERS), te == null ? "" : asciiString(te))); + return; + } + headerRemove(headerBlock, CONTENT_LENGTH); + + Metadata metadata = Utils.convertHeaders(headerBlock); + StatsTraceContext statsTraceCtx = + StatsTraceContext.newServerContext(config.streamTracerFactories, method, metadata); + synchronized (lock) { + OkHttpServerStream.TransportState stream = new OkHttpServerStream.TransportState( + OkHttpServerTransport.this, + streamId, + config.maxInboundMessageSize, + statsTraceCtx, + lock, + frameWriter, + outboundFlow, + config.flowControlWindow, + tracer, + method); + OkHttpServerStream streamForApp = new OkHttpServerStream( + stream, + attributes, + authority == null ? null : asciiString(authority), + statsTraceCtx, + tracer); + streams.put(streamId, stream); + listener.streamCreated(streamForApp, method, metadata); + stream.onStreamAllocated(); + if (inFinished) { + stream.inboundDataReceived(new Buffer(), 0, inFinished); + } + } + } + + private int headerBlockSize(List
headerBlock) { + // Calculate as defined for SETTINGS_MAX_HEADER_LIST_SIZE in RFC 7540 ยง6.5.2. + long size = 0; + for (int i = 0; i < headerBlock.size(); i++) { + Header header = headerBlock.get(i); + size += 32 + header.name.size() + header.value.size(); + } + size = Math.min(size, Integer.MAX_VALUE); + return (int) size; + } + + /** + * Handle an HTTP2 DATA frame. + */ + @Override + public void data(boolean inFinished, int streamId, BufferedSource in, int length) + throws IOException { + frameLogger.logData( + OkHttpFrameLogger.Direction.INBOUND, streamId, in.getBuffer(), length, inFinished); + if (streamId == 0) { + connectionError(ErrorCode.PROTOCOL_ERROR, + "Stream 0 is reserved for control messages. RFC7540 section 5.1.1"); + return; + } + if ((streamId & 1) == 0) { + // The server doesn't use PUSH_PROMISE, so all even streams are IDLE + connectionError(ErrorCode.PROTOCOL_ERROR, + "Clients cannot open even numbered streams. RFC7540 section 5.1.1"); + return; + } + + // Wait until the frame is complete. We only support 16 KiB frames, and the max permitted in + // HTTP/2 is 16 MiB. This is verified in OkHttp's Http2 deframer, so we don't need to be + // concerned with the window being exceeded at this point. + in.require(length); + + synchronized (lock) { + StreamState stream = streams.get(streamId); + if (stream == null) { + in.skip(length); + streamError(streamId, ErrorCode.STREAM_CLOSED, "Received data for closed stream"); + return; + } + if (stream.hasReceivedEndOfStream()) { + in.skip(length); + streamError(streamId, ErrorCode.STREAM_CLOSED, + "Received DATA for half-closed (remote) stream. RFC7540 section 5.1"); + return; + } + if (stream.inboundWindowAvailable() < length) { + in.skip(length); + streamError(streamId, ErrorCode.FLOW_CONTROL_ERROR, + "Received DATA size exceeded window size. RFC7540 section 6.9"); + return; + } + Buffer buf = new Buffer(); + buf.write(in.getBuffer(), length); + stream.inboundDataReceived(buf, length, inFinished); + } + + // connection window update + connectionUnacknowledgedBytesRead += length; + if (connectionUnacknowledgedBytesRead + >= config.flowControlWindow * Utils.DEFAULT_WINDOW_UPDATE_RATIO) { + synchronized (lock) { + frameWriter.windowUpdate(0, connectionUnacknowledgedBytesRead); + frameWriter.flush(); + } + connectionUnacknowledgedBytesRead = 0; + } + } + + @Override + public void rstStream(int streamId, ErrorCode errorCode) { + frameLogger.logRstStream(OkHttpFrameLogger.Direction.INBOUND, streamId, errorCode); + // streamId == 0 checking is in HTTP/2 decoder + + if (!(ErrorCode.NO_ERROR.equals(errorCode) + || ErrorCode.CANCEL.equals(errorCode) + || ErrorCode.STREAM_CLOSED.equals(errorCode))) { + log.log(Level.INFO, "Received RST_STREAM: " + errorCode); + } + Status status = GrpcUtil.Http2Error.statusForCode(errorCode.httpCode) + .withDescription("RST_STREAM"); + synchronized (lock) { + StreamState stream = streams.get(streamId); + if (stream != null) { + stream.inboundRstReceived(status); + streamClosed(streamId, /*flush=*/ false); + } + } + } + + @Override + public void settings(boolean clearPrevious, Settings settings) { + frameLogger.logSettings(OkHttpFrameLogger.Direction.INBOUND, settings); + synchronized (lock) { + boolean outboundWindowSizeIncreased = false; + if (OkHttpSettingsUtil.isSet(settings, OkHttpSettingsUtil.INITIAL_WINDOW_SIZE)) { + int initialWindowSize = OkHttpSettingsUtil.get( + settings, OkHttpSettingsUtil.INITIAL_WINDOW_SIZE); + outboundWindowSizeIncreased = outboundFlow.initialOutboundWindowSize(initialWindowSize); + } + + // The changed settings are not finalized until SETTINGS acknowledgment frame is sent. Any + // writes due to update in settings must be sent after SETTINGS acknowledgment frame, + // otherwise it will cause a stream error (RST_STREAM). + // FIXME: limit number of queued control frames + frameWriter.ackSettings(settings); + frameWriter.flush(); + if (!receivedSettings) { + receivedSettings = true; + attributes = listener.transportReady(attributes); + } + + // send any pending bytes / streams + if (outboundWindowSizeIncreased) { + outboundFlow.writeStreams(); + } + } + } + + @Override + public void ping(boolean ack, int payload1, int payload2) { + long payload = (((long) payload1) << 32) | (payload2 & 0xffffffffL); + if (!ack) { + frameLogger.logPing(OkHttpFrameLogger.Direction.INBOUND, payload); + synchronized (lock) { + // FIXME: limit number of queued control frames + frameWriter.ping(true, payload1, payload2); + frameWriter.flush(); + } + } else { + frameLogger.logPingAck(OkHttpFrameLogger.Direction.INBOUND, payload); + if (KEEPALIVE_PING == payload) { + return; + } + if (GRACEFUL_SHUTDOWN_PING == payload) { + triggerGracefulSecondGoaway(); + return; + } + log.log(Level.INFO, "Received unexpected ping ack: " + payload); + } + } + + @Override + public void ackSettings() {} + + @Override + public void goAway(int lastGoodStreamId, ErrorCode errorCode, ByteString debugData) { + frameLogger.logGoAway( + OkHttpFrameLogger.Direction.INBOUND, lastGoodStreamId, errorCode, debugData); + String description = String.format("Received GOAWAY: %s '%s'", errorCode, debugData.utf8()); + Status status = GrpcUtil.Http2Error.statusForCode(errorCode.httpCode) + .withDescription(description); + if (!ErrorCode.NO_ERROR.equals(errorCode)) { + log.log( + Level.WARNING, "Received GOAWAY: {0} {1}", new Object[] {errorCode, debugData.utf8()}); + } + synchronized (lock) { + goAwayStatus = status; + } + } + + @Override + public void pushPromise(int streamId, int promisedStreamId, List
requestHeaders) + throws IOException { + frameLogger.logPushPromise(OkHttpFrameLogger.Direction.INBOUND, + streamId, promisedStreamId, requestHeaders); + // streamId == 0 checking is in HTTP/2 decoder. + // The server doesn't use PUSH_PROMISE, so all even streams are IDLE, and odd streams are not + // peer-initiated. + connectionError(ErrorCode.PROTOCOL_ERROR, + "PUSH_PROMISE only allowed on peer-initiated streams. RFC7540 section 6.6"); + } + + @Override + public void windowUpdate(int streamId, long delta) { + frameLogger.logWindowsUpdate(OkHttpFrameLogger.Direction.INBOUND, streamId, delta); + // delta == 0 checking is in HTTP/2 decoder. And it isn't quite right, as it will always cause + // a GOAWAY. RFC7540 section 6.9 says to use RST_STREAM if the stream id isn't 0. Doesn't + // matter much though. + synchronized (lock) { + if (streamId == Utils.CONNECTION_STREAM_ID) { + outboundFlow.windowUpdate(null, (int) delta); + } else { + StreamState stream = streams.get(streamId); + if (stream != null) { + outboundFlow.windowUpdate(stream.getOutboundFlowState(), (int) delta); + } + } + } + } + + @Override + public void priority(int streamId, int streamDependency, int weight, boolean exclusive) { + frameLogger.logPriority( + OkHttpFrameLogger.Direction.INBOUND, streamId, streamDependency, weight, exclusive); + // streamId == 0 checking is in HTTP/2 decoder. + // Ignore priority change. + } + + @Override + public void alternateService(int streamId, String origin, ByteString protocol, String host, + int port, long maxAge) {} + + /** + * Send GOAWAY to the server, then finish all streams and close the transport. RFC7540 ยง5.4.1. + */ + private void connectionError(ErrorCode errorCode, String moreDetail) { + Status status = GrpcUtil.Http2Error.statusForCode(errorCode.httpCode) + .withDescription(String.format("HTTP2 connection error: %s '%s'", errorCode, moreDetail)); + abruptShutdown(errorCode, moreDetail, status, false); + } + + /** + * Respond with RST_STREAM, making sure to kill the associated stream if it exists. Reason will + * rarely be seen. RFC7540 ยง5.4.2. + */ + private void streamError(int streamId, ErrorCode errorCode, String reason) { + if (errorCode == ErrorCode.PROTOCOL_ERROR) { + log.log( + Level.FINE, "Responding with RST_STREAM {0}: {1}", new Object[] {errorCode, reason}); + } + synchronized (lock) { + // FIXME: limit number of queued control frames + frameWriter.rstStream(streamId, errorCode); + frameWriter.flush(); + StreamState stream = streams.get(streamId); + if (stream != null) { + stream.transportReportStatus( + Status.INTERNAL.withDescription( + String.format("Responded with RST_STREAM %s: %s", errorCode, reason))); + streamClosed(streamId, /*flush=*/ false); + } + } + } + + private void respondWithHttpError( + int streamId, boolean inFinished, int httpCode, Status.Code statusCode, String msg) { + Metadata metadata = new Metadata(); + metadata.put(InternalStatus.CODE_KEY, statusCode.toStatus()); + metadata.put(InternalStatus.MESSAGE_KEY, msg); + List
headers = + Headers.createHttpResponseHeaders(httpCode, "text/plain; charset=utf-8", metadata); + Buffer data = new Buffer().writeUtf8(msg); + + synchronized (lock) { + Http2ErrorStreamState stream = + new Http2ErrorStreamState(streamId, lock, outboundFlow, config.flowControlWindow); + streams.put(streamId, stream); + if (inFinished) { + stream.inboundDataReceived(new Buffer(), 0, true); + } + frameWriter.headers(streamId, headers); + outboundFlow.data( + /*outFinished=*/true, stream.getOutboundFlowState(), data, /*flush=*/true); + outboundFlow.notifyWhenNoPendingData( + stream.getOutboundFlowState(), () -> rstOkAtEndOfHttpError(stream)); + } + } + + private void rstOkAtEndOfHttpError(Http2ErrorStreamState stream) { + synchronized (lock) { + if (!stream.hasReceivedEndOfStream()) { + frameWriter.rstStream(stream.streamId, ErrorCode.NO_ERROR); + } + streamClosed(stream.streamId, /*flush=*/ true); + } + } + + private void respondWithGrpcError( + int streamId, boolean inFinished, Status.Code statusCode, String msg) { + Metadata metadata = new Metadata(); + metadata.put(InternalStatus.CODE_KEY, statusCode.toStatus()); + metadata.put(InternalStatus.MESSAGE_KEY, msg); + List
headers = Headers.createResponseTrailers(metadata, false); + + synchronized (lock) { + frameWriter.synReply(true, streamId, headers); + if (!inFinished) { + frameWriter.rstStream(streamId, ErrorCode.NO_ERROR); + } + frameWriter.flush(); + } + } + } + + private final class KeepAlivePinger implements KeepAliveManager.KeepAlivePinger { + @Override + public void ping() { + synchronized (lock) { + frameWriter.ping(false, 0, KEEPALIVE_PING); + frameWriter.flush(); + } + tracer.reportKeepAliveSent(); + } + + @Override + public void onPingTimeout() { + synchronized (lock) { + goAwayStatus = Status.UNAVAILABLE + .withDescription("Keepalive failed. Considering connection dead"); + GrpcUtil.closeQuietly(bareSocket); + } + } + } + + interface StreamState { + /** Must be holding 'lock' when calling. */ + void inboundDataReceived(Buffer frame, int windowConsumed, boolean endOfStream); + + /** Must be holding 'lock' when calling. */ + boolean hasReceivedEndOfStream(); + + /** Must be holding 'lock' when calling. */ + int inboundWindowAvailable(); + + /** Must be holding 'lock' when calling. */ + void transportReportStatus(Status status); + + /** Must be holding 'lock' when calling. */ + void inboundRstReceived(Status status); + + OutboundFlowController.StreamState getOutboundFlowState(); + } + + static class Http2ErrorStreamState implements StreamState, OutboundFlowController.Stream { + private final int streamId; + private final Object lock; + private final OutboundFlowController.StreamState outboundFlowState; + @GuardedBy("lock") + private int window; + @GuardedBy("lock") + private boolean receivedEndOfStream; + + Http2ErrorStreamState( + int streamId, Object lock, OutboundFlowController outboundFlow, int initialWindowSize) { + this.streamId = streamId; + this.lock = lock; + this.outboundFlowState = outboundFlow.createState(this, streamId); + this.window = initialWindowSize; + } + + @Override public void onSentBytes(int frameBytes) {} + + @Override public void inboundDataReceived( + Buffer frame, int windowConsumed, boolean endOfStream) { + synchronized (lock) { + if (endOfStream) { + receivedEndOfStream = true; + } + window -= windowConsumed; + try { + frame.skip(frame.size()); // Recycle segments + } catch (IOException ex) { + throw new AssertionError(ex); + } + } + } + + @Override public boolean hasReceivedEndOfStream() { + synchronized (lock) { + return receivedEndOfStream; + } + } + + @Override public int inboundWindowAvailable() { + synchronized (lock) { + return window; + } + } + + @Override public void transportReportStatus(Status status) {} + + @Override public void inboundRstReceived(Status status) {} + + @Override public OutboundFlowController.StreamState getOutboundFlowState() { + synchronized (lock) { + return outboundFlowState; + } + } + } +} diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpSettingsUtil.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpSettingsUtil.java index 5df85732ed..1406b39adf 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpSettingsUtil.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpSettingsUtil.java @@ -24,6 +24,8 @@ import io.grpc.okhttp.internal.framed.Settings; class OkHttpSettingsUtil { public static final int MAX_CONCURRENT_STREAMS = Settings.MAX_CONCURRENT_STREAMS; public static final int INITIAL_WINDOW_SIZE = Settings.INITIAL_WINDOW_SIZE; + public static final int MAX_HEADER_LIST_SIZE = Settings.MAX_HEADER_LIST_SIZE; + public static final int ENABLE_PUSH = Settings.ENABLE_PUSH; public static boolean isSet(Settings settings, int id) { return settings.isSet(id); diff --git a/okhttp/src/main/java/io/grpc/okhttp/OutboundFlowController.java b/okhttp/src/main/java/io/grpc/okhttp/OutboundFlowController.java index c935363213..35117c1b22 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OutboundFlowController.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OutboundFlowController.java @@ -33,17 +33,16 @@ import okio.Buffer; * streams. */ class OutboundFlowController { - private final OkHttpClientTransport transport; + private final Transport transport; private final FrameWriter frameWriter; private int initialWindowSize; - private final OutboundFlowState connectionState; + private final StreamState connectionState; - OutboundFlowController( - OkHttpClientTransport transport, FrameWriter frameWriter) { + public OutboundFlowController(Transport transport, FrameWriter frameWriter) { this.transport = Preconditions.checkNotNull(transport, "transport"); this.frameWriter = Preconditions.checkNotNull(frameWriter, "frameWriter"); this.initialWindowSize = DEFAULT_WINDOW_SIZE; - connectionState = new OutboundFlowState(CONNECTION_STREAM_ID, DEFAULT_WINDOW_SIZE); + connectionState = new StreamState(CONNECTION_STREAM_ID, DEFAULT_WINDOW_SIZE, null); } /** @@ -55,22 +54,15 @@ class OutboundFlowController { * * @return true, if new window size is increased, false otherwise. */ - boolean initialOutboundWindowSize(int newWindowSize) { + public boolean initialOutboundWindowSize(int newWindowSize) { if (newWindowSize < 0) { throw new IllegalArgumentException("Invalid initial window size: " + newWindowSize); } int delta = newWindowSize - initialWindowSize; initialWindowSize = newWindowSize; - for (OkHttpClientStream stream : transport.getActiveStreams()) { - OutboundFlowState state = (OutboundFlowState) stream.getOutboundFlowState(); - if (state == null) { - // Create the OutboundFlowState with the new window size. - state = new OutboundFlowState(stream, initialWindowSize); - stream.setOutboundFlowState(state); - } else { - state.incrementStreamWindow(delta); - } + for (StreamState state : transport.getActiveStreams()) { + state.incrementStreamWindow(delta); } return delta > 0; @@ -82,15 +74,14 @@ class OutboundFlowController { * *

Must be called with holding transport lock. */ - int windowUpdate(@Nullable OkHttpClientStream stream, int delta) { + public int windowUpdate(@Nullable StreamState state, int delta) { final int updatedWindow; - if (stream == null) { + if (state == null) { // Update the connection window and write any pending frames for all streams. updatedWindow = connectionState.incrementStreamWindow(delta); writeStreams(); } else { // Update the stream window and write any pending frames for the stream. - OutboundFlowState state = state(stream); updatedWindow = state.incrementStreamWindow(delta); WriteStatus writeStatus = new WriteStatus(); @@ -105,18 +96,9 @@ class OutboundFlowController { /** * Must be called with holding transport lock. */ - void data(boolean outFinished, int streamId, Buffer source, boolean flush) { + public void data(boolean outFinished, StreamState state, Buffer source, boolean flush) { Preconditions.checkNotNull(source, "source"); - OkHttpClientStream stream = transport.getStream(streamId); - if (stream == null) { - // This is possible for a stream that has received end-of-stream from server (but hasn't sent - // end-of-stream), and was removed from the transport stream map. - // In such case, we just throw away the data. - return; - } - - OutboundFlowState state = state(stream); int window = state.writableWindow(); boolean framesAlreadyQueued = state.hasPendingData(); int size = (int) source.size(); @@ -130,7 +112,7 @@ class OutboundFlowController { state.write(source, window, false); } // Queue remaining data in the buffer - state.enqueue(source, (int) source.size(), outFinished); + state.enqueueData(source, (int) source.size(), outFinished); } if (flush) { @@ -138,7 +120,19 @@ class OutboundFlowController { } } - void flush() { + /** + * Transport lock must be held when calling. + */ + public void notifyWhenNoPendingData(StreamState state, Runnable noPendingDataRunnable) { + Preconditions.checkNotNull(noPendingDataRunnable, "noPendingDataRunnable"); + if (state.hasPendingData()) { + state.notifyWhenNoPendingData(noPendingDataRunnable); + } else { + noPendingDataRunnable.run(); + } + } + + public void flush() { try { frameWriter.flush(); } catch (IOException e) { @@ -146,13 +140,9 @@ class OutboundFlowController { } } - private OutboundFlowState state(OkHttpClientStream stream) { - OutboundFlowState state = (OutboundFlowState) stream.getOutboundFlowState(); - if (state == null) { - state = new OutboundFlowState(stream, initialWindowSize); - stream.setOutboundFlowState(state); - } - return state; + public StreamState createState(Stream stream, int streamId) { + return new StreamState( + streamId, initialWindowSize, Preconditions.checkNotNull(stream, "stream")); } /** @@ -160,15 +150,14 @@ class OutboundFlowController { * *

Must be called with holding transport lock. */ - void writeStreams() { - OkHttpClientStream[] streams = transport.getActiveStreams(); + public void writeStreams() { + StreamState[] states = transport.getActiveStreams(); int connectionWindow = connectionState.window(); - for (int numStreams = streams.length; numStreams > 0 && connectionWindow > 0;) { + for (int numStreams = states.length; numStreams > 0 && connectionWindow > 0;) { int nextNumStreams = 0; int windowSlice = (int) ceil(connectionWindow / (float) numStreams); for (int index = 0; index < numStreams && connectionWindow > 0; ++index) { - OkHttpClientStream stream = streams[index]; - OutboundFlowState state = state(stream); + StreamState state = states[index]; int bytesForStream = min(connectionWindow, min(state.unallocatedBytes(), windowSlice)); if (bytesForStream > 0) { @@ -179,7 +168,7 @@ class OutboundFlowController { if (state.unallocatedBytes() > 0) { // There is more data to process for this stream. Add it to the next // pass. - streams[nextNumStreams++] = stream; + states[nextNumStreams++] = state; } } numStreams = nextNumStreams; @@ -187,8 +176,7 @@ class OutboundFlowController { // Now take one last pass through all of the streams and write any allocated bytes. WriteStatus writeStatus = new WriteStatus(); - for (OkHttpClientStream stream : transport.getActiveStreams()) { - OutboundFlowState state = state(stream); + for (StreamState state : transport.getActiveStreams()) { state.writeBytes(state.allocatedBytes(), writeStatus); state.clearAllocatedBytes(); } @@ -213,25 +201,29 @@ class OutboundFlowController { } } + public interface Transport { + StreamState[] getActiveStreams(); + } + + public interface Stream { + void onSentBytes(int frameBytes); + } + /** * The outbound flow control state for a single stream. */ - private final class OutboundFlowState { - final Buffer pendingWriteBuffer; - final int streamId; - int window; - int allocatedBytes; - OkHttpClientStream stream; - boolean pendingBufferHasEndOfStream = false; + public final class StreamState { + private final Buffer pendingWriteBuffer = new Buffer(); + private Runnable noPendingDataRunnable; + private final int streamId; + private int window; + private int allocatedBytes; + private final Stream stream; + private boolean pendingBufferHasEndOfStream = false; - OutboundFlowState(int streamId, int initialWindowSize) { + StreamState(int streamId, int initialWindowSize, Stream stream) { this.streamId = streamId; window = initialWindowSize; - pendingWriteBuffer = new Buffer(); - } - - OutboundFlowState(OkHttpClientStream stream, int initialWindowSize) { - this(stream.id(), initialWindowSize); this.stream = stream; } @@ -305,6 +297,10 @@ class OutboundFlowController { // Update the threshold. maxBytes = min(bytes - bytesAttempted, writableWindow()); } + if (!hasPendingData() && noPendingDataRunnable != null) { + noPendingDataRunnable.run(); + noPendingDataRunnable = null; + } return bytesAttempted; } @@ -328,14 +324,20 @@ class OutboundFlowController { } catch (IOException e) { throw new RuntimeException(e); } - stream.transportState().onSentBytes(frameBytes); + stream.onSentBytes(frameBytes); bytesToWrite -= frameBytes; } while (bytesToWrite > 0); } - void enqueue(Buffer buffer, int size, boolean endOfStream) { + void enqueueData(Buffer buffer, int size, boolean endOfStream) { this.pendingWriteBuffer.write(buffer, size); this.pendingBufferHasEndOfStream |= endOfStream; } + + void notifyWhenNoPendingData(Runnable noPendingDataRunnable) { + Preconditions.checkState( + this.noPendingDataRunnable == null, "pending data notification already requested"); + this.noPendingDataRunnable = noPendingDataRunnable; + } } -} \ No newline at end of file +} diff --git a/okhttp/src/main/java/io/grpc/okhttp/PlaintextHandshakerSocketFactory.java b/okhttp/src/main/java/io/grpc/okhttp/PlaintextHandshakerSocketFactory.java new file mode 100644 index 0000000000..5338536213 --- /dev/null +++ b/okhttp/src/main/java/io/grpc/okhttp/PlaintextHandshakerSocketFactory.java @@ -0,0 +1,39 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import io.grpc.Attributes; +import io.grpc.Grpc; +import io.grpc.SecurityLevel; +import io.grpc.internal.GrpcAttributes; +import java.io.IOException; +import java.net.Socket; + +/** + * No-thrills plaintext handshaker. + */ +final class PlaintextHandshakerSocketFactory implements HandshakerSocketFactory { + @Override + public HandshakeResult handshake(Socket socket, Attributes attributes) throws IOException { + attributes = attributes.toBuilder() + .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, socket.getLocalSocketAddress()) + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, socket.getRemoteSocketAddress()) + .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.NONE) + .build(); + return new HandshakeResult(socket, attributes, null); + } +} diff --git a/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryServerCredentials.java b/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryServerCredentials.java new file mode 100644 index 0000000000..63c6f33ff7 --- /dev/null +++ b/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryServerCredentials.java @@ -0,0 +1,60 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import com.google.common.base.Preconditions; +import io.grpc.ExperimentalApi; +import io.grpc.okhttp.internal.ConnectionSpec; +import javax.net.ssl.SSLSocketFactory; + +/** A credential with full control over the SSLSocketFactory. */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/1785") +public final class SslSocketFactoryServerCredentials { + private SslSocketFactoryServerCredentials() {} + + public static io.grpc.ServerCredentials create(SSLSocketFactory factory) { + return new ServerCredentials(factory); + } + + public static io.grpc.ServerCredentials create( + SSLSocketFactory factory, com.squareup.okhttp.ConnectionSpec connectionSpec) { + return new ServerCredentials(factory, Utils.convertSpec(connectionSpec)); + } + + // Hide implementation detail of how these credentials operate + static final class ServerCredentials extends io.grpc.ServerCredentials { + private final SSLSocketFactory factory; + private final ConnectionSpec connectionSpec; + + ServerCredentials(SSLSocketFactory factory) { + this(factory, OkHttpChannelBuilder.INTERNAL_DEFAULT_CONNECTION_SPEC); + } + + ServerCredentials(SSLSocketFactory factory, ConnectionSpec connectionSpec) { + this.factory = Preconditions.checkNotNull(factory, "factory"); + this.connectionSpec = Preconditions.checkNotNull(connectionSpec, "connectionSpec"); + } + + public SSLSocketFactory getFactory() { + return factory; + } + + public ConnectionSpec getConnectionSpec() { + return connectionSpec; + } + } +} diff --git a/okhttp/src/main/java/io/grpc/okhttp/TlsServerHandshakerSocketFactory.java b/okhttp/src/main/java/io/grpc/okhttp/TlsServerHandshakerSocketFactory.java new file mode 100644 index 0000000000..c375d6246c --- /dev/null +++ b/okhttp/src/main/java/io/grpc/okhttp/TlsServerHandshakerSocketFactory.java @@ -0,0 +1,72 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import io.grpc.Attributes; +import io.grpc.Grpc; +import io.grpc.InternalChannelz; +import io.grpc.SecurityLevel; +import io.grpc.internal.GrpcAttributes; +import io.grpc.okhttp.internal.ConnectionSpec; +import io.grpc.okhttp.internal.Protocol; +import java.io.IOException; +import java.net.Socket; +import java.util.Arrays; +import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; + +/** + * TLS handshaker. + */ +final class TlsServerHandshakerSocketFactory implements HandshakerSocketFactory { + private final PlaintextHandshakerSocketFactory delegate = new PlaintextHandshakerSocketFactory(); + private final SSLSocketFactory socketFactory; + private final ConnectionSpec connectionSpec; + + public TlsServerHandshakerSocketFactory( + SslSocketFactoryServerCredentials.ServerCredentials credentials) { + this.socketFactory = credentials.getFactory(); + this.connectionSpec = credentials.getConnectionSpec(); + } + + @Override + public HandshakeResult handshake(Socket socket, Attributes attributes) throws IOException { + HandshakeResult result = delegate.handshake(socket, attributes); + socket = socketFactory.createSocket(result.socket, null, -1, true); + if (!(socket instanceof SSLSocket)) { + throw new IOException( + "SocketFactory " + socketFactory + " did not produce an SSLSocket: " + socket.getClass()); + } + SSLSocket sslSocket = (SSLSocket) socket; + sslSocket.setUseClientMode(false); + connectionSpec.apply(sslSocket, false); + Protocol expectedProtocol = Protocol.HTTP_2; + String negotiatedProtocol = OkHttpProtocolNegotiator.get().negotiate( + sslSocket, + null, + connectionSpec.supportsTlsExtensions() ? Arrays.asList(expectedProtocol) : null); + if (!expectedProtocol.toString().equals(negotiatedProtocol)) { + throw new IOException("Expected NPN/ALPN " + expectedProtocol + ": " + negotiatedProtocol); + } + attributes = result.attributes.toBuilder() + .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.PRIVACY_AND_INTEGRITY) + .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSocket.getSession()) + .build(); + return new HandshakeResult(socket, attributes, + new InternalChannelz.Security(new InternalChannelz.Tls(sslSocket.getSession()))); + } +} diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java new file mode 100644 index 0000000000..201829d38a --- /dev/null +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java @@ -0,0 +1,1239 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import static com.google.common.base.Charsets.UTF_8; +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.okhttp.Headers.CONTENT_TYPE_HEADER; +import static io.grpc.okhttp.Headers.HTTP_SCHEME_HEADER; +import static io.grpc.okhttp.Headers.METHOD_HEADER; +import static io.grpc.okhttp.Headers.TE_HEADER; +import static org.mockito.AdditionalAnswers.answerVoid; +import static org.mockito.AdditionalAnswers.delegatesTo; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.verify; + +import com.google.common.io.ByteStreams; +import io.grpc.Attributes; +import io.grpc.InternalChannelz.SocketStats; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.internal.FakeClock; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.ServerStream; +import io.grpc.internal.ServerStreamListener; +import io.grpc.internal.ServerTransportListener; +import io.grpc.okhttp.internal.framed.ErrorCode; +import io.grpc.okhttp.internal.framed.FrameReader; +import io.grpc.okhttp.internal.framed.FrameWriter; +import io.grpc.okhttp.internal.framed.Header; +import io.grpc.okhttp.internal.framed.HeadersMode; +import io.grpc.okhttp.internal.framed.Http2; +import io.grpc.okhttp.internal.framed.Settings; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.PipedInputStream; +import java.io.PipedOutputStream; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.net.SocketAddress; +import java.util.ArrayDeque; +import java.util.Arrays; +import java.util.Deque; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import okio.Buffer; +import okio.BufferedSource; +import okio.ByteString; +import okio.Okio; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; + +/** + * Tests for {@link OkHttpServerTransport}. + */ +@RunWith(JUnit4.class) +public class OkHttpServerTransportTest { + private static final int TIME_OUT_MS = 2000; + private static final int INITIAL_WINDOW_SIZE = 65535; + + private MockServerTransportListener mockTransportListener = new MockServerTransportListener(); + private ServerTransportListener transportListener + = mock(ServerTransportListener.class, delegatesTo(mockTransportListener)); + private OkHttpServerTransport serverTransport; + private final PipeSocket socket = new PipeSocket(); + private final FrameWriter clientFrameWriter + = new Http2().newWriter(Okio.buffer(Okio.sink(socket.inputStreamSource)), true); + private final FrameReader clientFrameReader + = new Http2().newReader(Okio.buffer(Okio.source(socket.outputStreamSink)), true); + private final FrameReader.Handler clientFramesRead = mock(FrameReader.Handler.class); + private final DataFrameHandler clientDataFrames = mock(DataFrameHandler.class); + private ExecutorService threadPool = Executors.newCachedThreadPool(); + private HandshakerSocketFactory handshakerSocketFactory + = mock(HandshakerSocketFactory.class, delegatesTo(new PlaintextHandshakerSocketFactory())); + private OkHttpServerBuilder serverBuilder + = new OkHttpServerBuilder(new InetSocketAddress(1234), handshakerSocketFactory) + .executor(new FakeClock().getScheduledExecutorService()) // Executor unused + .scheduledExecutorService(new FakeClock().getScheduledExecutorService()) + .transportExecutor(new Executor() { + @Override public void execute(Runnable runnable) { + if (runnable instanceof OkHttpServerTransport.FrameHandler) { + threadPool.execute(runnable); + } else { + // Writing is buffered in the PipeSocket, so AsyncSinc can be executed immediately + runnable.run(); + } + } + }) + .flowControlWindow(INITIAL_WINDOW_SIZE); + + @Rule public final Timeout globalTimeout = Timeout.seconds(10); + + @Before + public void setUp() throws Exception { + doAnswer(answerVoid((Boolean outDone, Integer streamId, BufferedSource in, Integer length) -> { + in.require(length); + Buffer buf = new Buffer(); + buf.write(in.getBuffer(), length); + clientDataFrames.data(outDone, streamId, buf); + })).when(clientFramesRead).data(anyBoolean(), anyInt(), any(BufferedSource.class), anyInt()); + } + + @After + public void tearDown() throws Exception { + threadPool.shutdownNow(); + socket.closeSourceAndSink(); + } + + @Test + public void startThenShutdown() throws Exception { + initTransport(); + handshake(); + shutdownAndTerminate(/*lastStreamId=*/ 0); + } + + @Test + public void startThenShutdownTwice() throws Exception { + initTransport(); + handshake(); + serverTransport.shutdown(); + shutdownAndTerminate(/*lastStreamId=*/ 0); + } + + @Test + public void shutdownDuringHandshake() throws Exception { + doAnswer(invocation -> { + socket.getInputStream().read(); + throw new IOException("handshake purposefully failed"); + }).when(handshakerSocketFactory).handshake(any(Socket.class), any(Attributes.class)); + serverBuilder.transportExecutor(threadPool); + initTransport(); + serverTransport.shutdown(); + + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + verify(transportListener, never()).transportReady(any(Attributes.class)); + } + + @Test + public void shutdownNowDuringHandshake() throws Exception { + doAnswer(invocation -> { + socket.getInputStream().read(); + throw new IOException("handshake purposefully failed"); + }).when(handshakerSocketFactory).handshake(any(Socket.class), any(Attributes.class)); + serverBuilder.transportExecutor(threadPool); + initTransport(); + serverTransport.shutdownNow(Status.UNAVAILABLE.withDescription("shutdown now")); + + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + verify(transportListener, never()).transportReady(any(Attributes.class)); + } + + @Test + public void clientCloseDuringHandshake() throws Exception { + doAnswer(invocation -> { + socket.getInputStream().read(); + throw new IOException("handshake purposefully failed"); + }).when(handshakerSocketFactory).handshake(any(Socket.class), any(Attributes.class)); + serverBuilder.transportExecutor(threadPool); + initTransport(); + socket.close(); + + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + verify(transportListener, never()).transportReady(any(Attributes.class)); + } + + @Test + public void closeDuringHttp2Preface() throws Exception { + initTransport(); + socket.close(); + + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + verify(transportListener, never()).transportReady(any(Attributes.class)); + } + + @Test + public void noSettingsDuringHttp2HandshakeSettings() throws Exception { + initTransport(); + clientFrameWriter.connectionPreface(); + clientFrameWriter.flush(); + socket.close(); + + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + verify(transportListener, never()).transportReady(any(Attributes.class)); + } + + @Test + public void noSettingsDuringHttp2Handshake() throws Exception { + initTransport(); + clientFrameWriter.connectionPreface(); + clientFrameWriter.ping(false, 0, 0x1234); + clientFrameWriter.flush(); + + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + verify(transportListener, never()).transportReady(any(Attributes.class)); + } + + @Test + public void startThenClientDisconnect() throws Exception { + initTransport(); + handshake(); + + socket.closeSourceAndSink(); + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + } + + @Test + public void basicRpc_succeeds() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER, + new Header("some-metadata", "this could be anything"))); + Buffer requestMessageFrame = createMessageFrame("Hello server"); + clientFrameWriter.data(true, 1, requestMessageFrame, (int) requestMessageFrame.size()); + pingPong(); + + MockStreamListener streamListener = mockTransportListener.newStreams.pop(); + assertThat(streamListener.stream.getAuthority()).isEqualTo("example.com:80"); + assertThat(streamListener.method).isEqualTo("com.example/SimpleService.doit"); + assertThat(streamListener.headers.get( + Metadata.Key.of("Some-Metadata", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("this could be anything"); + streamListener.stream.request(1); + pingPong(); + assertThat(streamListener.messages.pop()).isEqualTo("Hello server"); + assertThat(streamListener.halfClosedCalled).isTrue(); + + streamListener.stream.writeHeaders(metadata("User-Data", "best data")); + streamListener.stream.writeMessage(new ByteArrayInputStream("Howdy client".getBytes(UTF_8))); + streamListener.stream.close(Status.OK, metadata("End-Metadata", "bye")); + + List

responseHeaders = Arrays.asList( + new Header(":status", "200"), + CONTENT_TYPE_HEADER, + new Header("user-data", "best data")); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead) + .headers(false, false, 1, -1, responseHeaders, HeadersMode.HTTP_20_HEADERS); + + Buffer responseMessageFrame = createMessageFrame("Howdy client"); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead) + .data(eq(false), eq(1), any(BufferedSource.class), eq((int) responseMessageFrame.size())); + verify(clientDataFrames).data(false, 1, responseMessageFrame); + + List
responseTrailers = Arrays.asList( + new Header("end-metadata", "bye"), + new Header("grpc-status", "0")); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead) + .headers(false, true, 1, -1, responseTrailers, HeadersMode.HTTP_20_HEADERS); + + SocketStats stats = serverTransport.getStats().get(); + assertThat(stats.data.streamsStarted).isEqualTo(1); + assertThat(stats.data.streamsSucceeded).isEqualTo(1); + assertThat(stats.data.streamsFailed).isEqualTo(0); + assertThat(stats.data.messagesSent).isEqualTo(1); + assertThat(stats.data.messagesReceived).isEqualTo(1); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void activeRpc_delaysShutdownTermination() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER, + new Header("some-metadata", "this could be anything"))); + Buffer requestMessageFrame = createMessageFrame("Hello server"); + clientFrameWriter.data(true, 1, requestMessageFrame, (int) requestMessageFrame.size()); + pingPong(); + + shutdownAndVerifyGraceful(1); + verify(transportListener, never()).transportTerminated(); + + MockStreamListener streamListener = mockTransportListener.newStreams.pop(); + streamListener.stream.request(1); + pingPong(); + assertThat(streamListener.messages.pop()).isEqualTo("Hello server"); + assertThat(streamListener.halfClosedCalled).isTrue(); + + streamListener.stream.writeHeaders(new Metadata()); + streamListener.stream.writeMessage(new ByteArrayInputStream("Howdy client".getBytes(UTF_8))); + streamListener.stream.flush(); + + List
responseHeaders = Arrays.asList( + new Header(":status", "200"), + CONTENT_TYPE_HEADER); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead) + .headers(false, false, 1, -1, responseHeaders, HeadersMode.HTTP_20_HEADERS); + + Buffer responseMessageFrame = createMessageFrame("Howdy client"); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead) + .data(eq(false), eq(1), any(BufferedSource.class), eq((int) responseMessageFrame.size())); + verify(clientDataFrames).data(false, 1, responseMessageFrame); + pingPong(); + assertThat(serverTransport.getActiveStreams().length).isEqualTo(1); + verify(transportListener, never()).transportTerminated(); + + streamListener.stream.close(Status.OK, new Metadata()); + List
responseTrailers = Arrays.asList( + new Header("grpc-status", "0")); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead) + .headers(false, true, 1, -1, responseTrailers, HeadersMode.HTTP_20_HEADERS); + + assertThat(serverTransport.getActiveStreams().length).isEqualTo(0); + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + } + + @Test + public void headersForStream0_failsWithGoAway() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(0, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.flush(); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).goAway( + 0, ErrorCode.INTERNAL_ERROR, + ByteString.encodeUtf8("Error in frame decoder")); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isFalse(); + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + } + + @Test + public void headersForEvenStream_failsWithGoAway() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(2, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.flush(); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).goAway( + 0, ErrorCode.PROTOCOL_ERROR, + ByteString.encodeUtf8("Clients cannot open even numbered streams. RFC7540 section 5.1.1")); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isFalse(); + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + } + + @Test + public void headersTooLarge_failsWith431() throws Exception { + initTransport(); + handshake(); + + StringBuilder largeString = new StringBuilder(); + for (int i = 0; i < 100; i++) { + largeString.append( + "Row, row, row your boat, gently down the stream. Merrily, merrily, merrily, merrily, " + + "life is but a dream. "); + } + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER, + new Header("too-large", largeString.toString()))); + clientFrameWriter.flush(); + + verifyHttpError( + 1, 431, Status.Code.RESOURCE_EXHAUSTED, "Request metadata larger than 8192: 10953"); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void invalidPseudoHeader_failsWithRst() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + new Header(":status", "999"), // Invalid for requests + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.flush(); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).rstStream(1, ErrorCode.PROTOCOL_ERROR); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void multipleAuthorityHeaders_failsWithRst() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_AUTHORITY, "example.com:8080"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.flush(); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).rstStream(1, ErrorCode.PROTOCOL_ERROR); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void pseudoHeaderAfterRegularHeader_failsWithRst() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + CONTENT_TYPE_HEADER, + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + TE_HEADER)); + clientFrameWriter.flush(); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).rstStream(1, ErrorCode.PROTOCOL_ERROR); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void missingSchemeHeader_failsWithRst() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.flush(); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).rstStream(1, ErrorCode.PROTOCOL_ERROR); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void connectionHeader_failsWithRst() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + new Header("connection", "content-type"), + TE_HEADER)); + clientFrameWriter.flush(); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).rstStream(1, ErrorCode.PROTOCOL_ERROR); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void trailersAfterEndStream_failsWithRst() throws Exception { + initTransport(); + handshake(); + + List
headers = Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER); + clientFrameWriter.synStream(true, false, 1, -1, headers); + clientFrameWriter.synStream(true, false, 1, -1, headers); + clientFrameWriter.flush(); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).rstStream(1, ErrorCode.STREAM_CLOSED); + pingPong(); + + MockStreamListener streamListener = mockTransportListener.newStreams.pop(); + assertThat(streamListener.status).isNotNull(); + assertThat(streamListener.status.getCode()).isNotEqualTo(Status.Code.OK); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void trailers_endStream() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.synStream(true, false, 1, -1, Arrays.asList( + new Header("some-client-sent-trailer", "trailer-value"))); + pingPong(); + + MockStreamListener streamListener = mockTransportListener.newStreams.pop(); + assertThat(streamListener.messages.peek()).isNull(); + assertThat(streamListener.halfClosedCalled).isTrue(); + + streamListener.stream.close(Status.OK, new Metadata()); + + List
responseTrailers = Arrays.asList( + new Header(":status", "200"), + CONTENT_TYPE_HEADER, + new Header("grpc-status", "0")); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead) + .headers(false, true, 1, -1, responseTrailers, HeadersMode.HTTP_20_HEADERS); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void headersInMiddleOfRequest_failsWithRst() throws Exception { + initTransport(); + handshake(); + + List
headers = Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER); + clientFrameWriter.headers(1, headers); + clientFrameWriter.headers(1, headers); + clientFrameWriter.flush(); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).rstStream(1, ErrorCode.PROTOCOL_ERROR); + pingPong(); + + MockStreamListener streamListener = mockTransportListener.newStreams.pop(); + assertThat(streamListener.status).isNotNull(); + assertThat(streamListener.status.getCode()).isNotEqualTo(Status.Code.OK); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void multipleHostHeaders_failsWith400() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER, + new Header("host", "example.com:80"), + new Header("host", "example.com:80"))); + clientFrameWriter.flush(); + + verifyHttpError( + 1, 400, Status.Code.INTERNAL, "Multiple host headers disallowed. RFC7230 section 5.4"); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void hostWithoutAuthority_usesHost() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + new Header("host", "example.com:80"), + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.rstStream(1, ErrorCode.CANCEL); + pingPong(); + + MockStreamListener streamListener = mockTransportListener.newStreams.pop(); + assertThat(streamListener.stream.getAuthority()).isEqualTo("example.com:80"); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void authorityAndHost_usesAuthority() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + new Header("host", "example2.com:8080"), + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.rstStream(1, ErrorCode.CANCEL); + pingPong(); + + MockStreamListener streamListener = mockTransportListener.newStreams.pop(); + assertThat(streamListener.stream.getAuthority()).isEqualTo("example.com:80"); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void missingAuthorityAndHost_hasNullAuthority() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.rstStream(1, ErrorCode.CANCEL); + pingPong(); + + MockStreamListener streamListener = mockTransportListener.newStreams.pop(); + assertThat(streamListener.stream.getAuthority()).isNull(); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void emptyPath_failsWith404() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, ""), + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.flush(); + + verifyHttpError(1, 404, Status.Code.UNIMPLEMENTED, "Expected path to start with /: "); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void nonAbsolutePath_failsWith404() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "https://example.com/"), + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.flush(); + + verifyHttpError( + 1, 404, Status.Code.UNIMPLEMENTED, "Expected path to start with /: https://example.com/"); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void missingContentType_failsWith415() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + TE_HEADER)); + clientFrameWriter.flush(); + + verifyHttpError(1, 415, Status.Code.INTERNAL, "Content-Type is missing or duplicated"); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void repeatedContentType_failsWith415() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.flush(); + + verifyHttpError(1, 415, Status.Code.INTERNAL, "Content-Type is missing or duplicated"); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void textContentType_failsWith415() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + new Header("content-type", "text/plain"), + TE_HEADER)); + clientFrameWriter.flush(); + + verifyHttpError(1, 415, Status.Code.INTERNAL, "Content-Type is not supported: text/plain"); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void httpGet_failsWith405() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + new Header(":method", "GET"), + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.flush(); + + verifyHttpError(1, 405, Status.Code.INTERNAL, "HTTP Method is not supported: GET"); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void missingTeTrailers_failsWithInternal() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER)); + clientFrameWriter.flush(); + + List
responseHeaders = Arrays.asList( + new Header(":status", "200"), + new Header("content-type", "application/grpc"), + new Header("grpc-status", "" + Status.Code.INTERNAL.value()), + new Header("grpc-message", "Expected header TE: trailers, but is received. " + + "Some intermediate proxy may not support trailers")); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead) + .headers(false, true, 1, -1, responseHeaders, HeadersMode.HTTP_20_HEADERS); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).rstStream(1, ErrorCode.NO_ERROR); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void httpErrorsAdhereToFlowControl() throws Exception { + Settings settings = new Settings(); + OkHttpSettingsUtil.set(settings, OkHttpSettingsUtil.INITIAL_WINDOW_SIZE, 1); + + initTransport(); + handshake(settings); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + new Header(":method", "GET"), // Invalid + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.flush(); + + String errorDescription = "HTTP Method is not supported: GET"; + List
responseHeaders = Arrays.asList( + new Header(":status", "405"), + new Header("content-type", "text/plain; charset=utf-8"), + new Header("grpc-status", "" + Status.Code.INTERNAL.value()), + new Header("grpc-message", errorDescription)); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead) + .headers(false, false, 1, -1, responseHeaders, HeadersMode.HTTP_20_HEADERS); + + Buffer responseDataFrame = new Buffer().writeUtf8(errorDescription.substring(0, 1)); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).data( + eq(false), eq(1), any(BufferedSource.class), eq((int) responseDataFrame.size())); + verify(clientDataFrames).data(false, 1, responseDataFrame); + + clientFrameWriter.windowUpdate(1, 1000); + clientFrameWriter.flush(); + + responseDataFrame = new Buffer().writeUtf8(errorDescription.substring(1)); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).data( + eq(true), eq(1), any(BufferedSource.class), eq((int) responseDataFrame.size())); + verify(clientDataFrames).data(true, 1, responseDataFrame); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).rstStream(1, ErrorCode.NO_ERROR); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void dataForStream0_failsWithGoAway() throws Exception { + initTransport(); + handshake(); + + Buffer requestMessageFrame = createMessageFrame("Nope"); + clientFrameWriter.data(true, 0, requestMessageFrame, (int) requestMessageFrame.size()); + clientFrameWriter.flush(); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).goAway( + 0, ErrorCode.PROTOCOL_ERROR, + ByteString.encodeUtf8("Stream 0 is reserved for control messages. RFC7540 section 5.1.1")); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isFalse(); + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + } + + @Test + public void dataForEvenStream_failsWithGoAway() throws Exception { + initTransport(); + handshake(); + + Buffer requestMessageFrame = createMessageFrame("Nope"); + clientFrameWriter.data(true, 2, requestMessageFrame, (int) requestMessageFrame.size()); + clientFrameWriter.flush(); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).goAway( + 0, ErrorCode.PROTOCOL_ERROR, + ByteString.encodeUtf8("Clients cannot open even numbered streams. RFC7540 section 5.1.1")); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isFalse(); + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + } + + @Test + public void dataAfterHalfClose_failsWithRst() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER, + new Header("some-metadata", "this could be anything"))); + Buffer requestMessageFrame = createMessageFrame("Hello server"); + clientFrameWriter.data(true, 1, requestMessageFrame, (int) requestMessageFrame.size()); + requestMessageFrame = createMessageFrame("oh, I forgot"); + clientFrameWriter.data(true, 1, requestMessageFrame, (int) requestMessageFrame.size()); + clientFrameWriter.flush(); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).rstStream(1, ErrorCode.STREAM_CLOSED); + MockStreamListener streamListener = mockTransportListener.newStreams.pop(); + pingPong(); + assertThat(streamListener.status).isNotNull(); + assertThat(streamListener.status.getCode()).isNotEqualTo(Status.Code.OK); + + shutdownAndTerminate(/*lastStreamId=*/ 1); + } + + @Test + public void pushPromise_failsWithGoAway() throws Exception { + initTransport(); + handshake(); + + clientFrameWriter.pushPromise(2, 3, Arrays.asList()); + clientFrameWriter.flush(); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).goAway( + 0, ErrorCode.PROTOCOL_ERROR, + ByteString.encodeUtf8( + "PUSH_PROMISE only allowed on peer-initiated streams. RFC7540 section 6.6")); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isFalse(); + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + } + + @Test + public void channelzStats() throws Exception { + serverBuilder.flowControlWindow(60000); + initTransport(); + handshake(); + clientFrameWriter.windowUpdate(0, 1000); // connection stream id + pingPong(); + + SocketStats stats = serverTransport.getStats().get(); + assertThat(stats.data.streamsStarted).isEqualTo(0); + assertThat(stats.data.streamsSucceeded).isEqualTo(0); + assertThat(stats.data.streamsFailed).isEqualTo(0); + assertThat(stats.data.messagesSent).isEqualTo(0); + assertThat(stats.data.messagesReceived).isEqualTo(0); + assertThat(stats.data.remoteFlowControlWindow).isEqualTo(30000); // Lower bound + assertThat(stats.data.localFlowControlWindow).isEqualTo(66535); + assertThat(stats.local).isEqualTo(new InetSocketAddress("127.0.0.1", 4000)); + assertThat(stats.remote).isEqualTo(new InetSocketAddress("127.0.0.2", 5000)); + } + + private void initTransport() throws Exception { + serverTransport = new OkHttpServerTransport( + new OkHttpServerTransport.Config(serverBuilder, Arrays.asList()), + socket); + serverTransport.start(transportListener); + } + + private void handshake() throws Exception { + handshake(new Settings()); + } + + private void handshake(Settings settings) throws Exception { + clientFrameWriter.connectionPreface(); + clientFrameWriter.settings(settings); + clientFrameWriter.flush(); + clientFrameReader.readConnectionPreface(); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + ArgumentCaptor settingsCaptor = ArgumentCaptor.forClass(Settings.class); + verify(clientFramesRead).settings(eq(false), settingsCaptor.capture()); + clientFrameWriter.ackSettings(settingsCaptor.getValue()); + clientFrameWriter.flush(); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).ackSettings(); + verify(transportListener, timeout(TIME_OUT_MS)).transportReady(any(Attributes.class)); + } + + private static Buffer createMessageFrame(String stringMessage) { + byte[] message = stringMessage.getBytes(UTF_8); + Buffer buffer = new Buffer(); + buffer.writeByte(0 /* UNCOMPRESSED */); + buffer.writeInt(message.length); + buffer.write(message); + return buffer; + } + + private Metadata metadata(String... keysAndValues) { + Metadata metadata = new Metadata(); + assertThat(keysAndValues.length % 2).isEqualTo(0); + for (int i = 0; i < keysAndValues.length; i += 2) { + metadata.put( + Metadata.Key.of(keysAndValues[i], Metadata.ASCII_STRING_MARSHALLER), + keysAndValues[i + 1]); + } + return metadata; + } + + private void shutdownAndVerifyGraceful(int lastStreamId) throws IOException { + serverTransport.shutdown(); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).goAway(2147483647, ErrorCode.NO_ERROR, ByteString.EMPTY); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).ping(false, 0, 0x1111); + clientFrameWriter.ping(true, 0, 0x1111); + clientFrameWriter.flush(); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).goAway(lastStreamId, ErrorCode.NO_ERROR, ByteString.EMPTY); + } + + private void shutdownAndTerminate(int lastStreamId) throws IOException { + assertThat(serverTransport.getActiveStreams().length).isEqualTo(0); + shutdownAndVerifyGraceful(lastStreamId); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isFalse(); + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + } + + private int pingPongId = 0; + + /** Send a ping and wait for the ping ack. */ + private void pingPong() throws IOException { + pingPongId++; + clientFrameWriter.ping(false, pingPongId, 0); + clientFrameWriter.flush(); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).ping(true, pingPongId, 0); + } + + private void verifyHttpError( + int streamId, int httpCode, Status.Code grpcCode, String errorDescription) throws Exception { + List
responseHeaders = Arrays.asList( + new Header(":status", "" + httpCode), + new Header("content-type", "text/plain; charset=utf-8"), + new Header("grpc-status", "" + grpcCode.value()), + new Header("grpc-message", errorDescription)); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead) + .headers(false, false, streamId, -1, responseHeaders, HeadersMode.HTTP_20_HEADERS); + + Buffer responseDataFrame = new Buffer().writeUtf8(errorDescription); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).data( + eq(true), eq(streamId), any(BufferedSource.class), eq((int) responseDataFrame.size())); + verify(clientDataFrames).data(true, streamId, responseDataFrame); + + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).rstStream(streamId, ErrorCode.NO_ERROR); + } + + private static class MockServerTransportListener implements ServerTransportListener { + Deque newStreams = new ArrayDeque<>(); + + @Override public void streamCreated(ServerStream stream, String method, Metadata headers) { + MockStreamListener streamListener = new MockStreamListener(stream, method, headers); + stream.setListener(streamListener); + newStreams.add(streamListener); + } + + @Override public Attributes transportReady(Attributes attributes) { + return attributes; + } + + @Override public void transportTerminated() {} + } + + private static class MockStreamListener implements ServerStreamListener { + final ServerStream stream; + final String method; + final Metadata headers; + + Deque messages = new ArrayDeque<>(); + boolean halfClosedCalled; + boolean onReadyCalled; + Status status; + CountDownLatch closed = new CountDownLatch(1); + + MockStreamListener(ServerStream stream, String method, Metadata headers) { + this.stream = stream; + this.method = method; + this.headers = headers; + } + + @Override + public void messagesAvailable(MessageProducer producer) { + InputStream inputStream; + while ((inputStream = producer.next()) != null) { + try { + String msg = getContent(inputStream); + if (msg != null) { + messages.add(msg); + } + } catch (IOException ex) { + while ((inputStream = producer.next()) != null) { + GrpcUtil.closeQuietly(inputStream); + } + throw new RuntimeException(ex); + } + } + } + + @Override + public void halfClosed() { + halfClosedCalled = true; + } + + @Override + public void closed(Status status) { + this.status = status; + closed.countDown(); + } + + @Override + public void onReady() { + onReadyCalled = true; + } + + boolean isOnReadyCalled() { + boolean value = onReadyCalled; + onReadyCalled = false; + return value; + } + + void waitUntilStreamClosed() throws InterruptedException, TimeoutException { + if (!closed.await(TIME_OUT_MS, TimeUnit.MILLISECONDS)) { + throw new TimeoutException("Failed waiting stream to be closed."); + } + } + + static String getContent(InputStream message) throws IOException { + try { + return new String(ByteStreams.toByteArray(message), UTF_8); + } finally { + message.close(); + } + } + } + + private static class PipeSocket extends Socket { + private final PipedOutputStream outputStream = new PipedOutputStream(); + private final PipedInputStream outputStreamSink = new PipedInputStream(); + private final PipedOutputStream inputStreamSource = new PipedOutputStream(); + private final PipedInputStream inputStream = new PipedInputStream(); + + public PipeSocket() { + try { + outputStreamSink.connect(outputStream); + inputStream.connect(inputStreamSource); + } catch (IOException ex) { + throw new AssertionError(ex); + } + } + + @Override + public synchronized void close() throws IOException { + try { + outputStream.close(); + } finally { + inputStream.close(); + // PipedInputStream can only be woken by PipedOutputStream, so PipedOutputStream.close() is + // a better imitation of Socket.close(). + inputStreamSource.close(); + } + } + + public void closeSourceAndSink() throws IOException { + try { + outputStreamSink.close(); + } finally { + inputStreamSource.close(); + } + } + + @Override + public SocketAddress getLocalSocketAddress() { + return new InetSocketAddress("127.0.0.1", 4000); + } + + @Override + public SocketAddress getRemoteSocketAddress() { + return new InetSocketAddress("127.0.0.2", 5000); + } + + @Override + public OutputStream getOutputStream() { + return outputStream; + } + + @Override + public InputStream getInputStream() { + return inputStream; + } + } + + private interface DataFrameHandler { + void data(boolean inFinished, int streamId, Buffer payload); + } +} diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpTransportTest.java index 44e493c259..076eea3349 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpTransportTest.java @@ -16,6 +16,7 @@ package io.grpc.okhttp; +import io.grpc.InsecureServerCredentials; import io.grpc.ServerStreamTracer; import io.grpc.internal.AbstractTransportTest; import io.grpc.internal.ClientTransportFactory; @@ -23,8 +24,6 @@ import io.grpc.internal.FakeClock; import io.grpc.internal.GrpcUtil; import io.grpc.internal.InternalServer; import io.grpc.internal.ManagedClientTransport; -import io.grpc.netty.InternalNettyServerBuilder; -import io.grpc.netty.NettyServerBuilder; import java.net.InetSocketAddress; import java.util.List; import java.util.concurrent.TimeUnit; @@ -53,21 +52,17 @@ public class OkHttpTransportTest extends AbstractTransportTest { @Override protected InternalServer newServer( List streamTracerFactories) { - NettyServerBuilder builder = NettyServerBuilder - .forPort(0) - .flowControlWindow(AbstractTransportTest.TEST_FLOW_CONTROL_WINDOW); - InternalNettyServerBuilder.setTransportTracerFactory(builder, fakeClockTransportTracer); - return InternalNettyServerBuilder.buildTransportServers(builder, streamTracerFactories); + return newServer(0, streamTracerFactories); } @Override protected InternalServer newServer( int port, List streamTracerFactories) { - NettyServerBuilder builder = NettyServerBuilder - .forAddress(new InetSocketAddress(port)) - .flowControlWindow(AbstractTransportTest.TEST_FLOW_CONTROL_WINDOW); - InternalNettyServerBuilder.setTransportTracerFactory(builder, fakeClockTransportTracer); - return InternalNettyServerBuilder.buildTransportServers(builder, streamTracerFactories); + return OkHttpServerBuilder + .forPort(port, InsecureServerCredentials.create()) + .flowControlWindow(AbstractTransportTest.TEST_FLOW_CONTROL_WINDOW) + .setTransportTracerFactory(fakeClockTransportTracer) + .buildTransportServers(streamTracerFactories); } @Override @@ -100,11 +95,4 @@ public class OkHttpTransportTest extends AbstractTransportTest { protected boolean haveTransportTracer() { return true; } - - @Override - @org.junit.Test - @org.junit.Ignore - public void clientChecksInboundMetadataSize_trailer() { - // Server-side is flaky due to https://github.com/netty/netty/pull/8332 - } } diff --git a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Http2.java b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Http2.java index 197a7f72fc..e30264b9e3 100644 --- a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Http2.java +++ b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Http2.java @@ -231,6 +231,7 @@ public final class Http2 implements Variant { short padding = (flags & FLAG_PADDED) != 0 ? (short) (source.readByte() & 0xff) : 0; length = lengthWithoutPadding(length, flags, padding); + // FIXME: pass padding length to handler because it should be included for flow control handler.data(inFinished, streamId, source, length); source.skip(padding); } diff --git a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Settings.java b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Settings.java index 0d0ecce998..591b59129e 100644 --- a/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Settings.java +++ b/okhttp/third_party/okhttp/main/java/io/grpc/okhttp/internal/framed/Settings.java @@ -46,7 +46,7 @@ public final class Settings { /** spdy/3: Sender's estimate of max outgoing kbps. */ static final int DOWNLOAD_BANDWIDTH = 2; /** HTTP/2: The peer must not send a PUSH_PROMISE frame when this is 0. */ - static final int ENABLE_PUSH = 2; + public static final int ENABLE_PUSH = 2; /** spdy/3: Sender's estimate of millis between sending a request and receiving a response. */ static final int ROUND_TRIP_TIME = 3; /** Sender's maximum number of concurrent streams. */ @@ -58,7 +58,7 @@ public final class Settings { /** spdy/3: Retransmission rate. Percentage */ static final int DOWNLOAD_RETRANS_RATE = 6; /** HTTP/2: Advisory only. Size in bytes of the largest header list the sender will accept. */ - static final int MAX_HEADER_LIST_SIZE = 6; + public static final int MAX_HEADER_LIST_SIZE = 6; /** Window size in bytes. */ public static final int INITIAL_WINDOW_SIZE = 7; /** spdy/3: Size of the client certificate vector. Unsupported. */