alts: add ALTS to interop tests

This commit is contained in:
Jiangtao Li 2018-02-23 12:44:18 -08:00 committed by Carl Mastrangelo
parent 21541243f8
commit d8630f2521
16 changed files with 255 additions and 129 deletions

View File

@ -1,7 +1,7 @@
description = "gRPC: ALTS"
sourceCompatibility = 1.8
targetCompatibility = 1.8
sourceCompatibility = 1.7
targetCompatibility = 1.7
buildscript {
repositories {

View File

@ -150,7 +150,7 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChann
@Override
public TransportCreationParamsFilter create(
SocketAddress serverAddress,
final SocketAddress serverAddress,
final String authority,
final String userAgent,
final ProxyParameters proxy) {

View File

@ -53,7 +53,7 @@ public abstract class AltsProtocolNegotiator implements ProtocolNegotiator {
}
/** Creates a negotiator used for ALTS. */
public static AltsProtocolNegotiator create(TsiHandshakerFactory handshakerFactory) {
public static AltsProtocolNegotiator create(final TsiHandshakerFactory handshakerFactory) {
return new AltsProtocolNegotiator() {
@Override
public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {

View File

@ -146,9 +146,9 @@ public final class InternalTsiFrameHandler extends ByteToMessageDecoder
}
@Override
public void flush(ChannelHandlerContext ctx) throws GeneralSecurityException {
public void flush(final ChannelHandlerContext ctx) throws GeneralSecurityException {
checkState(protector != null, "Cannot write frames while the TSI handshake is in progress");
ProtectedPromise aggregatePromise =
final ProtectedPromise aggregatePromise =
new ProtectedPromise(ctx.channel(), ctx.executor(), pendingUnprotectedWrites.size());
List<ByteBuf> bufs = new ArrayList<>(pendingUnprotectedWrites.size());
@ -168,7 +168,14 @@ public final class InternalTsiFrameHandler extends ByteToMessageDecoder
}
protector.protectFlush(
bufs, b -> ctx.writeAndFlush(b, aggregatePromise.newPromise()), ctx.alloc());
bufs,
new java.util.function.Consumer<ByteBuf>() {
@Override
public void accept(ByteBuf b) {
ctx.writeAndFlush(b, aggregatePromise.newPromise());
}
},
ctx.alloc());
// We're done writing, start the flow of promise events.
@SuppressWarnings("unused") // go/futurereturn-lsc

View File

@ -32,6 +32,7 @@ import io.grpc.alts.transportsecurity.TsiFrameProtector;
import io.grpc.alts.transportsecurity.TsiHandshaker;
import io.grpc.alts.transportsecurity.TsiHandshakerFactory;
import io.grpc.alts.transportsecurity.TsiPeer;
import io.grpc.alts.transportsecurity.TsiPeer.Property;
import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
@ -89,7 +90,7 @@ public class AltsProtocolNegotiatorTest {
private volatile InternalTsiHandshakeHandler.TsiHandshakeCompletionEvent tsiEvent;
private ChannelHandler handler;
private TsiPeer mockedTsiPeer = new TsiPeer(Collections.emptyList());
private TsiPeer mockedTsiPeer = new TsiPeer(Collections.<Property<?>>emptyList());
private AltsAuthContext mockedAltsContext =
new AltsAuthContext(
HandshakerResult.newBuilder()
@ -220,10 +221,15 @@ public class AltsProtocolNegotiatorTest {
assertEquals(message, unprotectedData.toString(UTF_8));
// Protect the same message at the server.
AtomicReference<ByteBuf> newlyProtectedData = new AtomicReference<>();
final AtomicReference<ByteBuf> newlyProtectedData = new AtomicReference<>();
serverProtector.protectFlush(
Collections.singletonList(unprotectedData),
b -> newlyProtectedData.set(b),
new java.util.function.Consumer<ByteBuf>() {
@Override
public void accept(ByteBuf buf) {
newlyProtectedData.set(buf);
}
},
channel.alloc());
// Read the protected message at the client and verify that it matches the original message.
@ -250,7 +256,14 @@ public class AltsProtocolNegotiatorTest {
TsiFrameProtector serverProtector =
serverHandshaker.createFrameProtector(serverFrameSize, channel.alloc());
serverProtector.protectFlush(
Collections.singletonList(unprotectedData), b -> channel.writeInbound(b), channel.alloc());
Collections.singletonList(unprotectedData),
new java.util.function.Consumer<ByteBuf>() {
@Override
public void accept(ByteBuf buf) {
channel.writeInbound(buf);
}
},
channel.alloc());
channel.flushInbound();
// Read the protected message at the client and verify that it matches the original message.

View File

@ -181,7 +181,16 @@ public class InternalNettyTsiHandshakerTest {
}
private void doHandshake() throws GeneralSecurityException {
doHandshake(clientHandshaker, serverHandshaker, alloc, buf -> ref(buf));
doHandshake(
clientHandshaker,
serverHandshaker,
alloc,
new java.util.function.Function<ByteBuf, ByteBuf>() {
@Override
public ByteBuf apply(ByteBuf buf) {
return ref(buf);
}
});
}
private ByteBuf ref(ByteBuf buf) {

View File

@ -23,6 +23,7 @@ import static io.grpc.alts.transportsecurity.ByteBufTestUtils.writeSlice;
import static org.junit.Assert.fail;
import com.google.common.testing.GcFinalization;
import io.grpc.alts.transportsecurity.ByteBufTestUtils.RegisterRef;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.util.ReferenceCounted;
@ -45,6 +46,16 @@ public class AltsTsiFrameProtectorTest {
AltsTsiFrameProtector.getHeaderTypeFieldBytes() + FakeChannelCrypter.getTagBytes();
private final List<ReferenceCounted> references = new ArrayList<ReferenceCounted>();
private final RegisterRef ref =
new RegisterRef() {
@Override
public ByteBuf register(ByteBuf buf) {
if (buf != null) {
references.add(buf);
}
return buf;
}
};
@Before
public void setUp() {
@ -68,7 +79,7 @@ public class AltsTsiFrameProtectorTest {
FakeChannelCrypter crypter = new FakeChannelCrypter();
AltsTsiFrameProtector.Unprotector unprotector =
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf in = getDirectBuffer(AltsTsiFrameProtector.getHeaderBytes(), this::ref);
ByteBuf in = getDirectBuffer(AltsTsiFrameProtector.getHeaderBytes(), ref);
in.writeIntLE(-1);
in.writeIntLE(6);
try {
@ -90,7 +101,7 @@ public class AltsTsiFrameProtectorTest {
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf in =
getDirectBuffer(
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref);
in.writeIntLE(FRAME_MIN_SIZE - 1);
in.writeIntLE(6);
try {
@ -112,7 +123,7 @@ public class AltsTsiFrameProtectorTest {
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf in =
getDirectBuffer(
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref);
in.writeIntLE(
AltsTsiFrameProtector.getLimitMaxAllowedFrameBytes()
- AltsTsiFrameProtector.getHeaderLenFieldBytes()
@ -137,7 +148,7 @@ public class AltsTsiFrameProtectorTest {
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf in =
getDirectBuffer(
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref);
in.writeIntLE(FRAME_MIN_SIZE);
in.writeIntLE(5);
try {
@ -159,7 +170,7 @@ public class AltsTsiFrameProtectorTest {
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf in =
getDirectBuffer(
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref);
in.writeIntLE(FRAME_MIN_SIZE);
in.writeIntLE(6);
@ -176,7 +187,7 @@ public class AltsTsiFrameProtectorTest {
FakeChannelCrypter crypter = new FakeChannelCrypter();
AltsTsiFrameProtector.Unprotector unprotector =
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf emptyBuf = getDirectBuffer(0, this::ref);
ByteBuf emptyBuf = getDirectBuffer(0, ref);
unprotector.unprotect(emptyBuf, out, alloc);
assertThat(emptyBuf.refCnt()).isEqualTo(1);
@ -193,7 +204,7 @@ public class AltsTsiFrameProtectorTest {
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf in =
getDirectBuffer(
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref);
in.writeIntLE(
AltsTsiFrameProtector.getLimitMaxAllowedFrameBytes()
- AltsTsiFrameProtector.getHeaderLenFieldBytes());
@ -214,7 +225,7 @@ public class AltsTsiFrameProtectorTest {
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf in =
getDirectBuffer(
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref);
in.writeIntLE(FRAME_MIN_SIZE);
in.writeIntLE(6);
ByteBuf in1 = in.readSlice(AltsTsiFrameProtector.getHeaderBytes() - 1);
@ -238,7 +249,7 @@ public class AltsTsiFrameProtectorTest {
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf in =
getDirectBuffer(
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref);
in.writeIntLE(FRAME_MIN_SIZE - 1);
in.writeIntLE(6);
ByteBuf in1 = in.readSlice(AltsTsiFrameProtector.getHeaderBytes() - 1);
@ -267,13 +278,13 @@ public class AltsTsiFrameProtectorTest {
FakeChannelCrypter crypter = new FakeChannelCrypter();
AltsTsiFrameProtector.Unprotector unprotector =
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf plain = getRandom(payloadBytes, this::ref);
ByteBuf plain = getRandom(payloadBytes, ref);
ByteBuf outFrame =
getDirectBuffer(
AltsTsiFrameProtector.getHeaderBytes()
+ payloadBytes
+ FakeChannelCrypter.getTagBytes(),
this::ref);
ref);
outFrame.writeIntLE(
AltsTsiFrameProtector.getHeaderTypeFieldBytes()
@ -305,12 +316,12 @@ public class AltsTsiFrameProtectorTest {
AltsTsiFrameProtector.Unprotector unprotector =
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf plain = getRandom(payloadBytes, this::ref);
ByteBuf plain = getRandom(payloadBytes, ref);
ByteBuf outFrame =
getDirectBuffer(
2 * (AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes())
+ payloadBytes,
this::ref);
ref);
outFrame.writeIntLE(
AltsTsiFrameProtector.getHeaderTypeFieldBytes()
@ -353,13 +364,13 @@ public class AltsTsiFrameProtectorTest {
AltsTsiFrameProtector.Unprotector unprotector =
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf plain = getRandom(payloadBytes, this::ref);
ByteBuf plain = getRandom(payloadBytes, ref);
ByteBuf protectedBuf =
getDirectBuffer(
2 * (AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes())
+ payloadBytes
+ AltsTsiFrameProtector.getHeaderBytes(),
this::ref);
ref);
protectedBuf.writeIntLE(
AltsTsiFrameProtector.getHeaderTypeFieldBytes()
@ -421,13 +432,13 @@ public class AltsTsiFrameProtectorTest {
AltsTsiFrameProtector.Unprotector unprotector =
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf plain = getRandom(payloadBytes, this::ref);
ByteBuf plain = getRandom(payloadBytes, ref);
ByteBuf protectedBuf =
getDirectBuffer(
2 * (AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes())
+ payloadBytes
+ AltsTsiFrameProtector.getHeaderBytes(),
this::ref);
ref);
protectedBuf.writeIntLE(
AltsTsiFrameProtector.getHeaderTypeFieldBytes()

View File

@ -22,6 +22,7 @@ import com.google.common.testing.GcFinalization;
import io.grpc.alts.Handshaker.HandshakeProtocol;
import io.grpc.alts.Handshaker.HandshakerReq;
import io.grpc.alts.Handshaker.HandshakerResp;
import io.grpc.alts.transportsecurity.ByteBufTestUtils.RegisterRef;
import io.grpc.alts.transportsecurity.TsiTest.Handshakers;
import io.netty.buffer.ByteBuf;
import io.netty.util.ReferenceCounted;
@ -45,6 +46,16 @@ public class AltsTsiTest {
private final List<ReferenceCounted> references = new ArrayList<>();
private AltsHandshakerClient client;
private AltsHandshakerClient server;
private final RegisterRef ref =
new RegisterRef() {
@Override
public ByteBuf register(ByteBuf buf) {
if (buf != null) {
references.add(buf);
}
return buf;
}
};
@Before
public void setUp() throws Exception {
@ -101,47 +112,47 @@ public class AltsTsiTest {
@Test
public void pingPong() throws GeneralSecurityException {
TsiTest.pingPongTest(newHandshakers(), this::ref);
TsiTest.pingPongTest(newHandshakers(), ref);
}
@Test
public void pingPongExactFrameSize() throws GeneralSecurityException {
TsiTest.pingPongExactFrameSizeTest(newHandshakers(), this::ref);
TsiTest.pingPongExactFrameSizeTest(newHandshakers(), ref);
}
@Test
public void pingPongSmallBuffer() throws GeneralSecurityException {
TsiTest.pingPongSmallBufferTest(newHandshakers(), this::ref);
TsiTest.pingPongSmallBufferTest(newHandshakers(), ref);
}
@Test
public void pingPongSmallFrame() throws GeneralSecurityException {
TsiTest.pingPongSmallFrameTest(OVERHEAD, newHandshakers(), this::ref);
TsiTest.pingPongSmallFrameTest(OVERHEAD, newHandshakers(), ref);
}
@Test
public void pingPongSmallFrameSmallBuffer() throws GeneralSecurityException {
TsiTest.pingPongSmallFrameSmallBufferTest(OVERHEAD, newHandshakers(), this::ref);
TsiTest.pingPongSmallFrameSmallBufferTest(OVERHEAD, newHandshakers(), ref);
}
@Test
public void corruptedCounter() throws GeneralSecurityException {
TsiTest.corruptedCounterTest(newHandshakers(), this::ref);
TsiTest.corruptedCounterTest(newHandshakers(), ref);
}
@Test
public void corruptedCiphertext() throws GeneralSecurityException {
TsiTest.corruptedCiphertextTest(newHandshakers(), this::ref);
TsiTest.corruptedCiphertextTest(newHandshakers(), ref);
}
@Test
public void corruptedTag() throws GeneralSecurityException {
TsiTest.corruptedTagTest(newHandshakers(), this::ref);
TsiTest.corruptedTagTest(newHandshakers(), ref);
}
@Test
public void reflectedCiphertext() throws GeneralSecurityException {
TsiTest.reflectedCiphertextTest(newHandshakers(), this::ref);
TsiTest.reflectedCiphertextTest(newHandshakers(), ref);
}
private static class MockAltsHandshakerStub extends AltsHandshakerStub {
@ -184,11 +195,4 @@ public class AltsTsiTest {
@Override
public void close() {}
}
private ByteBuf ref(ByteBuf buf) {
if (buf != null) {
references.add(buf);
}
return buf;
}
}

View File

@ -22,6 +22,7 @@ import static io.grpc.alts.transportsecurity.ByteBufTestUtils.getRandom;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.junit.Assert.fail;
import io.grpc.alts.transportsecurity.ByteBufTestUtils.RegisterRef;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.util.ReferenceCounted;
@ -39,6 +40,16 @@ public abstract class ChannelCrypterNettyTestBase {
protected final List<ReferenceCounted> references = new ArrayList<>();
public ChannelCrypterNetty client;
public ChannelCrypterNetty server;
private final RegisterRef ref =
new RegisterRef() {
@Override
public ByteBuf register(ByteBuf buf) {
if (buf != null) {
references.add(buf);
}
return buf;
}
};
static final class FrameEncrypt {
List<ByteBuf> plain;
@ -54,10 +65,10 @@ public abstract class ChannelCrypterNettyTestBase {
FrameEncrypt createFrameEncrypt(String message) {
byte[] messageBytes = message.getBytes(UTF_8);
FrameEncrypt frame = new FrameEncrypt();
ByteBuf plain = getDirectBuffer(messageBytes.length, this::ref);
ByteBuf plain = getDirectBuffer(messageBytes.length, ref);
plain.writeBytes(messageBytes);
frame.plain = Collections.singletonList(plain);
frame.out = getDirectBuffer(messageBytes.length + client.getSuffixLength(), this::ref);
frame.out = getDirectBuffer(messageBytes.length + client.getSuffixLength(), ref);
return frame;
}
@ -68,7 +79,7 @@ public abstract class ChannelCrypterNettyTestBase {
frameDecrypt.ciphertext =
Collections.singletonList(out.slice(out.readerIndex(), out.readableBytes() - tagLen));
frameDecrypt.tag = out.slice(out.readerIndex() + out.readableBytes() - tagLen, tagLen);
frameDecrypt.out = getDirectBuffer(out.readableBytes(), this::ref);
frameDecrypt.out = getDirectBuffer(out.readableBytes(), ref);
return frameDecrypt;
}
@ -87,9 +98,9 @@ public abstract class ChannelCrypterNettyTestBase {
@Test
public void encryptDecryptLarge() throws GeneralSecurityException {
FrameEncrypt frameEncrypt = new FrameEncrypt();
ByteBuf plain = getRandom(17 * 1024, this::ref);
ByteBuf plain = getRandom(17 * 1024, ref);
frameEncrypt.plain = Collections.singletonList(plain);
frameEncrypt.out = getDirectBuffer(plain.readableBytes() + client.getSuffixLength(), this::ref);
frameEncrypt.out = getDirectBuffer(plain.readableBytes() + client.getSuffixLength(), ref);
client.encrypt(frameEncrypt.out, frameEncrypt.plain);
FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt);
@ -120,13 +131,13 @@ public abstract class ChannelCrypterNettyTestBase {
int lastLen = 2;
byte[] messageBytes = message.getBytes(UTF_8);
FrameEncrypt frameEncrypt = new FrameEncrypt();
ByteBuf plain1 = getDirectBuffer(messageBytes.length - lastLen, this::ref);
ByteBuf plain2 = getDirectBuffer(lastLen, this::ref);
ByteBuf plain1 = getDirectBuffer(messageBytes.length - lastLen, ref);
ByteBuf plain2 = getDirectBuffer(lastLen, ref);
plain1.writeBytes(messageBytes, 0, messageBytes.length - lastLen);
plain2.writeBytes(messageBytes, messageBytes.length - lastLen, lastLen);
ByteBuf plain = Unpooled.wrappedBuffer(plain1, plain2);
frameEncrypt.plain = Collections.singletonList(plain);
frameEncrypt.out = getDirectBuffer(messageBytes.length + client.getSuffixLength(), this::ref);
frameEncrypt.out = getDirectBuffer(messageBytes.length + client.getSuffixLength(), ref);
client.encrypt(frameEncrypt.out, frameEncrypt.plain);
@ -134,14 +145,14 @@ public abstract class ChannelCrypterNettyTestBase {
FrameDecrypt frameDecrypt = new FrameDecrypt();
ByteBuf out = frameEncrypt.out;
int outLen = out.readableBytes();
ByteBuf cipher1 = getDirectBuffer(outLen - lastLen - tagLen, this::ref);
ByteBuf cipher2 = getDirectBuffer(lastLen, this::ref);
ByteBuf cipher1 = getDirectBuffer(outLen - lastLen - tagLen, ref);
ByteBuf cipher2 = getDirectBuffer(lastLen, ref);
cipher1.writeBytes(out, 0, outLen - lastLen - tagLen);
cipher2.writeBytes(out, outLen - tagLen - lastLen, lastLen);
ByteBuf cipher = Unpooled.wrappedBuffer(cipher1, cipher2);
frameDecrypt.ciphertext = Collections.singletonList(cipher);
frameDecrypt.tag = out.slice(out.readerIndex() + out.readableBytes() - tagLen, tagLen);
frameDecrypt.out = getDirectBuffer(out.readableBytes(), this::ref);
frameDecrypt.out = getDirectBuffer(out.readableBytes(), ref);
server.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext);
assertThat(frameEncrypt.plain.get(0).slice(0, frameDecrypt.out.readableBytes()))
@ -212,11 +223,4 @@ public abstract class ChannelCrypterNettyTestBase {
assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_MESSAGE);
}
}
private ByteBuf ref(ByteBuf buf) {
if (buf != null) {
references.add(buf);
}
return buf;
}
}

View File

@ -52,9 +52,10 @@ public final class FakeChannelCrypter implements ChannelCrypterNetty {
for (ByteBuf buf : ciphertext) {
out.writeBytes(buf);
}
boolean tagValid = tag.forEachByte((byte value) -> value == TAG_BYTE) == -1;
if (!tagValid) {
throw new AEADBadTagException("Tag mismatch!");
while (tag.isReadable()) {
if (tag.readByte() != TAG_BYTE) {
throw new AEADBadTagException("Tag mismatch!");
}
}
}

View File

@ -19,6 +19,7 @@ package io.grpc.alts.transportsecurity;
import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.common.base.Preconditions;
import io.grpc.alts.transportsecurity.TsiPeer.Property;
import io.netty.buffer.ByteBufAllocator;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
@ -203,7 +204,7 @@ public class FakeTsiHandshaker implements TsiHandshaker {
@Override
public TsiPeer extractPeer() {
return new TsiPeer(Collections.emptyList());
return new TsiPeer(Collections.<Property<?>>emptyList());
}
@Override

View File

@ -21,6 +21,7 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import com.google.common.testing.GcFinalization;
import io.grpc.alts.transportsecurity.ByteBufTestUtils.RegisterRef;
import io.grpc.alts.transportsecurity.TsiTest.Handshakers;
import io.netty.buffer.ByteBuf;
import io.netty.util.ReferenceCounted;
@ -44,6 +45,16 @@ public class FakeTsiTest {
FakeChannelCrypter.getTagBytes() + AltsTsiFrameProtector.getHeaderBytes();
private final List<ReferenceCounted> references = new ArrayList<>();
private final RegisterRef ref =
new RegisterRef() {
@Override
public ByteBuf register(ByteBuf buf) {
if (buf != null) {
references.add(buf);
}
return buf;
}
};
private static Handshakers newHandshakers() {
TsiHandshaker clientHandshaker = FakeTsiHandshaker.newFakeHandshakerClient();
@ -157,53 +168,46 @@ public class FakeTsiTest {
@Test
public void pingPong() throws GeneralSecurityException {
TsiTest.pingPongTest(newHandshakers(), this::ref);
TsiTest.pingPongTest(newHandshakers(), ref);
}
@Test
public void pingPongExactFrameSize() throws GeneralSecurityException {
TsiTest.pingPongExactFrameSizeTest(newHandshakers(), this::ref);
TsiTest.pingPongExactFrameSizeTest(newHandshakers(), ref);
}
@Test
public void pingPongSmallBuffer() throws GeneralSecurityException {
TsiTest.pingPongSmallBufferTest(newHandshakers(), this::ref);
TsiTest.pingPongSmallBufferTest(newHandshakers(), ref);
}
@Test
public void pingPongSmallFrame() throws GeneralSecurityException {
TsiTest.pingPongSmallFrameTest(OVERHEAD, newHandshakers(), this::ref);
TsiTest.pingPongSmallFrameTest(OVERHEAD, newHandshakers(), ref);
}
@Test
public void pingPongSmallFrameSmallBuffer() throws GeneralSecurityException {
TsiTest.pingPongSmallFrameSmallBufferTest(OVERHEAD, newHandshakers(), this::ref);
TsiTest.pingPongSmallFrameSmallBufferTest(OVERHEAD, newHandshakers(), ref);
}
@Test
public void corruptedCounter() throws GeneralSecurityException {
TsiTest.corruptedCounterTest(newHandshakers(), this::ref);
TsiTest.corruptedCounterTest(newHandshakers(), ref);
}
@Test
public void corruptedCiphertext() throws GeneralSecurityException {
TsiTest.corruptedCiphertextTest(newHandshakers(), this::ref);
TsiTest.corruptedCiphertextTest(newHandshakers(), ref);
}
@Test
public void corruptedTag() throws GeneralSecurityException {
TsiTest.corruptedTagTest(newHandshakers(), this::ref);
TsiTest.corruptedTagTest(newHandshakers(), ref);
}
@Test
public void reflectedCiphertext() throws GeneralSecurityException {
TsiTest.reflectedCiphertextTest(newHandshakers(), this::ref);
}
private ByteBuf ref(ByteBuf buf) {
if (buf != null) {
references.add(buf);
}
return buf;
TsiTest.reflectedCiphertextTest(newHandshakers(), ref);
}
}

View File

@ -123,10 +123,18 @@ public final class TsiTest {
throws GeneralSecurityException {
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
List<ByteBuf> protectOut = new ArrayList<>();
final List<ByteBuf> protectOut = new ArrayList<>();
List<Object> unprotectOut = new ArrayList<>();
sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc);
sender.protectFlush(
Collections.singletonList(plaintextBuffer),
new java.util.function.Consumer<ByteBuf>() {
@Override
public void accept(ByteBuf buf) {
protectOut.add(buf);
}
},
alloc);
assertThat(protectOut.size()).isEqualTo(1);
ByteBuf protect = ref.register(protectOut.get(0));
@ -249,10 +257,18 @@ public final class TsiTest {
String message = "hello world";
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
List<ByteBuf> protectOut = new ArrayList<>();
final List<ByteBuf> protectOut = new ArrayList<>();
List<Object> unprotectOut = new ArrayList<>();
sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc);
sender.protectFlush(
Collections.singletonList(plaintextBuffer),
new java.util.function.Consumer<ByteBuf>() {
@Override
public void accept(ByteBuf buf) {
protectOut.add(buf);
}
},
alloc);
assertThat(protectOut.size()).isEqualTo(1);
ByteBuf protect = ref.register(protectOut.get(0));
@ -282,10 +298,18 @@ public final class TsiTest {
String message = "hello world";
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
List<ByteBuf> protectOut = new ArrayList<>();
final List<ByteBuf> protectOut = new ArrayList<>();
List<Object> unprotectOut = new ArrayList<>();
sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc);
sender.protectFlush(
Collections.singletonList(plaintextBuffer),
new java.util.function.Consumer<ByteBuf>() {
@Override
public void accept(ByteBuf buf) {
protectOut.add(buf);
}
},
alloc);
assertThat(protectOut.size()).isEqualTo(1);
ByteBuf protect = ref.register(protectOut.get(0));
@ -313,10 +337,18 @@ public final class TsiTest {
String message = "hello world";
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
List<ByteBuf> protectOut = new ArrayList<>();
final List<ByteBuf> protectOut = new ArrayList<>();
List<Object> unprotectOut = new ArrayList<>();
sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc);
sender.protectFlush(
Collections.singletonList(plaintextBuffer),
new java.util.function.Consumer<ByteBuf>() {
@Override
public void accept(ByteBuf buf) {
protectOut.add(buf);
}
},
alloc);
assertThat(protectOut.size()).isEqualTo(1);
ByteBuf protect = ref.register(protectOut.get(0));
@ -344,10 +376,18 @@ public final class TsiTest {
String message = "hello world";
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
List<ByteBuf> protectOut = new ArrayList<>();
final List<ByteBuf> protectOut = new ArrayList<>();
List<Object> unprotectOut = new ArrayList<>();
sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc);
sender.protectFlush(
Collections.singletonList(plaintextBuffer),
new java.util.function.Consumer<ByteBuf>() {
@Override
public void accept(ByteBuf buf) {
protectOut.add(buf);
}
},
alloc);
assertThat(protectOut.size()).isEqualTo(1);
ByteBuf protect = ref.register(protectOut.get(0));

View File

@ -14,7 +14,8 @@ buildscript {
}
dependencies {
compile project(':grpc-auth'),
compile project(':grpc-alts'),
project(':grpc-auth'),
project(':grpc-core'),
project(':grpc-netty'),
project(':grpc-okhttp'),

View File

@ -19,6 +19,7 @@ package io.grpc.testing.integration;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.Files;
import io.grpc.ManagedChannel;
import io.grpc.alts.AltsChannelBuilder;
import io.grpc.internal.AbstractManagedChannelImplBuilder;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.testing.TestUtils;
@ -76,6 +77,7 @@ public class TestServiceClient {
private int serverPort = 8080;
private String testCase = "empty_unary";
private boolean useTls = true;
private boolean useAlts = false;
private boolean useTestCa;
private boolean useOkHttp;
private String defaultServiceAccount;
@ -116,6 +118,8 @@ public class TestServiceClient {
testCase = value;
} else if ("use_tls".equals(key)) {
useTls = Boolean.parseBoolean(value);
} else if ("use_alts".equals(key)) {
useAlts = Boolean.parseBoolean(value);
} else if ("use_test_ca".equals(key)) {
useTestCa = Boolean.parseBoolean(value);
} else if ("use_okhttp".equals(key)) {
@ -140,6 +144,9 @@ public class TestServiceClient {
break;
}
}
if (useAlts) {
useTls = false;
}
if (usage) {
TestServiceClient c = new TestServiceClient();
System.out.println(
@ -153,6 +160,8 @@ public class TestServiceClient {
+ "\n Valid options:"
+ validTestCasesHelpText()
+ "\n --use_tls=true|false Whether to use TLS. Default " + c.useTls
+ "\n --use_alts=true|false Whether to use ALTS. Enable ALTS will disable TLS."
+ "\n Default " + c.useTls
+ "\n --use_test_ca=true|false Whether to trust our fake CA. Requires --use_tls=true "
+ "\n to have effect. Default " + c.useTestCa
+ "\n --use_okhttp=true|false Whether to use OkHttp instead of Netty. Default "
@ -317,6 +326,9 @@ public class TestServiceClient {
private class Tester extends AbstractInteropTest {
@Override
protected ManagedChannel createChannel() {
if (useAlts) {
return AltsChannelBuilder.forAddress(serverHost, serverPort).build();
}
AbstractManagedChannelImplBuilder<?> builder;
if (!useOkHttp) {
SslContext sslContext = null;

View File

@ -20,6 +20,7 @@ import com.google.common.annotations.VisibleForTesting;
import com.google.common.util.concurrent.MoreExecutors;
import io.grpc.Server;
import io.grpc.ServerInterceptors;
import io.grpc.alts.AltsServerBuilder;
import io.grpc.internal.testing.TestUtils;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.NettyServerBuilder;
@ -28,34 +29,32 @@ import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
/**
* Server that manages startup/shutdown of a single {@code TestService}.
*/
/** Server that manages startup/shutdown of a single {@code TestService}. */
public class TestServiceServer {
/**
* The main application allowing this server to be launched from the command line.
*/
/** The main application allowing this server to be launched from the command line. */
public static void main(String[] args) throws Exception {
final TestServiceServer server = new TestServiceServer();
server.parseArgs(args);
if (server.useTls) {
System.out.println(
"\nUsing fake CA for TLS certificate. Test clients should expect host\n"
+ "*.test.google.fr and our test CA. For the Java test client binary, use:\n"
+ "--server_host_override=foo.test.google.fr --use_test_ca=true\n");
+ "*.test.google.fr and our test CA. For the Java test client binary, use:\n"
+ "--server_host_override=foo.test.google.fr --use_test_ca=true\n");
}
Runtime.getRuntime().addShutdownHook(new Thread() {
@Override
public void run() {
try {
System.out.println("Shutting down");
server.stop();
} catch (Exception e) {
e.printStackTrace();
}
}
});
Runtime.getRuntime()
.addShutdownHook(
new Thread() {
@Override
public void run() {
try {
System.out.println("Shutting down");
server.stop();
} catch (Exception e) {
e.printStackTrace();
}
}
});
server.start();
System.out.println("Server started on port " + server.port);
server.blockUntilShutdown();
@ -63,6 +62,7 @@ public class TestServiceServer {
private int port = 8080;
private boolean useTls = true;
private boolean useAlts = false;
private ScheduledExecutorService executor;
private Server server;
@ -92,6 +92,8 @@ public class TestServiceServer {
port = Integer.parseInt(value);
} else if ("use_tls".equals(key)) {
useTls = Boolean.parseBoolean(value);
} else if ("use_alts".equals(key)) {
useAlts = Boolean.parseBoolean(value);
} else if ("grpc_version".equals(key)) {
if (!"2".equals(value)) {
System.err.println("Only grpc version 2 is supported");
@ -104,13 +106,18 @@ public class TestServiceServer {
break;
}
}
if (useAlts) {
useTls = false;
}
if (usage) {
TestServiceServer s = new TestServiceServer();
System.out.println(
"Usage: [ARGS...]"
+ "\n"
+ "\n --port=PORT Port to connect to. Default " + s.port
+ "\n --use_tls=true|false Whether to use TLS. Default " + s.useTls
+ "\n"
+ "\n --port=PORT Port to connect to. Default " + s.port
+ "\n --use_tls=true|false Whether to use TLS. Default " + s.useTls
+ "\n --use_alts=true|false Whether to use ALTS. Enable ALTS will disable TLS."
+ "\n Default " + s.useAlts
);
System.exit(1);
}
@ -120,17 +127,31 @@ public class TestServiceServer {
void start() throws Exception {
executor = Executors.newSingleThreadScheduledExecutor();
SslContext sslContext = null;
if (useTls) {
sslContext = GrpcSslContexts.forServer(
TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key")).build();
if (useAlts) {
server =
AltsServerBuilder.forPort(port)
.addService(
ServerInterceptors.intercept(
new TestServiceImpl(executor), TestServiceImpl.interceptors()))
.build()
.start();
} else {
if (useTls) {
sslContext =
GrpcSslContexts.forServer(
TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key"))
.build();
}
server =
NettyServerBuilder.forPort(port)
.sslContext(sslContext)
.maxMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE)
.addService(
ServerInterceptors.intercept(
new TestServiceImpl(executor), TestServiceImpl.interceptors()))
.build()
.start();
}
server = NettyServerBuilder.forPort(port)
.sslContext(sslContext)
.maxMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE)
.addService(ServerInterceptors.intercept(
new TestServiceImpl(executor),
TestServiceImpl.interceptors()))
.build().start();
}
@VisibleForTesting
@ -147,9 +168,7 @@ public class TestServiceServer {
return server.getPort();
}
/**
* Await termination on the main thread since the grpc library uses daemon threads.
*/
/** Await termination on the main thread since the grpc library uses daemon threads. */
private void blockUntilShutdown() throws InterruptedException {
if (server != null) {
server.awaitTermination();