okhttp: Add server implementation

This commit is contained in:
Eric Anderson 2022-05-16 11:53:20 -07:00
parent 0099b06739
commit e96d04774b
26 changed files with 3622 additions and 141 deletions

View File

@ -246,7 +246,7 @@ public abstract class ServerBuilder<T extends ServerBuilder<T>> {
/**
* Sets the time without read activity before sending a keepalive ping. An unreasonably small
* value might be increased, and {@code Long.MAX_VALUE} nano seconds or an unreasonably large
* value will disable keepalive. The typical default is infinite when supported.
* value will disable keepalive. The typical default is two hours when supported.
*
* @throws IllegalArgumentException if time is not positive
* @throws UnsupportedOperationException if unsupported

View File

@ -50,7 +50,7 @@ public abstract class AbstractServerStream extends AbstractStream
* @param flush {@code true} if more data may not be arriving soon
* @param numMessages the number of messages this frame represents
*/
void writeFrame(@Nullable WritableBuffer frame, boolean flush, int numMessages);
void writeFrame(WritableBuffer frame, boolean flush, int numMessages);
/**
* Sends trailers to the remote end point. This call implies end of stream.
@ -108,7 +108,14 @@ public abstract class AbstractServerStream extends AbstractStream
WritableBuffer frame, boolean endOfStream, boolean flush, int numMessages) {
// Since endOfStream is triggered by the sending of trailers, avoid flush here and just flush
// after the trailers.
abstractServerStreamSink().writeFrame(frame, endOfStream ? false : flush, numMessages);
if (frame == null) {
assert endOfStream;
return;
}
if (endOfStream) {
flush = false;
}
abstractServerStreamSink().writeFrame(frame, flush, numMessages);
}
@Override

View File

@ -799,6 +799,12 @@ public final class GrpcUtil {
}
}
/** Reads {@code in} until end of stream. */
public static void exhaust(InputStream in) throws IOException {
byte[] buf = new byte[256];
while (in.read(buf) != -1) {}
}
/**
* Checks whether the given item exists in the iterable. This is copied from Guava Collect's
* {@code Iterables.contains()} because Guava Collect is not Android-friendly thus core can't

View File

@ -408,12 +408,13 @@ public abstract class AbstractTransportTest {
}
assumeTrue("transport is not using InetSocketAddress", port != -1);
server.shutdown();
assertTrue(serverListener.waitForShutdown(TIMEOUT_MS, TimeUnit.MILLISECONDS));
server = newServer(port, Arrays.asList(serverStreamTracerFactory));
boolean success;
Thread.currentThread().interrupt();
try {
server.start(serverListener);
server.start(serverListener = new MockServerListener());
success = true;
} catch (Exception ex) {
success = false;

View File

@ -52,6 +52,7 @@ import io.grpc.internal.TransportTracer;
import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2ServerHeadersDecoder;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
@ -854,6 +855,9 @@ class NettyServerHandler extends AbstractNettyHandler {
keepAliveManager.onDataReceived();
}
NettyServerHandler.this.onHeadersRead(ctx, streamId, headers);
if (endStream) {
NettyServerHandler.this.onDataRead(streamId, Unpooled.EMPTY_BUFFER, 0, endStream);
}
}
@Override

View File

@ -108,10 +108,6 @@ class NettyServerStream extends AbstractServerStream {
private void writeFrameInternal(WritableBuffer frame, boolean flush, final int numMessages) {
Preconditions.checkArgument(numMessages >= 0);
if (frame == null) {
writeQueue.scheduleFlush();
return;
}
ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf().touch();
final int numBytes = bytebuf.readableBytes();
// Add the bytes to outbound flow control.

View File

@ -21,7 +21,7 @@ dependencies {
testImplementation project(':grpc-core').sourceSets.test.output,
project(':grpc-api').sourceSets.test.output,
project(':grpc-testing'),
project(':grpc-netty'),
libraries.netty.codec.http2,
libraries.okhttp
signature "org.codehaus.mojo.signature:java17:1.0@signature"
signature "net.sf.androidscents.signature:android-api-level-14:4.0_r4@signature"

View File

@ -164,6 +164,13 @@ final class AsyncSink implements Sink {
serializingExecutor.execute(new Runnable() {
@Override
public void run() {
try {
if (buffer.size() > 0) {
sink.write(buffer, buffer.size());
}
} catch (IOException e) {
transportExceptionHandler.onException(e);
}
buffer.close();
try {
if (sink != null) {

View File

@ -0,0 +1,42 @@
/*
* Copyright 2022 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.okhttp;
import com.google.common.base.Preconditions;
import io.grpc.Attributes;
import io.grpc.InternalChannelz;
import java.io.IOException;
import java.net.Socket;
/** Handshakes new connections. */
interface HandshakerSocketFactory {
/** When the returned socket is closed, {@code socket} must be closed. */
HandshakeResult handshake(Socket socket, Attributes attributes) throws IOException;
static final class HandshakeResult {
public final Socket socket;
public final Attributes attributes;
public final InternalChannelz.Security securityInfo;
public HandshakeResult(
Socket socket, Attributes attributes, InternalChannelz.Security securityInfo) {
this.socket = Preconditions.checkNotNull(socket, "socket");
this.attributes = Preconditions.checkNotNull(attributes, "attributes");
this.securityInfo = securityInfo;
}
}
}

View File

@ -16,9 +16,6 @@
package io.grpc.okhttp;
import static io.grpc.internal.GrpcUtil.CONTENT_TYPE_KEY;
import static io.grpc.internal.GrpcUtil.USER_AGENT_KEY;
import com.google.common.base.Preconditions;
import io.grpc.InternalMetadata;
import io.grpc.Metadata;
@ -39,7 +36,7 @@ class Headers {
public static final Header METHOD_HEADER = new Header(Header.TARGET_METHOD, GrpcUtil.HTTP_METHOD);
public static final Header METHOD_GET_HEADER = new Header(Header.TARGET_METHOD, "GET");
public static final Header CONTENT_TYPE_HEADER =
new Header(CONTENT_TYPE_KEY.name(), GrpcUtil.CONTENT_TYPE_GRPC);
new Header(GrpcUtil.CONTENT_TYPE_KEY.name(), GrpcUtil.CONTENT_TYPE_GRPC);
public static final Header TE_HEADER = new Header("te", GrpcUtil.TE_TRAILERS);
/**
@ -58,10 +55,7 @@ class Headers {
Preconditions.checkNotNull(defaultPath, "defaultPath");
Preconditions.checkNotNull(authority, "authority");
// Discard any application supplied duplicates of the reserved headers
headers.discardAll(GrpcUtil.CONTENT_TYPE_KEY);
headers.discardAll(GrpcUtil.TE_HEADER);
headers.discardAll(GrpcUtil.USER_AGENT_KEY);
stripNonApplicationHeaders(headers);
// 7 is the number of explicit add calls below.
List<Header> okhttpHeaders = new ArrayList<>(7 + InternalMetadata.headerCount(headers));
@ -89,27 +83,72 @@ class Headers {
okhttpHeaders.add(TE_HEADER);
// Now add any application-provided headers.
byte[][] serializedHeaders = TransportFrameUtil.toHttp2Headers(headers);
for (int i = 0; i < serializedHeaders.length; i += 2) {
ByteString key = ByteString.of(serializedHeaders[i]);
String keyString = key.utf8();
if (isApplicationHeader(keyString)) {
ByteString value = ByteString.of(serializedHeaders[i + 1]);
okhttpHeaders.add(new Header(key, value));
}
}
return okhttpHeaders;
return addMetadata(okhttpHeaders, headers);
}
/**
* Returns {@code true} if the given header is an application-provided header. Otherwise, returns
* {@code false} if the header is reserved by GRPC.
* Serializes the given headers and creates a list of OkHttp {@link Header}s to be used when
* starting a response. Since this serializes the headers, this method should be called in the
* application thread context.
*/
private static boolean isApplicationHeader(String key) {
// Don't allow HTTP/2 pseudo headers or content-type to be added by the application.
return (!key.startsWith(":")
&& !CONTENT_TYPE_KEY.name().equalsIgnoreCase(key))
&& !USER_AGENT_KEY.name().equalsIgnoreCase(key);
public static List<Header> createResponseHeaders(Metadata headers) {
stripNonApplicationHeaders(headers);
// 2 is the number of explicit add calls below.
List<Header> okhttpHeaders = new ArrayList<>(2 + InternalMetadata.headerCount(headers));
okhttpHeaders.add(new Header(Header.RESPONSE_STATUS, "200"));
// All non-pseudo headers must come after pseudo headers.
okhttpHeaders.add(CONTENT_TYPE_HEADER);
return addMetadata(okhttpHeaders, headers);
}
/**
* Serializes the given headers and creates a list of OkHttp {@link Header}s to be used when
* finishing a response. Since this serializes the headers, this method should be called in the
* application thread context.
*/
public static List<Header> createResponseTrailers(Metadata trailers, boolean headersSent) {
if (!headersSent) {
return createResponseHeaders(trailers);
}
stripNonApplicationHeaders(trailers);
List<Header> okhttpTrailers = new ArrayList<>(InternalMetadata.headerCount(trailers));
return addMetadata(okhttpTrailers, trailers);
}
/**
* Serializes the given headers and creates a list of OkHttp {@link Header}s to be used when
* failing with an HTTP response.
*/
public static List<Header> createHttpResponseHeaders(
int httpCode, String contentType, Metadata headers) {
// 2 is the number of explicit add calls below.
List<Header> okhttpHeaders = new ArrayList<>(2 + InternalMetadata.headerCount(headers));
okhttpHeaders.add(new Header(Header.RESPONSE_STATUS, "" + httpCode));
// All non-pseudo headers must come after pseudo headers.
okhttpHeaders.add(new Header(GrpcUtil.CONTENT_TYPE_KEY.name(), contentType));
return addMetadata(okhttpHeaders, headers);
}
private static List<Header> addMetadata(List<Header> okhttpHeaders, Metadata toAdd) {
byte[][] serializedHeaders = TransportFrameUtil.toHttp2Headers(toAdd);
for (int i = 0; i < serializedHeaders.length; i += 2) {
ByteString key = ByteString.of(serializedHeaders[i]);
// Don't allow HTTP/2 pseudo headers to be added by the application.
if (key.size() == 0 || key.getByte(0) == ':') {
continue;
}
ByteString value = ByteString.of(serializedHeaders[i + 1]);
okhttpHeaders.add(new Header(key, value));
}
return okhttpHeaders;
}
/** Strips all non-pseudo headers reserved by gRPC, to avoid duplicates and misinterpretation. */
private static void stripNonApplicationHeaders(Metadata headers) {
headers.discardAll(GrpcUtil.CONTENT_TYPE_KEY);
headers.discardAll(GrpcUtil.TE_HEADER);
headers.discardAll(GrpcUtil.USER_AGENT_KEY);
}
}

View File

@ -139,7 +139,7 @@ public final class OkHttpChannelBuilder extends
((ExecutorService) executor).shutdown();
}
};
private static final ObjectPool<Executor> DEFAULT_TRANSPORT_EXECUTOR_POOL =
static final ObjectPool<Executor> DEFAULT_TRANSPORT_EXECUTOR_POOL =
SharedResourcePool.forResource(SHARED_EXECUTOR);
/** Creates a new builder for the given server host and port. */

View File

@ -53,8 +53,6 @@ class OkHttpClientStream extends AbstractClientStream {
private final String userAgent;
private final StatsTraceContext statsTraceCtx;
private String authority;
private Object outboundFlowState;
private volatile int id = ABSENT_ID;
private final TransportState state;
private final Sink sink = new Sink();
private final Attributes attributes;
@ -120,10 +118,6 @@ class OkHttpClientStream extends AbstractClientStream {
return method.getType();
}
public int id() {
return id;
}
/**
* Returns whether the stream uses GET. This is not known until after {@link Sink#writeHeaders} is
* invoked.
@ -198,7 +192,8 @@ class OkHttpClientStream extends AbstractClientStream {
}
}
class TransportState extends Http2ClientStreamTransportState {
class TransportState extends Http2ClientStreamTransportState
implements OutboundFlowController.Stream {
private final int initialWindowSize;
private final Object lock;
@GuardedBy("lock")
@ -223,6 +218,9 @@ class OkHttpClientStream extends AbstractClientStream {
@GuardedBy("lock")
private boolean canStart = true;
private final Tag tag;
@GuardedBy("lock")
private OutboundFlowController.StreamState outboundFlowState;
private int id = ABSENT_ID;
public TransportState(
int maxMessageSize,
@ -249,6 +247,7 @@ class OkHttpClientStream extends AbstractClientStream {
public void start(int streamId) {
checkState(id == ABSENT_ID, "the stream has been started with id %s", streamId);
id = streamId;
outboundFlowState = outboundFlow.createState(this, streamId);
// TODO(b/145386688): This access should be guarded by 'OkHttpClientStream.this.state.lock';
// instead found: 'this.lock'
state.onStreamAllocated();
@ -260,7 +259,9 @@ class OkHttpClientStream extends AbstractClientStream {
requestHeaders = null;
if (pendingData.size() > 0) {
outboundFlow.data(pendingDataHasEndOfStream, id, pendingData, flushPendingData);
outboundFlow.data(
pendingDataHasEndOfStream, outboundFlowState, pendingData, flushPendingData);
}
canStart = false;
}
@ -396,7 +397,7 @@ class OkHttpClientStream extends AbstractClientStream {
checkState(id() != ABSENT_ID, "streamId should be set");
// If buffer > frameWriter.maxDataLength() the flow-controller will ensure that it is
// properly chunked.
outboundFlow.data(endOfStream, id(), buffer, flush);
outboundFlow.data(endOfStream, outboundFlowState, buffer, flush);
}
}
@ -419,13 +420,15 @@ class OkHttpClientStream extends AbstractClientStream {
Tag tag() {
return tag;
}
}
void setOutboundFlowState(Object outboundFlowState) {
this.outboundFlowState = outboundFlowState;
}
int id() {
return id;
}
Object getOutboundFlowState() {
return outboundFlowState;
OutboundFlowController.StreamState getOutboundFlowState() {
synchronized (lock) {
return outboundFlowState;
}
}
}
}

View File

@ -105,10 +105,10 @@ import okio.Timeout;
/**
* A okhttp-based {@link ConnectionClientTransport} implementation.
*/
class OkHttpClientTransport implements ConnectionClientTransport, TransportExceptionHandler {
class OkHttpClientTransport implements ConnectionClientTransport, TransportExceptionHandler,
OutboundFlowController.Transport {
private static final Map<ErrorCode, Status> ERROR_CODE_TO_STATUS = buildErrorCodeToStatusMap();
private static final Logger log = Logger.getLogger(OkHttpClientTransport.class.getName());
private static final OkHttpClientStream[] EMPTY_STREAM_ARRAY = new OkHttpClientStream[0];
private static Map<ErrorCode, Status> buildErrorCodeToStatusMap() {
Map<ErrorCode, Status> errorToStatus = new EnumMap<>(ErrorCode.class);
@ -424,7 +424,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
@GuardedBy("lock")
private void startStream(OkHttpClientStream stream) {
Preconditions.checkState(
stream.id() == OkHttpClientStream.ABSENT_ID, "StreamId already assigned");
stream.transportState().id() == OkHttpClientStream.ABSENT_ID, "StreamId already assigned");
streams.put(nextStreamId, stream);
setInUse(stream);
// TODO(b/145386688): This access should be guarded by 'stream.transportState().lock'; instead
@ -808,9 +808,16 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
/**
* Gets all active streams as an array.
*/
OkHttpClientStream[] getActiveStreams() {
@Override
public OutboundFlowController.StreamState[] getActiveStreams() {
synchronized (lock) {
return streams.values().toArray(EMPTY_STREAM_ARRAY);
OutboundFlowController.StreamState[] flowStreams =
new OutboundFlowController.StreamState[streams.size()];
int i = 0;
for (OkHttpClientStream stream : streams.values()) {
flowStreams[i++] = stream.transportState().getOutboundFlowState();
}
return flowStreams;
}
}
@ -1125,7 +1132,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
if (stream == null) {
if (mayHaveCreatedStream(streamId)) {
synchronized (lock) {
frameWriter.rstStream(streamId, ErrorCode.INVALID_STREAM);
frameWriter.rstStream(streamId, ErrorCode.STREAM_CLOSED);
}
in.skip(length);
} else {
@ -1186,7 +1193,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
OkHttpClientStream stream = streams.get(streamId);
if (stream == null) {
if (mayHaveCreatedStream(streamId)) {
frameWriter.rstStream(streamId, ErrorCode.INVALID_STREAM);
frameWriter.rstStream(streamId, ErrorCode.STREAM_CLOSED);
} else {
unknownStream = true;
}
@ -1365,7 +1372,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
OkHttpClientStream stream = streams.get(streamId);
if (stream != null) {
outboundFlow.windowUpdate(stream, (int) delta);
outboundFlow.windowUpdate(stream.transportState().getOutboundFlowState(), (int) delta);
} else if (!mayHaveCreatedStream(streamId)) {
unknownStream = true;
}

View File

@ -0,0 +1,189 @@
/*
* Copyright 2022 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.okhttp;
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import io.grpc.InternalChannelz;
import io.grpc.InternalInstrumented;
import io.grpc.InternalLogId;
import io.grpc.ServerStreamTracer;
import io.grpc.internal.InternalServer;
import io.grpc.internal.ObjectPool;
import io.grpc.internal.ServerListener;
import java.io.IOException;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketAddress;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.net.ServerSocketFactory;
final class OkHttpServer implements InternalServer {
private static final Logger log = Logger.getLogger(OkHttpServer.class.getName());
private final SocketAddress originalListenAddress;
private final ServerSocketFactory socketFactory;
private final ObjectPool<Executor> transportExecutorPool;
private final ObjectPool<ScheduledExecutorService> scheduledExecutorServicePool;
private final OkHttpServerTransport.Config transportConfig;
private final InternalChannelz channelz;
private ServerSocket serverSocket;
private SocketAddress actualListenAddress;
private InternalInstrumented<InternalChannelz.SocketStats> listenInstrumented;
private Executor transportExecutor;
private ScheduledExecutorService scheduledExecutorService;
private ServerListener listener;
private boolean shutdown;
public OkHttpServer(
OkHttpServerBuilder builder,
List<? extends ServerStreamTracer.Factory> streamTracerFactories,
InternalChannelz channelz) {
this.originalListenAddress = Preconditions.checkNotNull(builder.listenAddress, "listenAddress");
this.socketFactory = Preconditions.checkNotNull(builder.socketFactory, "socketFactory");
this.transportExecutorPool =
Preconditions.checkNotNull(builder.transportExecutorPool, "transportExecutorPool");
this.scheduledExecutorServicePool =
Preconditions.checkNotNull(
builder.scheduledExecutorServicePool, "scheduledExecutorServicePool");
this.transportConfig = new OkHttpServerTransport.Config(builder, streamTracerFactories);
this.channelz = Preconditions.checkNotNull(channelz, "channelz");
}
@Override
public void start(ServerListener listener) throws IOException {
this.listener = Preconditions.checkNotNull(listener, "listener");
ServerSocket serverSocket = socketFactory.createServerSocket();
try {
serverSocket.bind(originalListenAddress);
} catch (IOException t) {
serverSocket.close();
throw t;
}
this.serverSocket = serverSocket;
this.actualListenAddress = serverSocket.getLocalSocketAddress();
this.listenInstrumented = new ListenSocket(serverSocket);
this.transportExecutor = transportExecutorPool.getObject();
// Keep reference alive to avoid frequent re-creation by server transports
this.scheduledExecutorService = scheduledExecutorServicePool.getObject();
channelz.addListenSocket(this.listenInstrumented);
transportExecutor.execute(this::acceptConnections);
}
private void acceptConnections() {
try {
while (true) {
Socket socket;
try {
socket = serverSocket.accept();
} catch (IOException ex) {
if (shutdown) {
break;
}
throw ex;
}
OkHttpServerTransport transport = new OkHttpServerTransport(transportConfig, socket);
transport.start(listener.transportCreated(transport));
}
} catch (Throwable t) {
log.log(Level.SEVERE, "Accept loop failed", t);
}
listener.serverShutdown();
}
@Override
public void shutdown() {
if (shutdown) {
return;
}
shutdown = true;
if (serverSocket == null) {
return;
}
channelz.removeListenSocket(this.listenInstrumented);
try {
serverSocket.close();
} catch (IOException ex) {
log.log(Level.WARNING, "Failed closing server socket", serverSocket);
}
transportExecutor = transportExecutorPool.returnObject(transportExecutor);
scheduledExecutorService = scheduledExecutorServicePool.returnObject(scheduledExecutorService);
}
@Override
public SocketAddress getListenSocketAddress() {
return actualListenAddress;
}
@Override
public InternalInstrumented<InternalChannelz.SocketStats> getListenSocketStats() {
return listenInstrumented;
}
@Override
public List<? extends SocketAddress> getListenSocketAddresses() {
return Collections.singletonList(getListenSocketAddress());
}
@Override
public List<InternalInstrumented<InternalChannelz.SocketStats>> getListenSocketStatsList() {
return Collections.singletonList(getListenSocketStats());
}
private static final class ListenSocket
implements InternalInstrumented<InternalChannelz.SocketStats> {
private final InternalLogId id;
private final ServerSocket socket;
public ListenSocket(ServerSocket socket) {
this.socket = socket;
this.id = InternalLogId.allocate(getClass(), String.valueOf(socket.getLocalSocketAddress()));
}
@Override
public ListenableFuture<InternalChannelz.SocketStats> getStats() {
return Futures.immediateFuture(new InternalChannelz.SocketStats(
/*data=*/ null,
socket.getLocalSocketAddress(),
/*remote=*/ null,
new InternalChannelz.SocketOptions.Builder().build(),
/*security=*/ null));
}
@Override
public InternalLogId getLogId() {
return id;
}
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("logId", id.getId())
.add("socket", socket)
.toString();
}
}
}

View File

@ -0,0 +1,387 @@
/*
* Copyright 2022 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.okhttp;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.errorprone.annotations.DoNotCall;
import io.grpc.ChoiceServerCredentials;
import io.grpc.ExperimentalApi;
import io.grpc.ForwardingServerBuilder;
import io.grpc.InsecureServerCredentials;
import io.grpc.Internal;
import io.grpc.ServerBuilder;
import io.grpc.ServerCredentials;
import io.grpc.ServerStreamTracer;
import io.grpc.TlsServerCredentials;
import io.grpc.internal.FixedObjectPool;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.InternalServer;
import io.grpc.internal.KeepAliveManager;
import io.grpc.internal.ObjectPool;
import io.grpc.internal.ServerImplBuilder;
import io.grpc.internal.SharedResourcePool;
import io.grpc.internal.TransportTracer;
import io.grpc.okhttp.internal.Platform;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.util.EnumSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.net.ServerSocketFactory;
import javax.net.ssl.KeyManager;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.security.auth.x500.X500Principal;
/**
* Build servers with the OkHttp transport.
*
* @since 1.49.0
*/
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/1785")
public final class OkHttpServerBuilder extends ForwardingServerBuilder<OkHttpServerBuilder> {
private static final Logger log = Logger.getLogger(OkHttpServerBuilder.class.getName());
private static final int DEFAULT_FLOW_CONTROL_WINDOW = 65535;
private static final long AS_LARGE_AS_INFINITE = TimeUnit.DAYS.toNanos(1000L);
private static final ObjectPool<Executor> DEFAULT_TRANSPORT_EXECUTOR_POOL =
OkHttpChannelBuilder.DEFAULT_TRANSPORT_EXECUTOR_POOL;
/**
* Always throws, to shadow {@code ServerBuilder.forPort()}.
*
* @deprecated Use {@link #forPort(int, ServerCredentials)} instead
*/
@DoNotCall("Always throws. Use forPort(int, ServerCredentials) instead")
@Deprecated
public static OkHttpServerBuilder forPort(int port) {
throw new UnsupportedOperationException();
}
/**
* Creates a builder for a server listening on {@code port}.
*/
public static OkHttpServerBuilder forPort(int port, ServerCredentials creds) {
return forPort(new InetSocketAddress(port), creds);
}
/**
* Creates a builder for a server listening on {@code address}.
*/
public static OkHttpServerBuilder forPort(SocketAddress address, ServerCredentials creds) {
HandshakerSocketFactoryResult result = handshakerSocketFactoryFrom(creds);
if (result.error != null) {
throw new IllegalArgumentException(result.error);
}
return new OkHttpServerBuilder(address, result.factory);
}
final ServerImplBuilder serverImplBuilder = new ServerImplBuilder(this::buildTransportServers);
final SocketAddress listenAddress;
final HandshakerSocketFactory handshakerSocketFactory;
TransportTracer.Factory transportTracerFactory = TransportTracer.getDefaultFactory();
ObjectPool<Executor> transportExecutorPool = DEFAULT_TRANSPORT_EXECUTOR_POOL;
ObjectPool<ScheduledExecutorService> scheduledExecutorServicePool =
SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE);
ServerSocketFactory socketFactory = ServerSocketFactory.getDefault();
long keepAliveTimeNanos = GrpcUtil.DEFAULT_SERVER_KEEPALIVE_TIME_NANOS;
long keepAliveTimeoutNanos = GrpcUtil.DEFAULT_SERVER_KEEPALIVE_TIMEOUT_NANOS;
int flowControlWindow = DEFAULT_FLOW_CONTROL_WINDOW;
int maxInboundMetadataSize = GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE;
int maxInboundMessageSize = GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
@VisibleForTesting
OkHttpServerBuilder(
SocketAddress address, HandshakerSocketFactory handshakerSocketFactory) {
this.listenAddress = Preconditions.checkNotNull(address, "address");
this.handshakerSocketFactory =
Preconditions.checkNotNull(handshakerSocketFactory, "handshakerSocketFactory");
}
@Internal
@Override
protected ServerBuilder<?> delegate() {
return serverImplBuilder;
}
@VisibleForTesting
OkHttpServerBuilder setTransportTracerFactory(TransportTracer.Factory transportTracerFactory) {
this.transportTracerFactory = transportTracerFactory;
return this;
}
/**
* Override the default executor necessary for internal transport use.
*
* <p>The channel does not take ownership of the given executor. It is the caller' responsibility
* to shutdown the executor when appropriate.
*/
public OkHttpServerBuilder transportExecutor(Executor transportExecutor) {
if (transportExecutor == null) {
this.transportExecutorPool = DEFAULT_TRANSPORT_EXECUTOR_POOL;
} else {
this.transportExecutorPool = new FixedObjectPool<>(transportExecutor);
}
return this;
}
/**
* Override the default {@link ServerSocketFactory} used to listen. If the socket factory is not
* set or set to null, a default one will be used.
*/
public OkHttpServerBuilder socketFactory(ServerSocketFactory socketFactory) {
if (socketFactory == null) {
this.socketFactory = ServerSocketFactory.getDefault();
} else {
this.socketFactory = socketFactory;
}
return this;
}
/**
* Sets the time without read activity before sending a keepalive ping. An unreasonably small
* value might be increased, and {@code Long.MAX_VALUE} nano seconds or an unreasonably large
* value will disable keepalive. Defaults to two hours.
*
* @throws IllegalArgumentException if time is not positive
*/
@Override
public OkHttpServerBuilder keepAliveTime(long keepAliveTime, TimeUnit timeUnit) {
Preconditions.checkArgument(keepAliveTime > 0L, "keepalive time must be positive");
keepAliveTimeNanos = timeUnit.toNanos(keepAliveTime);
keepAliveTimeNanos = KeepAliveManager.clampKeepAliveTimeInNanos(keepAliveTimeNanos);
if (keepAliveTimeNanos >= AS_LARGE_AS_INFINITE) {
// Bump keepalive time to infinite. This disables keepalive.
keepAliveTimeNanos = GrpcUtil.KEEPALIVE_TIME_NANOS_DISABLED;
}
return this;
}
/**
* Sets a time waiting for read activity after sending a keepalive ping. If the time expires
* without any read activity on the connection, the connection is considered dead. An unreasonably
* small value might be increased. Defaults to 20 seconds.
*
* <p>This value should be at least multiple times the RTT to allow for lost packets.
*
* @throws IllegalArgumentException if timeout is not positive
*/
@Override
public OkHttpServerBuilder keepAliveTimeout(long keepAliveTimeout, TimeUnit timeUnit) {
Preconditions.checkArgument(keepAliveTimeout > 0L, "keepalive timeout must be positive");
keepAliveTimeoutNanos = timeUnit.toNanos(keepAliveTimeout);
keepAliveTimeoutNanos = KeepAliveManager.clampKeepAliveTimeoutInNanos(keepAliveTimeoutNanos);
return this;
}
/**
* Sets the flow control window in bytes. If not called, the default value is 64 KiB.
*/
public OkHttpServerBuilder flowControlWindow(int flowControlWindow) {
Preconditions.checkState(flowControlWindow > 0, "flowControlWindow must be positive");
this.flowControlWindow = flowControlWindow;
return this;
}
/**
* Provides a custom scheduled executor service.
*
* <p>It's an optional parameter. If the user has not provided a scheduled executor service when
* the channel is built, the builder will use a static thread pool.
*
* @return this
*/
public OkHttpServerBuilder scheduledExecutorService(
ScheduledExecutorService scheduledExecutorService) {
this.scheduledExecutorServicePool = new FixedObjectPool<>(
Preconditions.checkNotNull(scheduledExecutorService, "scheduledExecutorService"));
return this;
}
/**
* Sets the maximum size of metadata allowed to be received. Defaults to 8 KiB.
*
* <p>The implementation does not currently limit memory usage; this value is checked only after
* the metadata is decoded from the wire. It does prevent large metadata from being passed to the
* application.
*
* @param bytes the maximum size of received metadata
* @return this
* @throws IllegalArgumentException if bytes is non-positive
*/
@Override
public OkHttpServerBuilder maxInboundMetadataSize(int bytes) {
Preconditions.checkArgument(bytes > 0, "maxInboundMetadataSize must be > 0");
this.maxInboundMetadataSize = bytes;
return this;
}
/**
* Sets the maximum message size allowed to be received on the server. If not called, defaults to
* defaults to 4 MiB. The default provides protection to servers who haven't considered the
* possibility of receiving large messages while trying to be large enough to not be hit in normal
* usage.
*
* @param bytes the maximum number of bytes a single message can be.
* @return this
* @throws IllegalArgumentException if bytes is negative.
*/
@Override
public OkHttpServerBuilder maxInboundMessageSize(int bytes) {
Preconditions.checkArgument(bytes >= 0, "negative max bytes");
maxInboundMessageSize = bytes;
return this;
}
void setStatsEnabled(boolean value) {
this.serverImplBuilder.setStatsEnabled(value);
}
InternalServer buildTransportServers(
List<? extends ServerStreamTracer.Factory> streamTracerFactories) {
return new OkHttpServer(this, streamTracerFactories, serverImplBuilder.getChannelz());
}
private static final EnumSet<TlsServerCredentials.Feature> understoodTlsFeatures =
EnumSet.of(
TlsServerCredentials.Feature.MTLS, TlsServerCredentials.Feature.CUSTOM_MANAGERS);
static HandshakerSocketFactoryResult handshakerSocketFactoryFrom(ServerCredentials creds) {
if (creds instanceof TlsServerCredentials) {
TlsServerCredentials tlsCreds = (TlsServerCredentials) creds;
Set<TlsServerCredentials.Feature> incomprehensible =
tlsCreds.incomprehensible(understoodTlsFeatures);
if (!incomprehensible.isEmpty()) {
return HandshakerSocketFactoryResult.error(
"TLS features not understood: " + incomprehensible);
}
KeyManager[] km = null;
if (tlsCreds.getKeyManagers() != null) {
km = tlsCreds.getKeyManagers().toArray(new KeyManager[0]);
} else if (tlsCreds.getPrivateKey() != null) {
return HandshakerSocketFactoryResult.error(
"byte[]-based private key unsupported. Use KeyManager");
} // else don't have a client cert
TrustManager[] tm = null;
if (tlsCreds.getTrustManagers() != null) {
tm = tlsCreds.getTrustManagers().toArray(new TrustManager[0]);
} else if (tlsCreds.getRootCertificates() != null) {
try {
tm = createTrustManager(tlsCreds.getRootCertificates());
} catch (GeneralSecurityException gse) {
log.log(Level.FINE, "Exception loading root certificates from credential", gse);
return HandshakerSocketFactoryResult.error(
"Unable to load root certificates: " + gse.getMessage());
}
} // else use system default
SSLContext sslContext;
try {
sslContext = SSLContext.getInstance("TLS", Platform.get().getProvider());
sslContext.init(km, tm, null);
} catch (GeneralSecurityException gse) {
throw new RuntimeException("TLS Provider failure", gse);
}
return HandshakerSocketFactoryResult.factory(new TlsServerHandshakerSocketFactory(
new SslSocketFactoryServerCredentials.ServerCredentials(
sslContext.getSocketFactory())));
} else if (creds instanceof InsecureServerCredentials) {
return HandshakerSocketFactoryResult.factory(new PlaintextHandshakerSocketFactory());
} else if (creds instanceof SslSocketFactoryServerCredentials.ServerCredentials) {
SslSocketFactoryServerCredentials.ServerCredentials factoryCreds =
(SslSocketFactoryServerCredentials.ServerCredentials) creds;
return HandshakerSocketFactoryResult.factory(
new TlsServerHandshakerSocketFactory(factoryCreds));
} else if (creds instanceof ChoiceServerCredentials) {
ChoiceServerCredentials choiceCreds = (ChoiceServerCredentials) creds;
StringBuilder error = new StringBuilder();
for (ServerCredentials innerCreds : choiceCreds.getCredentialsList()) {
HandshakerSocketFactoryResult result = handshakerSocketFactoryFrom(innerCreds);
if (result.error == null) {
return result;
}
error.append(", ");
error.append(result.error);
}
return HandshakerSocketFactoryResult.error(error.substring(2));
} else {
return HandshakerSocketFactoryResult.error(
"Unsupported credential type: " + creds.getClass().getName());
}
}
static TrustManager[] createTrustManager(byte[] rootCerts) throws GeneralSecurityException {
KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
try {
ks.load(null, null);
} catch (IOException ex) {
// Shouldn't really happen, as we're not loading any data.
throw new GeneralSecurityException(ex);
}
CertificateFactory cf = CertificateFactory.getInstance("X.509");
ByteArrayInputStream in = new ByteArrayInputStream(rootCerts);
try {
X509Certificate cert = (X509Certificate) cf.generateCertificate(in);
X500Principal principal = cert.getSubjectX500Principal();
ks.setCertificateEntry(principal.getName("RFC2253"), cert);
} finally {
GrpcUtil.closeQuietly(in);
}
TrustManagerFactory trustManagerFactory =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
trustManagerFactory.init(ks);
return trustManagerFactory.getTrustManagers();
}
static final class HandshakerSocketFactoryResult {
public final HandshakerSocketFactory factory;
public final String error;
private HandshakerSocketFactoryResult(HandshakerSocketFactory factory, String error) {
this.factory = factory;
this.error = error;
}
public static HandshakerSocketFactoryResult error(String error) {
return new HandshakerSocketFactoryResult(
null, Preconditions.checkNotNull(error, "error"));
}
public static HandshakerSocketFactoryResult factory(HandshakerSocketFactory factory) {
return new HandshakerSocketFactoryResult(
Preconditions.checkNotNull(factory, "factory"), null);
}
}
}

View File

@ -0,0 +1,302 @@
/*
* Copyright 2022 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.okhttp;
import com.google.common.base.Preconditions;
import io.grpc.Attributes;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.internal.AbstractServerStream;
import io.grpc.internal.StatsTraceContext;
import io.grpc.internal.TransportTracer;
import io.grpc.internal.WritableBuffer;
import io.grpc.okhttp.internal.framed.ErrorCode;
import io.grpc.okhttp.internal.framed.Header;
import io.perfmark.PerfMark;
import io.perfmark.Tag;
import java.util.List;
import javax.annotation.concurrent.GuardedBy;
import okio.Buffer;
/**
* Server stream for the okhttp transport.
*/
class OkHttpServerStream extends AbstractServerStream {
private final String authority;
private final TransportState state;
private final Sink sink = new Sink();
private final TransportTracer transportTracer;
private final Attributes attributes;
public OkHttpServerStream(
TransportState state,
Attributes transportAttrs,
String authority,
StatsTraceContext statsTraceCtx,
TransportTracer transportTracer) {
super(new OkHttpWritableBufferAllocator(), statsTraceCtx);
this.state = Preconditions.checkNotNull(state, "state");
this.attributes = Preconditions.checkNotNull(transportAttrs, "transportAttrs");
this.authority = authority;
this.transportTracer = Preconditions.checkNotNull(transportTracer, "transportTracer");
}
@Override
protected TransportState transportState() {
return state;
}
@Override
protected Sink abstractServerStreamSink() {
return sink;
}
@Override
public int streamId() {
return state.streamId;
}
@Override
public String getAuthority() {
return authority;
}
@Override
public Attributes getAttributes() {
return attributes;
}
class Sink implements AbstractServerStream.Sink {
@Override
public void writeHeaders(Metadata metadata) {
PerfMark.startTask("OkHttpServerStream$Sink.writeHeaders");
try {
List<Header> responseHeaders = Headers.createResponseHeaders(metadata);
synchronized (state.lock) {
state.sendHeaders(responseHeaders);
}
} finally {
PerfMark.stopTask("OkHttpServerStream$Sink.writeHeaders");
}
}
@Override
public void writeFrame(WritableBuffer frame, boolean flush, int numMessages) {
PerfMark.startTask("OkHttpServerStream$Sink.writeFrame");
Buffer buffer = ((OkHttpWritableBuffer) frame).buffer();
int size = (int) buffer.size();
if (size > 0) {
onSendingBytes(size);
}
try {
synchronized (state.lock) {
state.sendBuffer(buffer, flush);
transportTracer.reportMessageSent(numMessages);
}
} finally {
PerfMark.stopTask("OkHttpServerStream$Sink.writeFrame");
}
}
@Override
public void writeTrailers(Metadata trailers, boolean headersSent, Status status) {
PerfMark.startTask("OkHttpServerStream$Sink.writeTrailers");
try {
List<Header> responseTrailers = Headers.createResponseTrailers(trailers, headersSent);
synchronized (state.lock) {
state.sendTrailers(responseTrailers);
}
} finally {
PerfMark.stopTask("OkHttpServerStream$Sink.writeTrailers");
}
}
@Override
public void cancel(Status reason) {
PerfMark.startTask("OkHttpServerStream$Sink.cancel");
try {
synchronized (state.lock) {
state.cancel(ErrorCode.CANCEL, reason);
}
} finally {
PerfMark.stopTask("OkHttpServerStream$Sink.cancel");
}
}
}
static class TransportState extends AbstractServerStream.TransportState
implements OutboundFlowController.Stream, OkHttpServerTransport.StreamState {
@GuardedBy("lock")
private final OkHttpServerTransport transport;
private final int streamId;
private final int initialWindowSize;
private final Object lock;
@GuardedBy("lock")
private boolean cancelSent = false;
@GuardedBy("lock")
private int window;
@GuardedBy("lock")
private int processedWindow;
@GuardedBy("lock")
private final ExceptionHandlingFrameWriter frameWriter;
@GuardedBy("lock")
private final OutboundFlowController outboundFlow;
@GuardedBy("lock")
private boolean receivedEndOfStream;
private final Tag tag;
private final OutboundFlowController.StreamState outboundFlowState;
public TransportState(
OkHttpServerTransport transport,
int streamId,
int maxMessageSize,
StatsTraceContext statsTraceCtx,
Object lock,
ExceptionHandlingFrameWriter frameWriter,
OutboundFlowController outboundFlow,
int initialWindowSize,
TransportTracer transportTracer,
String methodName) {
super(maxMessageSize, statsTraceCtx, transportTracer);
this.transport = Preconditions.checkNotNull(transport, "transport");
this.streamId = streamId;
this.lock = Preconditions.checkNotNull(lock, "lock");
this.frameWriter = frameWriter;
this.outboundFlow = outboundFlow;
this.window = initialWindowSize;
this.processedWindow = initialWindowSize;
this.initialWindowSize = initialWindowSize;
tag = PerfMark.createTag(methodName);
outboundFlowState = outboundFlow.createState(this, streamId);
}
@Override
@GuardedBy("lock")
public void deframeFailed(Throwable cause) {
cancel(ErrorCode.INTERNAL_ERROR, Status.fromThrowable(cause));
}
@Override
@GuardedBy("lock")
public void bytesRead(int processedBytes) {
processedWindow -= processedBytes;
if (processedWindow <= initialWindowSize * Utils.DEFAULT_WINDOW_UPDATE_RATIO) {
int delta = initialWindowSize - processedWindow;
window += delta;
processedWindow += delta;
frameWriter.windowUpdate(streamId, delta);
frameWriter.flush();
}
}
@Override
@GuardedBy("lock")
public void runOnTransportThread(final Runnable r) {
synchronized (lock) {
r.run();
}
}
/**
* Must be called with holding the transport lock.
*/
@Override
public void inboundDataReceived(okio.Buffer frame, int windowConsumed, boolean endOfStream) {
synchronized (lock) {
PerfMark.event("OkHttpServerTransport$FrameHandler.data", tag);
if (endOfStream) {
this.receivedEndOfStream = true;
}
window -= windowConsumed;
super.inboundDataReceived(new OkHttpReadableBuffer(frame), endOfStream);
}
}
/** Must be called with holding the transport lock. */
@Override
public void inboundRstReceived(Status status) {
PerfMark.event("OkHttpServerTransport$FrameHandler.rstStream", tag);
transportReportStatus(status);
}
/** Must be called with holding the transport lock. */
@Override
public boolean hasReceivedEndOfStream() {
synchronized (lock) {
return receivedEndOfStream;
}
}
/** Must be called with holding the transport lock. */
@Override
public int inboundWindowAvailable() {
synchronized (lock) {
return window;
}
}
@GuardedBy("lock")
private void sendBuffer(Buffer buffer, boolean flush) {
if (cancelSent) {
return;
}
// If buffer > frameWriter.maxDataLength() the flow-controller will ensure that it is
// properly chunked.
outboundFlow.data(false, outboundFlowState, buffer, flush);
}
@GuardedBy("lock")
private void sendHeaders(List<Header> responseHeaders) {
frameWriter.synReply(false, streamId, responseHeaders);
frameWriter.flush();
}
@GuardedBy("lock")
private void sendTrailers(List<Header> responseTrailers) {
outboundFlow.notifyWhenNoPendingData(
outboundFlowState, () -> sendTrailersAfterFlowControlled(responseTrailers));
}
private void sendTrailersAfterFlowControlled(List<Header> responseTrailers) {
synchronized (lock) {
frameWriter.synReply(true, streamId, responseTrailers);
if (!receivedEndOfStream) {
frameWriter.rstStream(streamId, ErrorCode.NO_ERROR);
}
transport.streamClosed(streamId, /*flush=*/ true);
complete();
}
}
@GuardedBy("lock")
private void cancel(ErrorCode http2Error, Status reason) {
if (cancelSent) {
return;
}
cancelSent = true;
frameWriter.rstStream(streamId, http2Error);
transportReportStatus(reason);
transport.streamClosed(streamId, /*flush=*/ true);
}
@Override
public OutboundFlowController.StreamState getOutboundFlowState() {
return outboundFlowState;
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -24,6 +24,8 @@ import io.grpc.okhttp.internal.framed.Settings;
class OkHttpSettingsUtil {
public static final int MAX_CONCURRENT_STREAMS = Settings.MAX_CONCURRENT_STREAMS;
public static final int INITIAL_WINDOW_SIZE = Settings.INITIAL_WINDOW_SIZE;
public static final int MAX_HEADER_LIST_SIZE = Settings.MAX_HEADER_LIST_SIZE;
public static final int ENABLE_PUSH = Settings.ENABLE_PUSH;
public static boolean isSet(Settings settings, int id) {
return settings.isSet(id);

View File

@ -33,17 +33,16 @@ import okio.Buffer;
* streams.
*/
class OutboundFlowController {
private final OkHttpClientTransport transport;
private final Transport transport;
private final FrameWriter frameWriter;
private int initialWindowSize;
private final OutboundFlowState connectionState;
private final StreamState connectionState;
OutboundFlowController(
OkHttpClientTransport transport, FrameWriter frameWriter) {
public OutboundFlowController(Transport transport, FrameWriter frameWriter) {
this.transport = Preconditions.checkNotNull(transport, "transport");
this.frameWriter = Preconditions.checkNotNull(frameWriter, "frameWriter");
this.initialWindowSize = DEFAULT_WINDOW_SIZE;
connectionState = new OutboundFlowState(CONNECTION_STREAM_ID, DEFAULT_WINDOW_SIZE);
connectionState = new StreamState(CONNECTION_STREAM_ID, DEFAULT_WINDOW_SIZE, null);
}
/**
@ -55,22 +54,15 @@ class OutboundFlowController {
*
* @return true, if new window size is increased, false otherwise.
*/
boolean initialOutboundWindowSize(int newWindowSize) {
public boolean initialOutboundWindowSize(int newWindowSize) {
if (newWindowSize < 0) {
throw new IllegalArgumentException("Invalid initial window size: " + newWindowSize);
}
int delta = newWindowSize - initialWindowSize;
initialWindowSize = newWindowSize;
for (OkHttpClientStream stream : transport.getActiveStreams()) {
OutboundFlowState state = (OutboundFlowState) stream.getOutboundFlowState();
if (state == null) {
// Create the OutboundFlowState with the new window size.
state = new OutboundFlowState(stream, initialWindowSize);
stream.setOutboundFlowState(state);
} else {
state.incrementStreamWindow(delta);
}
for (StreamState state : transport.getActiveStreams()) {
state.incrementStreamWindow(delta);
}
return delta > 0;
@ -82,15 +74,14 @@ class OutboundFlowController {
*
* <p>Must be called with holding transport lock.
*/
int windowUpdate(@Nullable OkHttpClientStream stream, int delta) {
public int windowUpdate(@Nullable StreamState state, int delta) {
final int updatedWindow;
if (stream == null) {
if (state == null) {
// Update the connection window and write any pending frames for all streams.
updatedWindow = connectionState.incrementStreamWindow(delta);
writeStreams();
} else {
// Update the stream window and write any pending frames for the stream.
OutboundFlowState state = state(stream);
updatedWindow = state.incrementStreamWindow(delta);
WriteStatus writeStatus = new WriteStatus();
@ -105,18 +96,9 @@ class OutboundFlowController {
/**
* Must be called with holding transport lock.
*/
void data(boolean outFinished, int streamId, Buffer source, boolean flush) {
public void data(boolean outFinished, StreamState state, Buffer source, boolean flush) {
Preconditions.checkNotNull(source, "source");
OkHttpClientStream stream = transport.getStream(streamId);
if (stream == null) {
// This is possible for a stream that has received end-of-stream from server (but hasn't sent
// end-of-stream), and was removed from the transport stream map.
// In such case, we just throw away the data.
return;
}
OutboundFlowState state = state(stream);
int window = state.writableWindow();
boolean framesAlreadyQueued = state.hasPendingData();
int size = (int) source.size();
@ -130,7 +112,7 @@ class OutboundFlowController {
state.write(source, window, false);
}
// Queue remaining data in the buffer
state.enqueue(source, (int) source.size(), outFinished);
state.enqueueData(source, (int) source.size(), outFinished);
}
if (flush) {
@ -138,7 +120,19 @@ class OutboundFlowController {
}
}
void flush() {
/**
* Transport lock must be held when calling.
*/
public void notifyWhenNoPendingData(StreamState state, Runnable noPendingDataRunnable) {
Preconditions.checkNotNull(noPendingDataRunnable, "noPendingDataRunnable");
if (state.hasPendingData()) {
state.notifyWhenNoPendingData(noPendingDataRunnable);
} else {
noPendingDataRunnable.run();
}
}
public void flush() {
try {
frameWriter.flush();
} catch (IOException e) {
@ -146,13 +140,9 @@ class OutboundFlowController {
}
}
private OutboundFlowState state(OkHttpClientStream stream) {
OutboundFlowState state = (OutboundFlowState) stream.getOutboundFlowState();
if (state == null) {
state = new OutboundFlowState(stream, initialWindowSize);
stream.setOutboundFlowState(state);
}
return state;
public StreamState createState(Stream stream, int streamId) {
return new StreamState(
streamId, initialWindowSize, Preconditions.checkNotNull(stream, "stream"));
}
/**
@ -160,15 +150,14 @@ class OutboundFlowController {
*
* <p>Must be called with holding transport lock.
*/
void writeStreams() {
OkHttpClientStream[] streams = transport.getActiveStreams();
public void writeStreams() {
StreamState[] states = transport.getActiveStreams();
int connectionWindow = connectionState.window();
for (int numStreams = streams.length; numStreams > 0 && connectionWindow > 0;) {
for (int numStreams = states.length; numStreams > 0 && connectionWindow > 0;) {
int nextNumStreams = 0;
int windowSlice = (int) ceil(connectionWindow / (float) numStreams);
for (int index = 0; index < numStreams && connectionWindow > 0; ++index) {
OkHttpClientStream stream = streams[index];
OutboundFlowState state = state(stream);
StreamState state = states[index];
int bytesForStream = min(connectionWindow, min(state.unallocatedBytes(), windowSlice));
if (bytesForStream > 0) {
@ -179,7 +168,7 @@ class OutboundFlowController {
if (state.unallocatedBytes() > 0) {
// There is more data to process for this stream. Add it to the next
// pass.
streams[nextNumStreams++] = stream;
states[nextNumStreams++] = state;
}
}
numStreams = nextNumStreams;
@ -187,8 +176,7 @@ class OutboundFlowController {
// Now take one last pass through all of the streams and write any allocated bytes.
WriteStatus writeStatus = new WriteStatus();
for (OkHttpClientStream stream : transport.getActiveStreams()) {
OutboundFlowState state = state(stream);
for (StreamState state : transport.getActiveStreams()) {
state.writeBytes(state.allocatedBytes(), writeStatus);
state.clearAllocatedBytes();
}
@ -213,25 +201,29 @@ class OutboundFlowController {
}
}
public interface Transport {
StreamState[] getActiveStreams();
}
public interface Stream {
void onSentBytes(int frameBytes);
}
/**
* The outbound flow control state for a single stream.
*/
private final class OutboundFlowState {
final Buffer pendingWriteBuffer;
final int streamId;
int window;
int allocatedBytes;
OkHttpClientStream stream;
boolean pendingBufferHasEndOfStream = false;
public final class StreamState {
private final Buffer pendingWriteBuffer = new Buffer();
private Runnable noPendingDataRunnable;
private final int streamId;
private int window;
private int allocatedBytes;
private final Stream stream;
private boolean pendingBufferHasEndOfStream = false;
OutboundFlowState(int streamId, int initialWindowSize) {
StreamState(int streamId, int initialWindowSize, Stream stream) {
this.streamId = streamId;
window = initialWindowSize;
pendingWriteBuffer = new Buffer();
}
OutboundFlowState(OkHttpClientStream stream, int initialWindowSize) {
this(stream.id(), initialWindowSize);
this.stream = stream;
}
@ -305,6 +297,10 @@ class OutboundFlowController {
// Update the threshold.
maxBytes = min(bytes - bytesAttempted, writableWindow());
}
if (!hasPendingData() && noPendingDataRunnable != null) {
noPendingDataRunnable.run();
noPendingDataRunnable = null;
}
return bytesAttempted;
}
@ -328,14 +324,20 @@ class OutboundFlowController {
} catch (IOException e) {
throw new RuntimeException(e);
}
stream.transportState().onSentBytes(frameBytes);
stream.onSentBytes(frameBytes);
bytesToWrite -= frameBytes;
} while (bytesToWrite > 0);
}
void enqueue(Buffer buffer, int size, boolean endOfStream) {
void enqueueData(Buffer buffer, int size, boolean endOfStream) {
this.pendingWriteBuffer.write(buffer, size);
this.pendingBufferHasEndOfStream |= endOfStream;
}
void notifyWhenNoPendingData(Runnable noPendingDataRunnable) {
Preconditions.checkState(
this.noPendingDataRunnable == null, "pending data notification already requested");
this.noPendingDataRunnable = noPendingDataRunnable;
}
}
}
}

View File

@ -0,0 +1,39 @@
/*
* Copyright 2022 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.okhttp;
import io.grpc.Attributes;
import io.grpc.Grpc;
import io.grpc.SecurityLevel;
import io.grpc.internal.GrpcAttributes;
import java.io.IOException;
import java.net.Socket;
/**
* No-thrills plaintext handshaker.
*/
final class PlaintextHandshakerSocketFactory implements HandshakerSocketFactory {
@Override
public HandshakeResult handshake(Socket socket, Attributes attributes) throws IOException {
attributes = attributes.toBuilder()
.set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, socket.getLocalSocketAddress())
.set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, socket.getRemoteSocketAddress())
.set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.NONE)
.build();
return new HandshakeResult(socket, attributes, null);
}
}

View File

@ -0,0 +1,60 @@
/*
* Copyright 2022 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.okhttp;
import com.google.common.base.Preconditions;
import io.grpc.ExperimentalApi;
import io.grpc.okhttp.internal.ConnectionSpec;
import javax.net.ssl.SSLSocketFactory;
/** A credential with full control over the SSLSocketFactory. */
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/1785")
public final class SslSocketFactoryServerCredentials {
private SslSocketFactoryServerCredentials() {}
public static io.grpc.ServerCredentials create(SSLSocketFactory factory) {
return new ServerCredentials(factory);
}
public static io.grpc.ServerCredentials create(
SSLSocketFactory factory, com.squareup.okhttp.ConnectionSpec connectionSpec) {
return new ServerCredentials(factory, Utils.convertSpec(connectionSpec));
}
// Hide implementation detail of how these credentials operate
static final class ServerCredentials extends io.grpc.ServerCredentials {
private final SSLSocketFactory factory;
private final ConnectionSpec connectionSpec;
ServerCredentials(SSLSocketFactory factory) {
this(factory, OkHttpChannelBuilder.INTERNAL_DEFAULT_CONNECTION_SPEC);
}
ServerCredentials(SSLSocketFactory factory, ConnectionSpec connectionSpec) {
this.factory = Preconditions.checkNotNull(factory, "factory");
this.connectionSpec = Preconditions.checkNotNull(connectionSpec, "connectionSpec");
}
public SSLSocketFactory getFactory() {
return factory;
}
public ConnectionSpec getConnectionSpec() {
return connectionSpec;
}
}
}

View File

@ -0,0 +1,72 @@
/*
* Copyright 2022 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.okhttp;
import io.grpc.Attributes;
import io.grpc.Grpc;
import io.grpc.InternalChannelz;
import io.grpc.SecurityLevel;
import io.grpc.internal.GrpcAttributes;
import io.grpc.okhttp.internal.ConnectionSpec;
import io.grpc.okhttp.internal.Protocol;
import java.io.IOException;
import java.net.Socket;
import java.util.Arrays;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
/**
* TLS handshaker.
*/
final class TlsServerHandshakerSocketFactory implements HandshakerSocketFactory {
private final PlaintextHandshakerSocketFactory delegate = new PlaintextHandshakerSocketFactory();
private final SSLSocketFactory socketFactory;
private final ConnectionSpec connectionSpec;
public TlsServerHandshakerSocketFactory(
SslSocketFactoryServerCredentials.ServerCredentials credentials) {
this.socketFactory = credentials.getFactory();
this.connectionSpec = credentials.getConnectionSpec();
}
@Override
public HandshakeResult handshake(Socket socket, Attributes attributes) throws IOException {
HandshakeResult result = delegate.handshake(socket, attributes);
socket = socketFactory.createSocket(result.socket, null, -1, true);
if (!(socket instanceof SSLSocket)) {
throw new IOException(
"SocketFactory " + socketFactory + " did not produce an SSLSocket: " + socket.getClass());
}
SSLSocket sslSocket = (SSLSocket) socket;
sslSocket.setUseClientMode(false);
connectionSpec.apply(sslSocket, false);
Protocol expectedProtocol = Protocol.HTTP_2;
String negotiatedProtocol = OkHttpProtocolNegotiator.get().negotiate(
sslSocket,
null,
connectionSpec.supportsTlsExtensions() ? Arrays.asList(expectedProtocol) : null);
if (!expectedProtocol.toString().equals(negotiatedProtocol)) {
throw new IOException("Expected NPN/ALPN " + expectedProtocol + ": " + negotiatedProtocol);
}
attributes = result.attributes.toBuilder()
.set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.PRIVACY_AND_INTEGRITY)
.set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSocket.getSession())
.build();
return new HandshakeResult(socket, attributes,
new InternalChannelz.Security(new InternalChannelz.Tls(sslSocket.getSession())));
}
}

File diff suppressed because it is too large Load Diff

View File

@ -16,6 +16,7 @@
package io.grpc.okhttp;
import io.grpc.InsecureServerCredentials;
import io.grpc.ServerStreamTracer;
import io.grpc.internal.AbstractTransportTest;
import io.grpc.internal.ClientTransportFactory;
@ -23,8 +24,6 @@ import io.grpc.internal.FakeClock;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.InternalServer;
import io.grpc.internal.ManagedClientTransport;
import io.grpc.netty.InternalNettyServerBuilder;
import io.grpc.netty.NettyServerBuilder;
import java.net.InetSocketAddress;
import java.util.List;
import java.util.concurrent.TimeUnit;
@ -53,21 +52,17 @@ public class OkHttpTransportTest extends AbstractTransportTest {
@Override
protected InternalServer newServer(
List<ServerStreamTracer.Factory> streamTracerFactories) {
NettyServerBuilder builder = NettyServerBuilder
.forPort(0)
.flowControlWindow(AbstractTransportTest.TEST_FLOW_CONTROL_WINDOW);
InternalNettyServerBuilder.setTransportTracerFactory(builder, fakeClockTransportTracer);
return InternalNettyServerBuilder.buildTransportServers(builder, streamTracerFactories);
return newServer(0, streamTracerFactories);
}
@Override
protected InternalServer newServer(
int port, List<ServerStreamTracer.Factory> streamTracerFactories) {
NettyServerBuilder builder = NettyServerBuilder
.forAddress(new InetSocketAddress(port))
.flowControlWindow(AbstractTransportTest.TEST_FLOW_CONTROL_WINDOW);
InternalNettyServerBuilder.setTransportTracerFactory(builder, fakeClockTransportTracer);
return InternalNettyServerBuilder.buildTransportServers(builder, streamTracerFactories);
return OkHttpServerBuilder
.forPort(port, InsecureServerCredentials.create())
.flowControlWindow(AbstractTransportTest.TEST_FLOW_CONTROL_WINDOW)
.setTransportTracerFactory(fakeClockTransportTracer)
.buildTransportServers(streamTracerFactories);
}
@Override
@ -100,11 +95,4 @@ public class OkHttpTransportTest extends AbstractTransportTest {
protected boolean haveTransportTracer() {
return true;
}
@Override
@org.junit.Test
@org.junit.Ignore
public void clientChecksInboundMetadataSize_trailer() {
// Server-side is flaky due to https://github.com/netty/netty/pull/8332
}
}

View File

@ -231,6 +231,7 @@ public final class Http2 implements Variant {
short padding = (flags & FLAG_PADDED) != 0 ? (short) (source.readByte() & 0xff) : 0;
length = lengthWithoutPadding(length, flags, padding);
// FIXME: pass padding length to handler because it should be included for flow control
handler.data(inFinished, streamId, source, length);
source.skip(padding);
}

View File

@ -46,7 +46,7 @@ public final class Settings {
/** spdy/3: Sender's estimate of max outgoing kbps. */
static final int DOWNLOAD_BANDWIDTH = 2;
/** HTTP/2: The peer must not send a PUSH_PROMISE frame when this is 0. */
static final int ENABLE_PUSH = 2;
public static final int ENABLE_PUSH = 2;
/** spdy/3: Sender's estimate of millis between sending a request and receiving a response. */
static final int ROUND_TRIP_TIME = 3;
/** Sender's maximum number of concurrent streams. */
@ -58,7 +58,7 @@ public final class Settings {
/** spdy/3: Retransmission rate. Percentage */
static final int DOWNLOAD_RETRANS_RATE = 6;
/** HTTP/2: Advisory only. Size in bytes of the largest header list the sender will accept. */
static final int MAX_HEADER_LIST_SIZE = 6;
public static final int MAX_HEADER_LIST_SIZE = 6;
/** Window size in bytes. */
public static final int INITIAL_WINDOW_SIZE = 7;
/** spdy/3: Size of the client certificate vector. Unsupported. */