diff --git a/alts/src/main/java/io/grpc/alts/internal/AltsHandshakerClient.java b/alts/src/main/java/io/grpc/alts/internal/AltsHandshakerClient.java index e910396378..083ad05678 100644 --- a/alts/src/main/java/io/grpc/alts/internal/AltsHandshakerClient.java +++ b/alts/src/main/java/io/grpc/alts/internal/AltsHandshakerClient.java @@ -82,6 +82,7 @@ class AltsHandshakerClient { startClientReq.addTargetIdentitiesBuilder().setServiceAccount(serviceAccount); } } + startClientReq.setMaxFrameSize(AltsTsiFrameProtector.getMaxFrameSize()); req.setClientStart(startClientReq); } @@ -97,6 +98,7 @@ class AltsHandshakerClient { if (handshakerOptions.getRpcProtocolVersions() != null) { startServerReq.setRpcVersions(handshakerOptions.getRpcProtocolVersions()); } + startServerReq.setMaxFrameSize(AltsTsiFrameProtector.getMaxFrameSize()); req.setServerStart(startServerReq); } diff --git a/alts/src/main/java/io/grpc/alts/internal/AltsTsiFrameProtector.java b/alts/src/main/java/io/grpc/alts/internal/AltsTsiFrameProtector.java index 23e1dc9e5a..67d6637a13 100644 --- a/alts/src/main/java/io/grpc/alts/internal/AltsTsiFrameProtector.java +++ b/alts/src/main/java/io/grpc/alts/internal/AltsTsiFrameProtector.java @@ -33,9 +33,10 @@ public final class AltsTsiFrameProtector implements TsiFrameProtector { private static final int HEADER_TYPE_FIELD_BYTES = 4; private static final int HEADER_BYTES = HEADER_LEN_FIELD_BYTES + HEADER_TYPE_FIELD_BYTES; private static final int HEADER_TYPE_DEFAULT = 6; - // Total frame size including full header and tag. - private static final int MAX_ALLOWED_FRAME_BYTES = 16 * 1024; - private static final int LIMIT_MAX_ALLOWED_FRAME_BYTES = 1024 * 1024; + private static final int LIMIT_MAX_ALLOWED_FRAME_SIZE = 1024 * 1024; + // Frame size negotiation extends frame size range to [MIN_FRAME_SIZE, MAX_FRAME_SIZE]. + private static final int MIN_FRAME_SIZE = 16 * 1024; + private static final int MAX_FRAME_SIZE = 128 * 1024; private final Protector protector; private final Unprotector unprotector; @@ -44,7 +45,7 @@ public final class AltsTsiFrameProtector implements TsiFrameProtector { public AltsTsiFrameProtector( int maxProtectedFrameBytes, ChannelCrypterNetty crypter, ByteBufAllocator alloc) { checkArgument(maxProtectedFrameBytes > HEADER_BYTES + crypter.getSuffixLength()); - maxProtectedFrameBytes = Math.min(LIMIT_MAX_ALLOWED_FRAME_BYTES, maxProtectedFrameBytes); + maxProtectedFrameBytes = Math.min(LIMIT_MAX_ALLOWED_FRAME_SIZE, maxProtectedFrameBytes); protector = new Protector(maxProtectedFrameBytes, crypter); unprotector = new Unprotector(crypter, alloc); } @@ -65,12 +66,16 @@ public final class AltsTsiFrameProtector implements TsiFrameProtector { return HEADER_TYPE_DEFAULT; } - public static int getMaxAllowedFrameBytes() { - return MAX_ALLOWED_FRAME_BYTES; + static int getLimitMaxAllowedFrameSize() { + return LIMIT_MAX_ALLOWED_FRAME_SIZE; } - static int getLimitMaxAllowedFrameBytes() { - return LIMIT_MAX_ALLOWED_FRAME_BYTES; + public static int getMinFrameSize() { + return MIN_FRAME_SIZE; + } + + public static int getMaxFrameSize() { + return MAX_FRAME_SIZE; } @Override @@ -262,7 +267,7 @@ public final class AltsTsiFrameProtector implements TsiFrameProtector { checkArgument( requiredProtectedBytes >= suffixBytes, "Invalid header field: frame size too small"); checkArgument( - requiredProtectedBytes <= LIMIT_MAX_ALLOWED_FRAME_BYTES - HEADER_BYTES, + requiredProtectedBytes <= LIMIT_MAX_ALLOWED_FRAME_SIZE - HEADER_BYTES, "Invalid header field: frame size too large"); int frameType = header.readIntLE(); checkArgument(frameType == HEADER_TYPE_DEFAULT, "Invalid header field: frame type"); diff --git a/alts/src/main/java/io/grpc/alts/internal/AltsTsiHandshaker.java b/alts/src/main/java/io/grpc/alts/internal/AltsTsiHandshaker.java index 3cd639ad5f..21824a068c 100644 --- a/alts/src/main/java/io/grpc/alts/internal/AltsTsiHandshaker.java +++ b/alts/src/main/java/io/grpc/alts/internal/AltsTsiHandshaker.java @@ -26,11 +26,15 @@ import java.nio.ByteBuffer; import java.security.GeneralSecurityException; import java.util.ArrayList; import java.util.List; +import java.util.logging.Level; +import java.util.logging.Logger; /** * Negotiates a grpc channel key to be used by the TsiFrameProtector, using ALTs handshaker service. */ public final class AltsTsiHandshaker implements TsiHandshaker { + private static final Logger logger = Logger.getLogger(AltsTsiHandshaker.class.getName()); + public static final String TSI_SERVICE_ACCOUNT_PEER_PROPERTY = "service_account"; private final boolean isClient; @@ -178,6 +182,14 @@ public final class AltsTsiHandshaker implements TsiHandshaker { byte[] key = handshaker.getKey(); Preconditions.checkState(key.length == AltsChannelCrypter.getKeyLength(), "Bad key length."); + // Frame size negotiation is not performed if the peer does not send max frame size (e.g. peer + // is gRPC Go or peer uses an old binary). + int peerMaxFrameSize = handshaker.getResult().getMaxFrameSize(); + if (peerMaxFrameSize != 0) { + maxFrameSize = Math.min(peerMaxFrameSize, AltsTsiFrameProtector.getMaxFrameSize()); + maxFrameSize = Math.max(AltsTsiFrameProtector.getMinFrameSize(), maxFrameSize); + } + logger.log(Level.INFO, "Maximum frame size value is " + maxFrameSize); return new AltsTsiFrameProtector(maxFrameSize, new AltsChannelCrypter(key, isClient), alloc); } @@ -190,7 +202,7 @@ public final class AltsTsiHandshaker implements TsiHandshaker { */ @Override public TsiFrameProtector createFrameProtector(ByteBufAllocator alloc) { - return createFrameProtector(AltsTsiFrameProtector.getMaxAllowedFrameBytes(), alloc); + return createFrameProtector(AltsTsiFrameProtector.getMinFrameSize(), alloc); } @Override diff --git a/alts/src/test/java/io/grpc/alts/internal/AltsHandshakerClientTest.java b/alts/src/test/java/io/grpc/alts/internal/AltsHandshakerClientTest.java index 82c8a682bc..d1bd5ffea0 100644 --- a/alts/src/test/java/io/grpc/alts/internal/AltsHandshakerClientTest.java +++ b/alts/src/test/java/io/grpc/alts/internal/AltsHandshakerClientTest.java @@ -106,6 +106,7 @@ public class AltsHandshakerClientTest { .setTargetName(TEST_TARGET_NAME) .addTargetIdentities( Identity.newBuilder().setServiceAccount(TEST_TARGET_SERVICE_ACCOUNT)) + .setMaxFrameSize(AltsTsiFrameProtector.getMaxFrameSize()) .build()) .build(); verify(mockStub).send(req); @@ -133,6 +134,22 @@ public class AltsHandshakerClientTest { ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE); ByteBuffer outFrame = handshaker.startServerHandshake(inBytes); + HandshakerReq req = + HandshakerReq.newBuilder() + .setServerStart( + StartServerHandshakeReq.newBuilder() + .addApplicationProtocols(AltsHandshakerClient.getApplicationProtocol()) + .putHandshakeParameters( + HandshakeProtocol.ALTS.getNumber(), + ServerHandshakeParameters.newBuilder() + .addRecordProtocols(AltsHandshakerClient.getRecordProtocol()) + .build()) + .setInBytes(ByteString.copyFrom(ByteBuffer.allocate(IN_BYTES_SIZE))) + .setMaxFrameSize(AltsTsiFrameProtector.getMaxFrameSize()) + .build()) + .build(); + verify(mockStub).send(req); + assertEquals(ByteString.copyFrom(outFrame), MockAltsHandshakerResp.getOutFrame()); assertFalse(handshaker.isFinished()); assertNull(handshaker.getResult()); diff --git a/alts/src/test/java/io/grpc/alts/internal/AltsTsiFrameProtectorTest.java b/alts/src/test/java/io/grpc/alts/internal/AltsTsiFrameProtectorTest.java index 4405a99eff..cbda9dbcdc 100644 --- a/alts/src/test/java/io/grpc/alts/internal/AltsTsiFrameProtectorTest.java +++ b/alts/src/test/java/io/grpc/alts/internal/AltsTsiFrameProtectorTest.java @@ -125,7 +125,7 @@ public class AltsTsiFrameProtectorTest { getDirectBuffer( AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref); in.writeIntLE( - AltsTsiFrameProtector.getLimitMaxAllowedFrameBytes() + AltsTsiFrameProtector.getLimitMaxAllowedFrameSize() - AltsTsiFrameProtector.getHeaderLenFieldBytes() + 1); in.writeIntLE(6); @@ -206,7 +206,7 @@ public class AltsTsiFrameProtectorTest { getDirectBuffer( AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref); in.writeIntLE( - AltsTsiFrameProtector.getLimitMaxAllowedFrameBytes() + AltsTsiFrameProtector.getLimitMaxAllowedFrameSize() - AltsTsiFrameProtector.getHeaderLenFieldBytes()); in.writeIntLE(6); diff --git a/alts/src/test/java/io/grpc/alts/internal/FakeTsiHandshaker.java b/alts/src/test/java/io/grpc/alts/internal/FakeTsiHandshaker.java index a04bbfd07e..1483e4b08e 100644 --- a/alts/src/test/java/io/grpc/alts/internal/FakeTsiHandshaker.java +++ b/alts/src/test/java/io/grpc/alts/internal/FakeTsiHandshaker.java @@ -224,7 +224,7 @@ public class FakeTsiHandshaker implements TsiHandshaker { @Override public TsiFrameProtector createFrameProtector(ByteBufAllocator alloc) { - return createFrameProtector(AltsTsiFrameProtector.getMaxAllowedFrameBytes(), alloc); + return createFrameProtector(AltsTsiFrameProtector.getMinFrameSize(), alloc); } @Override