diff --git a/alts/build.gradle b/alts/build.gradle index 6656390b99..d903a4b24c 100644 --- a/alts/build.gradle +++ b/alts/build.gradle @@ -1,7 +1,7 @@ description = "gRPC: ALTS" -sourceCompatibility = 1.8 -targetCompatibility = 1.8 +sourceCompatibility = 1.7 +targetCompatibility = 1.7 buildscript { repositories { diff --git a/alts/src/main/java/io/grpc/alts/AltsChannelBuilder.java b/alts/src/main/java/io/grpc/alts/AltsChannelBuilder.java index 72a9668ccc..0ce15ac981 100644 --- a/alts/src/main/java/io/grpc/alts/AltsChannelBuilder.java +++ b/alts/src/main/java/io/grpc/alts/AltsChannelBuilder.java @@ -150,7 +150,7 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder bufs = new ArrayList<>(pendingUnprotectedWrites.size()); @@ -168,7 +168,14 @@ public final class InternalTsiFrameHandler extends ByteToMessageDecoder } protector.protectFlush( - bufs, b -> ctx.writeAndFlush(b, aggregatePromise.newPromise()), ctx.alloc()); + bufs, + new java.util.function.Consumer() { + @Override + public void accept(ByteBuf b) { + ctx.writeAndFlush(b, aggregatePromise.newPromise()); + } + }, + ctx.alloc()); // We're done writing, start the flow of promise events. @SuppressWarnings("unused") // go/futurereturn-lsc diff --git a/alts/src/test/java/io/grpc/alts/AltsProtocolNegotiatorTest.java b/alts/src/test/java/io/grpc/alts/AltsProtocolNegotiatorTest.java index 9f9e473c4b..e7a2afff48 100644 --- a/alts/src/test/java/io/grpc/alts/AltsProtocolNegotiatorTest.java +++ b/alts/src/test/java/io/grpc/alts/AltsProtocolNegotiatorTest.java @@ -32,6 +32,7 @@ import io.grpc.alts.transportsecurity.TsiFrameProtector; import io.grpc.alts.transportsecurity.TsiHandshaker; import io.grpc.alts.transportsecurity.TsiHandshakerFactory; import io.grpc.alts.transportsecurity.TsiPeer; +import io.grpc.alts.transportsecurity.TsiPeer.Property; import io.grpc.netty.GrpcHttp2ConnectionHandler; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; @@ -89,7 +90,7 @@ public class AltsProtocolNegotiatorTest { private volatile InternalTsiHandshakeHandler.TsiHandshakeCompletionEvent tsiEvent; private ChannelHandler handler; - private TsiPeer mockedTsiPeer = new TsiPeer(Collections.emptyList()); + private TsiPeer mockedTsiPeer = new TsiPeer(Collections.>emptyList()); private AltsAuthContext mockedAltsContext = new AltsAuthContext( HandshakerResult.newBuilder() @@ -220,10 +221,15 @@ public class AltsProtocolNegotiatorTest { assertEquals(message, unprotectedData.toString(UTF_8)); // Protect the same message at the server. - AtomicReference newlyProtectedData = new AtomicReference<>(); + final AtomicReference newlyProtectedData = new AtomicReference<>(); serverProtector.protectFlush( Collections.singletonList(unprotectedData), - b -> newlyProtectedData.set(b), + new java.util.function.Consumer() { + @Override + public void accept(ByteBuf buf) { + newlyProtectedData.set(buf); + } + }, channel.alloc()); // Read the protected message at the client and verify that it matches the original message. @@ -250,7 +256,14 @@ public class AltsProtocolNegotiatorTest { TsiFrameProtector serverProtector = serverHandshaker.createFrameProtector(serverFrameSize, channel.alloc()); serverProtector.protectFlush( - Collections.singletonList(unprotectedData), b -> channel.writeInbound(b), channel.alloc()); + Collections.singletonList(unprotectedData), + new java.util.function.Consumer() { + @Override + public void accept(ByteBuf buf) { + channel.writeInbound(buf); + } + }, + channel.alloc()); channel.flushInbound(); // Read the protected message at the client and verify that it matches the original message. diff --git a/alts/src/test/java/io/grpc/alts/InternalNettyTsiHandshakerTest.java b/alts/src/test/java/io/grpc/alts/InternalNettyTsiHandshakerTest.java index 51e235554a..ad437b0d2c 100644 --- a/alts/src/test/java/io/grpc/alts/InternalNettyTsiHandshakerTest.java +++ b/alts/src/test/java/io/grpc/alts/InternalNettyTsiHandshakerTest.java @@ -181,7 +181,16 @@ public class InternalNettyTsiHandshakerTest { } private void doHandshake() throws GeneralSecurityException { - doHandshake(clientHandshaker, serverHandshaker, alloc, buf -> ref(buf)); + doHandshake( + clientHandshaker, + serverHandshaker, + alloc, + new java.util.function.Function() { + @Override + public ByteBuf apply(ByteBuf buf) { + return ref(buf); + } + }); } private ByteBuf ref(ByteBuf buf) { diff --git a/alts/src/test/java/io/grpc/alts/transportsecurity/AltsTsiFrameProtectorTest.java b/alts/src/test/java/io/grpc/alts/transportsecurity/AltsTsiFrameProtectorTest.java index b37870587b..7f98299221 100644 --- a/alts/src/test/java/io/grpc/alts/transportsecurity/AltsTsiFrameProtectorTest.java +++ b/alts/src/test/java/io/grpc/alts/transportsecurity/AltsTsiFrameProtectorTest.java @@ -23,6 +23,7 @@ import static io.grpc.alts.transportsecurity.ByteBufTestUtils.writeSlice; import static org.junit.Assert.fail; import com.google.common.testing.GcFinalization; +import io.grpc.alts.transportsecurity.ByteBufTestUtils.RegisterRef; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.util.ReferenceCounted; @@ -45,6 +46,16 @@ public class AltsTsiFrameProtectorTest { AltsTsiFrameProtector.getHeaderTypeFieldBytes() + FakeChannelCrypter.getTagBytes(); private final List references = new ArrayList(); + private final RegisterRef ref = + new RegisterRef() { + @Override + public ByteBuf register(ByteBuf buf) { + if (buf != null) { + references.add(buf); + } + return buf; + } + }; @Before public void setUp() { @@ -68,7 +79,7 @@ public class AltsTsiFrameProtectorTest { FakeChannelCrypter crypter = new FakeChannelCrypter(); AltsTsiFrameProtector.Unprotector unprotector = new AltsTsiFrameProtector.Unprotector(crypter, alloc); - ByteBuf in = getDirectBuffer(AltsTsiFrameProtector.getHeaderBytes(), this::ref); + ByteBuf in = getDirectBuffer(AltsTsiFrameProtector.getHeaderBytes(), ref); in.writeIntLE(-1); in.writeIntLE(6); try { @@ -90,7 +101,7 @@ public class AltsTsiFrameProtectorTest { new AltsTsiFrameProtector.Unprotector(crypter, alloc); ByteBuf in = getDirectBuffer( - AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref); + AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref); in.writeIntLE(FRAME_MIN_SIZE - 1); in.writeIntLE(6); try { @@ -112,7 +123,7 @@ public class AltsTsiFrameProtectorTest { new AltsTsiFrameProtector.Unprotector(crypter, alloc); ByteBuf in = getDirectBuffer( - AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref); + AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref); in.writeIntLE( AltsTsiFrameProtector.getLimitMaxAllowedFrameBytes() - AltsTsiFrameProtector.getHeaderLenFieldBytes() @@ -137,7 +148,7 @@ public class AltsTsiFrameProtectorTest { new AltsTsiFrameProtector.Unprotector(crypter, alloc); ByteBuf in = getDirectBuffer( - AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref); + AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref); in.writeIntLE(FRAME_MIN_SIZE); in.writeIntLE(5); try { @@ -159,7 +170,7 @@ public class AltsTsiFrameProtectorTest { new AltsTsiFrameProtector.Unprotector(crypter, alloc); ByteBuf in = getDirectBuffer( - AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref); + AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref); in.writeIntLE(FRAME_MIN_SIZE); in.writeIntLE(6); @@ -176,7 +187,7 @@ public class AltsTsiFrameProtectorTest { FakeChannelCrypter crypter = new FakeChannelCrypter(); AltsTsiFrameProtector.Unprotector unprotector = new AltsTsiFrameProtector.Unprotector(crypter, alloc); - ByteBuf emptyBuf = getDirectBuffer(0, this::ref); + ByteBuf emptyBuf = getDirectBuffer(0, ref); unprotector.unprotect(emptyBuf, out, alloc); assertThat(emptyBuf.refCnt()).isEqualTo(1); @@ -193,7 +204,7 @@ public class AltsTsiFrameProtectorTest { new AltsTsiFrameProtector.Unprotector(crypter, alloc); ByteBuf in = getDirectBuffer( - AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref); + AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref); in.writeIntLE( AltsTsiFrameProtector.getLimitMaxAllowedFrameBytes() - AltsTsiFrameProtector.getHeaderLenFieldBytes()); @@ -214,7 +225,7 @@ public class AltsTsiFrameProtectorTest { new AltsTsiFrameProtector.Unprotector(crypter, alloc); ByteBuf in = getDirectBuffer( - AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref); + AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref); in.writeIntLE(FRAME_MIN_SIZE); in.writeIntLE(6); ByteBuf in1 = in.readSlice(AltsTsiFrameProtector.getHeaderBytes() - 1); @@ -238,7 +249,7 @@ public class AltsTsiFrameProtectorTest { new AltsTsiFrameProtector.Unprotector(crypter, alloc); ByteBuf in = getDirectBuffer( - AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref); + AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref); in.writeIntLE(FRAME_MIN_SIZE - 1); in.writeIntLE(6); ByteBuf in1 = in.readSlice(AltsTsiFrameProtector.getHeaderBytes() - 1); @@ -267,13 +278,13 @@ public class AltsTsiFrameProtectorTest { FakeChannelCrypter crypter = new FakeChannelCrypter(); AltsTsiFrameProtector.Unprotector unprotector = new AltsTsiFrameProtector.Unprotector(crypter, alloc); - ByteBuf plain = getRandom(payloadBytes, this::ref); + ByteBuf plain = getRandom(payloadBytes, ref); ByteBuf outFrame = getDirectBuffer( AltsTsiFrameProtector.getHeaderBytes() + payloadBytes + FakeChannelCrypter.getTagBytes(), - this::ref); + ref); outFrame.writeIntLE( AltsTsiFrameProtector.getHeaderTypeFieldBytes() @@ -305,12 +316,12 @@ public class AltsTsiFrameProtectorTest { AltsTsiFrameProtector.Unprotector unprotector = new AltsTsiFrameProtector.Unprotector(crypter, alloc); - ByteBuf plain = getRandom(payloadBytes, this::ref); + ByteBuf plain = getRandom(payloadBytes, ref); ByteBuf outFrame = getDirectBuffer( 2 * (AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes()) + payloadBytes, - this::ref); + ref); outFrame.writeIntLE( AltsTsiFrameProtector.getHeaderTypeFieldBytes() @@ -353,13 +364,13 @@ public class AltsTsiFrameProtectorTest { AltsTsiFrameProtector.Unprotector unprotector = new AltsTsiFrameProtector.Unprotector(crypter, alloc); - ByteBuf plain = getRandom(payloadBytes, this::ref); + ByteBuf plain = getRandom(payloadBytes, ref); ByteBuf protectedBuf = getDirectBuffer( 2 * (AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes()) + payloadBytes + AltsTsiFrameProtector.getHeaderBytes(), - this::ref); + ref); protectedBuf.writeIntLE( AltsTsiFrameProtector.getHeaderTypeFieldBytes() @@ -421,13 +432,13 @@ public class AltsTsiFrameProtectorTest { AltsTsiFrameProtector.Unprotector unprotector = new AltsTsiFrameProtector.Unprotector(crypter, alloc); - ByteBuf plain = getRandom(payloadBytes, this::ref); + ByteBuf plain = getRandom(payloadBytes, ref); ByteBuf protectedBuf = getDirectBuffer( 2 * (AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes()) + payloadBytes + AltsTsiFrameProtector.getHeaderBytes(), - this::ref); + ref); protectedBuf.writeIntLE( AltsTsiFrameProtector.getHeaderTypeFieldBytes() diff --git a/alts/src/test/java/io/grpc/alts/transportsecurity/AltsTsiTest.java b/alts/src/test/java/io/grpc/alts/transportsecurity/AltsTsiTest.java index cee22863fe..2ed8e155e5 100644 --- a/alts/src/test/java/io/grpc/alts/transportsecurity/AltsTsiTest.java +++ b/alts/src/test/java/io/grpc/alts/transportsecurity/AltsTsiTest.java @@ -22,6 +22,7 @@ import com.google.common.testing.GcFinalization; import io.grpc.alts.Handshaker.HandshakeProtocol; import io.grpc.alts.Handshaker.HandshakerReq; import io.grpc.alts.Handshaker.HandshakerResp; +import io.grpc.alts.transportsecurity.ByteBufTestUtils.RegisterRef; import io.grpc.alts.transportsecurity.TsiTest.Handshakers; import io.netty.buffer.ByteBuf; import io.netty.util.ReferenceCounted; @@ -45,6 +46,16 @@ public class AltsTsiTest { private final List references = new ArrayList<>(); private AltsHandshakerClient client; private AltsHandshakerClient server; + private final RegisterRef ref = + new RegisterRef() { + @Override + public ByteBuf register(ByteBuf buf) { + if (buf != null) { + references.add(buf); + } + return buf; + } + }; @Before public void setUp() throws Exception { @@ -101,47 +112,47 @@ public class AltsTsiTest { @Test public void pingPong() throws GeneralSecurityException { - TsiTest.pingPongTest(newHandshakers(), this::ref); + TsiTest.pingPongTest(newHandshakers(), ref); } @Test public void pingPongExactFrameSize() throws GeneralSecurityException { - TsiTest.pingPongExactFrameSizeTest(newHandshakers(), this::ref); + TsiTest.pingPongExactFrameSizeTest(newHandshakers(), ref); } @Test public void pingPongSmallBuffer() throws GeneralSecurityException { - TsiTest.pingPongSmallBufferTest(newHandshakers(), this::ref); + TsiTest.pingPongSmallBufferTest(newHandshakers(), ref); } @Test public void pingPongSmallFrame() throws GeneralSecurityException { - TsiTest.pingPongSmallFrameTest(OVERHEAD, newHandshakers(), this::ref); + TsiTest.pingPongSmallFrameTest(OVERHEAD, newHandshakers(), ref); } @Test public void pingPongSmallFrameSmallBuffer() throws GeneralSecurityException { - TsiTest.pingPongSmallFrameSmallBufferTest(OVERHEAD, newHandshakers(), this::ref); + TsiTest.pingPongSmallFrameSmallBufferTest(OVERHEAD, newHandshakers(), ref); } @Test public void corruptedCounter() throws GeneralSecurityException { - TsiTest.corruptedCounterTest(newHandshakers(), this::ref); + TsiTest.corruptedCounterTest(newHandshakers(), ref); } @Test public void corruptedCiphertext() throws GeneralSecurityException { - TsiTest.corruptedCiphertextTest(newHandshakers(), this::ref); + TsiTest.corruptedCiphertextTest(newHandshakers(), ref); } @Test public void corruptedTag() throws GeneralSecurityException { - TsiTest.corruptedTagTest(newHandshakers(), this::ref); + TsiTest.corruptedTagTest(newHandshakers(), ref); } @Test public void reflectedCiphertext() throws GeneralSecurityException { - TsiTest.reflectedCiphertextTest(newHandshakers(), this::ref); + TsiTest.reflectedCiphertextTest(newHandshakers(), ref); } private static class MockAltsHandshakerStub extends AltsHandshakerStub { @@ -184,11 +195,4 @@ public class AltsTsiTest { @Override public void close() {} } - - private ByteBuf ref(ByteBuf buf) { - if (buf != null) { - references.add(buf); - } - return buf; - } } diff --git a/alts/src/test/java/io/grpc/alts/transportsecurity/ChannelCrypterNettyTestBase.java b/alts/src/test/java/io/grpc/alts/transportsecurity/ChannelCrypterNettyTestBase.java index 8ca0f77d9c..25c1c6c05e 100644 --- a/alts/src/test/java/io/grpc/alts/transportsecurity/ChannelCrypterNettyTestBase.java +++ b/alts/src/test/java/io/grpc/alts/transportsecurity/ChannelCrypterNettyTestBase.java @@ -22,6 +22,7 @@ import static io.grpc.alts.transportsecurity.ByteBufTestUtils.getRandom; import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.fail; +import io.grpc.alts.transportsecurity.ByteBufTestUtils.RegisterRef; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.util.ReferenceCounted; @@ -39,6 +40,16 @@ public abstract class ChannelCrypterNettyTestBase { protected final List references = new ArrayList<>(); public ChannelCrypterNetty client; public ChannelCrypterNetty server; + private final RegisterRef ref = + new RegisterRef() { + @Override + public ByteBuf register(ByteBuf buf) { + if (buf != null) { + references.add(buf); + } + return buf; + } + }; static final class FrameEncrypt { List plain; @@ -54,10 +65,10 @@ public abstract class ChannelCrypterNettyTestBase { FrameEncrypt createFrameEncrypt(String message) { byte[] messageBytes = message.getBytes(UTF_8); FrameEncrypt frame = new FrameEncrypt(); - ByteBuf plain = getDirectBuffer(messageBytes.length, this::ref); + ByteBuf plain = getDirectBuffer(messageBytes.length, ref); plain.writeBytes(messageBytes); frame.plain = Collections.singletonList(plain); - frame.out = getDirectBuffer(messageBytes.length + client.getSuffixLength(), this::ref); + frame.out = getDirectBuffer(messageBytes.length + client.getSuffixLength(), ref); return frame; } @@ -68,7 +79,7 @@ public abstract class ChannelCrypterNettyTestBase { frameDecrypt.ciphertext = Collections.singletonList(out.slice(out.readerIndex(), out.readableBytes() - tagLen)); frameDecrypt.tag = out.slice(out.readerIndex() + out.readableBytes() - tagLen, tagLen); - frameDecrypt.out = getDirectBuffer(out.readableBytes(), this::ref); + frameDecrypt.out = getDirectBuffer(out.readableBytes(), ref); return frameDecrypt; } @@ -87,9 +98,9 @@ public abstract class ChannelCrypterNettyTestBase { @Test public void encryptDecryptLarge() throws GeneralSecurityException { FrameEncrypt frameEncrypt = new FrameEncrypt(); - ByteBuf plain = getRandom(17 * 1024, this::ref); + ByteBuf plain = getRandom(17 * 1024, ref); frameEncrypt.plain = Collections.singletonList(plain); - frameEncrypt.out = getDirectBuffer(plain.readableBytes() + client.getSuffixLength(), this::ref); + frameEncrypt.out = getDirectBuffer(plain.readableBytes() + client.getSuffixLength(), ref); client.encrypt(frameEncrypt.out, frameEncrypt.plain); FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt); @@ -120,13 +131,13 @@ public abstract class ChannelCrypterNettyTestBase { int lastLen = 2; byte[] messageBytes = message.getBytes(UTF_8); FrameEncrypt frameEncrypt = new FrameEncrypt(); - ByteBuf plain1 = getDirectBuffer(messageBytes.length - lastLen, this::ref); - ByteBuf plain2 = getDirectBuffer(lastLen, this::ref); + ByteBuf plain1 = getDirectBuffer(messageBytes.length - lastLen, ref); + ByteBuf plain2 = getDirectBuffer(lastLen, ref); plain1.writeBytes(messageBytes, 0, messageBytes.length - lastLen); plain2.writeBytes(messageBytes, messageBytes.length - lastLen, lastLen); ByteBuf plain = Unpooled.wrappedBuffer(plain1, plain2); frameEncrypt.plain = Collections.singletonList(plain); - frameEncrypt.out = getDirectBuffer(messageBytes.length + client.getSuffixLength(), this::ref); + frameEncrypt.out = getDirectBuffer(messageBytes.length + client.getSuffixLength(), ref); client.encrypt(frameEncrypt.out, frameEncrypt.plain); @@ -134,14 +145,14 @@ public abstract class ChannelCrypterNettyTestBase { FrameDecrypt frameDecrypt = new FrameDecrypt(); ByteBuf out = frameEncrypt.out; int outLen = out.readableBytes(); - ByteBuf cipher1 = getDirectBuffer(outLen - lastLen - tagLen, this::ref); - ByteBuf cipher2 = getDirectBuffer(lastLen, this::ref); + ByteBuf cipher1 = getDirectBuffer(outLen - lastLen - tagLen, ref); + ByteBuf cipher2 = getDirectBuffer(lastLen, ref); cipher1.writeBytes(out, 0, outLen - lastLen - tagLen); cipher2.writeBytes(out, outLen - tagLen - lastLen, lastLen); ByteBuf cipher = Unpooled.wrappedBuffer(cipher1, cipher2); frameDecrypt.ciphertext = Collections.singletonList(cipher); frameDecrypt.tag = out.slice(out.readerIndex() + out.readableBytes() - tagLen, tagLen); - frameDecrypt.out = getDirectBuffer(out.readableBytes(), this::ref); + frameDecrypt.out = getDirectBuffer(out.readableBytes(), ref); server.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext); assertThat(frameEncrypt.plain.get(0).slice(0, frameDecrypt.out.readableBytes())) @@ -212,11 +223,4 @@ public abstract class ChannelCrypterNettyTestBase { assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_MESSAGE); } } - - private ByteBuf ref(ByteBuf buf) { - if (buf != null) { - references.add(buf); - } - return buf; - } } diff --git a/alts/src/test/java/io/grpc/alts/transportsecurity/FakeChannelCrypter.java b/alts/src/test/java/io/grpc/alts/transportsecurity/FakeChannelCrypter.java index 7f18ed46bf..229146cec6 100644 --- a/alts/src/test/java/io/grpc/alts/transportsecurity/FakeChannelCrypter.java +++ b/alts/src/test/java/io/grpc/alts/transportsecurity/FakeChannelCrypter.java @@ -52,9 +52,10 @@ public final class FakeChannelCrypter implements ChannelCrypterNetty { for (ByteBuf buf : ciphertext) { out.writeBytes(buf); } - boolean tagValid = tag.forEachByte((byte value) -> value == TAG_BYTE) == -1; - if (!tagValid) { - throw new AEADBadTagException("Tag mismatch!"); + while (tag.isReadable()) { + if (tag.readByte() != TAG_BYTE) { + throw new AEADBadTagException("Tag mismatch!"); + } } } diff --git a/alts/src/test/java/io/grpc/alts/transportsecurity/FakeTsiHandshaker.java b/alts/src/test/java/io/grpc/alts/transportsecurity/FakeTsiHandshaker.java index 5e1ec1cc6d..490481103b 100644 --- a/alts/src/test/java/io/grpc/alts/transportsecurity/FakeTsiHandshaker.java +++ b/alts/src/test/java/io/grpc/alts/transportsecurity/FakeTsiHandshaker.java @@ -19,6 +19,7 @@ package io.grpc.alts.transportsecurity; import static java.nio.charset.StandardCharsets.UTF_8; import com.google.common.base.Preconditions; +import io.grpc.alts.transportsecurity.TsiPeer.Property; import io.netty.buffer.ByteBufAllocator; import java.nio.ByteBuffer; import java.security.GeneralSecurityException; @@ -203,7 +204,7 @@ public class FakeTsiHandshaker implements TsiHandshaker { @Override public TsiPeer extractPeer() { - return new TsiPeer(Collections.emptyList()); + return new TsiPeer(Collections.>emptyList()); } @Override diff --git a/alts/src/test/java/io/grpc/alts/transportsecurity/FakeTsiTest.java b/alts/src/test/java/io/grpc/alts/transportsecurity/FakeTsiTest.java index 3e75ca2171..5a10182082 100644 --- a/alts/src/test/java/io/grpc/alts/transportsecurity/FakeTsiTest.java +++ b/alts/src/test/java/io/grpc/alts/transportsecurity/FakeTsiTest.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import com.google.common.testing.GcFinalization; +import io.grpc.alts.transportsecurity.ByteBufTestUtils.RegisterRef; import io.grpc.alts.transportsecurity.TsiTest.Handshakers; import io.netty.buffer.ByteBuf; import io.netty.util.ReferenceCounted; @@ -44,6 +45,16 @@ public class FakeTsiTest { FakeChannelCrypter.getTagBytes() + AltsTsiFrameProtector.getHeaderBytes(); private final List references = new ArrayList<>(); + private final RegisterRef ref = + new RegisterRef() { + @Override + public ByteBuf register(ByteBuf buf) { + if (buf != null) { + references.add(buf); + } + return buf; + } + }; private static Handshakers newHandshakers() { TsiHandshaker clientHandshaker = FakeTsiHandshaker.newFakeHandshakerClient(); @@ -157,53 +168,46 @@ public class FakeTsiTest { @Test public void pingPong() throws GeneralSecurityException { - TsiTest.pingPongTest(newHandshakers(), this::ref); + TsiTest.pingPongTest(newHandshakers(), ref); } @Test public void pingPongExactFrameSize() throws GeneralSecurityException { - TsiTest.pingPongExactFrameSizeTest(newHandshakers(), this::ref); + TsiTest.pingPongExactFrameSizeTest(newHandshakers(), ref); } @Test public void pingPongSmallBuffer() throws GeneralSecurityException { - TsiTest.pingPongSmallBufferTest(newHandshakers(), this::ref); + TsiTest.pingPongSmallBufferTest(newHandshakers(), ref); } @Test public void pingPongSmallFrame() throws GeneralSecurityException { - TsiTest.pingPongSmallFrameTest(OVERHEAD, newHandshakers(), this::ref); + TsiTest.pingPongSmallFrameTest(OVERHEAD, newHandshakers(), ref); } @Test public void pingPongSmallFrameSmallBuffer() throws GeneralSecurityException { - TsiTest.pingPongSmallFrameSmallBufferTest(OVERHEAD, newHandshakers(), this::ref); + TsiTest.pingPongSmallFrameSmallBufferTest(OVERHEAD, newHandshakers(), ref); } @Test public void corruptedCounter() throws GeneralSecurityException { - TsiTest.corruptedCounterTest(newHandshakers(), this::ref); + TsiTest.corruptedCounterTest(newHandshakers(), ref); } @Test public void corruptedCiphertext() throws GeneralSecurityException { - TsiTest.corruptedCiphertextTest(newHandshakers(), this::ref); + TsiTest.corruptedCiphertextTest(newHandshakers(), ref); } @Test public void corruptedTag() throws GeneralSecurityException { - TsiTest.corruptedTagTest(newHandshakers(), this::ref); + TsiTest.corruptedTagTest(newHandshakers(), ref); } @Test public void reflectedCiphertext() throws GeneralSecurityException { - TsiTest.reflectedCiphertextTest(newHandshakers(), this::ref); - } - - private ByteBuf ref(ByteBuf buf) { - if (buf != null) { - references.add(buf); - } - return buf; + TsiTest.reflectedCiphertextTest(newHandshakers(), ref); } } diff --git a/alts/src/test/java/io/grpc/alts/transportsecurity/TsiTest.java b/alts/src/test/java/io/grpc/alts/transportsecurity/TsiTest.java index 01133f931e..0cfbd72473 100644 --- a/alts/src/test/java/io/grpc/alts/transportsecurity/TsiTest.java +++ b/alts/src/test/java/io/grpc/alts/transportsecurity/TsiTest.java @@ -123,10 +123,18 @@ public final class TsiTest { throws GeneralSecurityException { ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8)); - List protectOut = new ArrayList<>(); + final List protectOut = new ArrayList<>(); List unprotectOut = new ArrayList<>(); - sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc); + sender.protectFlush( + Collections.singletonList(plaintextBuffer), + new java.util.function.Consumer() { + @Override + public void accept(ByteBuf buf) { + protectOut.add(buf); + } + }, + alloc); assertThat(protectOut.size()).isEqualTo(1); ByteBuf protect = ref.register(protectOut.get(0)); @@ -249,10 +257,18 @@ public final class TsiTest { String message = "hello world"; ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8)); - List protectOut = new ArrayList<>(); + final List protectOut = new ArrayList<>(); List unprotectOut = new ArrayList<>(); - sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc); + sender.protectFlush( + Collections.singletonList(plaintextBuffer), + new java.util.function.Consumer() { + @Override + public void accept(ByteBuf buf) { + protectOut.add(buf); + } + }, + alloc); assertThat(protectOut.size()).isEqualTo(1); ByteBuf protect = ref.register(protectOut.get(0)); @@ -282,10 +298,18 @@ public final class TsiTest { String message = "hello world"; ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8)); - List protectOut = new ArrayList<>(); + final List protectOut = new ArrayList<>(); List unprotectOut = new ArrayList<>(); - sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc); + sender.protectFlush( + Collections.singletonList(plaintextBuffer), + new java.util.function.Consumer() { + @Override + public void accept(ByteBuf buf) { + protectOut.add(buf); + } + }, + alloc); assertThat(protectOut.size()).isEqualTo(1); ByteBuf protect = ref.register(protectOut.get(0)); @@ -313,10 +337,18 @@ public final class TsiTest { String message = "hello world"; ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8)); - List protectOut = new ArrayList<>(); + final List protectOut = new ArrayList<>(); List unprotectOut = new ArrayList<>(); - sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc); + sender.protectFlush( + Collections.singletonList(plaintextBuffer), + new java.util.function.Consumer() { + @Override + public void accept(ByteBuf buf) { + protectOut.add(buf); + } + }, + alloc); assertThat(protectOut.size()).isEqualTo(1); ByteBuf protect = ref.register(protectOut.get(0)); @@ -344,10 +376,18 @@ public final class TsiTest { String message = "hello world"; ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8)); - List protectOut = new ArrayList<>(); + final List protectOut = new ArrayList<>(); List unprotectOut = new ArrayList<>(); - sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc); + sender.protectFlush( + Collections.singletonList(plaintextBuffer), + new java.util.function.Consumer() { + @Override + public void accept(ByteBuf buf) { + protectOut.add(buf); + } + }, + alloc); assertThat(protectOut.size()).isEqualTo(1); ByteBuf protect = ref.register(protectOut.get(0)); diff --git a/interop-testing/build.gradle b/interop-testing/build.gradle index babe4344d5..3f87b76f41 100644 --- a/interop-testing/build.gradle +++ b/interop-testing/build.gradle @@ -14,7 +14,8 @@ buildscript { } dependencies { - compile project(':grpc-auth'), + compile project(':grpc-alts'), + project(':grpc-auth'), project(':grpc-core'), project(':grpc-netty'), project(':grpc-okhttp'), diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java index bc21fc3205..1a5d4ad797 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java @@ -19,6 +19,7 @@ package io.grpc.testing.integration; import com.google.common.annotations.VisibleForTesting; import com.google.common.io.Files; import io.grpc.ManagedChannel; +import io.grpc.alts.AltsChannelBuilder; import io.grpc.internal.AbstractManagedChannelImplBuilder; import io.grpc.internal.GrpcUtil; import io.grpc.internal.testing.TestUtils; @@ -76,6 +77,7 @@ public class TestServiceClient { private int serverPort = 8080; private String testCase = "empty_unary"; private boolean useTls = true; + private boolean useAlts = false; private boolean useTestCa; private boolean useOkHttp; private String defaultServiceAccount; @@ -116,6 +118,8 @@ public class TestServiceClient { testCase = value; } else if ("use_tls".equals(key)) { useTls = Boolean.parseBoolean(value); + } else if ("use_alts".equals(key)) { + useAlts = Boolean.parseBoolean(value); } else if ("use_test_ca".equals(key)) { useTestCa = Boolean.parseBoolean(value); } else if ("use_okhttp".equals(key)) { @@ -140,6 +144,9 @@ public class TestServiceClient { break; } } + if (useAlts) { + useTls = false; + } if (usage) { TestServiceClient c = new TestServiceClient(); System.out.println( @@ -153,6 +160,8 @@ public class TestServiceClient { + "\n Valid options:" + validTestCasesHelpText() + "\n --use_tls=true|false Whether to use TLS. Default " + c.useTls + + "\n --use_alts=true|false Whether to use ALTS. Enable ALTS will disable TLS." + + "\n Default " + c.useTls + "\n --use_test_ca=true|false Whether to trust our fake CA. Requires --use_tls=true " + "\n to have effect. Default " + c.useTestCa + "\n --use_okhttp=true|false Whether to use OkHttp instead of Netty. Default " @@ -317,6 +326,9 @@ public class TestServiceClient { private class Tester extends AbstractInteropTest { @Override protected ManagedChannel createChannel() { + if (useAlts) { + return AltsChannelBuilder.forAddress(serverHost, serverPort).build(); + } AbstractManagedChannelImplBuilder builder; if (!useOkHttp) { SslContext sslContext = null; diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java index 48dd37dbc9..bc683837b2 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java @@ -20,6 +20,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.util.concurrent.MoreExecutors; import io.grpc.Server; import io.grpc.ServerInterceptors; +import io.grpc.alts.AltsServerBuilder; import io.grpc.internal.testing.TestUtils; import io.grpc.netty.GrpcSslContexts; import io.grpc.netty.NettyServerBuilder; @@ -28,34 +29,32 @@ import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; -/** - * Server that manages startup/shutdown of a single {@code TestService}. - */ +/** Server that manages startup/shutdown of a single {@code TestService}. */ public class TestServiceServer { - /** - * The main application allowing this server to be launched from the command line. - */ + /** The main application allowing this server to be launched from the command line. */ public static void main(String[] args) throws Exception { final TestServiceServer server = new TestServiceServer(); server.parseArgs(args); if (server.useTls) { System.out.println( "\nUsing fake CA for TLS certificate. Test clients should expect host\n" - + "*.test.google.fr and our test CA. For the Java test client binary, use:\n" - + "--server_host_override=foo.test.google.fr --use_test_ca=true\n"); + + "*.test.google.fr and our test CA. For the Java test client binary, use:\n" + + "--server_host_override=foo.test.google.fr --use_test_ca=true\n"); } - Runtime.getRuntime().addShutdownHook(new Thread() { - @Override - public void run() { - try { - System.out.println("Shutting down"); - server.stop(); - } catch (Exception e) { - e.printStackTrace(); - } - } - }); + Runtime.getRuntime() + .addShutdownHook( + new Thread() { + @Override + public void run() { + try { + System.out.println("Shutting down"); + server.stop(); + } catch (Exception e) { + e.printStackTrace(); + } + } + }); server.start(); System.out.println("Server started on port " + server.port); server.blockUntilShutdown(); @@ -63,6 +62,7 @@ public class TestServiceServer { private int port = 8080; private boolean useTls = true; + private boolean useAlts = false; private ScheduledExecutorService executor; private Server server; @@ -92,6 +92,8 @@ public class TestServiceServer { port = Integer.parseInt(value); } else if ("use_tls".equals(key)) { useTls = Boolean.parseBoolean(value); + } else if ("use_alts".equals(key)) { + useAlts = Boolean.parseBoolean(value); } else if ("grpc_version".equals(key)) { if (!"2".equals(value)) { System.err.println("Only grpc version 2 is supported"); @@ -104,13 +106,18 @@ public class TestServiceServer { break; } } + if (useAlts) { + useTls = false; + } if (usage) { TestServiceServer s = new TestServiceServer(); System.out.println( "Usage: [ARGS...]" - + "\n" - + "\n --port=PORT Port to connect to. Default " + s.port - + "\n --use_tls=true|false Whether to use TLS. Default " + s.useTls + + "\n" + + "\n --port=PORT Port to connect to. Default " + s.port + + "\n --use_tls=true|false Whether to use TLS. Default " + s.useTls + + "\n --use_alts=true|false Whether to use ALTS. Enable ALTS will disable TLS." + + "\n Default " + s.useAlts ); System.exit(1); } @@ -120,17 +127,31 @@ public class TestServiceServer { void start() throws Exception { executor = Executors.newSingleThreadScheduledExecutor(); SslContext sslContext = null; - if (useTls) { - sslContext = GrpcSslContexts.forServer( - TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key")).build(); + if (useAlts) { + server = + AltsServerBuilder.forPort(port) + .addService( + ServerInterceptors.intercept( + new TestServiceImpl(executor), TestServiceImpl.interceptors())) + .build() + .start(); + } else { + if (useTls) { + sslContext = + GrpcSslContexts.forServer( + TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key")) + .build(); + } + server = + NettyServerBuilder.forPort(port) + .sslContext(sslContext) + .maxMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE) + .addService( + ServerInterceptors.intercept( + new TestServiceImpl(executor), TestServiceImpl.interceptors())) + .build() + .start(); } - server = NettyServerBuilder.forPort(port) - .sslContext(sslContext) - .maxMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE) - .addService(ServerInterceptors.intercept( - new TestServiceImpl(executor), - TestServiceImpl.interceptors())) - .build().start(); } @VisibleForTesting @@ -147,9 +168,7 @@ public class TestServiceServer { return server.getPort(); } - /** - * Await termination on the main thread since the grpc library uses daemon threads. - */ + /** Await termination on the main thread since the grpc library uses daemon threads. */ private void blockUntilShutdown() throws InterruptedException { if (server != null) { server.awaitTermination();