Merge pull request #748 from vnorigoog/master

add Message Deframer to parse frames in the input messages
This commit is contained in:
Vasu Nori 2020-03-12 10:52:47 -07:00 committed by GitHub
commit dc25fa759a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 270 additions and 20 deletions

View File

@ -56,7 +56,6 @@
<orderEntry type="library" name="Maven: backport-util-concurrent:backport-util-concurrent:3.1" level="project" />
<orderEntry type="library" name="Maven: org.apache.maven:maven-plugin-registry:2.2.1" level="project" />
<orderEntry type="library" name="Maven: org.codehaus.plexus:plexus-container-default:1.0-alpha-9-stable-1" level="project" />
<orderEntry type="library" name="Maven: junit:junit:3.8.1" level="project" />
<orderEntry type="library" name="Maven: classworlds:classworlds:1.1-alpha-2" level="project" />
<orderEntry type="library" name="Maven: org.apache.maven:maven-plugin-descriptor:2.2.1" level="project" />
<orderEntry type="library" name="Maven: org.eclipse.jetty:jetty-util:9.4.26.v20200117" level="project" />
@ -104,7 +103,6 @@
<orderEntry type="library" name="Maven: io.grpc:grpc-core:1.27.1" level="project" />
<orderEntry type="library" name="Maven: com.google.code.gson:gson:2.8.6" level="project" />
<orderEntry type="library" name="Maven: com.google.android:annotations:4.1.1.4" level="project" />
<orderEntry type="library" name="Maven: com.google.errorprone:error_prone_annotations:2.3.4" level="project" />
<orderEntry type="library" name="Maven: io.perfmark:perfmark-api:0.19.0" level="project" />
<orderEntry type="library" name="Maven: io.grpc:grpc-protobuf:1.27.1" level="project" />
<orderEntry type="library" name="Maven: io.grpc:grpc-api:1.27.1" level="project" />
@ -115,10 +113,16 @@
<orderEntry type="library" name="Maven: com.google.guava:guava:28.1-android" level="project" />
<orderEntry type="library" name="Maven: com.google.guava:failureaccess:1.0.1" level="project" />
<orderEntry type="library" name="Maven: com.google.guava:listenablefuture:9999.0-empty-to-avoid-conflict-with-guava" level="project" />
<orderEntry type="library" name="Maven: org.checkerframework:checker-compat-qual:2.5.5" level="project" />
<orderEntry type="library" name="Maven: com.google.j2objc:j2objc-annotations:1.3" level="project" />
<orderEntry type="library" name="Maven: com.google.api.grpc:proto-google-common-protos:1.17.0" level="project" />
<orderEntry type="library" name="Maven: io.grpc:grpc-protobuf-lite:1.27.1" level="project" />
<orderEntry type="library" name="Maven: io.grpc:grpc-stub:1.27.1" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: junit:junit:4.13" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: org.hamcrest:hamcrest-core:1.3" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: com.google.truth:truth:1.0.1" level="project" />
<orderEntry type="library" name="Maven: org.checkerframework:checker-compat-qual:2.5.5" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: com.googlecode.java-diff-utils:diffutils:1.3.0" level="project" />
<orderEntry type="library" scope="TEST" name="Maven: com.google.auto.value:auto-value-annotations:1.6.3" level="project" />
<orderEntry type="library" name="Maven: com.google.errorprone:error_prone_annotations:2.3.3" level="project" />
</component>
</module>

View File

@ -47,6 +47,18 @@
<artifactId>grpc-stub</artifactId>
<version>1.27.1</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.13</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.google.truth</groupId>
<artifactId>truth</artifactId>
<version>1.0.1</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>

View File

@ -0,0 +1,103 @@
package com.google.grpcweb;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.logging.Logger;
import org.apache.commons.io.IOUtils;
/**
* Reads frames from the input bytes and returns a single message.
*/
class MessageDeframer {
private static final Logger LOGGER = Logger.getLogger(MessageDeframer.class.getName());
static final byte DATA_BYTE = (byte) 0x00;
// TODO: Max value of Java int is = (2GB - 1). Max POST size can be 2GB or even 4GB.
// fix this code to be able to handle upto 4GB input size.
private int mLength = 0;
private int mReadSoFar = 0;
private int mNumFrames = 0;
private ArrayList<byte[]> mFrames = new ArrayList<>();
private byte[] mMsg = null;
byte[] getMessageBytes() { return mMsg;}
int getLength() { return mLength;}
/** Reads the bytes from the given InputStream and populates bytes in {@link #mMsg}
*/
boolean processInput(InputStream in) {
byte[] inBytes;
try {
inBytes = IOUtils.toByteArray(in);
} catch (IOException e) {
e.printStackTrace();
LOGGER.warning("invalid input");
return false;
}
if (inBytes.length < 5) {
LOGGER.warning("invalid input. Expected minimum of 5 bytes");
return false;
}
while (getNextFrameBytes(inBytes)) {}
// Make sure the input has a valid message
if (mFrames.isEmpty()) {
LOGGER.warning("input has NO frames!");
return false;
}
// common case is only one frame.
if (mFrames.size() == 1) {
mMsg = mFrames.get(0);
} else {
// concatenate all frames into one byte array
mMsg = new byte[mLength];
int offset = 0;
for (byte[] f : mFrames) {
System.arraycopy(f, 0, mMsg, offset, f.length);
offset += f.length;
}
mFrames = null;
}
return true;
}
/** returns true if the next frame is a DATA frame */
private boolean getNextFrameBytes(byte[] inBytes) {
// Firstbyte should be 0x00 (for this to be a DATA frame)
int firstByteValue = inBytes[mReadSoFar] | DATA_BYTE;
if (firstByteValue != 0) {
LOGGER.fine("done with DATA bytes");
return false;
}
// Next 4 bytes = length of the bytes array starting after the 4 bytes.
int offset = mReadSoFar + 1;
int len = ByteBuffer.wrap(inBytes, offset, 4).getInt();
if (len <= 0) {
LOGGER.warning(String.format("invalid length value: %d", len));
return false;
}
// Make sure we have enough bytes in the inputstream
int expectedNumBytes = len + 5 + mReadSoFar;
if (inBytes.length < expectedNumBytes) {
LOGGER.warning(String.format("input doesn't have enough bytes. expected: %d, found %d",
expectedNumBytes, inBytes.length));
return false;
}
// Read "len" bytes into message
mLength += len;
offset += 4;
byte[] inputBytes = Arrays.copyOfRange(inBytes, offset, len + offset);
mFrames.add(inputBytes);
mReadSoFar += (len + 5);
mNumFrames++;
// we have more frames to process, if there are bytes unprocessed
return inBytes.length > mReadSoFar;
}
}

View File

@ -0,0 +1,20 @@
package com.google.grpcweb;
/**
* Creates frames from the input bytes.
*/
class MessageFramer {
static final byte DATA_BYTE = (byte) 0x00;
// TODO: handle more than single frame; i.e., input byte array size > (2GB - 1)
byte[] getPrefix(byte[] in) {
int len = in.length;
return new byte[] {
DATA_BYTE,
(byte) ((len >> 24) & 0xff),
(byte) ((len >> 16) & 0xff),
(byte) ((len >> 8) & 0xff),
(byte) ((len >> 0) & 0xff),
};
}
}

View File

@ -64,10 +64,9 @@ class MessageHandler {
Object inputObj;
try {
// PUNT use MessageDeframer. ignore the first 5 bytes
inputObj = parseFromObj.invoke(null, Arrays.copyOfRange(in, 5, in.length));
inputObj = parseFromObj.invoke(null, in);
} catch (InvocationTargetException | IllegalAccessException e) {
throw new IllegalArgumentException("Couldn't invoke parseFrom " + e.getMessage());
throw new IllegalArgumentException(e);
}
if (inputObj == null || !inputArgClass.isInstance(inputObj)) {
@ -93,8 +92,11 @@ class MessageHandler {
private Object handleRpcInvocationForProtoContentType(HttpServletRequest req,
Object stub, Method rpcMethod) throws IOException {
ServletInputStream in = req.getInputStream();
byte[] inBytes = IOUtils.toByteArray(in);
Object inObj = getInputProtobufObj(rpcMethod, inBytes);
MessageDeframer deframer = new MessageDeframer();
if (!deframer.processInput(in)) {
return null;
}
Object inObj = getInputProtobufObj(rpcMethod, deframer.getMessageBytes());
Class returnClassType = rpcMethod.getReturnType();
LOGGER.fine("returnClassType is : " + returnClassType.getName());

View File

@ -23,11 +23,11 @@ class RequestHandler extends AbstractHandler {
public void handle(
String target, Request baseRequest, HttpServletRequest req, HttpServletResponse response)
throws IOException, ServletException {
mFactory.getDebugInfo().printRequest(req);
response.setContentType("text/html; charset=utf-8");
baseRequest.setHandled(true);
try {
mFactory.getDebugInfo().printRequest(req);
// From the request, get the rpc-method name and class name and then get their corresponding
// concrete objects.

View File

@ -10,17 +10,8 @@ class SendResponse {
*/
void writeResponse(HttpServletResponse response, byte[] out) throws IOException {
ServletOutputStream oStream = response.getOutputStream();
oStream.write((byte) 0x00);
oStream.write(lengthToBytes(out.length));
byte[] prefix = new MessageFramer().getPrefix(out);
oStream.write(prefix);
oStream.write(out);
}
private byte[] lengthToBytes(int len) {
return new byte[] {
(byte) ((len >> 24) & 0xff),
(byte) ((len >> 16) & 0xff),
(byte) ((len >> 8) & 0xff),
(byte) ((len >> 0) & 0xff),
};
}
}

View File

@ -0,0 +1,85 @@
package com.google.grpcweb;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@RunWith(JUnit4.class)
public class MessageDeframerTest {
private MessageDeframer testInstance;
@Before
public void setUp() {
testInstance = new MessageDeframer();
assertNotNull(testInstance);
}
@Test
public void testProcessInput_Singleframe() throws IOException {
String source = "This is the source of my input stream";
byte[] str = stringToFrame(source);
InputStream in = new ByteArrayInputStream(str);
assertTrue(testInstance.processInput(in));
byte[] result = testInstance.getMessageBytes();
assertTrue(source.equals(new String(result)));
assertEquals(source.length(), testInstance.getLength());
}
@Test
public void testProcessInput_Manyframes() throws IOException {
ArrayList<String> inputSrcs = new ArrayList<>();
// Create 10 frames
for (int i = 0; i < 10; i++) {
inputSrcs.add("this is string# " + i);
}
ByteArrayOutputStream combined = new ByteArrayOutputStream();
String concatenatedInputSrc = "";
for (String s : inputSrcs) {
combined.write(stringToFrame(s));
concatenatedInputSrc += s;
}
byte[] combinedBytes = combined.toByteArray();
assertEquals(concatenatedInputSrc.length() + inputSrcs.size() * 5,
combinedBytes.length);
InputStream in = new ByteArrayInputStream(combinedBytes);
assertTrue(testInstance.processInput(in));
byte[] result = testInstance.getMessageBytes();
assertTrue(concatenatedInputSrc.equals(new String(result)));
assertEquals(concatenatedInputSrc.length(), testInstance.getLength());
}
@Test
public void testProcessInput_NoDataframeInInput() throws IOException {
String source = "This is the source of my input stream";
byte[] str = stringToFrame(source);
// change the first byte to a non-DATA frame
str[0] = (byte) 0x80;
InputStream in = new ByteArrayInputStream(str);
assertFalse(testInstance.processInput(in));
assertNull(testInstance.getMessageBytes());
assertEquals(0, testInstance.getLength());
}
private byte[] stringToFrame(String source) throws IOException {
ByteArrayOutputStream output = new ByteArrayOutputStream();
output.write(MessageDeframer.DATA_BYTE);
output.write(ByteBuffer.allocate(4).putInt(source.length()).array());
output.write(source.getBytes());
return output.toByteArray();
}
}

View File

@ -0,0 +1,33 @@
package com.google.grpcweb;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import java.io.IOException;
import java.nio.ByteBuffer;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@RunWith(JUnit4.class)
public class MessageFramerTest {
private MessageFramer testInstance;
@Before
public void setUp() {
testInstance = new MessageFramer();
assertNotNull(testInstance);
}
@Test
public void testProcessInput_Singleframe() throws IOException {
String source = "This is the source of my input stream";
byte[] bytes = source.getBytes();
byte[] prefix = testInstance.getPrefix(bytes);
assertEquals(5, prefix.length);
int len = ByteBuffer.wrap(prefix, 1, 4).getInt();
assertEquals(source.length(), len);
}
}