okhttp: okhttp client and server transport should use padded length for flow control (#10422)

This commit is contained in:
yifeizhuang 2023-08-16 15:12:19 -07:00 committed by GitHub
parent 93118f4075
commit 5f34c600c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 288 additions and 58 deletions

View File

@ -321,11 +321,12 @@ class OkHttpClientStream extends AbstractClientStream {
* Must be called with holding the transport lock.
*/
@GuardedBy("lock")
public void transportDataReceived(okio.Buffer frame, boolean endOfStream) {
public void transportDataReceived(okio.Buffer frame, boolean endOfStream, int paddingLen) {
// We only support 16 KiB frames, and the max permitted in HTTP/2 is 16 MiB. This is verified
// in OkHttp's Http2 deframer. In addition, this code is after the data has been read.
int length = (int) frame.size();
window -= length;
window -= length + paddingLen;
processedWindow -= paddingLen;
if (window < 0) {
frameWriter.rstStream(id(), ErrorCode.FLOW_CONTROL_ERROR);
transport.finishStream(

View File

@ -1140,7 +1140,8 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
*/
@SuppressWarnings("GuardedBy")
@Override
public void data(boolean inFinished, int streamId, BufferedSource in, int length)
public void data(boolean inFinished, int streamId, BufferedSource in, int length,
int paddedLength)
throws IOException {
logger.logData(OkHttpFrameLogger.Direction.INBOUND,
streamId, in.getBuffer(), length, inFinished);
@ -1166,12 +1167,12 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
synchronized (lock) {
// TODO(b/145386688): This access should be guarded by 'stream.transportState().lock';
// instead found: 'OkHttpClientTransport.this.lock'
stream.transportState().transportDataReceived(buf, inFinished);
stream.transportState().transportDataReceived(buf, inFinished, paddedLength - length);
}
}
// connection window update
connectionUnacknowledgedBytesRead += length;
connectionUnacknowledgedBytesRead += paddedLength;
if (connectionUnacknowledgedBytesRead >= initialWindowSize * DEFAULT_WINDOW_UPDATE_RATIO) {
synchronized (lock) {
frameWriter.windowUpdate(0, connectionUnacknowledgedBytesRead);

View File

@ -208,13 +208,15 @@ class OkHttpServerStream extends AbstractServerStream {
* Must be called with holding the transport lock.
*/
@Override
public void inboundDataReceived(okio.Buffer frame, int windowConsumed, boolean endOfStream) {
public void inboundDataReceived(okio.Buffer frame, int dataLength, int paddingLength,
boolean endOfStream) {
synchronized (lock) {
PerfMark.event("OkHttpServerTransport$FrameHandler.data", tag);
if (endOfStream) {
this.receivedEndOfStream = true;
}
window -= windowConsumed;
window -= dataLength + paddingLength;
processedWindow -= paddingLength;
super.inboundDataReceived(new OkHttpReadableBuffer(frame), endOfStream);
}
}

View File

@ -248,8 +248,8 @@ final class OkHttpServerTransport implements ServerTransport,
TimeUnit.NANOSECONDS);
}
transportExecutor.execute(
new FrameHandler(variant.newReader(Okio.buffer(Okio.source(socket)), false)));
transportExecutor.execute(new FrameHandler(
variant.newReader(Okio.buffer(Okio.source(socket)), false)));
} catch (Error | IOException | RuntimeException ex) {
synchronized (lock) {
if (!handshakeShutdown) {
@ -708,7 +708,7 @@ final class OkHttpServerTransport implements ServerTransport,
return;
}
// Ignore the trailers, but still half-close the stream
stream.inboundDataReceived(new Buffer(), 0, true);
stream.inboundDataReceived(new Buffer(), 0, 0, true);
return;
}
} else {
@ -799,7 +799,7 @@ final class OkHttpServerTransport implements ServerTransport,
listener.streamCreated(streamForApp, method, metadata);
stream.onStreamAllocated();
if (inFinished) {
stream.inboundDataReceived(new Buffer(), 0, inFinished);
stream.inboundDataReceived(new Buffer(), 0, 0, inFinished);
}
}
}
@ -819,7 +819,8 @@ final class OkHttpServerTransport implements ServerTransport,
* Handle an HTTP2 DATA frame.
*/
@Override
public void data(boolean inFinished, int streamId, BufferedSource in, int length)
public void data(boolean inFinished, int streamId, BufferedSource in, int length,
int paddedLength)
throws IOException {
frameLogger.logData(
OkHttpFrameLogger.Direction.INBOUND, streamId, in.getBuffer(), length, inFinished);
@ -853,7 +854,7 @@ final class OkHttpServerTransport implements ServerTransport,
"Received DATA for half-closed (remote) stream. RFC7540 section 5.1");
return;
}
if (stream.inboundWindowAvailable() < length) {
if (stream.inboundWindowAvailable() < paddedLength) {
in.skip(length);
streamError(streamId, ErrorCode.FLOW_CONTROL_ERROR,
"Received DATA size exceeded window size. RFC7540 section 6.9");
@ -861,11 +862,11 @@ final class OkHttpServerTransport implements ServerTransport,
}
Buffer buf = new Buffer();
buf.write(in.getBuffer(), length);
stream.inboundDataReceived(buf, length, inFinished);
stream.inboundDataReceived(buf, length, paddedLength - length, inFinished);
}
// connection window update
connectionUnacknowledgedBytesRead += length;
connectionUnacknowledgedBytesRead += paddedLength;
if (connectionUnacknowledgedBytesRead
>= config.flowControlWindow * Utils.DEFAULT_WINDOW_UPDATE_RATIO) {
synchronized (lock) {
@ -1064,7 +1065,7 @@ final class OkHttpServerTransport implements ServerTransport,
}
streams.put(streamId, stream);
if (inFinished) {
stream.inboundDataReceived(new Buffer(), 0, true);
stream.inboundDataReceived(new Buffer(), 0, 0, true);
}
frameWriter.headers(streamId, headers);
outboundFlow.data(
@ -1122,7 +1123,7 @@ final class OkHttpServerTransport implements ServerTransport,
interface StreamState {
/** Must be holding 'lock' when calling. */
void inboundDataReceived(Buffer frame, int windowConsumed, boolean endOfStream);
void inboundDataReceived(Buffer frame, int dataLength, int paddingLength, boolean endOfStream);
/** Must be holding 'lock' when calling. */
boolean hasReceivedEndOfStream();
@ -1159,12 +1160,12 @@ final class OkHttpServerTransport implements ServerTransport,
@Override public void onSentBytes(int frameBytes) {}
@Override public void inboundDataReceived(
Buffer frame, int windowConsumed, boolean endOfStream) {
Buffer frame, int dataLength, int paddingLength, boolean endOfStream) {
synchronized (lock) {
if (endOfStream) {
receivedEndOfStream = true;
}
window -= windowConsumed;
window -= dataLength + paddingLength;
try {
frame.skip(frame.size()); // Recycle segments
} catch (IOException ex) {

View File

@ -291,14 +291,16 @@ public class OkHttpClientTransportTest {
final String message = "Hello Client";
Buffer buffer = createMessageFrame(message);
frameHandler().data(false, 3, buffer, (int) buffer.size());
frameHandler().data(false, 3, buffer, (int) buffer.size(),
(int) buffer.size());
assertThat(logs).hasSize(1);
log = logs.remove(0);
assertThat(log.getMessage()).startsWith(Direction.INBOUND + " DATA: streamId=" + 3);
assertThat(log.getLevel()).isEqualTo(Level.FINE);
// At most 64 bytes of data frame will be logged.
frameHandler().data(false, 3, createMessageFrame(new String(new char[1000])), 1000);
frameHandler().data(false, 3, createMessageFrame(new String(new char[1000])),
1000, 1000);
assertThat(logs).hasSize(1);
log = logs.remove(0);
String data = log.getMessage();
@ -377,7 +379,8 @@ public class OkHttpClientTransportTest {
// Receive the message.
final String message = "Hello Client";
Buffer buffer = createMessageFrame(message);
frameHandler().data(false, 3, buffer, (int) buffer.size());
frameHandler().data(false, 3, buffer, (int) buffer.size(),
(int) buffer.size());
listener.waitUntilStreamClosed();
assertEquals(Code.RESOURCE_EXHAUSTED, listener.status.getCode());
@ -500,7 +503,8 @@ public class OkHttpClientTransportTest {
assertNotNull(listener.headers);
for (int i = 0; i < numMessages; i++) {
Buffer buffer = createMessageFrame(message + i);
frameHandler().data(false, 3, buffer, (int) buffer.size());
frameHandler().data(false, 3, buffer, (int) buffer.size(),
(int) buffer.size());
}
frameHandler().headers(true, true, 3, 0, grpcResponseTrailers(), HeadersMode.HTTP_20_HEADERS);
listener.waitUntilStreamClosed();
@ -529,7 +533,8 @@ public class OkHttpClientTransportTest {
@Test
public void receivedDataForInvalidStreamShouldKillConnection() throws Exception {
initTransport();
frameHandler().data(false, 3, createMessageFrame(new String(new char[1000])), 1000);
frameHandler().data(false, 3, createMessageFrame(new String(new char[1000])),
1000, 1000);
verify(frameWriter, timeout(TIME_OUT_MS))
.goAway(eq(0), eq(ErrorCode.PROTOCOL_ERROR), any(byte[].class));
verify(transportListener).transportShutdown(isA(Status.class));
@ -551,7 +556,8 @@ public class OkHttpClientTransportTest {
HeadersMode.HTTP_20_HEADERS);
// Now wait to receive 1000 bytes of data so we can have a better error message before
// cancelling the streaam.
frameHandler().data(false, 3, createMessageFrame(new String(new char[1000])), 1000);
frameHandler().data(false, 3,
createMessageFrame(new String(new char[1000])), 1000, 1000);
verify(frameWriter, timeout(TIME_OUT_MS)).rstStream(eq(3), eq(ErrorCode.CANCEL));
assertNull(listener.headers);
assertEquals(Status.INTERNAL.getCode(), listener.status.getCode());
@ -622,7 +628,8 @@ public class OkHttpClientTransportTest {
assertContainStream(3);
frameHandler().headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS);
Buffer buffer = createMessageFrame("a message");
frameHandler().data(false, 3, buffer, (int) buffer.size());
frameHandler().data(false, 3, buffer, (int) buffer.size(),
(int) buffer.size());
frameHandler().headers(true, true, 3, 0, grpcResponseTrailers(), HeadersMode.HTTP_20_HEADERS);
frameHandler().rstStream(3, ErrorCode.NO_ERROR);
stream.request(1);
@ -762,15 +769,18 @@ public class OkHttpClientTransportTest {
int messageLength = INITIAL_WINDOW_SIZE / 4;
byte[] fakeMessage = new byte[messageLength];
int paddingLength = 2;
// Stream 1 receives a message
Buffer buffer = createMessageFrame(fakeMessage);
Buffer buffer = createMessageFrame(fakeMessage, paddingLength);
int messageFrameLength = (int) buffer.size();
frameHandler().data(false, 3, buffer, messageFrameLength);
frameHandler().data(false, 3, buffer, messageFrameLength - paddingLength,
messageFrameLength);
// Stream 2 receives a message
buffer = createMessageFrame(fakeMessage);
frameHandler().data(false, 5, buffer, messageFrameLength);
buffer = createMessageFrame(fakeMessage, paddingLength);
frameHandler().data(false, 5, buffer, messageFrameLength - paddingLength,
messageFrameLength);
verify(frameWriter, timeout(TIME_OUT_MS))
.windowUpdate(eq(0), eq((long) 2 * messageFrameLength));
@ -778,17 +788,18 @@ public class OkHttpClientTransportTest {
// Stream 1 receives another message
buffer = createMessageFrame(fakeMessage);
frameHandler().data(false, 3, buffer, messageFrameLength);
messageFrameLength = (int) buffer.size();
frameHandler().data(false, 3, buffer, messageFrameLength, messageFrameLength);
verify(frameWriter, timeout(TIME_OUT_MS))
.windowUpdate(eq(3), eq((long) 2 * messageFrameLength));
.windowUpdate(eq(3), eq((long) 2 * messageFrameLength + paddingLength));
// Stream 2 receives another message
buffer = createMessageFrame(fakeMessage);
frameHandler().data(false, 5, buffer, messageFrameLength);
frameHandler().data(false, 5, buffer, messageFrameLength, messageFrameLength);
verify(frameWriter, timeout(TIME_OUT_MS))
.windowUpdate(eq(5), eq((long) 2 * messageFrameLength));
.windowUpdate(eq(5), eq((long) 2 * messageFrameLength + paddingLength));
verify(frameWriter, timeout(TIME_OUT_MS))
.windowUpdate(eq(0), eq((long) 2 * messageFrameLength));
@ -819,7 +830,8 @@ public class OkHttpClientTransportTest {
frameHandler().headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS);
Buffer buffer = createMessageFrame(fakeMessage);
long messageFrameLength = buffer.size();
frameHandler().data(false, 3, buffer, (int) messageFrameLength);
frameHandler().data(false, 3, buffer, (int) messageFrameLength,
(int) messageFrameLength);
ArgumentCaptor<Integer> idCaptor = ArgumentCaptor.forClass(Integer.class);
verify(frameWriter, timeout(TIME_OUT_MS)).windowUpdate(
idCaptor.capture(), eq(messageFrameLength));
@ -1123,7 +1135,8 @@ public class OkHttpClientTransportTest {
frameHandler().headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS);
final String receivedMessage = "No, you are fine.";
Buffer buffer = createMessageFrame(receivedMessage);
frameHandler().data(false, 3, buffer, (int) buffer.size());
frameHandler().data(false, 3, buffer, (int) buffer.size(),
(int) buffer.size());
frameHandler().headers(true, true, 3, 0, grpcResponseTrailers(), HeadersMode.HTTP_20_HEADERS);
listener1.waitUntilStreamClosed();
assertEquals(1, listener1.messages.size());
@ -1154,12 +1167,12 @@ public class OkHttpClientTransportTest {
assertNotNull(listener.headers);
String message = "hello";
Buffer buffer = createMessageFrame(message);
frameHandler().data(false, startId, buffer, (int) buffer.size());
frameHandler().data(false, startId, buffer, (int) buffer.size(), (int) buffer.size());
getStream(startId).cancel(Status.CANCELLED);
// Receives the second message after be cancelled.
buffer = createMessageFrame(message);
frameHandler().data(false, startId, buffer, (int) buffer.size());
frameHandler().data(false, startId, buffer, (int) buffer.size(), (int) buffer.size());
listener.waitUntilStreamClosed();
// Should only have the first message delivered.
@ -1329,7 +1342,7 @@ public class OkHttpClientTransportTest {
byte[] fakeMessage = new byte[messageLength];
Buffer buffer = createMessageFrame(fakeMessage);
int messageFrameLength = (int) buffer.size();
frameHandler().data(false, 3, buffer, messageFrameLength);
frameHandler().data(false, 3, buffer, messageFrameLength, messageFrameLength);
listener.waitUntilStreamClosed();
assertEquals(Status.INTERNAL.getCode(), listener.status.getCode());
@ -1392,7 +1405,8 @@ public class OkHttpClientTransportTest {
stream.start(listener);
stream.request(1);
Buffer buffer = createMessageFrame(new byte[1]);
frameHandler().data(false, 3, buffer, (int) buffer.size());
frameHandler().data(false, 3, buffer, (int) buffer.size(),
(int) buffer.size());
// Trigger the failure by a trailer.
frameHandler().headers(
@ -1414,11 +1428,13 @@ public class OkHttpClientTransportTest {
stream.start(listener);
stream.request(1);
Buffer buffer = createMessageFrame(new byte[1]);
frameHandler().data(false, 3, buffer, (int) buffer.size());
frameHandler().data(false, 3, buffer, (int) buffer.size(),
(int) buffer.size());
// Trigger the failure by a data frame.
buffer = createMessageFrame(new byte[1]);
frameHandler().data(true, 3, buffer, (int) buffer.size());
frameHandler().data(true, 3, buffer, (int) buffer.size(),
(int) buffer.size());
listener.waitUntilStreamClosed();
assertEquals(Status.INTERNAL.getCode(), listener.status.getCode());
@ -1436,7 +1452,8 @@ public class OkHttpClientTransportTest {
stream.start(listener);
stream.request(1);
Buffer buffer = createMessageFrame(new byte[1000]);
frameHandler().data(false, 3, buffer, (int) buffer.size());
frameHandler().data(false, 3, buffer, (int) buffer.size(),
(int) buffer.size());
// Once we receive enough detail, we cancel the stream. so we should have sent cancel.
verify(frameWriter, timeout(TIME_OUT_MS)).rstStream(eq(3), eq(ErrorCode.CANCEL));
@ -1459,7 +1476,8 @@ public class OkHttpClientTransportTest {
Buffer buffer = createMessageFrame(
new byte[INITIAL_WINDOW_SIZE / 2 + 1]);
frameHandler().data(false, 3, buffer, (int) buffer.size());
frameHandler().data(false, 3, buffer, (int) buffer.size(),
(int) buffer.size());
// Should still update the connection window even stream 3 is gone.
verify(frameWriter, timeout(TIME_OUT_MS)).windowUpdate(0,
HEADER_LENGTH + INITIAL_WINDOW_SIZE / 2 + 1);
@ -1467,7 +1485,8 @@ public class OkHttpClientTransportTest {
new byte[INITIAL_WINDOW_SIZE / 2 + 1]);
// This should kill the connection, since we never created stream 5.
frameHandler().data(false, 5, buffer, (int) buffer.size());
frameHandler().data(false, 5, buffer, (int) buffer.size(),
(int) buffer.size());
verify(frameWriter, timeout(TIME_OUT_MS))
.goAway(eq(0), eq(ErrorCode.PROTOCOL_ERROR), any(byte[].class));
verify(transportListener).transportShutdown(isA(Status.class));
@ -2114,10 +2133,15 @@ public class OkHttpClientTransportTest {
}
private static Buffer createMessageFrame(byte[] message) {
return createMessageFrame(message,0);
}
private static Buffer createMessageFrame(byte[] message, int paddingLength) {
Buffer buffer = new Buffer();
buffer.writeByte(0 /* UNCOMPRESSED */);
buffer.writeInt(message.length);
buffer.write(message);
buffer.write(new byte[paddingLength]);
return buffer;
}

View File

@ -60,6 +60,7 @@ import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Deque;
import java.util.List;
@ -70,6 +71,7 @@ import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import okio.Buffer;
import okio.BufferedSink;
import okio.BufferedSource;
import okio.ByteString;
import okio.Okio;
@ -90,15 +92,19 @@ public class OkHttpServerTransportTest {
private static final int TIME_OUT_MS = 2000;
private static final int INITIAL_WINDOW_SIZE = 65535;
private static final long MAX_CONNECTION_IDLE = TimeUnit.SECONDS.toNanos(1);
private static final byte FLAG_NONE = 0x0;
private static final byte FLAG_PADDED = 0x8;
private static final byte FLAG_END_STREAM = 0x1;
private static final byte TYPE_DATA = 0x0;
private MockServerTransportListener mockTransportListener = new MockServerTransportListener();
private ServerTransportListener transportListener
= mock(ServerTransportListener.class, delegatesTo(mockTransportListener));
private OkHttpServerTransport serverTransport;
private final ExecutorService threadPool = Executors.newCachedThreadPool();
private final SocketPair socketPair = SocketPair.create(threadPool);
private final FrameWriter clientFrameWriter
= new Http2().newWriter(Okio.buffer(Okio.sink(socketPair.getClientOutputStream())), true);
private final BufferedSink clientWriterSink = Okio.buffer(
Okio.sink(socketPair.getClientOutputStream()));
private final FrameWriter clientFrameWriter = new Http2().newWriter(clientWriterSink, true);
private final FrameReader clientFrameReader
= new Http2().newReader(Okio.buffer(Okio.source(socketPair.getClientInputStream())), true);
private final FrameReader.Handler clientFramesRead = mock(FrameReader.Handler.class);
@ -135,7 +141,8 @@ public class OkHttpServerTransportTest {
Buffer buf = new Buffer();
buf.write(in.getBuffer(), length);
clientDataFrames.data(outDone, streamId, buf);
})).when(clientFramesRead).data(anyBoolean(), anyInt(), any(BufferedSource.class), anyInt());
})).when(clientFramesRead).data(anyBoolean(), anyInt(), any(BufferedSource.class), anyInt(),
anyInt());
}
@After
@ -379,7 +386,8 @@ public class OkHttpServerTransportTest {
Buffer responseMessageFrame = createMessageFrame("Howdy client");
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead)
.data(eq(false), eq(1), any(BufferedSource.class), eq((int) responseMessageFrame.size()));
.data(eq(false), eq(1), any(BufferedSource.class), eq((int) responseMessageFrame.size()),
eq((int) responseMessageFrame.size()));
verify(clientDataFrames).data(false, 1, responseMessageFrame);
List<Header> responseTrailers = Arrays.asList(
@ -440,7 +448,8 @@ public class OkHttpServerTransportTest {
Buffer responseMessageFrame = createMessageFrame("Howdy client");
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead)
.data(eq(false), eq(1), any(BufferedSource.class), eq((int) responseMessageFrame.size()));
.data(eq(false), eq(1), any(BufferedSource.class), eq((int) responseMessageFrame.size()),
eq((int) responseMessageFrame.size()));
verify(clientDataFrames).data(false, 1, responseMessageFrame);
pingPong();
assertThat(serverTransport.getActiveStreams().length).isEqualTo(1);
@ -975,7 +984,8 @@ public class OkHttpServerTransportTest {
Buffer responseDataFrame = new Buffer().writeUtf8(errorDescription.substring(0, 1));
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead).data(
eq(false), eq(1), any(BufferedSource.class), eq((int) responseDataFrame.size()));
eq(false), eq(1), any(BufferedSource.class), eq((int) responseDataFrame.size()),
eq((int) responseDataFrame.size()));
verify(clientDataFrames).data(false, 1, responseDataFrame);
clientFrameWriter.windowUpdate(1, 1000);
@ -984,7 +994,8 @@ public class OkHttpServerTransportTest {
responseDataFrame = new Buffer().writeUtf8(errorDescription.substring(1));
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead).data(
eq(true), eq(1), any(BufferedSource.class), eq((int) responseDataFrame.size()));
eq(true), eq(1), any(BufferedSource.class), eq((int) responseDataFrame.size()),
eq((int) responseDataFrame.size()));
verify(clientDataFrames).data(true, 1, responseDataFrame);
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
@ -993,6 +1004,71 @@ public class OkHttpServerTransportTest {
shutdownAndTerminate(/*lastStreamId=*/ 1);
}
@Test
public void windowUpdate() throws Exception {
serverBuilder.flowControlWindow(100);
initTransport();
handshake();
List<Header> headers = Arrays.asList(
HTTP_SCHEME_HEADER,
METHOD_HEADER,
new Header(Header.TARGET_AUTHORITY, "example.com:80"),
new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"),
CONTENT_TYPE_HEADER,
TE_HEADER,
new Header("some-metadata", "this could be anything"));
clientFrameWriter.headers(1, new ArrayList<>(headers));
clientFrameWriter.headers(3, new ArrayList<>(headers));
String message = "Hello Server Pad Me!"; // length = 20, add buffer length = 5
writeDataDirectly(clientWriterSink, FLAG_NONE, 1, message, 0);
pingPong();
MockStreamListener streamListener = mockTransportListener.newStreams.pop();
MockStreamListener streamListener2 = mockTransportListener.newStreams.pop();
assertThat(streamListener.stream.getAuthority()).isEqualTo("example.com:80");
assertThat(streamListener.method).isEqualTo("com.example/SimpleService.doit");
assertThat(streamListener.headers.get(
Metadata.Key.of("Some-Metadata", Metadata.ASCII_STRING_MARSHALLER)))
.isEqualTo("this could be anything");
streamListener.stream.request(1);
pingPong();
assertThat(streamListener.messages.pop()).isEqualTo("Hello Server Pad Me!");
streamListener.stream.writeHeaders(metadata("User-Data", "best data"));
streamListener.stream.writeMessage(new ByteArrayInputStream("Howdy client".getBytes(UTF_8)));
List<Header> responseHeaders = Arrays.asList(
new Header(":status", "200"),
CONTENT_TYPE_HEADER,
new Header("user-data", "best data"));
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead)
.headers(false, false, 1, -1, responseHeaders, HeadersMode.HTTP_20_HEADERS);
writeDataDirectly(clientWriterSink, FLAG_PADDED, 1, message, 10);
writeDataDirectly(clientWriterSink, FLAG_PADDED | FLAG_END_STREAM, 3, message, 40);
clientFrameWriter.flush();
int expectedConsumed = message.length() + 5;
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead).windowUpdate(0, expectedConsumed * 2 + 10);
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead).windowUpdate(0, expectedConsumed + 40);
streamListener.stream.request(2);
streamListener2.stream.request(1);
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead).windowUpdate(1, expectedConsumed * 2 + 10);
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead).windowUpdate(3, expectedConsumed + 40);
writeDataDirectly(clientWriterSink, FLAG_PADDED | FLAG_END_STREAM, 1, message, 100);
clientFrameWriter.flush();
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead).rstStream(eq(1), eq(ErrorCode.FLOW_CONTROL_ERROR));
clientFrameWriter.rstStream(3, ErrorCode.CANCEL);
pingPong();
shutdownAndTerminate(/*lastStreamId=*/ 3);
}
@Test
public void dataForStream0_failsWithGoAway() throws Exception {
initTransport();
@ -1223,6 +1299,32 @@ public class OkHttpServerTransportTest {
return buffer;
}
private void writeDataDirectly(BufferedSink sink, int flag, int streamId, String message,
int paddingLength) throws IOException {
Buffer buffer = createMessageFrame(message);
int bufferLengthWithPadding = (int) buffer.size();
if ((flag & FLAG_PADDED) != 0) {
bufferLengthWithPadding += paddingLength;
}
writeLength(sink, bufferLengthWithPadding);
sink.writeByte(TYPE_DATA);
sink.writeByte(flag & 0xff);
sink.writeInt(streamId & 0x7fffffff);
if ((flag & FLAG_PADDED) != 0) {
sink.writeByte((short)(paddingLength - 1));
char[] value = new char[paddingLength - 1];
Arrays.fill(value, '!');
buffer.write(new String(value).getBytes(UTF_8));
}
sink.write(buffer, buffer.size());
}
private void writeLength(BufferedSink sink, int length) throws IOException {
sink.writeByte((length >>> 16 ) & 0xff);
sink.writeByte((length >>> 8 ) & 0xff);
sink.writeByte(length & 0xff);
}
private Metadata metadata(String... keysAndValues) {
Metadata metadata = new Metadata();
assertThat(keysAndValues.length % 2).isEqualTo(0);
@ -1279,7 +1381,8 @@ public class OkHttpServerTransportTest {
Buffer responseDataFrame = new Buffer().writeUtf8(errorDescription);
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead).data(
eq(true), eq(streamId), any(BufferedSource.class), eq((int) responseDataFrame.size()));
eq(true), eq(streamId), any(BufferedSource.class),
eq((int) responseDataFrame.size()), eq((int) responseDataFrame.size()));
verify(clientDataFrames).data(true, streamId, responseDataFrame);
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();

View File

@ -32,7 +32,7 @@ public interface FrameReader extends Closeable {
boolean nextFrame(Handler handler) throws IOException;
interface Handler {
void data(boolean inFinished, int streamId, BufferedSource source, int length)
void data(boolean inFinished, int streamId, BufferedSource source, int length, int paddedLength)
throws IOException;
/**

View File

@ -220,7 +220,7 @@ public final class Http2 implements Variant {
return hpackReader.getAndResetHeaderList();
}
private void readData(Handler handler, int length, byte flags, int streamId)
private void readData(Handler handler, int paddedLength, byte flags, int streamId)
throws IOException {
// TODO: checkState open or half-closed (local) or raise STREAM_CLOSED
boolean inFinished = (flags & FLAG_END_STREAM) != 0;
@ -230,10 +230,10 @@ public final class Http2 implements Variant {
}
short padding = (flags & FLAG_PADDED) != 0 ? (short) (source.readByte() & 0xff) : 0;
length = lengthWithoutPadding(length, flags, padding);
int length = lengthWithoutPadding(paddedLength, flags, padding);
// FIXME: pass padding length to handler because it should be included for flow control
handler.data(inFinished, streamId, source, length);
handler.data(inFinished, streamId, source, length, paddedLength);
source.skip(padding);
}

View File

@ -0,0 +1,98 @@
/*
* Copyright (C) 2023 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.okhttp.internal.framed;
import static io.grpc.okhttp.internal.framed.Http2.FLAG_NONE;
import static io.grpc.okhttp.internal.framed.Http2.FLAG_PADDED;
import static io.grpc.okhttp.internal.framed.Http2.TYPE_DATA;
import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.verify;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import okio.Buffer;
import okio.BufferedSink;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;
@RunWith(JUnit4.class)
public class Http2Test {
@Rule
public final MockitoRule mocks = MockitoJUnit.rule();
private FrameReader http2FrameReader;
@Mock
private FrameReader.Handler mockHandler;
private final int STREAM_ID = 6;
@Test
public void dataFrameNoPadding() throws IOException {
Buffer bufferIn = createData(FLAG_NONE, 3239, 0 );
http2FrameReader = new Http2.Reader(bufferIn, 100, true);
http2FrameReader.nextFrame(mockHandler);
verify(mockHandler).data(eq(false), eq(STREAM_ID), eq(bufferIn), eq(3239), eq(3239));
assertEquals(3239, bufferIn.size());
}
@Test
public void dataFrameOneLengthPadding() throws IOException {
Buffer bufferIn = createData(FLAG_PADDED, 1876, 0);
http2FrameReader = new Http2.Reader(bufferIn, 100, true);
http2FrameReader.nextFrame(mockHandler);
verify(mockHandler).data(eq(false), eq(STREAM_ID), eq(bufferIn), eq(1875), eq(1876));
assertEquals(1876, bufferIn.size());
}
@Test
public void dataFramePadding() throws IOException {
Buffer bufferIn = createData(FLAG_PADDED, 2037, 125);
http2FrameReader = new Http2.Reader(bufferIn, 100, true);
http2FrameReader.nextFrame(mockHandler);
verify(mockHandler).data(eq(false), eq(STREAM_ID), eq(bufferIn), eq(2037 - 126), eq(2037));
assertEquals(2037 - 125, bufferIn.size());
}
private Buffer createData(int flag, int length, int paddingLength) throws IOException {
Buffer sink = new Buffer();
writeLength(sink, length);
sink.writeByte(TYPE_DATA);
sink.writeByte(flag);
sink.writeInt(STREAM_ID);
if ((flag & FLAG_PADDED) != 0) {
sink.writeByte((short)paddingLength);
}
char[] value = new char[length];
Arrays.fill(value, '!');
sink.write(new String(value).getBytes(StandardCharsets.UTF_8));
return sink;
}
private void writeLength(BufferedSink sink, int length) throws IOException {
sink.writeByte((length >>> 16 ) & 0xff);
sink.writeByte((length >>> 8 ) & 0xff);
sink.writeByte(length & 0xff);
}
}