alts: add ALTS to interop tests

This commit is contained in:
Jiangtao Li 2018-02-23 12:44:18 -08:00 committed by Carl Mastrangelo
parent 21541243f8
commit d8630f2521
16 changed files with 255 additions and 129 deletions

View File

@ -1,7 +1,7 @@
description = "gRPC: ALTS" description = "gRPC: ALTS"
sourceCompatibility = 1.8 sourceCompatibility = 1.7
targetCompatibility = 1.8 targetCompatibility = 1.7
buildscript { buildscript {
repositories { repositories {

View File

@ -150,7 +150,7 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChann
@Override @Override
public TransportCreationParamsFilter create( public TransportCreationParamsFilter create(
SocketAddress serverAddress, final SocketAddress serverAddress,
final String authority, final String authority,
final String userAgent, final String userAgent,
final ProxyParameters proxy) { final ProxyParameters proxy) {

View File

@ -53,7 +53,7 @@ public abstract class AltsProtocolNegotiator implements ProtocolNegotiator {
} }
/** Creates a negotiator used for ALTS. */ /** Creates a negotiator used for ALTS. */
public static AltsProtocolNegotiator create(TsiHandshakerFactory handshakerFactory) { public static AltsProtocolNegotiator create(final TsiHandshakerFactory handshakerFactory) {
return new AltsProtocolNegotiator() { return new AltsProtocolNegotiator() {
@Override @Override
public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {

View File

@ -146,9 +146,9 @@ public final class InternalTsiFrameHandler extends ByteToMessageDecoder
} }
@Override @Override
public void flush(ChannelHandlerContext ctx) throws GeneralSecurityException { public void flush(final ChannelHandlerContext ctx) throws GeneralSecurityException {
checkState(protector != null, "Cannot write frames while the TSI handshake is in progress"); checkState(protector != null, "Cannot write frames while the TSI handshake is in progress");
ProtectedPromise aggregatePromise = final ProtectedPromise aggregatePromise =
new ProtectedPromise(ctx.channel(), ctx.executor(), pendingUnprotectedWrites.size()); new ProtectedPromise(ctx.channel(), ctx.executor(), pendingUnprotectedWrites.size());
List<ByteBuf> bufs = new ArrayList<>(pendingUnprotectedWrites.size()); List<ByteBuf> bufs = new ArrayList<>(pendingUnprotectedWrites.size());
@ -168,7 +168,14 @@ public final class InternalTsiFrameHandler extends ByteToMessageDecoder
} }
protector.protectFlush( protector.protectFlush(
bufs, b -> ctx.writeAndFlush(b, aggregatePromise.newPromise()), ctx.alloc()); bufs,
new java.util.function.Consumer<ByteBuf>() {
@Override
public void accept(ByteBuf b) {
ctx.writeAndFlush(b, aggregatePromise.newPromise());
}
},
ctx.alloc());
// We're done writing, start the flow of promise events. // We're done writing, start the flow of promise events.
@SuppressWarnings("unused") // go/futurereturn-lsc @SuppressWarnings("unused") // go/futurereturn-lsc

View File

@ -32,6 +32,7 @@ import io.grpc.alts.transportsecurity.TsiFrameProtector;
import io.grpc.alts.transportsecurity.TsiHandshaker; import io.grpc.alts.transportsecurity.TsiHandshaker;
import io.grpc.alts.transportsecurity.TsiHandshakerFactory; import io.grpc.alts.transportsecurity.TsiHandshakerFactory;
import io.grpc.alts.transportsecurity.TsiPeer; import io.grpc.alts.transportsecurity.TsiPeer;
import io.grpc.alts.transportsecurity.TsiPeer.Property;
import io.grpc.netty.GrpcHttp2ConnectionHandler; import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufAllocator;
@ -89,7 +90,7 @@ public class AltsProtocolNegotiatorTest {
private volatile InternalTsiHandshakeHandler.TsiHandshakeCompletionEvent tsiEvent; private volatile InternalTsiHandshakeHandler.TsiHandshakeCompletionEvent tsiEvent;
private ChannelHandler handler; private ChannelHandler handler;
private TsiPeer mockedTsiPeer = new TsiPeer(Collections.emptyList()); private TsiPeer mockedTsiPeer = new TsiPeer(Collections.<Property<?>>emptyList());
private AltsAuthContext mockedAltsContext = private AltsAuthContext mockedAltsContext =
new AltsAuthContext( new AltsAuthContext(
HandshakerResult.newBuilder() HandshakerResult.newBuilder()
@ -220,10 +221,15 @@ public class AltsProtocolNegotiatorTest {
assertEquals(message, unprotectedData.toString(UTF_8)); assertEquals(message, unprotectedData.toString(UTF_8));
// Protect the same message at the server. // Protect the same message at the server.
AtomicReference<ByteBuf> newlyProtectedData = new AtomicReference<>(); final AtomicReference<ByteBuf> newlyProtectedData = new AtomicReference<>();
serverProtector.protectFlush( serverProtector.protectFlush(
Collections.singletonList(unprotectedData), Collections.singletonList(unprotectedData),
b -> newlyProtectedData.set(b), new java.util.function.Consumer<ByteBuf>() {
@Override
public void accept(ByteBuf buf) {
newlyProtectedData.set(buf);
}
},
channel.alloc()); channel.alloc());
// Read the protected message at the client and verify that it matches the original message. // Read the protected message at the client and verify that it matches the original message.
@ -250,7 +256,14 @@ public class AltsProtocolNegotiatorTest {
TsiFrameProtector serverProtector = TsiFrameProtector serverProtector =
serverHandshaker.createFrameProtector(serverFrameSize, channel.alloc()); serverHandshaker.createFrameProtector(serverFrameSize, channel.alloc());
serverProtector.protectFlush( serverProtector.protectFlush(
Collections.singletonList(unprotectedData), b -> channel.writeInbound(b), channel.alloc()); Collections.singletonList(unprotectedData),
new java.util.function.Consumer<ByteBuf>() {
@Override
public void accept(ByteBuf buf) {
channel.writeInbound(buf);
}
},
channel.alloc());
channel.flushInbound(); channel.flushInbound();
// Read the protected message at the client and verify that it matches the original message. // Read the protected message at the client and verify that it matches the original message.

View File

@ -181,7 +181,16 @@ public class InternalNettyTsiHandshakerTest {
} }
private void doHandshake() throws GeneralSecurityException { private void doHandshake() throws GeneralSecurityException {
doHandshake(clientHandshaker, serverHandshaker, alloc, buf -> ref(buf)); doHandshake(
clientHandshaker,
serverHandshaker,
alloc,
new java.util.function.Function<ByteBuf, ByteBuf>() {
@Override
public ByteBuf apply(ByteBuf buf) {
return ref(buf);
}
});
} }
private ByteBuf ref(ByteBuf buf) { private ByteBuf ref(ByteBuf buf) {

View File

@ -23,6 +23,7 @@ import static io.grpc.alts.transportsecurity.ByteBufTestUtils.writeSlice;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import com.google.common.testing.GcFinalization; import com.google.common.testing.GcFinalization;
import io.grpc.alts.transportsecurity.ByteBufTestUtils.RegisterRef;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufAllocator;
import io.netty.util.ReferenceCounted; import io.netty.util.ReferenceCounted;
@ -45,6 +46,16 @@ public class AltsTsiFrameProtectorTest {
AltsTsiFrameProtector.getHeaderTypeFieldBytes() + FakeChannelCrypter.getTagBytes(); AltsTsiFrameProtector.getHeaderTypeFieldBytes() + FakeChannelCrypter.getTagBytes();
private final List<ReferenceCounted> references = new ArrayList<ReferenceCounted>(); private final List<ReferenceCounted> references = new ArrayList<ReferenceCounted>();
private final RegisterRef ref =
new RegisterRef() {
@Override
public ByteBuf register(ByteBuf buf) {
if (buf != null) {
references.add(buf);
}
return buf;
}
};
@Before @Before
public void setUp() { public void setUp() {
@ -68,7 +79,7 @@ public class AltsTsiFrameProtectorTest {
FakeChannelCrypter crypter = new FakeChannelCrypter(); FakeChannelCrypter crypter = new FakeChannelCrypter();
AltsTsiFrameProtector.Unprotector unprotector = AltsTsiFrameProtector.Unprotector unprotector =
new AltsTsiFrameProtector.Unprotector(crypter, alloc); new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf in = getDirectBuffer(AltsTsiFrameProtector.getHeaderBytes(), this::ref); ByteBuf in = getDirectBuffer(AltsTsiFrameProtector.getHeaderBytes(), ref);
in.writeIntLE(-1); in.writeIntLE(-1);
in.writeIntLE(6); in.writeIntLE(6);
try { try {
@ -90,7 +101,7 @@ public class AltsTsiFrameProtectorTest {
new AltsTsiFrameProtector.Unprotector(crypter, alloc); new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf in = ByteBuf in =
getDirectBuffer( getDirectBuffer(
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref); AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref);
in.writeIntLE(FRAME_MIN_SIZE - 1); in.writeIntLE(FRAME_MIN_SIZE - 1);
in.writeIntLE(6); in.writeIntLE(6);
try { try {
@ -112,7 +123,7 @@ public class AltsTsiFrameProtectorTest {
new AltsTsiFrameProtector.Unprotector(crypter, alloc); new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf in = ByteBuf in =
getDirectBuffer( getDirectBuffer(
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref); AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref);
in.writeIntLE( in.writeIntLE(
AltsTsiFrameProtector.getLimitMaxAllowedFrameBytes() AltsTsiFrameProtector.getLimitMaxAllowedFrameBytes()
- AltsTsiFrameProtector.getHeaderLenFieldBytes() - AltsTsiFrameProtector.getHeaderLenFieldBytes()
@ -137,7 +148,7 @@ public class AltsTsiFrameProtectorTest {
new AltsTsiFrameProtector.Unprotector(crypter, alloc); new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf in = ByteBuf in =
getDirectBuffer( getDirectBuffer(
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref); AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref);
in.writeIntLE(FRAME_MIN_SIZE); in.writeIntLE(FRAME_MIN_SIZE);
in.writeIntLE(5); in.writeIntLE(5);
try { try {
@ -159,7 +170,7 @@ public class AltsTsiFrameProtectorTest {
new AltsTsiFrameProtector.Unprotector(crypter, alloc); new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf in = ByteBuf in =
getDirectBuffer( getDirectBuffer(
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref); AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref);
in.writeIntLE(FRAME_MIN_SIZE); in.writeIntLE(FRAME_MIN_SIZE);
in.writeIntLE(6); in.writeIntLE(6);
@ -176,7 +187,7 @@ public class AltsTsiFrameProtectorTest {
FakeChannelCrypter crypter = new FakeChannelCrypter(); FakeChannelCrypter crypter = new FakeChannelCrypter();
AltsTsiFrameProtector.Unprotector unprotector = AltsTsiFrameProtector.Unprotector unprotector =
new AltsTsiFrameProtector.Unprotector(crypter, alloc); new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf emptyBuf = getDirectBuffer(0, this::ref); ByteBuf emptyBuf = getDirectBuffer(0, ref);
unprotector.unprotect(emptyBuf, out, alloc); unprotector.unprotect(emptyBuf, out, alloc);
assertThat(emptyBuf.refCnt()).isEqualTo(1); assertThat(emptyBuf.refCnt()).isEqualTo(1);
@ -193,7 +204,7 @@ public class AltsTsiFrameProtectorTest {
new AltsTsiFrameProtector.Unprotector(crypter, alloc); new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf in = ByteBuf in =
getDirectBuffer( getDirectBuffer(
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref); AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref);
in.writeIntLE( in.writeIntLE(
AltsTsiFrameProtector.getLimitMaxAllowedFrameBytes() AltsTsiFrameProtector.getLimitMaxAllowedFrameBytes()
- AltsTsiFrameProtector.getHeaderLenFieldBytes()); - AltsTsiFrameProtector.getHeaderLenFieldBytes());
@ -214,7 +225,7 @@ public class AltsTsiFrameProtectorTest {
new AltsTsiFrameProtector.Unprotector(crypter, alloc); new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf in = ByteBuf in =
getDirectBuffer( getDirectBuffer(
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref); AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref);
in.writeIntLE(FRAME_MIN_SIZE); in.writeIntLE(FRAME_MIN_SIZE);
in.writeIntLE(6); in.writeIntLE(6);
ByteBuf in1 = in.readSlice(AltsTsiFrameProtector.getHeaderBytes() - 1); ByteBuf in1 = in.readSlice(AltsTsiFrameProtector.getHeaderBytes() - 1);
@ -238,7 +249,7 @@ public class AltsTsiFrameProtectorTest {
new AltsTsiFrameProtector.Unprotector(crypter, alloc); new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf in = ByteBuf in =
getDirectBuffer( getDirectBuffer(
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref); AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref);
in.writeIntLE(FRAME_MIN_SIZE - 1); in.writeIntLE(FRAME_MIN_SIZE - 1);
in.writeIntLE(6); in.writeIntLE(6);
ByteBuf in1 = in.readSlice(AltsTsiFrameProtector.getHeaderBytes() - 1); ByteBuf in1 = in.readSlice(AltsTsiFrameProtector.getHeaderBytes() - 1);
@ -267,13 +278,13 @@ public class AltsTsiFrameProtectorTest {
FakeChannelCrypter crypter = new FakeChannelCrypter(); FakeChannelCrypter crypter = new FakeChannelCrypter();
AltsTsiFrameProtector.Unprotector unprotector = AltsTsiFrameProtector.Unprotector unprotector =
new AltsTsiFrameProtector.Unprotector(crypter, alloc); new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf plain = getRandom(payloadBytes, this::ref); ByteBuf plain = getRandom(payloadBytes, ref);
ByteBuf outFrame = ByteBuf outFrame =
getDirectBuffer( getDirectBuffer(
AltsTsiFrameProtector.getHeaderBytes() AltsTsiFrameProtector.getHeaderBytes()
+ payloadBytes + payloadBytes
+ FakeChannelCrypter.getTagBytes(), + FakeChannelCrypter.getTagBytes(),
this::ref); ref);
outFrame.writeIntLE( outFrame.writeIntLE(
AltsTsiFrameProtector.getHeaderTypeFieldBytes() AltsTsiFrameProtector.getHeaderTypeFieldBytes()
@ -305,12 +316,12 @@ public class AltsTsiFrameProtectorTest {
AltsTsiFrameProtector.Unprotector unprotector = AltsTsiFrameProtector.Unprotector unprotector =
new AltsTsiFrameProtector.Unprotector(crypter, alloc); new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf plain = getRandom(payloadBytes, this::ref); ByteBuf plain = getRandom(payloadBytes, ref);
ByteBuf outFrame = ByteBuf outFrame =
getDirectBuffer( getDirectBuffer(
2 * (AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes()) 2 * (AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes())
+ payloadBytes, + payloadBytes,
this::ref); ref);
outFrame.writeIntLE( outFrame.writeIntLE(
AltsTsiFrameProtector.getHeaderTypeFieldBytes() AltsTsiFrameProtector.getHeaderTypeFieldBytes()
@ -353,13 +364,13 @@ public class AltsTsiFrameProtectorTest {
AltsTsiFrameProtector.Unprotector unprotector = AltsTsiFrameProtector.Unprotector unprotector =
new AltsTsiFrameProtector.Unprotector(crypter, alloc); new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf plain = getRandom(payloadBytes, this::ref); ByteBuf plain = getRandom(payloadBytes, ref);
ByteBuf protectedBuf = ByteBuf protectedBuf =
getDirectBuffer( getDirectBuffer(
2 * (AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes()) 2 * (AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes())
+ payloadBytes + payloadBytes
+ AltsTsiFrameProtector.getHeaderBytes(), + AltsTsiFrameProtector.getHeaderBytes(),
this::ref); ref);
protectedBuf.writeIntLE( protectedBuf.writeIntLE(
AltsTsiFrameProtector.getHeaderTypeFieldBytes() AltsTsiFrameProtector.getHeaderTypeFieldBytes()
@ -421,13 +432,13 @@ public class AltsTsiFrameProtectorTest {
AltsTsiFrameProtector.Unprotector unprotector = AltsTsiFrameProtector.Unprotector unprotector =
new AltsTsiFrameProtector.Unprotector(crypter, alloc); new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf plain = getRandom(payloadBytes, this::ref); ByteBuf plain = getRandom(payloadBytes, ref);
ByteBuf protectedBuf = ByteBuf protectedBuf =
getDirectBuffer( getDirectBuffer(
2 * (AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes()) 2 * (AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes())
+ payloadBytes + payloadBytes
+ AltsTsiFrameProtector.getHeaderBytes(), + AltsTsiFrameProtector.getHeaderBytes(),
this::ref); ref);
protectedBuf.writeIntLE( protectedBuf.writeIntLE(
AltsTsiFrameProtector.getHeaderTypeFieldBytes() AltsTsiFrameProtector.getHeaderTypeFieldBytes()

View File

@ -22,6 +22,7 @@ import com.google.common.testing.GcFinalization;
import io.grpc.alts.Handshaker.HandshakeProtocol; import io.grpc.alts.Handshaker.HandshakeProtocol;
import io.grpc.alts.Handshaker.HandshakerReq; import io.grpc.alts.Handshaker.HandshakerReq;
import io.grpc.alts.Handshaker.HandshakerResp; import io.grpc.alts.Handshaker.HandshakerResp;
import io.grpc.alts.transportsecurity.ByteBufTestUtils.RegisterRef;
import io.grpc.alts.transportsecurity.TsiTest.Handshakers; import io.grpc.alts.transportsecurity.TsiTest.Handshakers;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.util.ReferenceCounted; import io.netty.util.ReferenceCounted;
@ -45,6 +46,16 @@ public class AltsTsiTest {
private final List<ReferenceCounted> references = new ArrayList<>(); private final List<ReferenceCounted> references = new ArrayList<>();
private AltsHandshakerClient client; private AltsHandshakerClient client;
private AltsHandshakerClient server; private AltsHandshakerClient server;
private final RegisterRef ref =
new RegisterRef() {
@Override
public ByteBuf register(ByteBuf buf) {
if (buf != null) {
references.add(buf);
}
return buf;
}
};
@Before @Before
public void setUp() throws Exception { public void setUp() throws Exception {
@ -101,47 +112,47 @@ public class AltsTsiTest {
@Test @Test
public void pingPong() throws GeneralSecurityException { public void pingPong() throws GeneralSecurityException {
TsiTest.pingPongTest(newHandshakers(), this::ref); TsiTest.pingPongTest(newHandshakers(), ref);
} }
@Test @Test
public void pingPongExactFrameSize() throws GeneralSecurityException { public void pingPongExactFrameSize() throws GeneralSecurityException {
TsiTest.pingPongExactFrameSizeTest(newHandshakers(), this::ref); TsiTest.pingPongExactFrameSizeTest(newHandshakers(), ref);
} }
@Test @Test
public void pingPongSmallBuffer() throws GeneralSecurityException { public void pingPongSmallBuffer() throws GeneralSecurityException {
TsiTest.pingPongSmallBufferTest(newHandshakers(), this::ref); TsiTest.pingPongSmallBufferTest(newHandshakers(), ref);
} }
@Test @Test
public void pingPongSmallFrame() throws GeneralSecurityException { public void pingPongSmallFrame() throws GeneralSecurityException {
TsiTest.pingPongSmallFrameTest(OVERHEAD, newHandshakers(), this::ref); TsiTest.pingPongSmallFrameTest(OVERHEAD, newHandshakers(), ref);
} }
@Test @Test
public void pingPongSmallFrameSmallBuffer() throws GeneralSecurityException { public void pingPongSmallFrameSmallBuffer() throws GeneralSecurityException {
TsiTest.pingPongSmallFrameSmallBufferTest(OVERHEAD, newHandshakers(), this::ref); TsiTest.pingPongSmallFrameSmallBufferTest(OVERHEAD, newHandshakers(), ref);
} }
@Test @Test
public void corruptedCounter() throws GeneralSecurityException { public void corruptedCounter() throws GeneralSecurityException {
TsiTest.corruptedCounterTest(newHandshakers(), this::ref); TsiTest.corruptedCounterTest(newHandshakers(), ref);
} }
@Test @Test
public void corruptedCiphertext() throws GeneralSecurityException { public void corruptedCiphertext() throws GeneralSecurityException {
TsiTest.corruptedCiphertextTest(newHandshakers(), this::ref); TsiTest.corruptedCiphertextTest(newHandshakers(), ref);
} }
@Test @Test
public void corruptedTag() throws GeneralSecurityException { public void corruptedTag() throws GeneralSecurityException {
TsiTest.corruptedTagTest(newHandshakers(), this::ref); TsiTest.corruptedTagTest(newHandshakers(), ref);
} }
@Test @Test
public void reflectedCiphertext() throws GeneralSecurityException { public void reflectedCiphertext() throws GeneralSecurityException {
TsiTest.reflectedCiphertextTest(newHandshakers(), this::ref); TsiTest.reflectedCiphertextTest(newHandshakers(), ref);
} }
private static class MockAltsHandshakerStub extends AltsHandshakerStub { private static class MockAltsHandshakerStub extends AltsHandshakerStub {
@ -184,11 +195,4 @@ public class AltsTsiTest {
@Override @Override
public void close() {} public void close() {}
} }
private ByteBuf ref(ByteBuf buf) {
if (buf != null) {
references.add(buf);
}
return buf;
}
} }

View File

@ -22,6 +22,7 @@ import static io.grpc.alts.transportsecurity.ByteBufTestUtils.getRandom;
import static java.nio.charset.StandardCharsets.UTF_8; import static java.nio.charset.StandardCharsets.UTF_8;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import io.grpc.alts.transportsecurity.ByteBufTestUtils.RegisterRef;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.util.ReferenceCounted; import io.netty.util.ReferenceCounted;
@ -39,6 +40,16 @@ public abstract class ChannelCrypterNettyTestBase {
protected final List<ReferenceCounted> references = new ArrayList<>(); protected final List<ReferenceCounted> references = new ArrayList<>();
public ChannelCrypterNetty client; public ChannelCrypterNetty client;
public ChannelCrypterNetty server; public ChannelCrypterNetty server;
private final RegisterRef ref =
new RegisterRef() {
@Override
public ByteBuf register(ByteBuf buf) {
if (buf != null) {
references.add(buf);
}
return buf;
}
};
static final class FrameEncrypt { static final class FrameEncrypt {
List<ByteBuf> plain; List<ByteBuf> plain;
@ -54,10 +65,10 @@ public abstract class ChannelCrypterNettyTestBase {
FrameEncrypt createFrameEncrypt(String message) { FrameEncrypt createFrameEncrypt(String message) {
byte[] messageBytes = message.getBytes(UTF_8); byte[] messageBytes = message.getBytes(UTF_8);
FrameEncrypt frame = new FrameEncrypt(); FrameEncrypt frame = new FrameEncrypt();
ByteBuf plain = getDirectBuffer(messageBytes.length, this::ref); ByteBuf plain = getDirectBuffer(messageBytes.length, ref);
plain.writeBytes(messageBytes); plain.writeBytes(messageBytes);
frame.plain = Collections.singletonList(plain); frame.plain = Collections.singletonList(plain);
frame.out = getDirectBuffer(messageBytes.length + client.getSuffixLength(), this::ref); frame.out = getDirectBuffer(messageBytes.length + client.getSuffixLength(), ref);
return frame; return frame;
} }
@ -68,7 +79,7 @@ public abstract class ChannelCrypterNettyTestBase {
frameDecrypt.ciphertext = frameDecrypt.ciphertext =
Collections.singletonList(out.slice(out.readerIndex(), out.readableBytes() - tagLen)); Collections.singletonList(out.slice(out.readerIndex(), out.readableBytes() - tagLen));
frameDecrypt.tag = out.slice(out.readerIndex() + out.readableBytes() - tagLen, tagLen); frameDecrypt.tag = out.slice(out.readerIndex() + out.readableBytes() - tagLen, tagLen);
frameDecrypt.out = getDirectBuffer(out.readableBytes(), this::ref); frameDecrypt.out = getDirectBuffer(out.readableBytes(), ref);
return frameDecrypt; return frameDecrypt;
} }
@ -87,9 +98,9 @@ public abstract class ChannelCrypterNettyTestBase {
@Test @Test
public void encryptDecryptLarge() throws GeneralSecurityException { public void encryptDecryptLarge() throws GeneralSecurityException {
FrameEncrypt frameEncrypt = new FrameEncrypt(); FrameEncrypt frameEncrypt = new FrameEncrypt();
ByteBuf plain = getRandom(17 * 1024, this::ref); ByteBuf plain = getRandom(17 * 1024, ref);
frameEncrypt.plain = Collections.singletonList(plain); frameEncrypt.plain = Collections.singletonList(plain);
frameEncrypt.out = getDirectBuffer(plain.readableBytes() + client.getSuffixLength(), this::ref); frameEncrypt.out = getDirectBuffer(plain.readableBytes() + client.getSuffixLength(), ref);
client.encrypt(frameEncrypt.out, frameEncrypt.plain); client.encrypt(frameEncrypt.out, frameEncrypt.plain);
FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt); FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt);
@ -120,13 +131,13 @@ public abstract class ChannelCrypterNettyTestBase {
int lastLen = 2; int lastLen = 2;
byte[] messageBytes = message.getBytes(UTF_8); byte[] messageBytes = message.getBytes(UTF_8);
FrameEncrypt frameEncrypt = new FrameEncrypt(); FrameEncrypt frameEncrypt = new FrameEncrypt();
ByteBuf plain1 = getDirectBuffer(messageBytes.length - lastLen, this::ref); ByteBuf plain1 = getDirectBuffer(messageBytes.length - lastLen, ref);
ByteBuf plain2 = getDirectBuffer(lastLen, this::ref); ByteBuf plain2 = getDirectBuffer(lastLen, ref);
plain1.writeBytes(messageBytes, 0, messageBytes.length - lastLen); plain1.writeBytes(messageBytes, 0, messageBytes.length - lastLen);
plain2.writeBytes(messageBytes, messageBytes.length - lastLen, lastLen); plain2.writeBytes(messageBytes, messageBytes.length - lastLen, lastLen);
ByteBuf plain = Unpooled.wrappedBuffer(plain1, plain2); ByteBuf plain = Unpooled.wrappedBuffer(plain1, plain2);
frameEncrypt.plain = Collections.singletonList(plain); frameEncrypt.plain = Collections.singletonList(plain);
frameEncrypt.out = getDirectBuffer(messageBytes.length + client.getSuffixLength(), this::ref); frameEncrypt.out = getDirectBuffer(messageBytes.length + client.getSuffixLength(), ref);
client.encrypt(frameEncrypt.out, frameEncrypt.plain); client.encrypt(frameEncrypt.out, frameEncrypt.plain);
@ -134,14 +145,14 @@ public abstract class ChannelCrypterNettyTestBase {
FrameDecrypt frameDecrypt = new FrameDecrypt(); FrameDecrypt frameDecrypt = new FrameDecrypt();
ByteBuf out = frameEncrypt.out; ByteBuf out = frameEncrypt.out;
int outLen = out.readableBytes(); int outLen = out.readableBytes();
ByteBuf cipher1 = getDirectBuffer(outLen - lastLen - tagLen, this::ref); ByteBuf cipher1 = getDirectBuffer(outLen - lastLen - tagLen, ref);
ByteBuf cipher2 = getDirectBuffer(lastLen, this::ref); ByteBuf cipher2 = getDirectBuffer(lastLen, ref);
cipher1.writeBytes(out, 0, outLen - lastLen - tagLen); cipher1.writeBytes(out, 0, outLen - lastLen - tagLen);
cipher2.writeBytes(out, outLen - tagLen - lastLen, lastLen); cipher2.writeBytes(out, outLen - tagLen - lastLen, lastLen);
ByteBuf cipher = Unpooled.wrappedBuffer(cipher1, cipher2); ByteBuf cipher = Unpooled.wrappedBuffer(cipher1, cipher2);
frameDecrypt.ciphertext = Collections.singletonList(cipher); frameDecrypt.ciphertext = Collections.singletonList(cipher);
frameDecrypt.tag = out.slice(out.readerIndex() + out.readableBytes() - tagLen, tagLen); frameDecrypt.tag = out.slice(out.readerIndex() + out.readableBytes() - tagLen, tagLen);
frameDecrypt.out = getDirectBuffer(out.readableBytes(), this::ref); frameDecrypt.out = getDirectBuffer(out.readableBytes(), ref);
server.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext); server.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext);
assertThat(frameEncrypt.plain.get(0).slice(0, frameDecrypt.out.readableBytes())) assertThat(frameEncrypt.plain.get(0).slice(0, frameDecrypt.out.readableBytes()))
@ -212,11 +223,4 @@ public abstract class ChannelCrypterNettyTestBase {
assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_MESSAGE); assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_MESSAGE);
} }
} }
private ByteBuf ref(ByteBuf buf) {
if (buf != null) {
references.add(buf);
}
return buf;
}
} }

View File

@ -52,9 +52,10 @@ public final class FakeChannelCrypter implements ChannelCrypterNetty {
for (ByteBuf buf : ciphertext) { for (ByteBuf buf : ciphertext) {
out.writeBytes(buf); out.writeBytes(buf);
} }
boolean tagValid = tag.forEachByte((byte value) -> value == TAG_BYTE) == -1; while (tag.isReadable()) {
if (!tagValid) { if (tag.readByte() != TAG_BYTE) {
throw new AEADBadTagException("Tag mismatch!"); throw new AEADBadTagException("Tag mismatch!");
}
} }
} }

View File

@ -19,6 +19,7 @@ package io.grpc.alts.transportsecurity;
import static java.nio.charset.StandardCharsets.UTF_8; import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import io.grpc.alts.transportsecurity.TsiPeer.Property;
import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufAllocator;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.security.GeneralSecurityException; import java.security.GeneralSecurityException;
@ -203,7 +204,7 @@ public class FakeTsiHandshaker implements TsiHandshaker {
@Override @Override
public TsiPeer extractPeer() { public TsiPeer extractPeer() {
return new TsiPeer(Collections.emptyList()); return new TsiPeer(Collections.<Property<?>>emptyList());
} }
@Override @Override

View File

@ -21,6 +21,7 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
import com.google.common.testing.GcFinalization; import com.google.common.testing.GcFinalization;
import io.grpc.alts.transportsecurity.ByteBufTestUtils.RegisterRef;
import io.grpc.alts.transportsecurity.TsiTest.Handshakers; import io.grpc.alts.transportsecurity.TsiTest.Handshakers;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.util.ReferenceCounted; import io.netty.util.ReferenceCounted;
@ -44,6 +45,16 @@ public class FakeTsiTest {
FakeChannelCrypter.getTagBytes() + AltsTsiFrameProtector.getHeaderBytes(); FakeChannelCrypter.getTagBytes() + AltsTsiFrameProtector.getHeaderBytes();
private final List<ReferenceCounted> references = new ArrayList<>(); private final List<ReferenceCounted> references = new ArrayList<>();
private final RegisterRef ref =
new RegisterRef() {
@Override
public ByteBuf register(ByteBuf buf) {
if (buf != null) {
references.add(buf);
}
return buf;
}
};
private static Handshakers newHandshakers() { private static Handshakers newHandshakers() {
TsiHandshaker clientHandshaker = FakeTsiHandshaker.newFakeHandshakerClient(); TsiHandshaker clientHandshaker = FakeTsiHandshaker.newFakeHandshakerClient();
@ -157,53 +168,46 @@ public class FakeTsiTest {
@Test @Test
public void pingPong() throws GeneralSecurityException { public void pingPong() throws GeneralSecurityException {
TsiTest.pingPongTest(newHandshakers(), this::ref); TsiTest.pingPongTest(newHandshakers(), ref);
} }
@Test @Test
public void pingPongExactFrameSize() throws GeneralSecurityException { public void pingPongExactFrameSize() throws GeneralSecurityException {
TsiTest.pingPongExactFrameSizeTest(newHandshakers(), this::ref); TsiTest.pingPongExactFrameSizeTest(newHandshakers(), ref);
} }
@Test @Test
public void pingPongSmallBuffer() throws GeneralSecurityException { public void pingPongSmallBuffer() throws GeneralSecurityException {
TsiTest.pingPongSmallBufferTest(newHandshakers(), this::ref); TsiTest.pingPongSmallBufferTest(newHandshakers(), ref);
} }
@Test @Test
public void pingPongSmallFrame() throws GeneralSecurityException { public void pingPongSmallFrame() throws GeneralSecurityException {
TsiTest.pingPongSmallFrameTest(OVERHEAD, newHandshakers(), this::ref); TsiTest.pingPongSmallFrameTest(OVERHEAD, newHandshakers(), ref);
} }
@Test @Test
public void pingPongSmallFrameSmallBuffer() throws GeneralSecurityException { public void pingPongSmallFrameSmallBuffer() throws GeneralSecurityException {
TsiTest.pingPongSmallFrameSmallBufferTest(OVERHEAD, newHandshakers(), this::ref); TsiTest.pingPongSmallFrameSmallBufferTest(OVERHEAD, newHandshakers(), ref);
} }
@Test @Test
public void corruptedCounter() throws GeneralSecurityException { public void corruptedCounter() throws GeneralSecurityException {
TsiTest.corruptedCounterTest(newHandshakers(), this::ref); TsiTest.corruptedCounterTest(newHandshakers(), ref);
} }
@Test @Test
public void corruptedCiphertext() throws GeneralSecurityException { public void corruptedCiphertext() throws GeneralSecurityException {
TsiTest.corruptedCiphertextTest(newHandshakers(), this::ref); TsiTest.corruptedCiphertextTest(newHandshakers(), ref);
} }
@Test @Test
public void corruptedTag() throws GeneralSecurityException { public void corruptedTag() throws GeneralSecurityException {
TsiTest.corruptedTagTest(newHandshakers(), this::ref); TsiTest.corruptedTagTest(newHandshakers(), ref);
} }
@Test @Test
public void reflectedCiphertext() throws GeneralSecurityException { public void reflectedCiphertext() throws GeneralSecurityException {
TsiTest.reflectedCiphertextTest(newHandshakers(), this::ref); TsiTest.reflectedCiphertextTest(newHandshakers(), ref);
}
private ByteBuf ref(ByteBuf buf) {
if (buf != null) {
references.add(buf);
}
return buf;
} }
} }

View File

@ -123,10 +123,18 @@ public final class TsiTest {
throws GeneralSecurityException { throws GeneralSecurityException {
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8)); ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
List<ByteBuf> protectOut = new ArrayList<>(); final List<ByteBuf> protectOut = new ArrayList<>();
List<Object> unprotectOut = new ArrayList<>(); List<Object> unprotectOut = new ArrayList<>();
sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc); sender.protectFlush(
Collections.singletonList(plaintextBuffer),
new java.util.function.Consumer<ByteBuf>() {
@Override
public void accept(ByteBuf buf) {
protectOut.add(buf);
}
},
alloc);
assertThat(protectOut.size()).isEqualTo(1); assertThat(protectOut.size()).isEqualTo(1);
ByteBuf protect = ref.register(protectOut.get(0)); ByteBuf protect = ref.register(protectOut.get(0));
@ -249,10 +257,18 @@ public final class TsiTest {
String message = "hello world"; String message = "hello world";
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8)); ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
List<ByteBuf> protectOut = new ArrayList<>(); final List<ByteBuf> protectOut = new ArrayList<>();
List<Object> unprotectOut = new ArrayList<>(); List<Object> unprotectOut = new ArrayList<>();
sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc); sender.protectFlush(
Collections.singletonList(plaintextBuffer),
new java.util.function.Consumer<ByteBuf>() {
@Override
public void accept(ByteBuf buf) {
protectOut.add(buf);
}
},
alloc);
assertThat(protectOut.size()).isEqualTo(1); assertThat(protectOut.size()).isEqualTo(1);
ByteBuf protect = ref.register(protectOut.get(0)); ByteBuf protect = ref.register(protectOut.get(0));
@ -282,10 +298,18 @@ public final class TsiTest {
String message = "hello world"; String message = "hello world";
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8)); ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
List<ByteBuf> protectOut = new ArrayList<>(); final List<ByteBuf> protectOut = new ArrayList<>();
List<Object> unprotectOut = new ArrayList<>(); List<Object> unprotectOut = new ArrayList<>();
sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc); sender.protectFlush(
Collections.singletonList(plaintextBuffer),
new java.util.function.Consumer<ByteBuf>() {
@Override
public void accept(ByteBuf buf) {
protectOut.add(buf);
}
},
alloc);
assertThat(protectOut.size()).isEqualTo(1); assertThat(protectOut.size()).isEqualTo(1);
ByteBuf protect = ref.register(protectOut.get(0)); ByteBuf protect = ref.register(protectOut.get(0));
@ -313,10 +337,18 @@ public final class TsiTest {
String message = "hello world"; String message = "hello world";
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8)); ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
List<ByteBuf> protectOut = new ArrayList<>(); final List<ByteBuf> protectOut = new ArrayList<>();
List<Object> unprotectOut = new ArrayList<>(); List<Object> unprotectOut = new ArrayList<>();
sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc); sender.protectFlush(
Collections.singletonList(plaintextBuffer),
new java.util.function.Consumer<ByteBuf>() {
@Override
public void accept(ByteBuf buf) {
protectOut.add(buf);
}
},
alloc);
assertThat(protectOut.size()).isEqualTo(1); assertThat(protectOut.size()).isEqualTo(1);
ByteBuf protect = ref.register(protectOut.get(0)); ByteBuf protect = ref.register(protectOut.get(0));
@ -344,10 +376,18 @@ public final class TsiTest {
String message = "hello world"; String message = "hello world";
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8)); ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
List<ByteBuf> protectOut = new ArrayList<>(); final List<ByteBuf> protectOut = new ArrayList<>();
List<Object> unprotectOut = new ArrayList<>(); List<Object> unprotectOut = new ArrayList<>();
sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc); sender.protectFlush(
Collections.singletonList(plaintextBuffer),
new java.util.function.Consumer<ByteBuf>() {
@Override
public void accept(ByteBuf buf) {
protectOut.add(buf);
}
},
alloc);
assertThat(protectOut.size()).isEqualTo(1); assertThat(protectOut.size()).isEqualTo(1);
ByteBuf protect = ref.register(protectOut.get(0)); ByteBuf protect = ref.register(protectOut.get(0));

View File

@ -14,7 +14,8 @@ buildscript {
} }
dependencies { dependencies {
compile project(':grpc-auth'), compile project(':grpc-alts'),
project(':grpc-auth'),
project(':grpc-core'), project(':grpc-core'),
project(':grpc-netty'), project(':grpc-netty'),
project(':grpc-okhttp'), project(':grpc-okhttp'),

View File

@ -19,6 +19,7 @@ package io.grpc.testing.integration;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.Files; import com.google.common.io.Files;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.alts.AltsChannelBuilder;
import io.grpc.internal.AbstractManagedChannelImplBuilder; import io.grpc.internal.AbstractManagedChannelImplBuilder;
import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil;
import io.grpc.internal.testing.TestUtils; import io.grpc.internal.testing.TestUtils;
@ -76,6 +77,7 @@ public class TestServiceClient {
private int serverPort = 8080; private int serverPort = 8080;
private String testCase = "empty_unary"; private String testCase = "empty_unary";
private boolean useTls = true; private boolean useTls = true;
private boolean useAlts = false;
private boolean useTestCa; private boolean useTestCa;
private boolean useOkHttp; private boolean useOkHttp;
private String defaultServiceAccount; private String defaultServiceAccount;
@ -116,6 +118,8 @@ public class TestServiceClient {
testCase = value; testCase = value;
} else if ("use_tls".equals(key)) { } else if ("use_tls".equals(key)) {
useTls = Boolean.parseBoolean(value); useTls = Boolean.parseBoolean(value);
} else if ("use_alts".equals(key)) {
useAlts = Boolean.parseBoolean(value);
} else if ("use_test_ca".equals(key)) { } else if ("use_test_ca".equals(key)) {
useTestCa = Boolean.parseBoolean(value); useTestCa = Boolean.parseBoolean(value);
} else if ("use_okhttp".equals(key)) { } else if ("use_okhttp".equals(key)) {
@ -140,6 +144,9 @@ public class TestServiceClient {
break; break;
} }
} }
if (useAlts) {
useTls = false;
}
if (usage) { if (usage) {
TestServiceClient c = new TestServiceClient(); TestServiceClient c = new TestServiceClient();
System.out.println( System.out.println(
@ -153,6 +160,8 @@ public class TestServiceClient {
+ "\n Valid options:" + "\n Valid options:"
+ validTestCasesHelpText() + validTestCasesHelpText()
+ "\n --use_tls=true|false Whether to use TLS. Default " + c.useTls + "\n --use_tls=true|false Whether to use TLS. Default " + c.useTls
+ "\n --use_alts=true|false Whether to use ALTS. Enable ALTS will disable TLS."
+ "\n Default " + c.useTls
+ "\n --use_test_ca=true|false Whether to trust our fake CA. Requires --use_tls=true " + "\n --use_test_ca=true|false Whether to trust our fake CA. Requires --use_tls=true "
+ "\n to have effect. Default " + c.useTestCa + "\n to have effect. Default " + c.useTestCa
+ "\n --use_okhttp=true|false Whether to use OkHttp instead of Netty. Default " + "\n --use_okhttp=true|false Whether to use OkHttp instead of Netty. Default "
@ -317,6 +326,9 @@ public class TestServiceClient {
private class Tester extends AbstractInteropTest { private class Tester extends AbstractInteropTest {
@Override @Override
protected ManagedChannel createChannel() { protected ManagedChannel createChannel() {
if (useAlts) {
return AltsChannelBuilder.forAddress(serverHost, serverPort).build();
}
AbstractManagedChannelImplBuilder<?> builder; AbstractManagedChannelImplBuilder<?> builder;
if (!useOkHttp) { if (!useOkHttp) {
SslContext sslContext = null; SslContext sslContext = null;

View File

@ -20,6 +20,7 @@ import com.google.common.annotations.VisibleForTesting;
import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.MoreExecutors;
import io.grpc.Server; import io.grpc.Server;
import io.grpc.ServerInterceptors; import io.grpc.ServerInterceptors;
import io.grpc.alts.AltsServerBuilder;
import io.grpc.internal.testing.TestUtils; import io.grpc.internal.testing.TestUtils;
import io.grpc.netty.GrpcSslContexts; import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.NettyServerBuilder; import io.grpc.netty.NettyServerBuilder;
@ -28,34 +29,32 @@ import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
/** /** Server that manages startup/shutdown of a single {@code TestService}. */
* Server that manages startup/shutdown of a single {@code TestService}.
*/
public class TestServiceServer { public class TestServiceServer {
/** /** The main application allowing this server to be launched from the command line. */
* The main application allowing this server to be launched from the command line.
*/
public static void main(String[] args) throws Exception { public static void main(String[] args) throws Exception {
final TestServiceServer server = new TestServiceServer(); final TestServiceServer server = new TestServiceServer();
server.parseArgs(args); server.parseArgs(args);
if (server.useTls) { if (server.useTls) {
System.out.println( System.out.println(
"\nUsing fake CA for TLS certificate. Test clients should expect host\n" "\nUsing fake CA for TLS certificate. Test clients should expect host\n"
+ "*.test.google.fr and our test CA. For the Java test client binary, use:\n" + "*.test.google.fr and our test CA. For the Java test client binary, use:\n"
+ "--server_host_override=foo.test.google.fr --use_test_ca=true\n"); + "--server_host_override=foo.test.google.fr --use_test_ca=true\n");
} }
Runtime.getRuntime().addShutdownHook(new Thread() { Runtime.getRuntime()
@Override .addShutdownHook(
public void run() { new Thread() {
try { @Override
System.out.println("Shutting down"); public void run() {
server.stop(); try {
} catch (Exception e) { System.out.println("Shutting down");
e.printStackTrace(); server.stop();
} } catch (Exception e) {
} e.printStackTrace();
}); }
}
});
server.start(); server.start();
System.out.println("Server started on port " + server.port); System.out.println("Server started on port " + server.port);
server.blockUntilShutdown(); server.blockUntilShutdown();
@ -63,6 +62,7 @@ public class TestServiceServer {
private int port = 8080; private int port = 8080;
private boolean useTls = true; private boolean useTls = true;
private boolean useAlts = false;
private ScheduledExecutorService executor; private ScheduledExecutorService executor;
private Server server; private Server server;
@ -92,6 +92,8 @@ public class TestServiceServer {
port = Integer.parseInt(value); port = Integer.parseInt(value);
} else if ("use_tls".equals(key)) { } else if ("use_tls".equals(key)) {
useTls = Boolean.parseBoolean(value); useTls = Boolean.parseBoolean(value);
} else if ("use_alts".equals(key)) {
useAlts = Boolean.parseBoolean(value);
} else if ("grpc_version".equals(key)) { } else if ("grpc_version".equals(key)) {
if (!"2".equals(value)) { if (!"2".equals(value)) {
System.err.println("Only grpc version 2 is supported"); System.err.println("Only grpc version 2 is supported");
@ -104,13 +106,18 @@ public class TestServiceServer {
break; break;
} }
} }
if (useAlts) {
useTls = false;
}
if (usage) { if (usage) {
TestServiceServer s = new TestServiceServer(); TestServiceServer s = new TestServiceServer();
System.out.println( System.out.println(
"Usage: [ARGS...]" "Usage: [ARGS...]"
+ "\n" + "\n"
+ "\n --port=PORT Port to connect to. Default " + s.port + "\n --port=PORT Port to connect to. Default " + s.port
+ "\n --use_tls=true|false Whether to use TLS. Default " + s.useTls + "\n --use_tls=true|false Whether to use TLS. Default " + s.useTls
+ "\n --use_alts=true|false Whether to use ALTS. Enable ALTS will disable TLS."
+ "\n Default " + s.useAlts
); );
System.exit(1); System.exit(1);
} }
@ -120,17 +127,31 @@ public class TestServiceServer {
void start() throws Exception { void start() throws Exception {
executor = Executors.newSingleThreadScheduledExecutor(); executor = Executors.newSingleThreadScheduledExecutor();
SslContext sslContext = null; SslContext sslContext = null;
if (useTls) { if (useAlts) {
sslContext = GrpcSslContexts.forServer( server =
TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key")).build(); AltsServerBuilder.forPort(port)
.addService(
ServerInterceptors.intercept(
new TestServiceImpl(executor), TestServiceImpl.interceptors()))
.build()
.start();
} else {
if (useTls) {
sslContext =
GrpcSslContexts.forServer(
TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key"))
.build();
}
server =
NettyServerBuilder.forPort(port)
.sslContext(sslContext)
.maxMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE)
.addService(
ServerInterceptors.intercept(
new TestServiceImpl(executor), TestServiceImpl.interceptors()))
.build()
.start();
} }
server = NettyServerBuilder.forPort(port)
.sslContext(sslContext)
.maxMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE)
.addService(ServerInterceptors.intercept(
new TestServiceImpl(executor),
TestServiceImpl.interceptors()))
.build().start();
} }
@VisibleForTesting @VisibleForTesting
@ -147,9 +168,7 @@ public class TestServiceServer {
return server.getPort(); return server.getPort();
} }
/** /** Await termination on the main thread since the grpc library uses daemon threads. */
* Await termination on the main thread since the grpc library uses daemon threads.
*/
private void blockUntilShutdown() throws InterruptedException { private void blockUntilShutdown() throws InterruptedException {
if (server != null) { if (server != null) {
server.awaitTermination(); server.awaitTermination();