mirror of https://github.com/grpc/grpc-java.git
alts: add ALTS to interop tests
This commit is contained in:
parent
21541243f8
commit
d8630f2521
|
|
@ -1,7 +1,7 @@
|
||||||
description = "gRPC: ALTS"
|
description = "gRPC: ALTS"
|
||||||
|
|
||||||
sourceCompatibility = 1.8
|
sourceCompatibility = 1.7
|
||||||
targetCompatibility = 1.8
|
targetCompatibility = 1.7
|
||||||
|
|
||||||
buildscript {
|
buildscript {
|
||||||
repositories {
|
repositories {
|
||||||
|
|
|
||||||
|
|
@ -150,7 +150,7 @@ public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChann
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public TransportCreationParamsFilter create(
|
public TransportCreationParamsFilter create(
|
||||||
SocketAddress serverAddress,
|
final SocketAddress serverAddress,
|
||||||
final String authority,
|
final String authority,
|
||||||
final String userAgent,
|
final String userAgent,
|
||||||
final ProxyParameters proxy) {
|
final ProxyParameters proxy) {
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ public abstract class AltsProtocolNegotiator implements ProtocolNegotiator {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Creates a negotiator used for ALTS. */
|
/** Creates a negotiator used for ALTS. */
|
||||||
public static AltsProtocolNegotiator create(TsiHandshakerFactory handshakerFactory) {
|
public static AltsProtocolNegotiator create(final TsiHandshakerFactory handshakerFactory) {
|
||||||
return new AltsProtocolNegotiator() {
|
return new AltsProtocolNegotiator() {
|
||||||
@Override
|
@Override
|
||||||
public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
|
public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
|
||||||
|
|
|
||||||
|
|
@ -146,9 +146,9 @@ public final class InternalTsiFrameHandler extends ByteToMessageDecoder
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void flush(ChannelHandlerContext ctx) throws GeneralSecurityException {
|
public void flush(final ChannelHandlerContext ctx) throws GeneralSecurityException {
|
||||||
checkState(protector != null, "Cannot write frames while the TSI handshake is in progress");
|
checkState(protector != null, "Cannot write frames while the TSI handshake is in progress");
|
||||||
ProtectedPromise aggregatePromise =
|
final ProtectedPromise aggregatePromise =
|
||||||
new ProtectedPromise(ctx.channel(), ctx.executor(), pendingUnprotectedWrites.size());
|
new ProtectedPromise(ctx.channel(), ctx.executor(), pendingUnprotectedWrites.size());
|
||||||
|
|
||||||
List<ByteBuf> bufs = new ArrayList<>(pendingUnprotectedWrites.size());
|
List<ByteBuf> bufs = new ArrayList<>(pendingUnprotectedWrites.size());
|
||||||
|
|
@ -168,7 +168,14 @@ public final class InternalTsiFrameHandler extends ByteToMessageDecoder
|
||||||
}
|
}
|
||||||
|
|
||||||
protector.protectFlush(
|
protector.protectFlush(
|
||||||
bufs, b -> ctx.writeAndFlush(b, aggregatePromise.newPromise()), ctx.alloc());
|
bufs,
|
||||||
|
new java.util.function.Consumer<ByteBuf>() {
|
||||||
|
@Override
|
||||||
|
public void accept(ByteBuf b) {
|
||||||
|
ctx.writeAndFlush(b, aggregatePromise.newPromise());
|
||||||
|
}
|
||||||
|
},
|
||||||
|
ctx.alloc());
|
||||||
|
|
||||||
// We're done writing, start the flow of promise events.
|
// We're done writing, start the flow of promise events.
|
||||||
@SuppressWarnings("unused") // go/futurereturn-lsc
|
@SuppressWarnings("unused") // go/futurereturn-lsc
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,7 @@ import io.grpc.alts.transportsecurity.TsiFrameProtector;
|
||||||
import io.grpc.alts.transportsecurity.TsiHandshaker;
|
import io.grpc.alts.transportsecurity.TsiHandshaker;
|
||||||
import io.grpc.alts.transportsecurity.TsiHandshakerFactory;
|
import io.grpc.alts.transportsecurity.TsiHandshakerFactory;
|
||||||
import io.grpc.alts.transportsecurity.TsiPeer;
|
import io.grpc.alts.transportsecurity.TsiPeer;
|
||||||
|
import io.grpc.alts.transportsecurity.TsiPeer.Property;
|
||||||
import io.grpc.netty.GrpcHttp2ConnectionHandler;
|
import io.grpc.netty.GrpcHttp2ConnectionHandler;
|
||||||
import io.netty.buffer.ByteBuf;
|
import io.netty.buffer.ByteBuf;
|
||||||
import io.netty.buffer.ByteBufAllocator;
|
import io.netty.buffer.ByteBufAllocator;
|
||||||
|
|
@ -89,7 +90,7 @@ public class AltsProtocolNegotiatorTest {
|
||||||
private volatile InternalTsiHandshakeHandler.TsiHandshakeCompletionEvent tsiEvent;
|
private volatile InternalTsiHandshakeHandler.TsiHandshakeCompletionEvent tsiEvent;
|
||||||
private ChannelHandler handler;
|
private ChannelHandler handler;
|
||||||
|
|
||||||
private TsiPeer mockedTsiPeer = new TsiPeer(Collections.emptyList());
|
private TsiPeer mockedTsiPeer = new TsiPeer(Collections.<Property<?>>emptyList());
|
||||||
private AltsAuthContext mockedAltsContext =
|
private AltsAuthContext mockedAltsContext =
|
||||||
new AltsAuthContext(
|
new AltsAuthContext(
|
||||||
HandshakerResult.newBuilder()
|
HandshakerResult.newBuilder()
|
||||||
|
|
@ -220,10 +221,15 @@ public class AltsProtocolNegotiatorTest {
|
||||||
assertEquals(message, unprotectedData.toString(UTF_8));
|
assertEquals(message, unprotectedData.toString(UTF_8));
|
||||||
|
|
||||||
// Protect the same message at the server.
|
// Protect the same message at the server.
|
||||||
AtomicReference<ByteBuf> newlyProtectedData = new AtomicReference<>();
|
final AtomicReference<ByteBuf> newlyProtectedData = new AtomicReference<>();
|
||||||
serverProtector.protectFlush(
|
serverProtector.protectFlush(
|
||||||
Collections.singletonList(unprotectedData),
|
Collections.singletonList(unprotectedData),
|
||||||
b -> newlyProtectedData.set(b),
|
new java.util.function.Consumer<ByteBuf>() {
|
||||||
|
@Override
|
||||||
|
public void accept(ByteBuf buf) {
|
||||||
|
newlyProtectedData.set(buf);
|
||||||
|
}
|
||||||
|
},
|
||||||
channel.alloc());
|
channel.alloc());
|
||||||
|
|
||||||
// Read the protected message at the client and verify that it matches the original message.
|
// Read the protected message at the client and verify that it matches the original message.
|
||||||
|
|
@ -250,7 +256,14 @@ public class AltsProtocolNegotiatorTest {
|
||||||
TsiFrameProtector serverProtector =
|
TsiFrameProtector serverProtector =
|
||||||
serverHandshaker.createFrameProtector(serverFrameSize, channel.alloc());
|
serverHandshaker.createFrameProtector(serverFrameSize, channel.alloc());
|
||||||
serverProtector.protectFlush(
|
serverProtector.protectFlush(
|
||||||
Collections.singletonList(unprotectedData), b -> channel.writeInbound(b), channel.alloc());
|
Collections.singletonList(unprotectedData),
|
||||||
|
new java.util.function.Consumer<ByteBuf>() {
|
||||||
|
@Override
|
||||||
|
public void accept(ByteBuf buf) {
|
||||||
|
channel.writeInbound(buf);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
channel.alloc());
|
||||||
channel.flushInbound();
|
channel.flushInbound();
|
||||||
|
|
||||||
// Read the protected message at the client and verify that it matches the original message.
|
// Read the protected message at the client and verify that it matches the original message.
|
||||||
|
|
|
||||||
|
|
@ -181,7 +181,16 @@ public class InternalNettyTsiHandshakerTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
private void doHandshake() throws GeneralSecurityException {
|
private void doHandshake() throws GeneralSecurityException {
|
||||||
doHandshake(clientHandshaker, serverHandshaker, alloc, buf -> ref(buf));
|
doHandshake(
|
||||||
|
clientHandshaker,
|
||||||
|
serverHandshaker,
|
||||||
|
alloc,
|
||||||
|
new java.util.function.Function<ByteBuf, ByteBuf>() {
|
||||||
|
@Override
|
||||||
|
public ByteBuf apply(ByteBuf buf) {
|
||||||
|
return ref(buf);
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
private ByteBuf ref(ByteBuf buf) {
|
private ByteBuf ref(ByteBuf buf) {
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ import static io.grpc.alts.transportsecurity.ByteBufTestUtils.writeSlice;
|
||||||
import static org.junit.Assert.fail;
|
import static org.junit.Assert.fail;
|
||||||
|
|
||||||
import com.google.common.testing.GcFinalization;
|
import com.google.common.testing.GcFinalization;
|
||||||
|
import io.grpc.alts.transportsecurity.ByteBufTestUtils.RegisterRef;
|
||||||
import io.netty.buffer.ByteBuf;
|
import io.netty.buffer.ByteBuf;
|
||||||
import io.netty.buffer.ByteBufAllocator;
|
import io.netty.buffer.ByteBufAllocator;
|
||||||
import io.netty.util.ReferenceCounted;
|
import io.netty.util.ReferenceCounted;
|
||||||
|
|
@ -45,6 +46,16 @@ public class AltsTsiFrameProtectorTest {
|
||||||
AltsTsiFrameProtector.getHeaderTypeFieldBytes() + FakeChannelCrypter.getTagBytes();
|
AltsTsiFrameProtector.getHeaderTypeFieldBytes() + FakeChannelCrypter.getTagBytes();
|
||||||
|
|
||||||
private final List<ReferenceCounted> references = new ArrayList<ReferenceCounted>();
|
private final List<ReferenceCounted> references = new ArrayList<ReferenceCounted>();
|
||||||
|
private final RegisterRef ref =
|
||||||
|
new RegisterRef() {
|
||||||
|
@Override
|
||||||
|
public ByteBuf register(ByteBuf buf) {
|
||||||
|
if (buf != null) {
|
||||||
|
references.add(buf);
|
||||||
|
}
|
||||||
|
return buf;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
|
|
@ -68,7 +79,7 @@ public class AltsTsiFrameProtectorTest {
|
||||||
FakeChannelCrypter crypter = new FakeChannelCrypter();
|
FakeChannelCrypter crypter = new FakeChannelCrypter();
|
||||||
AltsTsiFrameProtector.Unprotector unprotector =
|
AltsTsiFrameProtector.Unprotector unprotector =
|
||||||
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
||||||
ByteBuf in = getDirectBuffer(AltsTsiFrameProtector.getHeaderBytes(), this::ref);
|
ByteBuf in = getDirectBuffer(AltsTsiFrameProtector.getHeaderBytes(), ref);
|
||||||
in.writeIntLE(-1);
|
in.writeIntLE(-1);
|
||||||
in.writeIntLE(6);
|
in.writeIntLE(6);
|
||||||
try {
|
try {
|
||||||
|
|
@ -90,7 +101,7 @@ public class AltsTsiFrameProtectorTest {
|
||||||
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
||||||
ByteBuf in =
|
ByteBuf in =
|
||||||
getDirectBuffer(
|
getDirectBuffer(
|
||||||
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
|
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref);
|
||||||
in.writeIntLE(FRAME_MIN_SIZE - 1);
|
in.writeIntLE(FRAME_MIN_SIZE - 1);
|
||||||
in.writeIntLE(6);
|
in.writeIntLE(6);
|
||||||
try {
|
try {
|
||||||
|
|
@ -112,7 +123,7 @@ public class AltsTsiFrameProtectorTest {
|
||||||
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
||||||
ByteBuf in =
|
ByteBuf in =
|
||||||
getDirectBuffer(
|
getDirectBuffer(
|
||||||
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
|
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref);
|
||||||
in.writeIntLE(
|
in.writeIntLE(
|
||||||
AltsTsiFrameProtector.getLimitMaxAllowedFrameBytes()
|
AltsTsiFrameProtector.getLimitMaxAllowedFrameBytes()
|
||||||
- AltsTsiFrameProtector.getHeaderLenFieldBytes()
|
- AltsTsiFrameProtector.getHeaderLenFieldBytes()
|
||||||
|
|
@ -137,7 +148,7 @@ public class AltsTsiFrameProtectorTest {
|
||||||
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
||||||
ByteBuf in =
|
ByteBuf in =
|
||||||
getDirectBuffer(
|
getDirectBuffer(
|
||||||
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
|
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref);
|
||||||
in.writeIntLE(FRAME_MIN_SIZE);
|
in.writeIntLE(FRAME_MIN_SIZE);
|
||||||
in.writeIntLE(5);
|
in.writeIntLE(5);
|
||||||
try {
|
try {
|
||||||
|
|
@ -159,7 +170,7 @@ public class AltsTsiFrameProtectorTest {
|
||||||
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
||||||
ByteBuf in =
|
ByteBuf in =
|
||||||
getDirectBuffer(
|
getDirectBuffer(
|
||||||
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
|
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref);
|
||||||
in.writeIntLE(FRAME_MIN_SIZE);
|
in.writeIntLE(FRAME_MIN_SIZE);
|
||||||
in.writeIntLE(6);
|
in.writeIntLE(6);
|
||||||
|
|
||||||
|
|
@ -176,7 +187,7 @@ public class AltsTsiFrameProtectorTest {
|
||||||
FakeChannelCrypter crypter = new FakeChannelCrypter();
|
FakeChannelCrypter crypter = new FakeChannelCrypter();
|
||||||
AltsTsiFrameProtector.Unprotector unprotector =
|
AltsTsiFrameProtector.Unprotector unprotector =
|
||||||
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
||||||
ByteBuf emptyBuf = getDirectBuffer(0, this::ref);
|
ByteBuf emptyBuf = getDirectBuffer(0, ref);
|
||||||
unprotector.unprotect(emptyBuf, out, alloc);
|
unprotector.unprotect(emptyBuf, out, alloc);
|
||||||
|
|
||||||
assertThat(emptyBuf.refCnt()).isEqualTo(1);
|
assertThat(emptyBuf.refCnt()).isEqualTo(1);
|
||||||
|
|
@ -193,7 +204,7 @@ public class AltsTsiFrameProtectorTest {
|
||||||
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
||||||
ByteBuf in =
|
ByteBuf in =
|
||||||
getDirectBuffer(
|
getDirectBuffer(
|
||||||
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
|
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref);
|
||||||
in.writeIntLE(
|
in.writeIntLE(
|
||||||
AltsTsiFrameProtector.getLimitMaxAllowedFrameBytes()
|
AltsTsiFrameProtector.getLimitMaxAllowedFrameBytes()
|
||||||
- AltsTsiFrameProtector.getHeaderLenFieldBytes());
|
- AltsTsiFrameProtector.getHeaderLenFieldBytes());
|
||||||
|
|
@ -214,7 +225,7 @@ public class AltsTsiFrameProtectorTest {
|
||||||
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
||||||
ByteBuf in =
|
ByteBuf in =
|
||||||
getDirectBuffer(
|
getDirectBuffer(
|
||||||
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
|
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref);
|
||||||
in.writeIntLE(FRAME_MIN_SIZE);
|
in.writeIntLE(FRAME_MIN_SIZE);
|
||||||
in.writeIntLE(6);
|
in.writeIntLE(6);
|
||||||
ByteBuf in1 = in.readSlice(AltsTsiFrameProtector.getHeaderBytes() - 1);
|
ByteBuf in1 = in.readSlice(AltsTsiFrameProtector.getHeaderBytes() - 1);
|
||||||
|
|
@ -238,7 +249,7 @@ public class AltsTsiFrameProtectorTest {
|
||||||
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
||||||
ByteBuf in =
|
ByteBuf in =
|
||||||
getDirectBuffer(
|
getDirectBuffer(
|
||||||
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
|
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), ref);
|
||||||
in.writeIntLE(FRAME_MIN_SIZE - 1);
|
in.writeIntLE(FRAME_MIN_SIZE - 1);
|
||||||
in.writeIntLE(6);
|
in.writeIntLE(6);
|
||||||
ByteBuf in1 = in.readSlice(AltsTsiFrameProtector.getHeaderBytes() - 1);
|
ByteBuf in1 = in.readSlice(AltsTsiFrameProtector.getHeaderBytes() - 1);
|
||||||
|
|
@ -267,13 +278,13 @@ public class AltsTsiFrameProtectorTest {
|
||||||
FakeChannelCrypter crypter = new FakeChannelCrypter();
|
FakeChannelCrypter crypter = new FakeChannelCrypter();
|
||||||
AltsTsiFrameProtector.Unprotector unprotector =
|
AltsTsiFrameProtector.Unprotector unprotector =
|
||||||
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
||||||
ByteBuf plain = getRandom(payloadBytes, this::ref);
|
ByteBuf plain = getRandom(payloadBytes, ref);
|
||||||
ByteBuf outFrame =
|
ByteBuf outFrame =
|
||||||
getDirectBuffer(
|
getDirectBuffer(
|
||||||
AltsTsiFrameProtector.getHeaderBytes()
|
AltsTsiFrameProtector.getHeaderBytes()
|
||||||
+ payloadBytes
|
+ payloadBytes
|
||||||
+ FakeChannelCrypter.getTagBytes(),
|
+ FakeChannelCrypter.getTagBytes(),
|
||||||
this::ref);
|
ref);
|
||||||
|
|
||||||
outFrame.writeIntLE(
|
outFrame.writeIntLE(
|
||||||
AltsTsiFrameProtector.getHeaderTypeFieldBytes()
|
AltsTsiFrameProtector.getHeaderTypeFieldBytes()
|
||||||
|
|
@ -305,12 +316,12 @@ public class AltsTsiFrameProtectorTest {
|
||||||
AltsTsiFrameProtector.Unprotector unprotector =
|
AltsTsiFrameProtector.Unprotector unprotector =
|
||||||
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
||||||
|
|
||||||
ByteBuf plain = getRandom(payloadBytes, this::ref);
|
ByteBuf plain = getRandom(payloadBytes, ref);
|
||||||
ByteBuf outFrame =
|
ByteBuf outFrame =
|
||||||
getDirectBuffer(
|
getDirectBuffer(
|
||||||
2 * (AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes())
|
2 * (AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes())
|
||||||
+ payloadBytes,
|
+ payloadBytes,
|
||||||
this::ref);
|
ref);
|
||||||
|
|
||||||
outFrame.writeIntLE(
|
outFrame.writeIntLE(
|
||||||
AltsTsiFrameProtector.getHeaderTypeFieldBytes()
|
AltsTsiFrameProtector.getHeaderTypeFieldBytes()
|
||||||
|
|
@ -353,13 +364,13 @@ public class AltsTsiFrameProtectorTest {
|
||||||
AltsTsiFrameProtector.Unprotector unprotector =
|
AltsTsiFrameProtector.Unprotector unprotector =
|
||||||
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
||||||
|
|
||||||
ByteBuf plain = getRandom(payloadBytes, this::ref);
|
ByteBuf plain = getRandom(payloadBytes, ref);
|
||||||
ByteBuf protectedBuf =
|
ByteBuf protectedBuf =
|
||||||
getDirectBuffer(
|
getDirectBuffer(
|
||||||
2 * (AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes())
|
2 * (AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes())
|
||||||
+ payloadBytes
|
+ payloadBytes
|
||||||
+ AltsTsiFrameProtector.getHeaderBytes(),
|
+ AltsTsiFrameProtector.getHeaderBytes(),
|
||||||
this::ref);
|
ref);
|
||||||
|
|
||||||
protectedBuf.writeIntLE(
|
protectedBuf.writeIntLE(
|
||||||
AltsTsiFrameProtector.getHeaderTypeFieldBytes()
|
AltsTsiFrameProtector.getHeaderTypeFieldBytes()
|
||||||
|
|
@ -421,13 +432,13 @@ public class AltsTsiFrameProtectorTest {
|
||||||
AltsTsiFrameProtector.Unprotector unprotector =
|
AltsTsiFrameProtector.Unprotector unprotector =
|
||||||
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
||||||
|
|
||||||
ByteBuf plain = getRandom(payloadBytes, this::ref);
|
ByteBuf plain = getRandom(payloadBytes, ref);
|
||||||
ByteBuf protectedBuf =
|
ByteBuf protectedBuf =
|
||||||
getDirectBuffer(
|
getDirectBuffer(
|
||||||
2 * (AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes())
|
2 * (AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes())
|
||||||
+ payloadBytes
|
+ payloadBytes
|
||||||
+ AltsTsiFrameProtector.getHeaderBytes(),
|
+ AltsTsiFrameProtector.getHeaderBytes(),
|
||||||
this::ref);
|
ref);
|
||||||
|
|
||||||
protectedBuf.writeIntLE(
|
protectedBuf.writeIntLE(
|
||||||
AltsTsiFrameProtector.getHeaderTypeFieldBytes()
|
AltsTsiFrameProtector.getHeaderTypeFieldBytes()
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ import com.google.common.testing.GcFinalization;
|
||||||
import io.grpc.alts.Handshaker.HandshakeProtocol;
|
import io.grpc.alts.Handshaker.HandshakeProtocol;
|
||||||
import io.grpc.alts.Handshaker.HandshakerReq;
|
import io.grpc.alts.Handshaker.HandshakerReq;
|
||||||
import io.grpc.alts.Handshaker.HandshakerResp;
|
import io.grpc.alts.Handshaker.HandshakerResp;
|
||||||
|
import io.grpc.alts.transportsecurity.ByteBufTestUtils.RegisterRef;
|
||||||
import io.grpc.alts.transportsecurity.TsiTest.Handshakers;
|
import io.grpc.alts.transportsecurity.TsiTest.Handshakers;
|
||||||
import io.netty.buffer.ByteBuf;
|
import io.netty.buffer.ByteBuf;
|
||||||
import io.netty.util.ReferenceCounted;
|
import io.netty.util.ReferenceCounted;
|
||||||
|
|
@ -45,6 +46,16 @@ public class AltsTsiTest {
|
||||||
private final List<ReferenceCounted> references = new ArrayList<>();
|
private final List<ReferenceCounted> references = new ArrayList<>();
|
||||||
private AltsHandshakerClient client;
|
private AltsHandshakerClient client;
|
||||||
private AltsHandshakerClient server;
|
private AltsHandshakerClient server;
|
||||||
|
private final RegisterRef ref =
|
||||||
|
new RegisterRef() {
|
||||||
|
@Override
|
||||||
|
public ByteBuf register(ByteBuf buf) {
|
||||||
|
if (buf != null) {
|
||||||
|
references.add(buf);
|
||||||
|
}
|
||||||
|
return buf;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() throws Exception {
|
public void setUp() throws Exception {
|
||||||
|
|
@ -101,47 +112,47 @@ public class AltsTsiTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void pingPong() throws GeneralSecurityException {
|
public void pingPong() throws GeneralSecurityException {
|
||||||
TsiTest.pingPongTest(newHandshakers(), this::ref);
|
TsiTest.pingPongTest(newHandshakers(), ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void pingPongExactFrameSize() throws GeneralSecurityException {
|
public void pingPongExactFrameSize() throws GeneralSecurityException {
|
||||||
TsiTest.pingPongExactFrameSizeTest(newHandshakers(), this::ref);
|
TsiTest.pingPongExactFrameSizeTest(newHandshakers(), ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void pingPongSmallBuffer() throws GeneralSecurityException {
|
public void pingPongSmallBuffer() throws GeneralSecurityException {
|
||||||
TsiTest.pingPongSmallBufferTest(newHandshakers(), this::ref);
|
TsiTest.pingPongSmallBufferTest(newHandshakers(), ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void pingPongSmallFrame() throws GeneralSecurityException {
|
public void pingPongSmallFrame() throws GeneralSecurityException {
|
||||||
TsiTest.pingPongSmallFrameTest(OVERHEAD, newHandshakers(), this::ref);
|
TsiTest.pingPongSmallFrameTest(OVERHEAD, newHandshakers(), ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void pingPongSmallFrameSmallBuffer() throws GeneralSecurityException {
|
public void pingPongSmallFrameSmallBuffer() throws GeneralSecurityException {
|
||||||
TsiTest.pingPongSmallFrameSmallBufferTest(OVERHEAD, newHandshakers(), this::ref);
|
TsiTest.pingPongSmallFrameSmallBufferTest(OVERHEAD, newHandshakers(), ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void corruptedCounter() throws GeneralSecurityException {
|
public void corruptedCounter() throws GeneralSecurityException {
|
||||||
TsiTest.corruptedCounterTest(newHandshakers(), this::ref);
|
TsiTest.corruptedCounterTest(newHandshakers(), ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void corruptedCiphertext() throws GeneralSecurityException {
|
public void corruptedCiphertext() throws GeneralSecurityException {
|
||||||
TsiTest.corruptedCiphertextTest(newHandshakers(), this::ref);
|
TsiTest.corruptedCiphertextTest(newHandshakers(), ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void corruptedTag() throws GeneralSecurityException {
|
public void corruptedTag() throws GeneralSecurityException {
|
||||||
TsiTest.corruptedTagTest(newHandshakers(), this::ref);
|
TsiTest.corruptedTagTest(newHandshakers(), ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void reflectedCiphertext() throws GeneralSecurityException {
|
public void reflectedCiphertext() throws GeneralSecurityException {
|
||||||
TsiTest.reflectedCiphertextTest(newHandshakers(), this::ref);
|
TsiTest.reflectedCiphertextTest(newHandshakers(), ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class MockAltsHandshakerStub extends AltsHandshakerStub {
|
private static class MockAltsHandshakerStub extends AltsHandshakerStub {
|
||||||
|
|
@ -184,11 +195,4 @@ public class AltsTsiTest {
|
||||||
@Override
|
@Override
|
||||||
public void close() {}
|
public void close() {}
|
||||||
}
|
}
|
||||||
|
|
||||||
private ByteBuf ref(ByteBuf buf) {
|
|
||||||
if (buf != null) {
|
|
||||||
references.add(buf);
|
|
||||||
}
|
|
||||||
return buf;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ import static io.grpc.alts.transportsecurity.ByteBufTestUtils.getRandom;
|
||||||
import static java.nio.charset.StandardCharsets.UTF_8;
|
import static java.nio.charset.StandardCharsets.UTF_8;
|
||||||
import static org.junit.Assert.fail;
|
import static org.junit.Assert.fail;
|
||||||
|
|
||||||
|
import io.grpc.alts.transportsecurity.ByteBufTestUtils.RegisterRef;
|
||||||
import io.netty.buffer.ByteBuf;
|
import io.netty.buffer.ByteBuf;
|
||||||
import io.netty.buffer.Unpooled;
|
import io.netty.buffer.Unpooled;
|
||||||
import io.netty.util.ReferenceCounted;
|
import io.netty.util.ReferenceCounted;
|
||||||
|
|
@ -39,6 +40,16 @@ public abstract class ChannelCrypterNettyTestBase {
|
||||||
protected final List<ReferenceCounted> references = new ArrayList<>();
|
protected final List<ReferenceCounted> references = new ArrayList<>();
|
||||||
public ChannelCrypterNetty client;
|
public ChannelCrypterNetty client;
|
||||||
public ChannelCrypterNetty server;
|
public ChannelCrypterNetty server;
|
||||||
|
private final RegisterRef ref =
|
||||||
|
new RegisterRef() {
|
||||||
|
@Override
|
||||||
|
public ByteBuf register(ByteBuf buf) {
|
||||||
|
if (buf != null) {
|
||||||
|
references.add(buf);
|
||||||
|
}
|
||||||
|
return buf;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
static final class FrameEncrypt {
|
static final class FrameEncrypt {
|
||||||
List<ByteBuf> plain;
|
List<ByteBuf> plain;
|
||||||
|
|
@ -54,10 +65,10 @@ public abstract class ChannelCrypterNettyTestBase {
|
||||||
FrameEncrypt createFrameEncrypt(String message) {
|
FrameEncrypt createFrameEncrypt(String message) {
|
||||||
byte[] messageBytes = message.getBytes(UTF_8);
|
byte[] messageBytes = message.getBytes(UTF_8);
|
||||||
FrameEncrypt frame = new FrameEncrypt();
|
FrameEncrypt frame = new FrameEncrypt();
|
||||||
ByteBuf plain = getDirectBuffer(messageBytes.length, this::ref);
|
ByteBuf plain = getDirectBuffer(messageBytes.length, ref);
|
||||||
plain.writeBytes(messageBytes);
|
plain.writeBytes(messageBytes);
|
||||||
frame.plain = Collections.singletonList(plain);
|
frame.plain = Collections.singletonList(plain);
|
||||||
frame.out = getDirectBuffer(messageBytes.length + client.getSuffixLength(), this::ref);
|
frame.out = getDirectBuffer(messageBytes.length + client.getSuffixLength(), ref);
|
||||||
return frame;
|
return frame;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -68,7 +79,7 @@ public abstract class ChannelCrypterNettyTestBase {
|
||||||
frameDecrypt.ciphertext =
|
frameDecrypt.ciphertext =
|
||||||
Collections.singletonList(out.slice(out.readerIndex(), out.readableBytes() - tagLen));
|
Collections.singletonList(out.slice(out.readerIndex(), out.readableBytes() - tagLen));
|
||||||
frameDecrypt.tag = out.slice(out.readerIndex() + out.readableBytes() - tagLen, tagLen);
|
frameDecrypt.tag = out.slice(out.readerIndex() + out.readableBytes() - tagLen, tagLen);
|
||||||
frameDecrypt.out = getDirectBuffer(out.readableBytes(), this::ref);
|
frameDecrypt.out = getDirectBuffer(out.readableBytes(), ref);
|
||||||
return frameDecrypt;
|
return frameDecrypt;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -87,9 +98,9 @@ public abstract class ChannelCrypterNettyTestBase {
|
||||||
@Test
|
@Test
|
||||||
public void encryptDecryptLarge() throws GeneralSecurityException {
|
public void encryptDecryptLarge() throws GeneralSecurityException {
|
||||||
FrameEncrypt frameEncrypt = new FrameEncrypt();
|
FrameEncrypt frameEncrypt = new FrameEncrypt();
|
||||||
ByteBuf plain = getRandom(17 * 1024, this::ref);
|
ByteBuf plain = getRandom(17 * 1024, ref);
|
||||||
frameEncrypt.plain = Collections.singletonList(plain);
|
frameEncrypt.plain = Collections.singletonList(plain);
|
||||||
frameEncrypt.out = getDirectBuffer(plain.readableBytes() + client.getSuffixLength(), this::ref);
|
frameEncrypt.out = getDirectBuffer(plain.readableBytes() + client.getSuffixLength(), ref);
|
||||||
|
|
||||||
client.encrypt(frameEncrypt.out, frameEncrypt.plain);
|
client.encrypt(frameEncrypt.out, frameEncrypt.plain);
|
||||||
FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt);
|
FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt);
|
||||||
|
|
@ -120,13 +131,13 @@ public abstract class ChannelCrypterNettyTestBase {
|
||||||
int lastLen = 2;
|
int lastLen = 2;
|
||||||
byte[] messageBytes = message.getBytes(UTF_8);
|
byte[] messageBytes = message.getBytes(UTF_8);
|
||||||
FrameEncrypt frameEncrypt = new FrameEncrypt();
|
FrameEncrypt frameEncrypt = new FrameEncrypt();
|
||||||
ByteBuf plain1 = getDirectBuffer(messageBytes.length - lastLen, this::ref);
|
ByteBuf plain1 = getDirectBuffer(messageBytes.length - lastLen, ref);
|
||||||
ByteBuf plain2 = getDirectBuffer(lastLen, this::ref);
|
ByteBuf plain2 = getDirectBuffer(lastLen, ref);
|
||||||
plain1.writeBytes(messageBytes, 0, messageBytes.length - lastLen);
|
plain1.writeBytes(messageBytes, 0, messageBytes.length - lastLen);
|
||||||
plain2.writeBytes(messageBytes, messageBytes.length - lastLen, lastLen);
|
plain2.writeBytes(messageBytes, messageBytes.length - lastLen, lastLen);
|
||||||
ByteBuf plain = Unpooled.wrappedBuffer(plain1, plain2);
|
ByteBuf plain = Unpooled.wrappedBuffer(plain1, plain2);
|
||||||
frameEncrypt.plain = Collections.singletonList(plain);
|
frameEncrypt.plain = Collections.singletonList(plain);
|
||||||
frameEncrypt.out = getDirectBuffer(messageBytes.length + client.getSuffixLength(), this::ref);
|
frameEncrypt.out = getDirectBuffer(messageBytes.length + client.getSuffixLength(), ref);
|
||||||
|
|
||||||
client.encrypt(frameEncrypt.out, frameEncrypt.plain);
|
client.encrypt(frameEncrypt.out, frameEncrypt.plain);
|
||||||
|
|
||||||
|
|
@ -134,14 +145,14 @@ public abstract class ChannelCrypterNettyTestBase {
|
||||||
FrameDecrypt frameDecrypt = new FrameDecrypt();
|
FrameDecrypt frameDecrypt = new FrameDecrypt();
|
||||||
ByteBuf out = frameEncrypt.out;
|
ByteBuf out = frameEncrypt.out;
|
||||||
int outLen = out.readableBytes();
|
int outLen = out.readableBytes();
|
||||||
ByteBuf cipher1 = getDirectBuffer(outLen - lastLen - tagLen, this::ref);
|
ByteBuf cipher1 = getDirectBuffer(outLen - lastLen - tagLen, ref);
|
||||||
ByteBuf cipher2 = getDirectBuffer(lastLen, this::ref);
|
ByteBuf cipher2 = getDirectBuffer(lastLen, ref);
|
||||||
cipher1.writeBytes(out, 0, outLen - lastLen - tagLen);
|
cipher1.writeBytes(out, 0, outLen - lastLen - tagLen);
|
||||||
cipher2.writeBytes(out, outLen - tagLen - lastLen, lastLen);
|
cipher2.writeBytes(out, outLen - tagLen - lastLen, lastLen);
|
||||||
ByteBuf cipher = Unpooled.wrappedBuffer(cipher1, cipher2);
|
ByteBuf cipher = Unpooled.wrappedBuffer(cipher1, cipher2);
|
||||||
frameDecrypt.ciphertext = Collections.singletonList(cipher);
|
frameDecrypt.ciphertext = Collections.singletonList(cipher);
|
||||||
frameDecrypt.tag = out.slice(out.readerIndex() + out.readableBytes() - tagLen, tagLen);
|
frameDecrypt.tag = out.slice(out.readerIndex() + out.readableBytes() - tagLen, tagLen);
|
||||||
frameDecrypt.out = getDirectBuffer(out.readableBytes(), this::ref);
|
frameDecrypt.out = getDirectBuffer(out.readableBytes(), ref);
|
||||||
|
|
||||||
server.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext);
|
server.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext);
|
||||||
assertThat(frameEncrypt.plain.get(0).slice(0, frameDecrypt.out.readableBytes()))
|
assertThat(frameEncrypt.plain.get(0).slice(0, frameDecrypt.out.readableBytes()))
|
||||||
|
|
@ -212,11 +223,4 @@ public abstract class ChannelCrypterNettyTestBase {
|
||||||
assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_MESSAGE);
|
assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_MESSAGE);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private ByteBuf ref(ByteBuf buf) {
|
|
||||||
if (buf != null) {
|
|
||||||
references.add(buf);
|
|
||||||
}
|
|
||||||
return buf;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -52,9 +52,10 @@ public final class FakeChannelCrypter implements ChannelCrypterNetty {
|
||||||
for (ByteBuf buf : ciphertext) {
|
for (ByteBuf buf : ciphertext) {
|
||||||
out.writeBytes(buf);
|
out.writeBytes(buf);
|
||||||
}
|
}
|
||||||
boolean tagValid = tag.forEachByte((byte value) -> value == TAG_BYTE) == -1;
|
while (tag.isReadable()) {
|
||||||
if (!tagValid) {
|
if (tag.readByte() != TAG_BYTE) {
|
||||||
throw new AEADBadTagException("Tag mismatch!");
|
throw new AEADBadTagException("Tag mismatch!");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ package io.grpc.alts.transportsecurity;
|
||||||
import static java.nio.charset.StandardCharsets.UTF_8;
|
import static java.nio.charset.StandardCharsets.UTF_8;
|
||||||
|
|
||||||
import com.google.common.base.Preconditions;
|
import com.google.common.base.Preconditions;
|
||||||
|
import io.grpc.alts.transportsecurity.TsiPeer.Property;
|
||||||
import io.netty.buffer.ByteBufAllocator;
|
import io.netty.buffer.ByteBufAllocator;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.security.GeneralSecurityException;
|
import java.security.GeneralSecurityException;
|
||||||
|
|
@ -203,7 +204,7 @@ public class FakeTsiHandshaker implements TsiHandshaker {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public TsiPeer extractPeer() {
|
public TsiPeer extractPeer() {
|
||||||
return new TsiPeer(Collections.emptyList());
|
return new TsiPeer(Collections.<Property<?>>emptyList());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertFalse;
|
import static org.junit.Assert.assertFalse;
|
||||||
|
|
||||||
import com.google.common.testing.GcFinalization;
|
import com.google.common.testing.GcFinalization;
|
||||||
|
import io.grpc.alts.transportsecurity.ByteBufTestUtils.RegisterRef;
|
||||||
import io.grpc.alts.transportsecurity.TsiTest.Handshakers;
|
import io.grpc.alts.transportsecurity.TsiTest.Handshakers;
|
||||||
import io.netty.buffer.ByteBuf;
|
import io.netty.buffer.ByteBuf;
|
||||||
import io.netty.util.ReferenceCounted;
|
import io.netty.util.ReferenceCounted;
|
||||||
|
|
@ -44,6 +45,16 @@ public class FakeTsiTest {
|
||||||
FakeChannelCrypter.getTagBytes() + AltsTsiFrameProtector.getHeaderBytes();
|
FakeChannelCrypter.getTagBytes() + AltsTsiFrameProtector.getHeaderBytes();
|
||||||
|
|
||||||
private final List<ReferenceCounted> references = new ArrayList<>();
|
private final List<ReferenceCounted> references = new ArrayList<>();
|
||||||
|
private final RegisterRef ref =
|
||||||
|
new RegisterRef() {
|
||||||
|
@Override
|
||||||
|
public ByteBuf register(ByteBuf buf) {
|
||||||
|
if (buf != null) {
|
||||||
|
references.add(buf);
|
||||||
|
}
|
||||||
|
return buf;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
private static Handshakers newHandshakers() {
|
private static Handshakers newHandshakers() {
|
||||||
TsiHandshaker clientHandshaker = FakeTsiHandshaker.newFakeHandshakerClient();
|
TsiHandshaker clientHandshaker = FakeTsiHandshaker.newFakeHandshakerClient();
|
||||||
|
|
@ -157,53 +168,46 @@ public class FakeTsiTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void pingPong() throws GeneralSecurityException {
|
public void pingPong() throws GeneralSecurityException {
|
||||||
TsiTest.pingPongTest(newHandshakers(), this::ref);
|
TsiTest.pingPongTest(newHandshakers(), ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void pingPongExactFrameSize() throws GeneralSecurityException {
|
public void pingPongExactFrameSize() throws GeneralSecurityException {
|
||||||
TsiTest.pingPongExactFrameSizeTest(newHandshakers(), this::ref);
|
TsiTest.pingPongExactFrameSizeTest(newHandshakers(), ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void pingPongSmallBuffer() throws GeneralSecurityException {
|
public void pingPongSmallBuffer() throws GeneralSecurityException {
|
||||||
TsiTest.pingPongSmallBufferTest(newHandshakers(), this::ref);
|
TsiTest.pingPongSmallBufferTest(newHandshakers(), ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void pingPongSmallFrame() throws GeneralSecurityException {
|
public void pingPongSmallFrame() throws GeneralSecurityException {
|
||||||
TsiTest.pingPongSmallFrameTest(OVERHEAD, newHandshakers(), this::ref);
|
TsiTest.pingPongSmallFrameTest(OVERHEAD, newHandshakers(), ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void pingPongSmallFrameSmallBuffer() throws GeneralSecurityException {
|
public void pingPongSmallFrameSmallBuffer() throws GeneralSecurityException {
|
||||||
TsiTest.pingPongSmallFrameSmallBufferTest(OVERHEAD, newHandshakers(), this::ref);
|
TsiTest.pingPongSmallFrameSmallBufferTest(OVERHEAD, newHandshakers(), ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void corruptedCounter() throws GeneralSecurityException {
|
public void corruptedCounter() throws GeneralSecurityException {
|
||||||
TsiTest.corruptedCounterTest(newHandshakers(), this::ref);
|
TsiTest.corruptedCounterTest(newHandshakers(), ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void corruptedCiphertext() throws GeneralSecurityException {
|
public void corruptedCiphertext() throws GeneralSecurityException {
|
||||||
TsiTest.corruptedCiphertextTest(newHandshakers(), this::ref);
|
TsiTest.corruptedCiphertextTest(newHandshakers(), ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void corruptedTag() throws GeneralSecurityException {
|
public void corruptedTag() throws GeneralSecurityException {
|
||||||
TsiTest.corruptedTagTest(newHandshakers(), this::ref);
|
TsiTest.corruptedTagTest(newHandshakers(), ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void reflectedCiphertext() throws GeneralSecurityException {
|
public void reflectedCiphertext() throws GeneralSecurityException {
|
||||||
TsiTest.reflectedCiphertextTest(newHandshakers(), this::ref);
|
TsiTest.reflectedCiphertextTest(newHandshakers(), ref);
|
||||||
}
|
|
||||||
|
|
||||||
private ByteBuf ref(ByteBuf buf) {
|
|
||||||
if (buf != null) {
|
|
||||||
references.add(buf);
|
|
||||||
}
|
|
||||||
return buf;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -123,10 +123,18 @@ public final class TsiTest {
|
||||||
throws GeneralSecurityException {
|
throws GeneralSecurityException {
|
||||||
|
|
||||||
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
|
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
|
||||||
List<ByteBuf> protectOut = new ArrayList<>();
|
final List<ByteBuf> protectOut = new ArrayList<>();
|
||||||
List<Object> unprotectOut = new ArrayList<>();
|
List<Object> unprotectOut = new ArrayList<>();
|
||||||
|
|
||||||
sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc);
|
sender.protectFlush(
|
||||||
|
Collections.singletonList(plaintextBuffer),
|
||||||
|
new java.util.function.Consumer<ByteBuf>() {
|
||||||
|
@Override
|
||||||
|
public void accept(ByteBuf buf) {
|
||||||
|
protectOut.add(buf);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
alloc);
|
||||||
assertThat(protectOut.size()).isEqualTo(1);
|
assertThat(protectOut.size()).isEqualTo(1);
|
||||||
|
|
||||||
ByteBuf protect = ref.register(protectOut.get(0));
|
ByteBuf protect = ref.register(protectOut.get(0));
|
||||||
|
|
@ -249,10 +257,18 @@ public final class TsiTest {
|
||||||
|
|
||||||
String message = "hello world";
|
String message = "hello world";
|
||||||
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
|
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
|
||||||
List<ByteBuf> protectOut = new ArrayList<>();
|
final List<ByteBuf> protectOut = new ArrayList<>();
|
||||||
List<Object> unprotectOut = new ArrayList<>();
|
List<Object> unprotectOut = new ArrayList<>();
|
||||||
|
|
||||||
sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc);
|
sender.protectFlush(
|
||||||
|
Collections.singletonList(plaintextBuffer),
|
||||||
|
new java.util.function.Consumer<ByteBuf>() {
|
||||||
|
@Override
|
||||||
|
public void accept(ByteBuf buf) {
|
||||||
|
protectOut.add(buf);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
alloc);
|
||||||
assertThat(protectOut.size()).isEqualTo(1);
|
assertThat(protectOut.size()).isEqualTo(1);
|
||||||
|
|
||||||
ByteBuf protect = ref.register(protectOut.get(0));
|
ByteBuf protect = ref.register(protectOut.get(0));
|
||||||
|
|
@ -282,10 +298,18 @@ public final class TsiTest {
|
||||||
|
|
||||||
String message = "hello world";
|
String message = "hello world";
|
||||||
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
|
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
|
||||||
List<ByteBuf> protectOut = new ArrayList<>();
|
final List<ByteBuf> protectOut = new ArrayList<>();
|
||||||
List<Object> unprotectOut = new ArrayList<>();
|
List<Object> unprotectOut = new ArrayList<>();
|
||||||
|
|
||||||
sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc);
|
sender.protectFlush(
|
||||||
|
Collections.singletonList(plaintextBuffer),
|
||||||
|
new java.util.function.Consumer<ByteBuf>() {
|
||||||
|
@Override
|
||||||
|
public void accept(ByteBuf buf) {
|
||||||
|
protectOut.add(buf);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
alloc);
|
||||||
assertThat(protectOut.size()).isEqualTo(1);
|
assertThat(protectOut.size()).isEqualTo(1);
|
||||||
|
|
||||||
ByteBuf protect = ref.register(protectOut.get(0));
|
ByteBuf protect = ref.register(protectOut.get(0));
|
||||||
|
|
@ -313,10 +337,18 @@ public final class TsiTest {
|
||||||
|
|
||||||
String message = "hello world";
|
String message = "hello world";
|
||||||
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
|
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
|
||||||
List<ByteBuf> protectOut = new ArrayList<>();
|
final List<ByteBuf> protectOut = new ArrayList<>();
|
||||||
List<Object> unprotectOut = new ArrayList<>();
|
List<Object> unprotectOut = new ArrayList<>();
|
||||||
|
|
||||||
sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc);
|
sender.protectFlush(
|
||||||
|
Collections.singletonList(plaintextBuffer),
|
||||||
|
new java.util.function.Consumer<ByteBuf>() {
|
||||||
|
@Override
|
||||||
|
public void accept(ByteBuf buf) {
|
||||||
|
protectOut.add(buf);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
alloc);
|
||||||
assertThat(protectOut.size()).isEqualTo(1);
|
assertThat(protectOut.size()).isEqualTo(1);
|
||||||
|
|
||||||
ByteBuf protect = ref.register(protectOut.get(0));
|
ByteBuf protect = ref.register(protectOut.get(0));
|
||||||
|
|
@ -344,10 +376,18 @@ public final class TsiTest {
|
||||||
|
|
||||||
String message = "hello world";
|
String message = "hello world";
|
||||||
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
|
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
|
||||||
List<ByteBuf> protectOut = new ArrayList<>();
|
final List<ByteBuf> protectOut = new ArrayList<>();
|
||||||
List<Object> unprotectOut = new ArrayList<>();
|
List<Object> unprotectOut = new ArrayList<>();
|
||||||
|
|
||||||
sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc);
|
sender.protectFlush(
|
||||||
|
Collections.singletonList(plaintextBuffer),
|
||||||
|
new java.util.function.Consumer<ByteBuf>() {
|
||||||
|
@Override
|
||||||
|
public void accept(ByteBuf buf) {
|
||||||
|
protectOut.add(buf);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
alloc);
|
||||||
assertThat(protectOut.size()).isEqualTo(1);
|
assertThat(protectOut.size()).isEqualTo(1);
|
||||||
|
|
||||||
ByteBuf protect = ref.register(protectOut.get(0));
|
ByteBuf protect = ref.register(protectOut.get(0));
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,8 @@ buildscript {
|
||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
compile project(':grpc-auth'),
|
compile project(':grpc-alts'),
|
||||||
|
project(':grpc-auth'),
|
||||||
project(':grpc-core'),
|
project(':grpc-core'),
|
||||||
project(':grpc-netty'),
|
project(':grpc-netty'),
|
||||||
project(':grpc-okhttp'),
|
project(':grpc-okhttp'),
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ package io.grpc.testing.integration;
|
||||||
import com.google.common.annotations.VisibleForTesting;
|
import com.google.common.annotations.VisibleForTesting;
|
||||||
import com.google.common.io.Files;
|
import com.google.common.io.Files;
|
||||||
import io.grpc.ManagedChannel;
|
import io.grpc.ManagedChannel;
|
||||||
|
import io.grpc.alts.AltsChannelBuilder;
|
||||||
import io.grpc.internal.AbstractManagedChannelImplBuilder;
|
import io.grpc.internal.AbstractManagedChannelImplBuilder;
|
||||||
import io.grpc.internal.GrpcUtil;
|
import io.grpc.internal.GrpcUtil;
|
||||||
import io.grpc.internal.testing.TestUtils;
|
import io.grpc.internal.testing.TestUtils;
|
||||||
|
|
@ -76,6 +77,7 @@ public class TestServiceClient {
|
||||||
private int serverPort = 8080;
|
private int serverPort = 8080;
|
||||||
private String testCase = "empty_unary";
|
private String testCase = "empty_unary";
|
||||||
private boolean useTls = true;
|
private boolean useTls = true;
|
||||||
|
private boolean useAlts = false;
|
||||||
private boolean useTestCa;
|
private boolean useTestCa;
|
||||||
private boolean useOkHttp;
|
private boolean useOkHttp;
|
||||||
private String defaultServiceAccount;
|
private String defaultServiceAccount;
|
||||||
|
|
@ -116,6 +118,8 @@ public class TestServiceClient {
|
||||||
testCase = value;
|
testCase = value;
|
||||||
} else if ("use_tls".equals(key)) {
|
} else if ("use_tls".equals(key)) {
|
||||||
useTls = Boolean.parseBoolean(value);
|
useTls = Boolean.parseBoolean(value);
|
||||||
|
} else if ("use_alts".equals(key)) {
|
||||||
|
useAlts = Boolean.parseBoolean(value);
|
||||||
} else if ("use_test_ca".equals(key)) {
|
} else if ("use_test_ca".equals(key)) {
|
||||||
useTestCa = Boolean.parseBoolean(value);
|
useTestCa = Boolean.parseBoolean(value);
|
||||||
} else if ("use_okhttp".equals(key)) {
|
} else if ("use_okhttp".equals(key)) {
|
||||||
|
|
@ -140,6 +144,9 @@ public class TestServiceClient {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (useAlts) {
|
||||||
|
useTls = false;
|
||||||
|
}
|
||||||
if (usage) {
|
if (usage) {
|
||||||
TestServiceClient c = new TestServiceClient();
|
TestServiceClient c = new TestServiceClient();
|
||||||
System.out.println(
|
System.out.println(
|
||||||
|
|
@ -153,6 +160,8 @@ public class TestServiceClient {
|
||||||
+ "\n Valid options:"
|
+ "\n Valid options:"
|
||||||
+ validTestCasesHelpText()
|
+ validTestCasesHelpText()
|
||||||
+ "\n --use_tls=true|false Whether to use TLS. Default " + c.useTls
|
+ "\n --use_tls=true|false Whether to use TLS. Default " + c.useTls
|
||||||
|
+ "\n --use_alts=true|false Whether to use ALTS. Enable ALTS will disable TLS."
|
||||||
|
+ "\n Default " + c.useTls
|
||||||
+ "\n --use_test_ca=true|false Whether to trust our fake CA. Requires --use_tls=true "
|
+ "\n --use_test_ca=true|false Whether to trust our fake CA. Requires --use_tls=true "
|
||||||
+ "\n to have effect. Default " + c.useTestCa
|
+ "\n to have effect. Default " + c.useTestCa
|
||||||
+ "\n --use_okhttp=true|false Whether to use OkHttp instead of Netty. Default "
|
+ "\n --use_okhttp=true|false Whether to use OkHttp instead of Netty. Default "
|
||||||
|
|
@ -317,6 +326,9 @@ public class TestServiceClient {
|
||||||
private class Tester extends AbstractInteropTest {
|
private class Tester extends AbstractInteropTest {
|
||||||
@Override
|
@Override
|
||||||
protected ManagedChannel createChannel() {
|
protected ManagedChannel createChannel() {
|
||||||
|
if (useAlts) {
|
||||||
|
return AltsChannelBuilder.forAddress(serverHost, serverPort).build();
|
||||||
|
}
|
||||||
AbstractManagedChannelImplBuilder<?> builder;
|
AbstractManagedChannelImplBuilder<?> builder;
|
||||||
if (!useOkHttp) {
|
if (!useOkHttp) {
|
||||||
SslContext sslContext = null;
|
SslContext sslContext = null;
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ import com.google.common.annotations.VisibleForTesting;
|
||||||
import com.google.common.util.concurrent.MoreExecutors;
|
import com.google.common.util.concurrent.MoreExecutors;
|
||||||
import io.grpc.Server;
|
import io.grpc.Server;
|
||||||
import io.grpc.ServerInterceptors;
|
import io.grpc.ServerInterceptors;
|
||||||
|
import io.grpc.alts.AltsServerBuilder;
|
||||||
import io.grpc.internal.testing.TestUtils;
|
import io.grpc.internal.testing.TestUtils;
|
||||||
import io.grpc.netty.GrpcSslContexts;
|
import io.grpc.netty.GrpcSslContexts;
|
||||||
import io.grpc.netty.NettyServerBuilder;
|
import io.grpc.netty.NettyServerBuilder;
|
||||||
|
|
@ -28,34 +29,32 @@ import java.util.concurrent.Executors;
|
||||||
import java.util.concurrent.ScheduledExecutorService;
|
import java.util.concurrent.ScheduledExecutorService;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
/**
|
/** Server that manages startup/shutdown of a single {@code TestService}. */
|
||||||
* Server that manages startup/shutdown of a single {@code TestService}.
|
|
||||||
*/
|
|
||||||
public class TestServiceServer {
|
public class TestServiceServer {
|
||||||
/**
|
/** The main application allowing this server to be launched from the command line. */
|
||||||
* The main application allowing this server to be launched from the command line.
|
|
||||||
*/
|
|
||||||
public static void main(String[] args) throws Exception {
|
public static void main(String[] args) throws Exception {
|
||||||
final TestServiceServer server = new TestServiceServer();
|
final TestServiceServer server = new TestServiceServer();
|
||||||
server.parseArgs(args);
|
server.parseArgs(args);
|
||||||
if (server.useTls) {
|
if (server.useTls) {
|
||||||
System.out.println(
|
System.out.println(
|
||||||
"\nUsing fake CA for TLS certificate. Test clients should expect host\n"
|
"\nUsing fake CA for TLS certificate. Test clients should expect host\n"
|
||||||
+ "*.test.google.fr and our test CA. For the Java test client binary, use:\n"
|
+ "*.test.google.fr and our test CA. For the Java test client binary, use:\n"
|
||||||
+ "--server_host_override=foo.test.google.fr --use_test_ca=true\n");
|
+ "--server_host_override=foo.test.google.fr --use_test_ca=true\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
Runtime.getRuntime().addShutdownHook(new Thread() {
|
Runtime.getRuntime()
|
||||||
@Override
|
.addShutdownHook(
|
||||||
public void run() {
|
new Thread() {
|
||||||
try {
|
@Override
|
||||||
System.out.println("Shutting down");
|
public void run() {
|
||||||
server.stop();
|
try {
|
||||||
} catch (Exception e) {
|
System.out.println("Shutting down");
|
||||||
e.printStackTrace();
|
server.stop();
|
||||||
}
|
} catch (Exception e) {
|
||||||
}
|
e.printStackTrace();
|
||||||
});
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
server.start();
|
server.start();
|
||||||
System.out.println("Server started on port " + server.port);
|
System.out.println("Server started on port " + server.port);
|
||||||
server.blockUntilShutdown();
|
server.blockUntilShutdown();
|
||||||
|
|
@ -63,6 +62,7 @@ public class TestServiceServer {
|
||||||
|
|
||||||
private int port = 8080;
|
private int port = 8080;
|
||||||
private boolean useTls = true;
|
private boolean useTls = true;
|
||||||
|
private boolean useAlts = false;
|
||||||
|
|
||||||
private ScheduledExecutorService executor;
|
private ScheduledExecutorService executor;
|
||||||
private Server server;
|
private Server server;
|
||||||
|
|
@ -92,6 +92,8 @@ public class TestServiceServer {
|
||||||
port = Integer.parseInt(value);
|
port = Integer.parseInt(value);
|
||||||
} else if ("use_tls".equals(key)) {
|
} else if ("use_tls".equals(key)) {
|
||||||
useTls = Boolean.parseBoolean(value);
|
useTls = Boolean.parseBoolean(value);
|
||||||
|
} else if ("use_alts".equals(key)) {
|
||||||
|
useAlts = Boolean.parseBoolean(value);
|
||||||
} else if ("grpc_version".equals(key)) {
|
} else if ("grpc_version".equals(key)) {
|
||||||
if (!"2".equals(value)) {
|
if (!"2".equals(value)) {
|
||||||
System.err.println("Only grpc version 2 is supported");
|
System.err.println("Only grpc version 2 is supported");
|
||||||
|
|
@ -104,13 +106,18 @@ public class TestServiceServer {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (useAlts) {
|
||||||
|
useTls = false;
|
||||||
|
}
|
||||||
if (usage) {
|
if (usage) {
|
||||||
TestServiceServer s = new TestServiceServer();
|
TestServiceServer s = new TestServiceServer();
|
||||||
System.out.println(
|
System.out.println(
|
||||||
"Usage: [ARGS...]"
|
"Usage: [ARGS...]"
|
||||||
+ "\n"
|
+ "\n"
|
||||||
+ "\n --port=PORT Port to connect to. Default " + s.port
|
+ "\n --port=PORT Port to connect to. Default " + s.port
|
||||||
+ "\n --use_tls=true|false Whether to use TLS. Default " + s.useTls
|
+ "\n --use_tls=true|false Whether to use TLS. Default " + s.useTls
|
||||||
|
+ "\n --use_alts=true|false Whether to use ALTS. Enable ALTS will disable TLS."
|
||||||
|
+ "\n Default " + s.useAlts
|
||||||
);
|
);
|
||||||
System.exit(1);
|
System.exit(1);
|
||||||
}
|
}
|
||||||
|
|
@ -120,17 +127,31 @@ public class TestServiceServer {
|
||||||
void start() throws Exception {
|
void start() throws Exception {
|
||||||
executor = Executors.newSingleThreadScheduledExecutor();
|
executor = Executors.newSingleThreadScheduledExecutor();
|
||||||
SslContext sslContext = null;
|
SslContext sslContext = null;
|
||||||
if (useTls) {
|
if (useAlts) {
|
||||||
sslContext = GrpcSslContexts.forServer(
|
server =
|
||||||
TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key")).build();
|
AltsServerBuilder.forPort(port)
|
||||||
|
.addService(
|
||||||
|
ServerInterceptors.intercept(
|
||||||
|
new TestServiceImpl(executor), TestServiceImpl.interceptors()))
|
||||||
|
.build()
|
||||||
|
.start();
|
||||||
|
} else {
|
||||||
|
if (useTls) {
|
||||||
|
sslContext =
|
||||||
|
GrpcSslContexts.forServer(
|
||||||
|
TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key"))
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
server =
|
||||||
|
NettyServerBuilder.forPort(port)
|
||||||
|
.sslContext(sslContext)
|
||||||
|
.maxMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE)
|
||||||
|
.addService(
|
||||||
|
ServerInterceptors.intercept(
|
||||||
|
new TestServiceImpl(executor), TestServiceImpl.interceptors()))
|
||||||
|
.build()
|
||||||
|
.start();
|
||||||
}
|
}
|
||||||
server = NettyServerBuilder.forPort(port)
|
|
||||||
.sslContext(sslContext)
|
|
||||||
.maxMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE)
|
|
||||||
.addService(ServerInterceptors.intercept(
|
|
||||||
new TestServiceImpl(executor),
|
|
||||||
TestServiceImpl.interceptors()))
|
|
||||||
.build().start();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@VisibleForTesting
|
@VisibleForTesting
|
||||||
|
|
@ -147,9 +168,7 @@ public class TestServiceServer {
|
||||||
return server.getPort();
|
return server.getPort();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/** Await termination on the main thread since the grpc library uses daemon threads. */
|
||||||
* Await termination on the main thread since the grpc library uses daemon threads.
|
|
||||||
*/
|
|
||||||
private void blockUntilShutdown() throws InterruptedException {
|
private void blockUntilShutdown() throws InterruptedException {
|
||||||
if (server != null) {
|
if (server != null) {
|
||||||
server.awaitTermination();
|
server.awaitTermination();
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue