diff --git a/core/src/main/java/com/google/net/stubby/newtransport/CompressionFramer.java b/core/src/main/java/com/google/net/stubby/newtransport/CompressionFramer.java new file mode 100644 index 0000000000..82b4943b8f --- /dev/null +++ b/core/src/main/java/com/google/net/stubby/newtransport/CompressionFramer.java @@ -0,0 +1,336 @@ +package com.google.net.stubby.newtransport; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.io.ByteStreams; +import com.google.net.stubby.DeferredInputStream; +import com.google.net.stubby.newtransport.Framer.Sink; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.util.logging.Level; +import java.util.logging.Logger; +import java.util.zip.Deflater; + +/** + * Compression framer for HTTP/2 transport frames, for use in both compression and + * non-compression scenarios. Receives message-stream as input. It is able to change compression + * configuration on-the-fly, but will not actually begin using the new configuration until the next + * full frame. + */ +class CompressionFramer { + /** + * Compression level to indicate using this class's default level. Note that this value is + * allowed to conflict with Deflate.DEFAULT_COMPRESSION, in which case this class's default + * prevails. + */ + public static final int DEFAULT_COMPRESSION_LEVEL = -1; + private static final byte[] EMPTY_BYTE_ARRAY = new byte[0]; + /** + * Size of the GRPC compression frame header which consists of: + * 1 byte for the compression type, + * 3 bytes for the length of the compression frame. + */ + @VisibleForTesting + static final int HEADER_LENGTH = 4; + /** + * Number of frame bytes to reserve to allow for zlib overhead. This does not include data-length + * dependent overheads and compression latency (delay between providing data to zlib and output of + * the compressed data). + * + *

References: + * deflate framing: http://www.gzip.org/zlib/rfc-deflate.html + * (note that bit-packing is little-endian (section 3.1.1) whereas description of sequences + * is big-endian, so bits appear reversed), + * zlib framing: http://tools.ietf.org/html/rfc1950, + * details on flush behavior: http://www.zlib.net/manual.html + */ + @VisibleForTesting + static final int MARGIN + = 5 /* deflate current block overhead, assuming no compression: + block type (1) + len (2) + nlen (2) */ + + 5 /* deflate flush; adds an empty block after current: + 00 (not end; no compression) 00 00 (len) FF FF (nlen) */ + + 5 /* deflate flush; some versions of zlib output two empty blocks on some flushes */ + + 5 /* deflate finish; adds empty block to mark end, since we commonly flush before finish: + 03 (end; fixed Huffman + 5 bits of end of block) 00 (last 3 bits + padding), + or if compression level is 0: 01 (end; no compression) 00 00 (len) FF FF (nlen) */ + + 2 /* zlib header; CMF (1) + FLG (1) */ + 4 /* zlib ADLER32 (4) */ + + 5 /* additional safety for good measure */; + + private static final Logger log = Logger.getLogger(CompressionFramer.class.getName()); + + private final Sink sink; + /** + * Bytes of frame being constructed. {@code position() == 0} when no frame in progress. + */ + private final ByteBuffer bytebuf; + /** Number of frame bytes it is acceptable to leave unused when compressing. */ + private final int sufficient; + private Deflater deflater; + /** Number of bytes written to deflater since last deflate sync. */ + private int writtenSinceSync; + /** Number of bytes read from deflater since last deflate sync. */ + private int readSinceSync; + /** + * Whether the current frame is actually being compressed. If {@code bytebuf.position() == 0}, + * then this value has no meaning. + */ + private boolean usingCompression; + /** + * Whether compression is requested. This does not imply we are compressing the current frame + * (see {@link #usingCompression}), or that we will even compress the next frame (see {@link + * #compressionUnsupported}). + */ + private boolean allowCompression; + /** Whether compression is possible with current configuration and platform. */ + private final boolean compressionUnsupported; + /** + * Compression level to set on the Deflater, where {@code DEFAULT_COMPRESSION_LEVEL} implies this + * class's default. + */ + private int compressionLevel = DEFAULT_COMPRESSION_LEVEL; + private final OutputStreamAdapter outputStreamAdapter = new OutputStreamAdapter(); + + /** + * Since compression tries to form full frames, if compression is working well then it will + * consecutively compress smaller amounts of input data in order to not exceed the frame size. For + * example, if the data is getting 50% compression and a maximum frame size of 128, then it will + * encode roughly 128 bytes which leaves 64, so we encode 64, 32, 16, 8, 4, 2, 1, 1. + * {@code sufficient} cuts off the long tail and says that at some point the frame is "good + * enough" to stop. Choosing a value of {@code 0} is not outrageous. + * + * @param maxFrameSize maximum number of bytes allowed for output frames + * @param allowCompression whether frames should be compressed + * @param sufficient number of frame bytes it is acceptable to leave unused when compressing + */ + public CompressionFramer(Sink sink, int maxFrameSize, boolean allowCompression, + int sufficient) { + this.sink = sink; + this.allowCompression = allowCompression; + int maxSufficient = maxFrameSize - HEADER_LENGTH - MARGIN + - 1 /* to force at least one byte of data */; + boolean compressionUnsupported = false; + if (maxSufficient < 0) { + compressionUnsupported = true; + log.log(Level.INFO, "Frame not large enough for compression"); + } else if (maxSufficient < sufficient) { + log.log(Level.INFO, "Compression sufficient reduced to {0} from {1} to fit in frame size {2}", + new Object[] {maxSufficient, sufficient, maxFrameSize}); + sufficient = maxSufficient; + } + this.sufficient = sufficient; + // TODO(user): Benchmark before switching to direct buffers + bytebuf = ByteBuffer.allocate(maxFrameSize); + if (!bytebuf.hasArray()) { + compressionUnsupported = true; + log.log(Level.INFO, "Byte buffer doesn't support array(), which is required for compression"); + } + this.compressionUnsupported = compressionUnsupported; + } + + /** + * Sets whether compression is encouraged. + */ + public void setAllowCompression(boolean allow) { + this.allowCompression = allow; + } + + /** + * Set the preferred compression level for when compression is enabled. + * + * @param level the preferred compression level (0-9), or {@code DEFAULT_COMPRESSION_LEVEL} to use + * this class's default + * @see java.util.zip.Deflater#setLevel + */ + public void setCompressionLevel(int level) { + Preconditions.checkArgument(level == DEFAULT_COMPRESSION_LEVEL + || (level >= Deflater.NO_COMPRESSION && level <= Deflater.BEST_COMPRESSION), + "invalid compression level"); + this.compressionLevel = level; + } + + /** + * Ensures state and buffers are initialized for writing data to a frame. Callers should be very + * aware this method may modify {@code usingCompression}. + */ + private void checkInitFrame() { + if (bytebuf.position() != 0) { + return; + } + bytebuf.position(HEADER_LENGTH); + usingCompression = compressionUnsupported ? false : allowCompression; + if (usingCompression) { + if (deflater == null) { + deflater = new Deflater(); + } else { + deflater.reset(); + } + deflater.setLevel(compressionLevel == DEFAULT_COMPRESSION_LEVEL + ? Deflater.DEFAULT_COMPRESSION : compressionLevel); + writtenSinceSync = 0; + readSinceSync = 0; + } + } + + /** Frame contents of {@code message}, flushing to {@code sink} as necessary. */ + public int write(InputStream message) throws IOException { + checkInitFrame(); + if (!usingCompression && bytebuf.hasArray()) { + if (bytebuf.remaining() == 0) { + commitToSink(false, false); + } + int available = message.available(); + if (available <= bytebuf.remaining()) { + // When InputStream is DeferredProtoInputStream, this is zero-copy because bytebuf is large + // enough for the proto to be serialized directly into it. + int read = ByteStreams.read(message, + bytebuf.array(), bytebuf.arrayOffset() + bytebuf.position(), bytebuf.remaining()); + bytebuf.position(bytebuf.position() + read); + if (read != available) { + throw new RuntimeException("message.available() did not follow our semantics of always " + + "returning the number of remaining bytes"); + } + return read; + } + } + if (message instanceof DeferredInputStream) { + return ((DeferredInputStream) message).flushTo(outputStreamAdapter); + } else { + // This could be optimized when compression is off, but we expect performance-critical code + // to provide a DeferredInputStream. + return (int) ByteStreams.copy(message, outputStreamAdapter); + } + } + + /** + * Frame contents of {@code b} between {@code off} (inclusive) and {@code off + len} (exclusive), + * flushing to {@code sink} as necessary. + */ + public void write(byte[] b, int off, int len) { + while (len > 0) { + checkInitFrame(); + if (!usingCompression) { + if (bytebuf.remaining() == 0) { + commitToSink(false, false); + continue; + } + int toWrite = Math.min(len, bytebuf.remaining()); + bytebuf.put(b, off, toWrite); + off += toWrite; + len -= toWrite; + } else { + if (bytebuf.remaining() <= MARGIN + sufficient) { + commitToSink(false, false); + continue; + } + // Amount of memory that is guaranteed not to be consumed, including in-flight data in zlib. + int safeCapacity = bytebuf.remaining() - MARGIN + - (writtenSinceSync - readSinceSync) - dataLengthDependentOverhead(writtenSinceSync); + if (safeCapacity <= 0) { + while (deflatePut(deflater, bytebuf, Deflater.SYNC_FLUSH) != 0) {} + writtenSinceSync = 0; + readSinceSync = 0; + continue; + } + int toWrite = Math.min(len, safeCapacity - dataLengthDependentOverhead(safeCapacity)); + deflater.setInput(b, off, toWrite); + writtenSinceSync += toWrite; + while (!deflater.needsInput()) { + readSinceSync += deflatePut(deflater, bytebuf, Deflater.NO_FLUSH); + } + // Clear internal references of byte[] b. + deflater.setInput(EMPTY_BYTE_ARRAY); + off += toWrite; + len -= toWrite; + } + } + } + + /** + * When data is uncompressable, there are 5B of overhead per deflate block, which is generally + * 16 KiB for zlib, but the format supports up to 32 KiB. One block's overhead is already + * accounted for in MARGIN. We use 1B/2KiB to circumvent dealing with rounding errors. Note that + * 1B/2KiB is not enough to support 8 KiB blocks due to rounding errors. + */ + private static int dataLengthDependentOverhead(int length) { + return length / 2048; + } + + private static int deflatePut(Deflater deflater, ByteBuffer bytebuf, int flush) { + if (bytebuf.remaining() == 0) { + throw new AssertionError("Compressed data exceeded frame size"); + } + int deflateBytes = deflater.deflate(bytebuf.array(), bytebuf.arrayOffset() + bytebuf.position(), + bytebuf.remaining(), flush); + bytebuf.position(bytebuf.position() + deflateBytes); + return deflateBytes; + } + + public void endOfMessage() { + if ((!usingCompression && bytebuf.remaining() == 0) + || (usingCompression && bytebuf.remaining() <= MARGIN + sufficient)) { + commitToSink(true, false); + } + } + + public void flush() { + if (bytebuf.position() == 0) { + return; + } + commitToSink(true, false); + } + + public void close() { + if (bytebuf.position() == 0) { + // No pending frame, so send an empty one. + bytebuf.flip(); + sink.deliverFrame(bytebuf, true); + bytebuf.clear(); + } else { + commitToSink(true, true); + } + } + + /** + * Writes compression frame to sink. It does not initialize the next frame, so {@link + * #checkInitFrame()} is necessary if other frames are to follow. + */ + private void commitToSink(boolean endOfMessage, boolean endOfStream) { + if (usingCompression) { + deflater.finish(); + while (!deflater.finished()) { + deflatePut(deflater, bytebuf, Deflater.NO_FLUSH); + } + if (endOfMessage) { + deflater.end(); + deflater = null; + } + } + int frameFlag = usingCompression + ? TransportFrameUtil.FLATE_FLAG : TransportFrameUtil.NO_COMPRESS_FLAG; + // Header = 1b flag | 3b length of GRPC frame + int header = (frameFlag << 24) | (bytebuf.position() - 4); + bytebuf.putInt(0, header); + bytebuf.flip(); + sink.deliverFrame(bytebuf, endOfStream); + bytebuf.clear(); + } + + private class OutputStreamAdapter extends OutputStream { + private final byte[] singleByte = new byte[1]; + + @Override + public void write(int b) { + singleByte[0] = (byte) b; + write(singleByte, 0, 1); + } + + @Override + public void write(byte[] b, int off, int len) { + CompressionFramer.this.write(b, off, len); + } + } +} diff --git a/core/src/main/java/com/google/net/stubby/newtransport/Deframer.java b/core/src/main/java/com/google/net/stubby/newtransport/Deframer.java new file mode 100644 index 0000000000..06b26e50a6 --- /dev/null +++ b/core/src/main/java/com/google/net/stubby/newtransport/Deframer.java @@ -0,0 +1,162 @@ +package com.google.net.stubby.newtransport; + +import com.google.common.io.ByteStreams; +import com.google.net.stubby.GrpcFramingUtil; +import com.google.net.stubby.Operation; +import com.google.net.stubby.Status; +import com.google.net.stubby.transport.Transport; + +import java.io.ByteArrayInputStream; +import java.io.DataInputStream; +import java.io.IOException; +import java.io.InputStream; + +/** + * Base implementation that joins a sequence of framed GRPC data produced by a {@link Framer}, + * reconstructs their messages and hands them off to a receiving {@link Operation} + */ +public abstract class Deframer implements Framer.Sink { + + /** + * Unset frame length + */ + private static final int LENGTH_NOT_SET = -1; + + private final Framer target; + private boolean inFrame; + private byte currentFlags; + private int currentLength = LENGTH_NOT_SET; + + public Deframer(Framer target) { + this.target = target; + } + + @Override + public void deliverFrame(F frame, boolean endOfStream) { + int remaining = internalDeliverFrame(frame); + if (endOfStream && remaining > 0) { + target.writeStatus(new Status(Transport.Code.UNKNOWN, "EOF on incomplete frame")); + } + } + + /** + * Consume a frame of bytes provided by the transport. Note that transport framing is not + * aligned on GRPC frame boundaries so this code needs to do bounds checking and buffering + * across transport frame boundaries. + * + * @return the number of unconsumed bytes remaining in the buffer + */ + private int internalDeliverFrame(F frame) { + try { + frame = decompress(frame); + DataInputStream grpcStream = prefix(frame); + // Loop until no more GRPC frames can be fully decoded + while (true) { + if (!inFrame) { + // Not in frame so attempt to read flags + if (!ensure(grpcStream, GrpcFramingUtil.FRAME_TYPE_LENGTH)) { + return consolidate(); + } + currentFlags = grpcStream.readByte(); + inFrame = true; + } + if (currentLength == LENGTH_NOT_SET) { + // Read the frame length + if (!ensure(grpcStream, GrpcFramingUtil.FRAME_LENGTH)) { + return consolidate(); + } + currentLength = grpcStream.readInt(); + } + // Ensure that the entire frame length is available to read + InputStream framedChunk = ensureMessage(grpcStream, currentLength); + if (framedChunk == null) { + // Insufficient bytes available + return consolidate(); + } + if (GrpcFramingUtil.isPayloadFrame(currentFlags)) { + // Advance stream now, because target.addPayload() may not or may process the frame on + // another thread. + framedChunk = new ByteArrayInputStream(ByteStreams.toByteArray(framedChunk)); + try { + // Report payload to the receiving operation + target.writePayload(framedChunk, currentLength); + } finally { + currentLength = LENGTH_NOT_SET; + inFrame = false; + } + } else if (GrpcFramingUtil.isContextValueFrame(currentFlags)) { + // Not clear if using proto encoding here is of any benefit. + // Using ContextValue.parseFrom requires copying out of the framed chunk + // Writing a custom parser would have to do varint handling and potentially + // deal with out-of-order tags etc. + Transport.ContextValue contextValue = Transport.ContextValue.parseFrom(framedChunk); + try { + target.writeContext(contextValue.getKey(), + contextValue.getValue().newInput(), currentLength); + } finally { + currentLength = LENGTH_NOT_SET; + inFrame = false; + } + } else if (GrpcFramingUtil.isStatusFrame(currentFlags)) { + int status = framedChunk.read() << 8 | framedChunk.read(); + Transport.Code code = Transport.Code.valueOf(status); + // TODO(user): Resolve what to do with remainder of framedChunk + try { + if (code == null) { + // Log for unknown code + target.writeStatus( + new Status(Transport.Code.UNKNOWN, "Unknown status code " + status)); + } else { + target.writeStatus(new Status(code)); + } + } finally { + currentLength = LENGTH_NOT_SET; + inFrame = false; + } + } + if (grpcStream.available() == 0) { + // We've processed all the data so consolidate the underlying buffers + return consolidate(); + } + } + } catch (IOException ioe) { + Status status = new Status(Transport.Code.UNKNOWN, ioe); + target.writeStatus(status); + throw status.asRuntimeException(); + } + } + + /** + * Return a stream view over the current buffer prefixed to the input frame + */ + protected abstract DataInputStream prefix(F frame) throws IOException; + + /** + * Consolidate the underlying buffers and return the number of buffered bytes remaining + */ + protected abstract int consolidate() throws IOException; + + /** + * Decompress the raw frame buffer prior to prefixing it. + */ + protected abstract F decompress(F frame) throws IOException; + + /** + * Ensure that {@code len} bytes are available in the buffer and frame + */ + private boolean ensure(InputStream input, int len) throws IOException { + return (input.available() >= len); + } + + /** + * Return a message of {@code len} bytes than can be read from the buffer. If sufficient + * bytes are unavailable then buffer the available bytes and return null. + */ + private InputStream ensureMessage(InputStream input, int len) + throws IOException { + if (input.available() < len) { + return null; + } + return ByteStreams.limit(input, len); + } +} diff --git a/core/src/main/java/com/google/net/stubby/newtransport/Framer.java b/core/src/main/java/com/google/net/stubby/newtransport/Framer.java new file mode 100644 index 0000000000..125ae597fd --- /dev/null +++ b/core/src/main/java/com/google/net/stubby/newtransport/Framer.java @@ -0,0 +1,58 @@ +package com.google.net.stubby.newtransport; + +import com.google.net.stubby.Status; + +import java.io.InputStream; + +/** + * Implementations produce the GRPC byte sequence and then split it over multiple frames to be + * delivered via the transport layer which implements {@link Framer.Sink} + */ +public interface Framer { + + /** + * Sink implemented by the transport layer to receive frames and forward them to their + * destination + */ + public interface Sink { + /** + * Deliver a frame via the transport. + * @param frame the contents of the frame to deliver + * @param endOfStream whether the frame is the last one for the GRPC stream + */ + public void deliverFrame(T frame, boolean endOfStream); + } + + /** + * Write out a Context-Value message. {@code message} will be completely consumed. + * {@code message.available()} must return the number of remaining bytes to be read. + */ + public void writeContext(String type, InputStream message, int length); + + /** + * Write out a Payload message. {@code payload} will be completely consumed. + * {@code payload.available()} must return the number of remaining bytes to be read. + */ + public void writePayload(InputStream payload, int length); + + /** + * Write out a Status message. + */ + // TODO(user): change this signature when we actually start writing out the complete Status. + public void writeStatus(Status status); + + /** + * Flush any buffered data in the framer to the sink. + */ + public void flush(); + + /** + * Flushes and closes the framer and releases any buffers. + */ + public void close(); + + /** + * Closes the framer and releases any buffers, but does not flush. + */ + public void dispose(); +} diff --git a/core/src/main/java/com/google/net/stubby/newtransport/InputStreamDeframer.java b/core/src/main/java/com/google/net/stubby/newtransport/InputStreamDeframer.java new file mode 100644 index 0000000000..c0bdbae191 --- /dev/null +++ b/core/src/main/java/com/google/net/stubby/newtransport/InputStreamDeframer.java @@ -0,0 +1,149 @@ +package com.google.net.stubby.newtransport; + +import com.google.common.io.ByteStreams; + +import java.io.DataInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.zip.InflaterInputStream; + +/** + * Deframer that expects the input frames to be provided as {@link InputStream} instances + * which accurately report their size using {@link java.io.InputStream#available()}. + */ +public class InputStreamDeframer extends Deframer { + + private final PrefixingInputStream prefixingInputStream; + + public InputStreamDeframer(Framer target) { + super(target); + prefixingInputStream = new PrefixingInputStream(4096); + } + + /** + * Deframing a single input stream that contains multiple GRPC frames + */ + @Override + public void deliverFrame(InputStream frame, boolean endOfStream) { + super.deliverFrame(frame, endOfStream); + try { + if (frame.available() > 0) { + throw new AssertionError(); + } + } catch (IOException ex) { + throw new RuntimeException(ex); + } + } + + @Override + protected DataInputStream prefix(InputStream frame) throws IOException { + prefixingInputStream.consolidate(); + prefixingInputStream.prefix(frame); + return new DataInputStream(prefixingInputStream); + } + + @Override + protected int consolidate() throws IOException { + prefixingInputStream.consolidate(); + return prefixingInputStream.available(); + } + + @Override + protected InputStream decompress(InputStream frame) throws IOException { + int compressionType = frame.read(); + int frameLength = frame.read() << 16 | frame.read() << 8 | frame.read(); + InputStream raw = ByteStreams.limit(frame, frameLength); + if (TransportFrameUtil.isNotCompressed(compressionType)) { + return raw; + } else if (TransportFrameUtil.isFlateCompressed(compressionType)) { + return new InflaterInputStream(raw); + } + throw new IOException("Unknown compression type " + compressionType); + } + + /** + * InputStream that prefixes another input stream with a fixed buffer. + */ + private class PrefixingInputStream extends InputStream { + + private InputStream suffix; + private byte[] buffer; + private int bufferIndex; + private int maxRetainedBuffer; + + private PrefixingInputStream(int maxRetainedBuffer) { + // TODO(user): Implement support for this. + this.maxRetainedBuffer = maxRetainedBuffer; + } + + void prefix(InputStream suffix) { + this.suffix = suffix; + } + + void consolidate() throws IOException { + int remainingSuffix = suffix == null ? 0 : suffix.available(); + if (remainingSuffix == 0) { + // No suffix so clear + suffix = null; + return; + } + int bufferLength = buffer == null ? 0 : buffer.length; + int bytesInBuffer = bufferLength - bufferIndex; + // Shift existing bytes + if (bufferLength < bytesInBuffer + remainingSuffix) { + // Buffer too small, so create a new buffer before copying in the suffix + byte[] newBuffer = new byte[bytesInBuffer + remainingSuffix]; + if (bytesInBuffer > 0) { + System.arraycopy(buffer, bufferIndex, newBuffer, 0, bytesInBuffer); + } + buffer = newBuffer; + bufferIndex = 0; + } else { + // Enough space is in buffer, so shift the existing bytes to open up exactly enough bytes + // for the suffix at the end. + System.arraycopy(buffer, bufferIndex, buffer, bufferIndex - remainingSuffix, bytesInBuffer); + bufferIndex -= remainingSuffix; + } + // Write suffix to buffer + ByteStreams.readFully(suffix, buffer, buffer.length - remainingSuffix, remainingSuffix); + suffix = null; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + int read = readFromBuffer(b, off, len); + if (suffix != null) { + read += suffix.read(b, off + read, len - read); + } + return read; + } + + private int readFromBuffer(byte[] b, int off, int len) { + if (buffer == null) { + return 0; + } + len = Math.min(buffer.length - bufferIndex, len); + System.arraycopy(buffer, bufferIndex, b, off, len); + bufferIndex += len; + return len; + } + + @Override + public int read() throws IOException { + if (buffer == null || bufferIndex == buffer.length) { + return suffix == null ? -1 : suffix.read(); + } + return buffer[bufferIndex++]; + } + + @Override + public int available() throws IOException { + int available = buffer != null ? buffer.length - bufferIndex : 0; + if (suffix != null) { + // FIXME(ejona): This is likely broken with compressed streams. + available += suffix.available(); + } + return available; + } + } +} diff --git a/core/src/main/java/com/google/net/stubby/newtransport/MessageFramer.java b/core/src/main/java/com/google/net/stubby/newtransport/MessageFramer.java new file mode 100644 index 0000000000..f74b83736f --- /dev/null +++ b/core/src/main/java/com/google/net/stubby/newtransport/MessageFramer.java @@ -0,0 +1,177 @@ +package com.google.net.stubby.newtransport; + +import com.google.net.stubby.GrpcFramingUtil; +import com.google.net.stubby.Status; +import com.google.net.stubby.transport.Transport; +import com.google.protobuf.CodedOutputStream; +import com.google.protobuf.WireFormat; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.charset.Charset; +import java.util.Arrays; + +/** + * Default {@link Framer} implementation. + */ +public class MessageFramer implements Framer { + + /** + * Size of the GRPC message frame header which consists of + * 1 byte for the type (payload, context, status) + * 4 bytes for the length of the message + */ + private static final int MESSAGE_HEADER_SIZE = 5; + + /** + * UTF-8 charset which is used for key name encoding in context values + */ + private static final Charset UTF_8 = Charset.forName("UTF-8"); + + /** + * Precomputed protobuf tags for ContextValue + */ + private static final byte[] VALUE_TAG; + private static final byte[] KEY_TAG; + + + static { + // Initialize constants for serializing context-value in a protobuf compatible manner + try { + byte[] buf = new byte[8]; + CodedOutputStream coded = CodedOutputStream.newInstance(buf); + coded.writeTag(Transport.ContextValue.KEY_FIELD_NUMBER, WireFormat.WIRETYPE_LENGTH_DELIMITED); + coded.flush(); + KEY_TAG = Arrays.copyOf(buf, coded.getTotalBytesWritten()); + coded = CodedOutputStream.newInstance(buf); + coded.writeTag(Transport.ContextValue.VALUE_FIELD_NUMBER, + WireFormat.WIRETYPE_LENGTH_DELIMITED); + coded.flush(); + VALUE_TAG = Arrays.copyOf(buf, coded.getTotalBytesWritten()); + } catch (IOException ioe) { + // Unrecoverable + throw new RuntimeException(ioe); + } + } + + private CompressionFramer framer; + private final ByteBuffer scratch = ByteBuffer.allocate(16); + + public MessageFramer(Sink sink, int maxFrameSize) { + // TODO(user): maxFrameSize should probably come from a 'Platform' class + framer = new CompressionFramer(sink, maxFrameSize, false, maxFrameSize / 16); + } + + /** + * Sets whether compression is encouraged. + */ + public void setAllowCompression(boolean enable) { + framer.setAllowCompression(enable); + } + + /** + * Set the preferred compression level for when compression is enabled. + * @param level the preferred compression level, or {@code -1} to use the framing default + * @see java.util.zip.Deflater#setLevel + */ + public void setCompressionLevel(int level) { + framer.setCompressionLevel(level); + } + + @Override + public void writePayload(InputStream message, int messageLength) { + try { + scratch.clear(); + scratch.put(GrpcFramingUtil.PAYLOAD_FRAME); + scratch.putInt(messageLength); + framer.write(scratch.array(), 0, scratch.position()); + if (messageLength != framer.write(message)) { + throw new RuntimeException("Message length was inaccurate"); + } + framer.endOfMessage(); + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } + } + + + @Override + public void writeContext(String key, InputStream message, int messageLen) { + try { + scratch.clear(); + scratch.put(GrpcFramingUtil.CONTEXT_VALUE_FRAME); + byte[] keyBytes = key.getBytes(UTF_8); + int lenKeyPrefix = KEY_TAG.length + + CodedOutputStream.computeRawVarint32Size(keyBytes.length); + int lenValPrefix = VALUE_TAG.length + CodedOutputStream.computeRawVarint32Size(messageLen); + int totalLen = lenKeyPrefix + keyBytes.length + lenValPrefix + messageLen; + scratch.putInt(totalLen); + framer.write(scratch.array(), 0, scratch.position()); + + // Write key + scratch.clear(); + scratch.put(KEY_TAG); + writeRawVarInt32(keyBytes.length, scratch); + framer.write(scratch.array(), 0, scratch.position()); + framer.write(keyBytes, 0, keyBytes.length); + + // Write value + scratch.clear(); + scratch.put(VALUE_TAG); + writeRawVarInt32(messageLen, scratch); + framer.write(scratch.array(), 0, scratch.position()); + if (messageLen != framer.write(message)) { + throw new RuntimeException("Message length was inaccurate"); + } + framer.endOfMessage(); + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } + } + + @Override + public void writeStatus(Status status) { + short code = (short) status.getCode().ordinal(); + scratch.clear(); + scratch.put(GrpcFramingUtil.STATUS_FRAME); + int length = 2; + scratch.putInt(length); + scratch.putShort(code); + framer.write(scratch.array(), 0, scratch.position()); + framer.endOfMessage(); + } + + @Override + public void flush() { + framer.flush(); + } + + @Override + public void close() { + // TODO(user): Returning buffer to a pool would go here + framer.close(); + framer = null; + } + + @Override + public void dispose() { + // TODO(user): Returning buffer to a pool would go here + framer = null; + } + + /** + * Write a raw VarInt32 to the buffer + */ + private static void writeRawVarInt32(int value, ByteBuffer bytebuf) { + while (true) { + if ((value & ~0x7F) == 0) { + bytebuf.put((byte) value); + return; + } else { + bytebuf.put((byte) ((value & 0x7F) | 0x80)); + value >>>= 7; + } + } + } +} diff --git a/core/src/main/java/com/google/net/stubby/newtransport/TransportFrameUtil.java b/core/src/main/java/com/google/net/stubby/newtransport/TransportFrameUtil.java new file mode 100644 index 0000000000..e543961d9b --- /dev/null +++ b/core/src/main/java/com/google/net/stubby/newtransport/TransportFrameUtil.java @@ -0,0 +1,23 @@ +package com.google.net.stubby.newtransport; + +/** + * Utility functions for transport layer framing. + * + * Within a given transport frame we reserve the first byte to indicate the + * type of compression used for the contents of the transport frame. + */ +public class TransportFrameUtil { + + // Compression modes (lowest order 3 bits of frame flags) + public static final byte NO_COMPRESS_FLAG = 0x0; + public static final byte FLATE_FLAG = 0x1; + public static final byte COMPRESSION_FLAG_MASK = 0x7; + + public static boolean isNotCompressed(int b) { + return ((b & COMPRESSION_FLAG_MASK) == NO_COMPRESS_FLAG); + } + + public static boolean isFlateCompressed(int b) { + return ((b & COMPRESSION_FLAG_MASK) == FLATE_FLAG); + } +} diff --git a/core/src/test/java/com/google/net/stubby/newtransport/CompressionFramerTest.java b/core/src/test/java/com/google/net/stubby/newtransport/CompressionFramerTest.java new file mode 100644 index 0000000000..b93875abec --- /dev/null +++ b/core/src/test/java/com/google/net/stubby/newtransport/CompressionFramerTest.java @@ -0,0 +1,99 @@ +package com.google.net.stubby.newtransport; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import com.google.common.collect.Lists; +import com.google.common.io.ByteStreams; +import com.google.common.primitives.Bytes; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Random; +import java.util.zip.Deflater; +import java.util.zip.InflaterInputStream; + +/** Unit tests for {@link CompressionFramer}. */ +@RunWith(JUnit4.class) +public class CompressionFramerTest { + private int maxFrameSize = 1024; + private int sufficient = 8; + private CapturingSink sink = new CapturingSink(); + private CompressionFramer framer = new CompressionFramer(sink, maxFrameSize, true, sufficient); + + @Test + public void testGoodCompression() { + byte[] payload = new byte[1000]; + framer.setCompressionLevel(Deflater.BEST_COMPRESSION); + framer.write(payload, 0, payload.length); + framer.endOfMessage(); + framer.flush(); + + assertEquals(1, sink.frames.size()); + byte[] frame = sink.frames.get(0); + assertEquals(TransportFrameUtil.FLATE_FLAG, frame[0]); + assertTrue(decodeFrameLength(frame) < 30); + assertArrayEquals(payload, decompress(frame)); + } + + @Test + public void testPoorCompression() { + byte[] payload = new byte[3 * maxFrameSize / 2]; + new Random(1).nextBytes(payload); + framer.setCompressionLevel(Deflater.DEFAULT_COMPRESSION); + framer.write(payload, 0, payload.length); + framer.endOfMessage(); + framer.flush(); + + assertEquals(2, sink.frames.size()); + assertEquals(TransportFrameUtil.FLATE_FLAG, sink.frames.get(0)[0]); + assertEquals(TransportFrameUtil.FLATE_FLAG, sink.frames.get(1)[0]); + assertTrue(decodeFrameLength(sink.frames.get(0)) <= maxFrameSize); + assertTrue(decodeFrameLength(sink.frames.get(0)) + >= maxFrameSize - CompressionFramer.HEADER_LENGTH - CompressionFramer.MARGIN - sufficient); + assertArrayEquals(payload, decompress(sink.frames)); + } + + private static int decodeFrameLength(byte[] frame) { + return ((frame[1] & 0xFF) << 16) + | ((frame[2] & 0xFF) << 8) + | (frame[3] & 0xFF); + } + + private static byte[] decompress(byte[] frame) { + try { + return ByteStreams.toByteArray(new InflaterInputStream(new ByteArrayInputStream(frame, + CompressionFramer.HEADER_LENGTH, frame.length - CompressionFramer.HEADER_LENGTH))); + } catch (IOException ex) { + throw new AssertionError(); + } + } + + private static byte[] decompress(List frames) { + byte[][] bytes = new byte[frames.size()][]; + for (int i = 0; i < frames.size(); i++) { + bytes[i] = decompress(frames.get(i)); + } + return Bytes.concat(bytes); + } + + private static class CapturingSink implements Framer.Sink { + public final List frames = Lists.newArrayList(); + + @Override + public void deliverFrame(ByteBuffer frame, boolean endOfMessage) { + byte[] frameBytes = new byte[frame.remaining()]; + frame.get(frameBytes); + assertEquals(frameBytes.length - CompressionFramer.HEADER_LENGTH, + decodeFrameLength(frameBytes)); + frames.add(frameBytes); + } + } +} diff --git a/core/src/test/java/com/google/net/stubby/newtransport/MessageFramerTest.java b/core/src/test/java/com/google/net/stubby/newtransport/MessageFramerTest.java new file mode 100644 index 0000000000..11730a7c12 --- /dev/null +++ b/core/src/test/java/com/google/net/stubby/newtransport/MessageFramerTest.java @@ -0,0 +1,127 @@ +package com.google.net.stubby.newtransport; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import com.google.common.io.ByteBuffers; +import com.google.common.primitives.Bytes; +import com.google.net.stubby.GrpcFramingUtil; +import com.google.net.stubby.Status; +import com.google.net.stubby.transport.Transport; +import com.google.protobuf.ByteString; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.ByteArrayInputStream; +import java.nio.ByteBuffer; +import java.util.Arrays; + +/** + * Tests for {@link MessageFramer} + */ +@RunWith(JUnit4.class) +public class MessageFramerTest { + + public static final int TRANSPORT_FRAME_SIZE = 57; + + @Test + public void testPayload() throws Exception { + CapturingSink sink = new CapturingSink(); + MessageFramer framer = new MessageFramer(sink, TRANSPORT_FRAME_SIZE); + byte[] payload = new byte[]{11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}; + byte[] unframedStream = + Bytes.concat( + new byte[]{GrpcFramingUtil.PAYLOAD_FRAME}, + new byte[]{0, 0, 0, (byte) payload.length}, + payload); + for (int i = 0; i < 1000; i++) { + framer.writePayload(new ByteArrayInputStream(payload), payload.length); + if ((i + 1) % 13 == 0) { + // Test flushing periodically + framer.flush(); + } + } + framer.flush(); + assertEquals(sink.deframedStream.length, unframedStream.length * 1000); + for (int i = 0; i < 1000; i++) { + assertArrayEquals(unframedStream, + Arrays.copyOfRange(sink.deframedStream, i * unframedStream.length, + (i + 1) * unframedStream.length)); + } + } + + @Test + public void testContext() throws Exception { + CapturingSink sink = new CapturingSink(); + MessageFramer framer = new MessageFramer(sink, TRANSPORT_FRAME_SIZE); + byte[] payload = new byte[]{11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}; + byte[] contextValue = Transport.ContextValue.newBuilder() + .setKey("somekey") + .setValue(ByteString.copyFrom(payload)) + .build().toByteArray(); + byte[] unframedStream = + Bytes.concat( + new byte[]{GrpcFramingUtil.CONTEXT_VALUE_FRAME}, + new byte[]{0, 0, + (byte) (contextValue.length >> 8 & 0xff), + (byte) (contextValue.length & 0xff)}, + contextValue); + for (int i = 0; i < 1000; i++) { + framer.writeContext("somekey", new ByteArrayInputStream(payload), payload.length); + if ((i + 1) % 13 == 0) { + framer.flush(); + } + } + framer.flush(); + assertEquals(unframedStream.length * 1000, sink.deframedStream.length); + for (int i = 0; i < 1000; i++) { + assertArrayEquals(unframedStream, + Arrays.copyOfRange(sink.deframedStream, i * unframedStream.length, + (i + 1) * unframedStream.length)); + } + } + + @Test + public void testStatus() throws Exception { + CapturingSink sink = new CapturingSink(); + MessageFramer framer = new MessageFramer(sink, TRANSPORT_FRAME_SIZE); + byte[] unframedStream = Bytes.concat( + new byte[]{GrpcFramingUtil.STATUS_FRAME}, + new byte[]{0, 0, 0, 2}, // Len is 2 bytes + new byte[]{0, 13}); // Internal==13 + for (int i = 0; i < 1000; i++) { + framer.writeStatus(new Status(Transport.Code.INTERNAL)); + if ((i + 1) % 13 == 0) { + framer.flush(); + } + } + framer.flush(); + assertEquals(sink.deframedStream.length, unframedStream.length * 1000); + for (int i = 0; i < 1000; i++) { + assertArrayEquals(unframedStream, + Arrays.copyOfRange(sink.deframedStream, i * unframedStream.length, + (i + 1) * unframedStream.length)); + } + } + + static class CapturingSink implements Framer.Sink { + + byte[] deframedStream = new byte[0]; + + @Override + public void deliverFrame(ByteBuffer frame, boolean endOfMessage) { + assertTrue(frame.remaining() <= TRANSPORT_FRAME_SIZE); + // Frame must contain compression flag & 24 bit length + int header = frame.getInt(); + byte flag = (byte) (header >>> 24); + int length = header & 0xFFFFFF; + assertTrue(TransportFrameUtil.isNotCompressed(flag)); + assertEquals(frame.remaining(), length); + // Frame must exceed dictated transport frame size + deframedStream = Bytes.concat(deframedStream, ByteBuffers.extractBytes(frame)); + } + } +}