netty: Netty server poorly handles unknown content type

This commit is contained in:
ramaraochavali 2017-11-16 01:02:50 +05:30 committed by Eric Anderson
parent 66f9ef5d69
commit df357cb8d3
2 changed files with 131 additions and 39 deletions

View File

@ -32,6 +32,8 @@ import static io.netty.handler.codec.http2.DefaultHttp2LocalFlowController.DEFAU
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import io.grpc.Attributes; import io.grpc.Attributes;
import io.grpc.InternalMetadata;
import io.grpc.InternalStatus;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.ServerStreamTracer; import io.grpc.ServerStreamTracer;
import io.grpc.Status; import io.grpc.Status;
@ -54,6 +56,7 @@ import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder;
import io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder; import io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder;
import io.netty.handler.codec.http2.DefaultHttp2FrameReader; import io.netty.handler.codec.http2.DefaultHttp2FrameReader;
import io.netty.handler.codec.http2.DefaultHttp2FrameWriter; import io.netty.handler.codec.http2.DefaultHttp2FrameWriter;
import io.netty.handler.codec.http2.DefaultHttp2Headers;
import io.netty.handler.codec.http2.DefaultHttp2LocalFlowController; import io.netty.handler.codec.http2.DefaultHttp2LocalFlowController;
import io.netty.handler.codec.http2.DefaultHttp2RemoteFlowController; import io.netty.handler.codec.http2.DefaultHttp2RemoteFlowController;
import io.netty.handler.codec.http2.Http2Connection; import io.netty.handler.codec.http2.Http2Connection;
@ -375,15 +378,49 @@ class NettyServerHandler extends AbstractNettyHandler {
throws Http2Exception { throws Http2Exception {
if (!teWarningLogged && !TE_TRAILERS.equals(headers.get(TE_HEADER))) { if (!teWarningLogged && !TE_TRAILERS.equals(headers.get(TE_HEADER))) {
logger.warning(String.format("Expected header TE: %s, but %s is received. This means " logger.warning(String.format("Expected header TE: %s, but %s is received. This means "
+ "some intermediate proxy may not support trailers", + "some intermediate proxy may not support trailers",
TE_TRAILERS, headers.get(TE_HEADER))); TE_TRAILERS, headers.get(TE_HEADER)));
teWarningLogged = true; teWarningLogged = true;
} }
try { try {
// Remove the leading slash of the path and get the fully qualified method name
CharSequence path = headers.path();
if (path == null) {
respondWithHttpError(ctx, streamId, 404, Status.Code.UNIMPLEMENTED,
"Expected path but is missing");
return;
}
if (path.charAt(0) != '/') {
respondWithHttpError(ctx, streamId, 404, Status.Code.UNIMPLEMENTED,
String.format("Expected path to start with /: %s", path));
return;
}
String method = path.subSequence(1, path.length()).toString();
// Verify that the Content-Type is correct in the request. // Verify that the Content-Type is correct in the request.
verifyContentType(streamId, headers); CharSequence contentType = headers.get(CONTENT_TYPE_HEADER);
String method = determineMethod(streamId, headers); if (contentType == null) {
respondWithHttpError(
ctx, streamId, 415, Status.Code.INTERNAL, "Content-Type is missing from the request");
return;
}
String contentTypeString = contentType.toString();
if (!GrpcUtil.isGrpcContentType(contentTypeString)) {
respondWithHttpError(ctx, streamId, 415, Status.Code.INTERNAL,
String.format("Content-Type '%s' is not supported", contentTypeString));
return;
}
if (!HTTP_METHOD.equals(headers.method())) {
respondWithHttpError(ctx, streamId, 405, Status.Code.INTERNAL,
String.format("Method '%s' is not supported", headers.method()));
return;
}
// The Http2Stream object was put by AbstractHttp2ConnectionHandler before calling this // The Http2Stream object was put by AbstractHttp2ConnectionHandler before calling this
// method. // method.
@ -400,7 +437,7 @@ class NettyServerHandler extends AbstractNettyHandler {
maxMessageSize, maxMessageSize,
statsTraceCtx, statsTraceCtx,
transportTracer); transportTracer);
String authority = getOrUpdateAuthority((AsciiString)headers.authority()); String authority = getOrUpdateAuthority((AsciiString) headers.authority());
NettyServerStream stream = new NettyServerStream( NettyServerStream stream = new NettyServerStream(
ctx.channel(), ctx.channel(),
state, state,
@ -411,10 +448,7 @@ class NettyServerHandler extends AbstractNettyHandler {
transportListener.streamCreated(stream, method, metadata); transportListener.streamCreated(stream, method, metadata);
state.onStreamAllocated(); state.onStreamAllocated();
http2Stream.setProperty(streamKey, state); http2Stream.setProperty(streamKey, state);
} catch (Exception e) {
} catch (Http2Exception e) {
throw e;
} catch (Throwable e) {
logger.log(Level.WARNING, "Exception in onHeadersRead()", e); logger.log(Level.WARNING, "Exception in onHeadersRead()", e);
// Throw an exception that will get handled by onStreamError. // Throw an exception that will get handled by onStreamError.
throw newStreamException(streamId, e); throw newStreamException(streamId, e);
@ -634,17 +668,22 @@ class NettyServerHandler extends AbstractNettyHandler {
}); });
} }
private void verifyContentType(int streamId, Http2Headers headers) throws Http2Exception { private void respondWithHttpError(
CharSequence contentType = headers.get(CONTENT_TYPE_HEADER); ChannelHandlerContext ctx, int streamId, int code, Status.Code statusCode, String msg) {
if (contentType == null) { Metadata metadata = new Metadata();
throw Http2Exception.streamError(streamId, Http2Error.REFUSED_STREAM, metadata.put(InternalStatus.CODE_KEY, statusCode.toStatus());
"Content-Type is missing from the request"); metadata.put(InternalStatus.MESSAGE_KEY, msg);
} byte[][] serialized = InternalMetadata.serialize(metadata);
String contentTypeString = contentType.toString();
if (!GrpcUtil.isGrpcContentType(contentTypeString)) { Http2Headers headers = new DefaultHttp2Headers(true, serialized.length / 2)
throw Http2Exception.streamError(streamId, Http2Error.REFUSED_STREAM, .status("" + code)
"Content-Type '%s' is not supported", contentTypeString); .set(CONTENT_TYPE_HEADER, "text/plain; encoding=utf-8");
for (int i = 0; i < serialized.length; i += 2) {
headers.add(new AsciiString(serialized[i], false), new AsciiString(serialized[i + 1], false));
} }
encoder().writeHeaders(ctx, streamId, headers, 0, false, ctx.newPromise());
ByteBuf msgBuf = ByteBufUtil.writeUtf8(ctx.alloc(), msg);
encoder().writeData(ctx, streamId, msgBuf, 0, true, ctx.newPromise());
} }
private Http2Stream requireHttp2Stream(int streamId) { private Http2Stream requireHttp2Stream(int streamId) {
@ -656,20 +695,6 @@ class NettyServerHandler extends AbstractNettyHandler {
return stream; return stream;
} }
private String determineMethod(int streamId, Http2Headers headers) throws Http2Exception {
if (!HTTP_METHOD.equals(headers.method())) {
throw Http2Exception.streamError(streamId, Http2Error.REFUSED_STREAM,
"Method '%s' is not supported", headers.method());
}
// Remove the leading slash of the path and get the fully qualified method name
CharSequence path = headers.path();
if (path.charAt(0) != '/') {
throw Http2Exception.streamError(streamId, Http2Error.REFUSED_STREAM,
"Malformatted path: %s", path);
}
return path.subSequence(1, path.length()).toString();
}
/** /**
* Returns the server stream associated to the given HTTP/2 stream object. * Returns the server stream associated to the given HTTP/2 stream object.
*/ */

View File

@ -29,6 +29,7 @@ import static io.grpc.netty.Utils.HTTP_METHOD;
import static io.grpc.netty.Utils.TE_HEADER; import static io.grpc.netty.Utils.TE_HEADER;
import static io.grpc.netty.Utils.TE_TRAILERS; import static io.grpc.netty.Utils.TE_TRAILERS;
import static io.netty.buffer.Unpooled.directBuffer; import static io.netty.buffer.Unpooled.directBuffer;
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT;
import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
@ -53,6 +54,7 @@ import static org.mockito.Mockito.when;
import com.google.common.io.ByteStreams; import com.google.common.io.ByteStreams;
import com.google.common.truth.Truth; import com.google.common.truth.Truth;
import io.grpc.Attributes; import io.grpc.Attributes;
import io.grpc.InternalStatus;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.ServerStreamTracer; import io.grpc.ServerStreamTracer;
import io.grpc.Status; import io.grpc.Status;
@ -110,6 +112,9 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
private static final int STREAM_ID = 3; private static final int STREAM_ID = 3;
private static final AsciiString HTTP_FAKE_METHOD = AsciiString.of("FAKE");
@Mock @Mock
private ServerStreamListener streamListener; private ServerStreamListener streamListener;
@ -406,14 +411,76 @@ public class NettyServerHandlerTest extends NettyHandlerTestBase<NettyServerHand
public void headersWithInvalidContentTypeShouldFail() throws Exception { public void headersWithInvalidContentTypeShouldFail() throws Exception {
manualSetUp(); manualSetUp();
Http2Headers headers = new DefaultHttp2Headers() Http2Headers headers = new DefaultHttp2Headers()
.method(HTTP_METHOD) .method(HTTP_METHOD)
.set(CONTENT_TYPE_HEADER, new AsciiString("application/bad", UTF_8)) .set(CONTENT_TYPE_HEADER, new AsciiString("application/bad", UTF_8))
.set(TE_HEADER, TE_TRAILERS) .set(TE_HEADER, TE_TRAILERS)
.path(new AsciiString("/foo/bar")); .path(new AsciiString("/foo/bar"));
ByteBuf headersFrame = headersFrame(STREAM_ID, headers); ByteBuf headersFrame = headersFrame(STREAM_ID, headers);
channelRead(headersFrame); channelRead(headersFrame);
verifyWrite().writeRstStream(eq(ctx()), eq(STREAM_ID), eq(Http2Error.REFUSED_STREAM.code()), Http2Headers responseHeaders = new DefaultHttp2Headers()
any(ChannelPromise.class)); .set(InternalStatus.CODE_KEY.name(), String.valueOf(Code.INTERNAL.value()))
.set(InternalStatus.MESSAGE_KEY.name(), "Content-Type 'application/bad' is not supported")
.status("" + 415)
.set(CONTENT_TYPE_HEADER, "text/plain; encoding=utf-8");
verifyWrite().writeHeaders(eq(ctx()), eq(STREAM_ID), eq(responseHeaders), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), any(ChannelPromise.class));
}
@Test
public void headersWithInvalidMethodShouldFail() throws Exception {
manualSetUp();
Http2Headers headers = new DefaultHttp2Headers()
.method(HTTP_FAKE_METHOD)
.set(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC)
.path(new AsciiString("/foo/bar"));
ByteBuf headersFrame = headersFrame(STREAM_ID, headers);
channelRead(headersFrame);
Http2Headers responseHeaders = new DefaultHttp2Headers()
.set(InternalStatus.CODE_KEY.name(), String.valueOf(Code.INTERNAL.value()))
.set(InternalStatus.MESSAGE_KEY.name(), "Method 'FAKE' is not supported")
.status("" + 405)
.set(CONTENT_TYPE_HEADER, "text/plain; encoding=utf-8");
verifyWrite().writeHeaders(eq(ctx()), eq(STREAM_ID), eq(responseHeaders), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), any(ChannelPromise.class));
}
@Test
public void headersWithMissingPathShouldFail() throws Exception {
manualSetUp();
Http2Headers headers = new DefaultHttp2Headers()
.method(HTTP_METHOD)
.set(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC);
ByteBuf headersFrame = headersFrame(STREAM_ID, headers);
channelRead(headersFrame);
Http2Headers responseHeaders = new DefaultHttp2Headers()
.set(InternalStatus.CODE_KEY.name(), String.valueOf(Code.UNIMPLEMENTED.value()))
.set(InternalStatus.MESSAGE_KEY.name(), "Expected path but is missing")
.status("" + 404)
.set(CONTENT_TYPE_HEADER, "text/plain; encoding=utf-8");
verifyWrite().writeHeaders(eq(ctx()), eq(STREAM_ID), eq(responseHeaders), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), any(ChannelPromise.class));
}
@Test
public void headersWithInvalidPathShouldFail() throws Exception {
manualSetUp();
Http2Headers headers = new DefaultHttp2Headers()
.method(HTTP_METHOD)
.set(CONTENT_TYPE_HEADER, CONTENT_TYPE_GRPC)
.path(new AsciiString("foo/bar"));
ByteBuf headersFrame = headersFrame(STREAM_ID, headers);
channelRead(headersFrame);
Http2Headers responseHeaders = new DefaultHttp2Headers()
.set(InternalStatus.CODE_KEY.name(), String.valueOf(Code.UNIMPLEMENTED.value()))
.set(InternalStatus.MESSAGE_KEY.name(), "Expected path to start with /: foo/bar")
.status("" + 404)
.set(CONTENT_TYPE_HEADER, "text/plain; encoding=utf-8");
verifyWrite().writeHeaders(eq(ctx()), eq(STREAM_ID), eq(responseHeaders), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), any(ChannelPromise.class));
} }
@Test @Test