alts: add gRPC ALTS

This commit is contained in:
Jiangtao Li 2018-02-15 09:28:00 -08:00 committed by Carl Mastrangelo
parent a45d07bcb5
commit e7f2f1dedd
70 changed files with 26379 additions and 1 deletions

82
alts/BUILD.bazel Normal file
View File

@ -0,0 +1,82 @@
load("//:java_grpc_library.bzl", "java_grpc_library")
java_library(
name = "alts_tsi",
srcs = glob([
"src/main/java/io/grpc/alts/transportsecurity/*.java",
]),
visibility = ["//visibility:public"],
deps = [
"//core",
"//core:internal",
"//stub",
"@com_google_code_findbugs_jsr305//jar",
"@com_google_guava_guava//jar",
"@com_google_protobuf//:protobuf_java",
"@com_google_protobuf//:protobuf_java_util",
"@io_netty_netty_buffer//jar",
"@io_netty_netty_common//jar",
":handshaker_java_proto",
":handshaker_java_grpc",
],
)
java_library(
name = "alts",
srcs = glob([
"src/main/java/io/grpc/alts/*.java",
]),
visibility = ["//visibility:public"],
deps = [
"//core",
"//core:internal",
"//netty",
"//stub",
"@com_google_code_findbugs_jsr305//jar",
"@com_google_guava_guava//jar",
"@com_google_protobuf//:protobuf_java",
"@com_google_protobuf//:protobuf_java_util",
"@io_netty_netty_buffer//jar",
"@io_netty_netty_codec//jar",
"@io_netty_netty_common//jar",
"@io_netty_netty_transport//jar",
"@org_apache_commons_commons_lang3//jar",
":alts_tsi",
":handshaker_java_proto",
":handshaker_java_grpc",
],
)
# bazel only accepts proto import with absolute path.
genrule(
name = "protobuf_imports",
srcs = glob(["src/main/proto/*.proto"]),
outs = [
"protobuf_out/altscontext.proto",
"protobuf_out/handshaker.proto",
"protobuf_out/transport_security_common.proto",
],
cmd = "for fname in $(SRCS); do " +
"sed 's,import \",import \"alts/protobuf_out/,g' $$fname > " +
"$(@D)/protobuf_out/$$(basename $$fname); done",
)
proto_library(
name = "handshaker_proto",
srcs = [
"protobuf_out/altscontext.proto",
"protobuf_out/handshaker.proto",
"protobuf_out/transport_security_common.proto",
],
)
java_proto_library(
name = "handshaker_java_proto",
deps = [":handshaker_proto"],
)
java_grpc_library(
name = "handshaker_java_grpc",
srcs = [":handshaker_proto"],
deps = [":handshaker_java_proto"],
)

41
alts/build.gradle Normal file
View File

@ -0,0 +1,41 @@
description = "gRPC: ALTS"
sourceCompatibility = 1.8
targetCompatibility = 1.8
buildscript {
repositories {
mavenCentral()
}
dependencies {
classpath libraries.protobuf_plugin
}
}
dependencies {
compile project(':grpc-core'),
project(':grpc-netty'),
project(':grpc-protobuf'),
project(':grpc-stub'),
libraries.lang,
libraries.protobuf
testCompile libraries.guava_testlib,
libraries.junit,
libraries.mockito,
libraries.truth
}
configureProtoCompilation()
[compileJava, compileTestJava].each() {
// ALTS retuns a lot of futures that we mostly don't care about.
// protobuf calls valueof. Will be fixed in next release (google/protobuf#4046)
it.options.compilerArgs += ["-Xlint:-deprecation", "-Xep:FutureReturnValueIgnored:OFF"]
}
idea {
module {
sourceDirs += file("${projectDir}/src/generated/main/grpc");
sourceDirs += file("${projectDir}/src/generated/main/java");
}
}

View File

@ -0,0 +1,278 @@
package io.grpc.alts;
import static io.grpc.MethodDescriptor.generateFullMethodName;
import static io.grpc.stub.ClientCalls.asyncBidiStreamingCall;
import static io.grpc.stub.ClientCalls.asyncClientStreamingCall;
import static io.grpc.stub.ClientCalls.asyncServerStreamingCall;
import static io.grpc.stub.ClientCalls.asyncUnaryCall;
import static io.grpc.stub.ClientCalls.blockingServerStreamingCall;
import static io.grpc.stub.ClientCalls.blockingUnaryCall;
import static io.grpc.stub.ClientCalls.futureUnaryCall;
import static io.grpc.stub.ServerCalls.asyncBidiStreamingCall;
import static io.grpc.stub.ServerCalls.asyncClientStreamingCall;
import static io.grpc.stub.ServerCalls.asyncServerStreamingCall;
import static io.grpc.stub.ServerCalls.asyncUnaryCall;
import static io.grpc.stub.ServerCalls.asyncUnimplementedStreamingCall;
import static io.grpc.stub.ServerCalls.asyncUnimplementedUnaryCall;
/**
*/
@javax.annotation.Generated(
value = "by gRPC proto compiler",
comments = "Source: handshaker.proto")
public final class HandshakerServiceGrpc {
private HandshakerServiceGrpc() {}
public static final String SERVICE_NAME = "grpc.gcp.HandshakerService";
// Static method descriptors that strictly reflect the proto.
@io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901")
@java.lang.Deprecated // Use {@link #getDoHandshakeMethod()} instead.
public static final io.grpc.MethodDescriptor<io.grpc.alts.Handshaker.HandshakerReq,
io.grpc.alts.Handshaker.HandshakerResp> METHOD_DO_HANDSHAKE = getDoHandshakeMethodHelper();
private static volatile io.grpc.MethodDescriptor<io.grpc.alts.Handshaker.HandshakerReq,
io.grpc.alts.Handshaker.HandshakerResp> getDoHandshakeMethod;
@io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901")
public static io.grpc.MethodDescriptor<io.grpc.alts.Handshaker.HandshakerReq,
io.grpc.alts.Handshaker.HandshakerResp> getDoHandshakeMethod() {
return getDoHandshakeMethodHelper();
}
private static io.grpc.MethodDescriptor<io.grpc.alts.Handshaker.HandshakerReq,
io.grpc.alts.Handshaker.HandshakerResp> getDoHandshakeMethodHelper() {
io.grpc.MethodDescriptor<io.grpc.alts.Handshaker.HandshakerReq, io.grpc.alts.Handshaker.HandshakerResp> getDoHandshakeMethod;
if ((getDoHandshakeMethod = HandshakerServiceGrpc.getDoHandshakeMethod) == null) {
synchronized (HandshakerServiceGrpc.class) {
if ((getDoHandshakeMethod = HandshakerServiceGrpc.getDoHandshakeMethod) == null) {
HandshakerServiceGrpc.getDoHandshakeMethod = getDoHandshakeMethod =
io.grpc.MethodDescriptor.<io.grpc.alts.Handshaker.HandshakerReq, io.grpc.alts.Handshaker.HandshakerResp>newBuilder()
.setType(io.grpc.MethodDescriptor.MethodType.BIDI_STREAMING)
.setFullMethodName(generateFullMethodName(
"grpc.gcp.HandshakerService", "DoHandshake"))
.setSampledToLocalTracing(true)
.setRequestMarshaller(io.grpc.protobuf.ProtoUtils.marshaller(
io.grpc.alts.Handshaker.HandshakerReq.getDefaultInstance()))
.setResponseMarshaller(io.grpc.protobuf.ProtoUtils.marshaller(
io.grpc.alts.Handshaker.HandshakerResp.getDefaultInstance()))
.setSchemaDescriptor(new HandshakerServiceMethodDescriptorSupplier("DoHandshake"))
.build();
}
}
}
return getDoHandshakeMethod;
}
/**
* Creates a new async stub that supports all call types for the service
*/
public static HandshakerServiceStub newStub(io.grpc.Channel channel) {
return new HandshakerServiceStub(channel);
}
/**
* Creates a new blocking-style stub that supports unary and streaming output calls on the service
*/
public static HandshakerServiceBlockingStub newBlockingStub(
io.grpc.Channel channel) {
return new HandshakerServiceBlockingStub(channel);
}
/**
* Creates a new ListenableFuture-style stub that supports unary calls on the service
*/
public static HandshakerServiceFutureStub newFutureStub(
io.grpc.Channel channel) {
return new HandshakerServiceFutureStub(channel);
}
/**
*/
public static abstract class HandshakerServiceImplBase implements io.grpc.BindableService {
/**
* <pre>
* Accepts a stream of handshaker request, returning a stream of handshaker
* response.
* </pre>
*/
public io.grpc.stub.StreamObserver<io.grpc.alts.Handshaker.HandshakerReq> doHandshake(
io.grpc.stub.StreamObserver<io.grpc.alts.Handshaker.HandshakerResp> responseObserver) {
return asyncUnimplementedStreamingCall(getDoHandshakeMethodHelper(), responseObserver);
}
@java.lang.Override public final io.grpc.ServerServiceDefinition bindService() {
return io.grpc.ServerServiceDefinition.builder(getServiceDescriptor())
.addMethod(
getDoHandshakeMethodHelper(),
asyncBidiStreamingCall(
new MethodHandlers<
io.grpc.alts.Handshaker.HandshakerReq,
io.grpc.alts.Handshaker.HandshakerResp>(
this, METHODID_DO_HANDSHAKE)))
.build();
}
}
/**
*/
public static final class HandshakerServiceStub extends io.grpc.stub.AbstractStub<HandshakerServiceStub> {
private HandshakerServiceStub(io.grpc.Channel channel) {
super(channel);
}
private HandshakerServiceStub(io.grpc.Channel channel,
io.grpc.CallOptions callOptions) {
super(channel, callOptions);
}
@java.lang.Override
protected HandshakerServiceStub build(io.grpc.Channel channel,
io.grpc.CallOptions callOptions) {
return new HandshakerServiceStub(channel, callOptions);
}
/**
* <pre>
* Accepts a stream of handshaker request, returning a stream of handshaker
* response.
* </pre>
*/
public io.grpc.stub.StreamObserver<io.grpc.alts.Handshaker.HandshakerReq> doHandshake(
io.grpc.stub.StreamObserver<io.grpc.alts.Handshaker.HandshakerResp> responseObserver) {
return asyncBidiStreamingCall(
getChannel().newCall(getDoHandshakeMethodHelper(), getCallOptions()), responseObserver);
}
}
/**
*/
public static final class HandshakerServiceBlockingStub extends io.grpc.stub.AbstractStub<HandshakerServiceBlockingStub> {
private HandshakerServiceBlockingStub(io.grpc.Channel channel) {
super(channel);
}
private HandshakerServiceBlockingStub(io.grpc.Channel channel,
io.grpc.CallOptions callOptions) {
super(channel, callOptions);
}
@java.lang.Override
protected HandshakerServiceBlockingStub build(io.grpc.Channel channel,
io.grpc.CallOptions callOptions) {
return new HandshakerServiceBlockingStub(channel, callOptions);
}
}
/**
*/
public static final class HandshakerServiceFutureStub extends io.grpc.stub.AbstractStub<HandshakerServiceFutureStub> {
private HandshakerServiceFutureStub(io.grpc.Channel channel) {
super(channel);
}
private HandshakerServiceFutureStub(io.grpc.Channel channel,
io.grpc.CallOptions callOptions) {
super(channel, callOptions);
}
@java.lang.Override
protected HandshakerServiceFutureStub build(io.grpc.Channel channel,
io.grpc.CallOptions callOptions) {
return new HandshakerServiceFutureStub(channel, callOptions);
}
}
private static final int METHODID_DO_HANDSHAKE = 0;
private static final class MethodHandlers<Req, Resp> implements
io.grpc.stub.ServerCalls.UnaryMethod<Req, Resp>,
io.grpc.stub.ServerCalls.ServerStreamingMethod<Req, Resp>,
io.grpc.stub.ServerCalls.ClientStreamingMethod<Req, Resp>,
io.grpc.stub.ServerCalls.BidiStreamingMethod<Req, Resp> {
private final HandshakerServiceImplBase serviceImpl;
private final int methodId;
MethodHandlers(HandshakerServiceImplBase serviceImpl, int methodId) {
this.serviceImpl = serviceImpl;
this.methodId = methodId;
}
@java.lang.Override
@java.lang.SuppressWarnings("unchecked")
public void invoke(Req request, io.grpc.stub.StreamObserver<Resp> responseObserver) {
switch (methodId) {
default:
throw new AssertionError();
}
}
@java.lang.Override
@java.lang.SuppressWarnings("unchecked")
public io.grpc.stub.StreamObserver<Req> invoke(
io.grpc.stub.StreamObserver<Resp> responseObserver) {
switch (methodId) {
case METHODID_DO_HANDSHAKE:
return (io.grpc.stub.StreamObserver<Req>) serviceImpl.doHandshake(
(io.grpc.stub.StreamObserver<io.grpc.alts.Handshaker.HandshakerResp>) responseObserver);
default:
throw new AssertionError();
}
}
}
private static abstract class HandshakerServiceBaseDescriptorSupplier
implements io.grpc.protobuf.ProtoFileDescriptorSupplier, io.grpc.protobuf.ProtoServiceDescriptorSupplier {
HandshakerServiceBaseDescriptorSupplier() {}
@java.lang.Override
public com.google.protobuf.Descriptors.FileDescriptor getFileDescriptor() {
return io.grpc.alts.Handshaker.getDescriptor();
}
@java.lang.Override
public com.google.protobuf.Descriptors.ServiceDescriptor getServiceDescriptor() {
return getFileDescriptor().findServiceByName("HandshakerService");
}
}
private static final class HandshakerServiceFileDescriptorSupplier
extends HandshakerServiceBaseDescriptorSupplier {
HandshakerServiceFileDescriptorSupplier() {}
}
private static final class HandshakerServiceMethodDescriptorSupplier
extends HandshakerServiceBaseDescriptorSupplier
implements io.grpc.protobuf.ProtoMethodDescriptorSupplier {
private final String methodName;
HandshakerServiceMethodDescriptorSupplier(String methodName) {
this.methodName = methodName;
}
@java.lang.Override
public com.google.protobuf.Descriptors.MethodDescriptor getMethodDescriptor() {
return getServiceDescriptor().findMethodByName(methodName);
}
}
private static volatile io.grpc.ServiceDescriptor serviceDescriptor;
public static io.grpc.ServiceDescriptor getServiceDescriptor() {
io.grpc.ServiceDescriptor result = serviceDescriptor;
if (result == null) {
synchronized (HandshakerServiceGrpc.class) {
result = serviceDescriptor;
if (result == null) {
serviceDescriptor = result = io.grpc.ServiceDescriptor.newBuilder(SERVICE_NAME)
.setSchemaDescriptor(new HandshakerServiceFileDescriptorSupplier())
.addMethod(getDoHandshakeMethodHelper())
.build();
}
}
}
return result;
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,252 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts;
import static com.google.common.base.Preconditions.checkArgument;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.CallOptions;
import io.grpc.ClientCall;
import io.grpc.ConnectivityState;
import io.grpc.ForwardingChannelBuilder;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.MethodDescriptor;
import io.grpc.alts.transportsecurity.AltsClientOptions;
import io.grpc.alts.transportsecurity.AltsTsiHandshaker;
import io.grpc.alts.transportsecurity.TsiHandshaker;
import io.grpc.alts.transportsecurity.TsiHandshakerFactory;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.ProxyParameters;
import io.grpc.netty.InternalNettyChannelBuilder;
import io.grpc.netty.InternalNettyChannelBuilder.TransportCreationParamsFilter;
import io.grpc.netty.InternalNettyChannelBuilder.TransportCreationParamsFilterFactory;
import io.grpc.netty.NettyChannelBuilder;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
/**
* ALTS version of {@code ManagedChannelBuilder}. This class sets up a secure and authenticated
* commmunication between two cloud VMs using ALTS.
*/
public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChannelBuilder> {
private final NettyChannelBuilder delegate;
private final AltsClientOptions.Builder handshakerOptionsBuilder =
new AltsClientOptions.Builder();
private TcpfFactory tcpfFactoryForTest;
private boolean enableUntrustedAlts;
/** "Overrides" the static method in {@link ManagedChannelBuilder}. */
public static final AltsChannelBuilder forTarget(String target) {
return new AltsChannelBuilder(target);
}
/** "Overrides" the static method in {@link ManagedChannelBuilder}. */
public static AltsChannelBuilder forAddress(String name, int port) {
return forTarget(GrpcUtil.authorityFromHostAndPort(name, port));
}
private AltsChannelBuilder(String target) {
delegate =
NettyChannelBuilder.forTarget(target)
.keepAliveTime(20, TimeUnit.SECONDS)
.keepAliveTimeout(10, TimeUnit.SECONDS)
.keepAliveWithoutCalls(true);
handshakerOptionsBuilder.setRpcProtocolVersions(
RpcProtocolVersionsUtil.getRpcProtocolVersions());
}
/** The server service account name for secure name checking. */
public AltsChannelBuilder withSecureNamingTarget(String targetName) {
handshakerOptionsBuilder.setTargetName(targetName);
return this;
}
/**
* Adds an expected target service accounts. One of the added service accounts should match peer
* service account in the handshaker result. Otherwise, the handshake fails.
*/
public AltsChannelBuilder addTargetServiceAccount(String targetServiceAccount) {
handshakerOptionsBuilder.addTargetServiceAccount(targetServiceAccount);
return this;
}
/**
* Enables untrusted ALTS for testing. If this function is called, we will not check whether ALTS
* is running on Google Cloud Platform.
*/
public AltsChannelBuilder enableUntrustedAltsForTesting() {
enableUntrustedAlts = true;
return this;
}
/** Sets a new handshaker service address for testing. */
public AltsChannelBuilder setHandshakerAddressForTesting(String handshakerAddress) {
HandshakerServiceChannel.setHandshakerAddressForTesting(handshakerAddress);
return this;
}
@Override
protected NettyChannelBuilder delegate() {
return delegate;
}
@Override
public ManagedChannel build() {
CheckGcpEnvironment.check(enableUntrustedAlts);
TcpfFactory tcpfFactory = new TcpfFactory();
InternalNettyChannelBuilder.setDynamicTransportParamsFactory(delegate(), tcpfFactory);
tcpfFactoryForTest = tcpfFactory;
return new AltsChannel(delegate().build());
}
@VisibleForTesting
@Nullable
TransportCreationParamsFilterFactory getTcpfFactoryForTest() {
return tcpfFactoryForTest;
}
@VisibleForTesting
@Nullable
AltsClientOptions getAltsClientOptionsForTest() {
if (tcpfFactoryForTest == null) {
return null;
}
return tcpfFactoryForTest.handshakerOptions;
}
private final class TcpfFactory implements TransportCreationParamsFilterFactory {
final AltsClientOptions handshakerOptions = handshakerOptionsBuilder.build();
private final TsiHandshakerFactory altsHandshakerFactory =
new TsiHandshakerFactory() {
@Override
public TsiHandshaker newHandshaker() {
// Used the shared grpc channel to connecting to the ALTS handshaker service.
ManagedChannel channel = HandshakerServiceChannel.get();
return AltsTsiHandshaker.newClient(
HandshakerServiceGrpc.newStub(channel), handshakerOptions);
}
};
@Override
public TransportCreationParamsFilter create(
SocketAddress serverAddress,
final String authority,
final String userAgent,
final ProxyParameters proxy) {
checkArgument(
serverAddress instanceof InetSocketAddress,
"%s must be a InetSocketAddress",
serverAddress);
final AltsProtocolNegotiator negotiator =
AltsProtocolNegotiator.create(altsHandshakerFactory);
return new TransportCreationParamsFilter() {
@Override
public SocketAddress getTargetServerAddress() {
return serverAddress;
}
@Override
public String getAuthority() {
return authority;
}
@Override
public String getUserAgent() {
return userAgent;
}
@Override
public AltsProtocolNegotiator getProtocolNegotiator() {
return negotiator;
}
};
}
}
static final class AltsChannel extends ManagedChannel {
private final ManagedChannel delegate;
AltsChannel(ManagedChannel delegate) {
this.delegate = delegate;
}
@Override
public ConnectivityState getState(boolean requestConnection) {
return delegate.getState(requestConnection);
}
@Override
public void notifyWhenStateChanged(ConnectivityState source, Runnable callback) {
delegate.notifyWhenStateChanged(source, callback);
}
@Override
public AltsChannel shutdown() {
delegate.shutdown();
return this;
}
@Override
public boolean isShutdown() {
return delegate.isShutdown();
}
@Override
public boolean isTerminated() {
return delegate.isTerminated();
}
@Override
public AltsChannel shutdownNow() {
delegate.shutdownNow();
return this;
}
@Override
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
return delegate.awaitTermination(timeout, unit);
}
@Override
public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(
MethodDescriptor<RequestT, ResponseT> methodDescriptor, CallOptions callOptions) {
return delegate.newCall(methodDescriptor, callOptions);
}
@Override
public String authority() {
return delegate.authority();
}
@Override
public void resetConnectBackoff() {
delegate.resetConnectBackoff();
}
@Override
public void prepareToLoseNetwork() {
delegate.prepareToLoseNetwork();
}
}
}

View File

@ -0,0 +1,142 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.grpc.Attributes;
import io.grpc.Grpc;
import io.grpc.Status;
import io.grpc.alts.InternalTsiHandshakeHandler.TsiHandshakeCompletionEvent;
import io.grpc.alts.RpcProtocolVersionsUtil.RpcVersionsCheckResult;
import io.grpc.alts.transportsecurity.AltsAuthContext;
import io.grpc.alts.transportsecurity.TsiHandshakerFactory;
import io.grpc.alts.transportsecurity.TsiPeer;
import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.ProtocolNegotiator;
import io.grpc.netty.ProtocolNegotiators.AbstractBufferingHandler;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.util.AsciiString;
/**
* A client-side GRPC {@link ProtocolNegotiator} for ALTS. This class creates a Netty handler that
* provides ALTS security on the wire, similar to Netty's {@code SslHandler}.
*/
public abstract class AltsProtocolNegotiator implements ProtocolNegotiator {
private static final Attributes.Key<TsiPeer> TSI_PEER_KEY = Attributes.Key.of("TSI_PEER");
private static final Attributes.Key<AltsAuthContext> ALTS_CONTEXT_KEY =
Attributes.Key.of("ALTS_CONTEXT_KEY");
private static final AsciiString scheme = AsciiString.of("https");
public static Attributes.Key<TsiPeer> getTsiPeerAttributeKey() {
return TSI_PEER_KEY;
}
public static Attributes.Key<AltsAuthContext> getAltsAuthContextAttributeKey() {
return ALTS_CONTEXT_KEY;
}
/** Creates a negotiator used for ALTS. */
public static AltsProtocolNegotiator create(TsiHandshakerFactory handshakerFactory) {
return new AltsProtocolNegotiator() {
@Override
public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
return new BufferUntilAltsNegotiatedHandler(
grpcHandler,
new InternalTsiHandshakeHandler(
new InternalNettyTsiHandshaker(handshakerFactory.newHandshaker())),
new InternalTsiFrameHandler());
}
};
}
/** Buffers all writes until the ALTS handshake is complete. */
@VisibleForTesting
static class BufferUntilAltsNegotiatedHandler extends AbstractBufferingHandler
implements ProtocolNegotiator.Handler {
private final GrpcHttp2ConnectionHandler grpcHandler;
BufferUntilAltsNegotiatedHandler(
GrpcHttp2ConnectionHandler grpcHandler, ChannelHandler... negotiationhandlers) {
super(negotiationhandlers);
// Save the gRPC handler. The ALTS handler doesn't support buffering before the handshake
// completes, so we wait until the handshake was successful before adding the grpc handler.
this.grpcHandler = grpcHandler;
}
// TODO: Remove this once https://github.com/grpc/grpc-java/pull/3715 is in.
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
fail(ctx, cause);
ctx.fireExceptionCaught(cause);
}
@Override
public AsciiString scheme() {
return scheme;
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof TsiHandshakeCompletionEvent) {
TsiHandshakeCompletionEvent altsEvt = (TsiHandshakeCompletionEvent) evt;
if (altsEvt.isSuccess()) {
// Add the gRPC handler just before this handler. We only allow the grpcHandler to be
// null to support testing. In production, a grpc handler will always be provided.
if (grpcHandler != null) {
ctx.pipeline().addBefore(ctx.name(), null, grpcHandler);
AltsAuthContext altsContext = (AltsAuthContext) altsEvt.context();
Preconditions.checkNotNull(altsContext);
// Checks peer Rpc Protocol Versions in the ALTS auth context. Fails the connection if
// Rpc Protocol Versions mismatch.
RpcVersionsCheckResult checkResult =
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
RpcProtocolVersionsUtil.getRpcProtocolVersions(),
altsContext.getPeerRpcVersions());
if (!checkResult.getResult()) {
String errorMessage =
"Local Rpc Protocol Versions "
+ RpcProtocolVersionsUtil.getRpcProtocolVersions().toString()
+ "are not compatible with peer Rpc Protocol Versions "
+ altsContext.getPeerRpcVersions().toString();
fail(ctx, Status.UNAVAILABLE.withDescription(errorMessage).asRuntimeException());
}
grpcHandler.handleProtocolNegotiationCompleted(
Attributes.newBuilder()
.set(TSI_PEER_KEY, altsEvt.peer())
.set(ALTS_CONTEXT_KEY, altsContext)
.set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, ctx.channel().remoteAddress())
.build());
}
// Now write any buffered data and remove this handler.
writeBufferedAndRemove(ctx);
} else {
fail(ctx, unavailableException("ALTS handshake failed", altsEvt.cause()));
}
}
super.userEventTriggered(ctx, evt);
}
private static RuntimeException unavailableException(String msg, Throwable cause) {
return Status.UNAVAILABLE.withCause(cause).withDescription(msg).asRuntimeException();
}
}
}

View File

@ -0,0 +1,247 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts;
import io.grpc.BindableService;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.HandlerRegistry;
import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.ServerInterceptor;
import io.grpc.ServerServiceDefinition;
import io.grpc.ServerStreamTracer.Factory;
import io.grpc.ServerTransportFilter;
import io.grpc.alts.transportsecurity.AltsHandshakerOptions;
import io.grpc.alts.transportsecurity.AltsTsiHandshaker;
import io.grpc.alts.transportsecurity.TsiHandshaker;
import io.grpc.alts.transportsecurity.TsiHandshakerFactory;
import io.grpc.netty.NettyServerBuilder;
import java.io.File;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
/**
* gRPC secure server builder used for ALTS. This class adds on the necessary ALTS support to create
* a production server on Google Cloud Platform.
*/
public final class AltsServerBuilder extends ServerBuilder<AltsServerBuilder> {
final NettyServerBuilder delegate;
private boolean enableUntrustedAlts;
private AltsServerBuilder(NettyServerBuilder nettyDelegate) {
this.delegate = nettyDelegate;
}
/** Creates a gRPC server builder for the given port. */
public static AltsServerBuilder forPort(int port) {
NettyServerBuilder nettyDelegate =
NettyServerBuilder.forAddress(new InetSocketAddress(port))
.maxConnectionIdle(1, TimeUnit.HOURS)
.keepAliveTime(270, TimeUnit.SECONDS)
.keepAliveTimeout(20, TimeUnit.SECONDS)
.permitKeepAliveTime(10, TimeUnit.SECONDS)
.permitKeepAliveWithoutCalls(true);
return new AltsServerBuilder(nettyDelegate);
}
/**
* Enables untrusted ALTS for testing. If this function is called, we will not check whether ALTS
* is running on Google Cloud Platform.
*/
public AltsServerBuilder enableUntrustedAltsForTesting() {
enableUntrustedAlts = true;
return this;
}
/** Sets a new handshaker service address for testing. */
public AltsServerBuilder setHandshakerAddressForTesting(String handshakerAddress) {
HandshakerServiceChannel.setHandshakerAddressForTesting(handshakerAddress);
return this;
}
/** {@inheritDoc} */
@Override
public AltsServerBuilder handshakeTimeout(long timeout, TimeUnit unit) {
delegate.handshakeTimeout(timeout, unit);
return this;
}
/** {@inheritDoc} */
@Override
public AltsServerBuilder directExecutor() {
delegate.directExecutor();
return this;
}
/** {@inheritDoc} */
@Override
public AltsServerBuilder addStreamTracerFactory(Factory factory) {
delegate.addStreamTracerFactory(factory);
return this;
}
/** {@inheritDoc} */
@Override
public AltsServerBuilder addTransportFilter(ServerTransportFilter filter) {
delegate.addTransportFilter(filter);
return this;
}
/** {@inheritDoc} */
@Override
public AltsServerBuilder executor(Executor executor) {
delegate.executor(executor);
return this;
}
/** {@inheritDoc} */
@Override
public AltsServerBuilder addService(ServerServiceDefinition service) {
delegate.addService(service);
return this;
}
/** {@inheritDoc} */
@Override
public AltsServerBuilder addService(BindableService bindableService) {
delegate.addService(bindableService);
return this;
}
/** {@inheritDoc} */
@Override
public AltsServerBuilder fallbackHandlerRegistry(HandlerRegistry fallbackRegistry) {
delegate.fallbackHandlerRegistry(fallbackRegistry);
return this;
}
/** {@inheritDoc} */
@Override
public AltsServerBuilder useTransportSecurity(File certChain, File privateKey) {
throw new UnsupportedOperationException("Can't set TLS settings for ALTS");
}
/** {@inheritDoc} */
@Override
public AltsServerBuilder decompressorRegistry(DecompressorRegistry registry) {
delegate.decompressorRegistry(registry);
return this;
}
/** {@inheritDoc} */
@Override
public AltsServerBuilder compressorRegistry(CompressorRegistry registry) {
delegate.compressorRegistry(registry);
return this;
}
/** {@inheritDoc} */
@Override
public AltsServerBuilder intercept(ServerInterceptor interceptor) {
delegate.intercept(interceptor);
return this;
}
/** {@inheritDoc} */
@Override
public Server build() {
CheckGcpEnvironment.check(enableUntrustedAlts);
delegate.protocolNegotiator(
AltsProtocolNegotiator.create(
new TsiHandshakerFactory() {
@Override
public TsiHandshaker newHandshaker() {
// Used the shared grpc channel to connecting to the ALTS handshaker service.
return AltsTsiHandshaker.newServer(
HandshakerServiceGrpc.newStub(HandshakerServiceChannel.get()),
new AltsHandshakerOptions(RpcProtocolVersionsUtil.getRpcProtocolVersions()));
}
}));
return new AltsServer(delegate.build());
}
static final class AltsServer extends io.grpc.Server {
private final Server delegate;
AltsServer(Server delegate) {
this.delegate = delegate;
}
@Override
public List<ServerServiceDefinition> getImmutableServices() {
return delegate.getImmutableServices();
}
@Override
public List<ServerServiceDefinition> getMutableServices() {
return delegate.getMutableServices();
}
@Override
public int getPort() {
return delegate.getPort();
}
@Override
public List<ServerServiceDefinition> getServices() {
return delegate.getServices();
}
@Override
public Server start() throws IOException {
delegate.start();
return this;
}
@Override
public Server shutdown() {
delegate.shutdown();
return this;
}
@Override
public Server shutdownNow() {
delegate.shutdownNow();
return this;
}
@Override
public boolean isShutdown() {
return delegate.isShutdown();
}
@Override
public boolean isTerminated() {
return delegate.isTerminated();
}
@Override
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
return delegate.awaitTermination(timeout, unit);
}
@Override
public void awaitTermination() throws InterruptedException {
delegate.awaitTermination();
}
}
}

View File

@ -0,0 +1,84 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import java.nio.ByteBuffer;
/** Unwraps {@link ByteBuf}s into {@link ByteBuffer}s. */
final class BufUnwrapper implements AutoCloseable {
private final ByteBuffer[] singleReadBuffer = new ByteBuffer[1];
private final ByteBuffer[] singleWriteBuffer = new ByteBuffer[1];
/**
* Called to get access to the underlying NIO buffers for a {@link ByteBuf} that will be used for
* writing.
*/
ByteBuffer[] writableNioBuffers(ByteBuf buf) {
// Set the writer index to the capacity to guarantee that the returned NIO buffers will have
// the capacity available.
int readerIndex = buf.readerIndex();
int writerIndex = buf.writerIndex();
buf.readerIndex(writerIndex);
buf.writerIndex(buf.capacity());
try {
return nioBuffers(buf, singleWriteBuffer);
} finally {
// Restore the writer index before returning.
buf.readerIndex(readerIndex);
buf.writerIndex(writerIndex);
}
}
/**
* Called to get access to the underlying NIO buffers for a {@link ByteBuf} that will be used for
* reading.
*/
ByteBuffer[] readableNioBuffers(ByteBuf buf) {
return nioBuffers(buf, singleReadBuffer);
}
@Override
public void close() {
singleReadBuffer[0] = null;
singleWriteBuffer[0] = null;
}
/**
* Optimized accessor for obtaining the underlying NIO buffers for a Netty {@link ByteBuf}. Based
* on code from Netty's {@code SslHandler}. This method returns NIO buffers that span the readable
* region of the {@link ByteBuf}.
*/
private static ByteBuffer[] nioBuffers(ByteBuf buf, ByteBuffer[] singleBuffer) {
// As CompositeByteBuf.nioBufferCount() can be expensive (as it needs to check all composed
// ByteBuf to calculate the count) we will just assume a CompositeByteBuf contains more than 1
// ByteBuf. The worst that can happen is that we allocate an extra ByteBuffer[] in
// CompositeByteBuf.nioBuffers() which is better than walking the composed ByteBuf in most
// cases.
if (!(buf instanceof CompositeByteBuf) && buf.nioBufferCount() == 1) {
// We know its only backed by 1 ByteBuffer so use internalNioBuffer to keep object
// allocation to a minimum.
singleBuffer[0] = buf.internalNioBuffer(buf.readerIndex(), buf.readableBytes());
return singleBuffer;
}
return buf.nioBuffers();
}
}

View File

@ -0,0 +1,98 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts;
import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.common.annotations.VisibleForTesting;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.apache.commons.lang3.SystemUtils;
/** Class for checking if the system is running on Google Cloud Platform (GCP). */
final class CheckGcpEnvironment {
private static final Logger logger = Logger.getLogger(CheckGcpEnvironment.class.getName());
private static final String DMI_PRODUCT_NAME = "/sys/class/dmi/id/product_name";
private static final String WINDOWS_COMMAND = "powershell.exe";
private static Boolean cachedResult = null;
// Construct me not!
private CheckGcpEnvironment() {}
public static void check(boolean enableUntrustedAlts) {
if (enableUntrustedAlts) {
logger.log(
Level.WARNING,
"Untrusted ALTS mode is enabled and we cannot guarantee the trustworthiness of the ALTS "
+ "handshaker service.");
} else if (!isOnGcp()) {
throw new RuntimeException("ALTS is only allowed to run on Google Cloud Platform.");
}
}
private static synchronized boolean isOnGcp() {
if (cachedResult == null) {
cachedResult = isRunningOnGcp();
}
return cachedResult;
}
@VisibleForTesting
static boolean checkProductNameOnLinux(BufferedReader reader) throws IOException {
String name = reader.readLine().trim();
return name.equals("Google") || name.equals("Google Compute Engine");
}
@VisibleForTesting
static boolean checkBiosDataOnWindows(BufferedReader reader) throws IOException {
String line;
while ((line = reader.readLine()) != null) {
if (line.startsWith("Manufacturer")) {
String name = line.substring(line.indexOf(':') + 1).trim();
return name.equals("Google");
}
}
return false;
}
private static boolean isRunningOnGcp() {
try {
if (SystemUtils.IS_OS_LINUX) {
// Checks GCE residency on Linux platform.
return checkProductNameOnLinux(Files.newBufferedReader(Paths.get(DMI_PRODUCT_NAME), UTF_8));
} else if (SystemUtils.IS_OS_WINDOWS) {
// Checks GCE residency on Windows platform.
Process p =
new ProcessBuilder()
.command(WINDOWS_COMMAND, "Get-WmiObject", "-Class", "Win32_BIOS")
.start();
return checkBiosDataOnWindows(
new BufferedReader(new InputStreamReader(p.getInputStream(), UTF_8)));
}
} catch (IOException e) {
logger.log(Level.WARNING, "Fail to read platform information: ", e);
return false;
}
// Platforms other than Linux and Windows are not supported.
return false;
}
}

View File

@ -0,0 +1,72 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts;
import com.google.common.base.Preconditions;
import io.grpc.ManagedChannel;
import io.grpc.netty.NettyChannelBuilder;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.util.concurrent.DefaultThreadFactory;
import java.util.concurrent.ThreadFactory;
/**
* Class for creating a single shared grpc channel to the ALTS Handshaker Service. The channel to
* the handshaker service is local and is over plaintext. Each application will have at most one
* connection to the handshaker service.
*
* <p>TODO: Release the channel if it is not used.
*/
final class HandshakerServiceChannel {
// Default handshaker service address.
private static String handshakerAddress = "metadata.google.internal:8080";
// Shared channel to ALTS handshaker service.
private static ManagedChannel channel = null;
// Construct me not!
private HandshakerServiceChannel() {}
// Sets handshaker service address for testing and creates the channel to the handshaker service.
public static synchronized void setHandshakerAddressForTesting(String handshakerAddress) {
Preconditions.checkState(
channel == null || HandshakerServiceChannel.handshakerAddress.equals(handshakerAddress),
"HandshakerServiceChannel already created with a different handshakerAddress");
HandshakerServiceChannel.handshakerAddress = handshakerAddress;
if (channel == null) {
channel = createChannel();
}
}
/** Create a new channel to ALTS handshaker service, if it has not been created yet. */
private static ManagedChannel createChannel() {
/* Use its own event loop thread pool to avoid blocking. */
ThreadFactory clientThreadFactory = new DefaultThreadFactory("handshaker pool", true);
ManagedChannel channel =
NettyChannelBuilder.forTarget(handshakerAddress)
.directExecutor()
.eventLoopGroup(new NioEventLoopGroup(1, clientThreadFactory))
.usePlaintext(true)
.build();
return channel;
}
public static synchronized ManagedChannel get() {
if (channel == null) {
channel = createChannel();
}
return channel;
}
}

View File

@ -0,0 +1,155 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import io.grpc.Internal;
import io.grpc.alts.transportsecurity.TsiFrameProtector;
import io.grpc.alts.transportsecurity.TsiHandshaker;
import io.grpc.alts.transportsecurity.TsiPeer;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
/** A wrapper for a {@link TsiHandshaker} that accepts netty {@link ByteBuf}s. */
@Internal
public final class InternalNettyTsiHandshaker {
private BufUnwrapper unwrapper = new BufUnwrapper();
private final TsiHandshaker internalHandshaker;
public InternalNettyTsiHandshaker(TsiHandshaker handshaker) {
internalHandshaker = checkNotNull(handshaker);
}
/**
* Gets data that is ready to be sent to the to the remote peer. This should be called in a loop
* until no bytes are written to the output buffer.
*
* @param out the buffer to receive the bytes.
*/
void getBytesToSendToPeer(ByteBuf out) throws GeneralSecurityException {
checkState(unwrapper != null, "protector already created");
try (BufUnwrapper unwrapper = this.unwrapper) {
// Write as many bytes as possible into the buffer.
int bytesWritten = 0;
for (ByteBuffer nioBuffer : unwrapper.writableNioBuffers(out)) {
if (!nioBuffer.hasRemaining()) {
// This buffer doesn't have any more space to write, go to the next buffer.
continue;
}
int prevPos = nioBuffer.position();
internalHandshaker.getBytesToSendToPeer(nioBuffer);
bytesWritten += nioBuffer.position() - prevPos;
// If the buffer position was not changed, the frame has been completely read into the
// buffers.
if (nioBuffer.position() == prevPos) {
break;
}
}
out.writerIndex(out.writerIndex() + bytesWritten);
}
}
/**
* Process handshake data received from the remote peer.
*
* @return {@code true}, if the handshake has all the data it needs to process and {@code false},
* if the method must be called again to complete processing.
*/
boolean processBytesFromPeer(ByteBuf data) throws GeneralSecurityException {
checkState(unwrapper != null, "protector already created");
try (BufUnwrapper unwrapper = this.unwrapper) {
int bytesRead = 0;
boolean done = false;
for (ByteBuffer nioBuffer : unwrapper.readableNioBuffers(data)) {
if (!nioBuffer.hasRemaining()) {
// This buffer has been fully read, continue to the next buffer.
continue;
}
int prevPos = nioBuffer.position();
done = internalHandshaker.processBytesFromPeer(nioBuffer);
bytesRead += nioBuffer.position() - prevPos;
if (done) {
break;
}
}
data.readerIndex(data.readerIndex() + bytesRead);
return done;
}
}
/**
* Returns true if and only if the handshake is still in progress
*
* @return true, if the handshake is still in progress, false otherwise.
*/
boolean isInProgress() {
return internalHandshaker.isInProgress();
}
/**
* Returns the peer extracted from a completed handshake.
*
* @return the extracted peer.
*/
TsiPeer extractPeer() throws GeneralSecurityException {
checkState(!internalHandshaker.isInProgress());
return internalHandshaker.extractPeer();
}
/**
* Returns the peer extracted from a completed handshake.
*
* @return the extracted peer.
*/
Object extractPeerObject() throws GeneralSecurityException {
checkState(!internalHandshaker.isInProgress());
return internalHandshaker.extractPeerObject();
}
/**
* Creates a frame protector from a completed handshake. No other methods may be called after the
* frame protector is created.
*
* @param maxFrameSize the requested max frame size, the callee is free to ignore.
* @return a new {@link TsiFrameProtector}.
*/
TsiFrameProtector createFrameProtector(int maxFrameSize, ByteBufAllocator alloc) {
unwrapper = null;
return internalHandshaker.createFrameProtector(maxFrameSize, alloc);
}
/**
* Creates a frame protector from a completed handshake. No other methods may be called after the
* frame protector is created.
*
* @return a new {@link TsiFrameProtector}.
*/
TsiFrameProtector createFrameProtector(ByteBufAllocator alloc) {
unwrapper = null;
return internalHandshaker.createFrameProtector(alloc);
}
}

View File

@ -0,0 +1,177 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.Internal;
import io.grpc.alts.InternalTsiHandshakeHandler.TsiHandshakeCompletionEvent;
import io.grpc.alts.transportsecurity.TsiFrameProtector;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelException;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandler;
import io.netty.channel.ChannelPromise;
import io.netty.channel.PendingWriteQueue;
import io.netty.handler.codec.ByteToMessageDecoder;
import java.net.SocketAddress;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Future;
/**
* Encrypts and decrypts TSI Frames. Writes are buffered here until {@link #flush} is called. Writes
* must not be made before the TSI handshake is complete.
*/
@Internal
public final class InternalTsiFrameHandler extends ByteToMessageDecoder
implements ChannelOutboundHandler {
private TsiFrameProtector protector;
private PendingWriteQueue pendingUnprotectedWrites;
public InternalTsiFrameHandler() {}
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
super.handlerAdded(ctx);
assert pendingUnprotectedWrites == null;
pendingUnprotectedWrites = new PendingWriteQueue(checkNotNull(ctx));
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object event) throws Exception {
if (event instanceof TsiHandshakeCompletionEvent) {
TsiHandshakeCompletionEvent tsiEvent = (TsiHandshakeCompletionEvent) event;
if (tsiEvent.isSuccess()) {
setProtector(tsiEvent.protector());
}
// Ignore errors. Another handler in the pipeline must handle TSI Errors.
}
// Keep propagating the message, as others may want to read it.
super.userEventTriggered(ctx, event);
}
@VisibleForTesting
void setProtector(TsiFrameProtector protector) {
checkState(this.protector == null);
this.protector = checkNotNull(protector);
}
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
checkState(protector != null, "Cannot read frames while the TSI handshake is in progress");
protector.unprotect(in, out, ctx.alloc());
}
@Override
public void write(ChannelHandlerContext ctx, Object message, ChannelPromise promise)
throws Exception {
checkState(protector != null, "Cannot write frames while the TSI handshake is in progress");
ByteBuf msg = (ByteBuf) message;
if (!msg.isReadable()) {
// Nothing to encode.
@SuppressWarnings("unused") // go/futurereturn-lsc
Future<?> possiblyIgnoredError = promise.setSuccess();
return;
}
// Just add the message to the pending queue. We'll write it on the next flush.
pendingUnprotectedWrites.add(msg, promise);
}
@Override
public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
if (!pendingUnprotectedWrites.isEmpty()) {
pendingUnprotectedWrites.removeAndFailAll(
new ChannelException("Pending write on removal of TSI handler"));
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
pendingUnprotectedWrites.removeAndFailAll(cause);
super.exceptionCaught(ctx, cause);
}
@Override
public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) {
ctx.bind(localAddress, promise);
}
@Override
public void connect(
ChannelHandlerContext ctx,
SocketAddress remoteAddress,
SocketAddress localAddress,
ChannelPromise promise) {
ctx.connect(remoteAddress, localAddress, promise);
}
@Override
public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) {
ctx.disconnect(promise);
}
@Override
public void close(ChannelHandlerContext ctx, ChannelPromise promise) {
ctx.close(promise);
}
@Override
public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) {
ctx.deregister(promise);
}
@Override
public void read(ChannelHandlerContext ctx) {
ctx.read();
}
@Override
public void flush(ChannelHandlerContext ctx) throws GeneralSecurityException {
checkState(protector != null, "Cannot write frames while the TSI handshake is in progress");
ProtectedPromise aggregatePromise =
new ProtectedPromise(ctx.channel(), ctx.executor(), pendingUnprotectedWrites.size());
List<ByteBuf> bufs = new ArrayList<>(pendingUnprotectedWrites.size());
if (pendingUnprotectedWrites.isEmpty()) {
// Return early if there's nothing to write. Otherwise protector.protectFlush() below may
// not check for "no-data" and go on writing the 0-byte "data" to the socket with the
// protection framing.
return;
}
// Drain the unprotected writes.
while (!pendingUnprotectedWrites.isEmpty()) {
ByteBuf in = (ByteBuf) pendingUnprotectedWrites.current();
bufs.add(in.retain());
// Remove and release the buffer and add its promise to the aggregate.
aggregatePromise.addUnprotectedPromise(pendingUnprotectedWrites.remove());
}
protector.protectFlush(
bufs, b -> ctx.writeAndFlush(b, aggregatePromise.newPromise()), ctx.alloc());
// We're done writing, start the flow of promise events.
@SuppressWarnings("unused") // go/futurereturn-lsc
Future<?> possiblyIgnoredError = aggregatePromise.doneAllocatingPromises();
}
}

View File

@ -0,0 +1,216 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.Internal;
import io.grpc.alts.transportsecurity.TsiFrameProtector;
import io.grpc.alts.transportsecurity.TsiPeer;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.util.ReferenceCountUtil;
import java.security.GeneralSecurityException;
import java.util.List;
import java.util.concurrent.Future;
import javax.annotation.Nullable;
/**
* Performs The TSI Handshake. When the handshake is complete, it fires a user event with a {@link
* TsiHandshakeCompletionEvent} indicating the result of the handshake.
*/
@Internal
public final class InternalTsiHandshakeHandler extends ByteToMessageDecoder {
private static final int HANDSHAKE_FRAME_SIZE = 1024;
private final InternalNettyTsiHandshaker handshaker;
private boolean started;
/**
* This buffer doesn't store any state. We just hold onto it in case we end up allocating a buffer
* that ends up being unused.
*/
private ByteBuf buffer;
public InternalTsiHandshakeHandler(InternalNettyTsiHandshaker handshaker) {
this.handshaker = checkNotNull(handshaker);
}
/**
* Event that is fired once the TSI handshake is complete, which may be because it was successful
* or there was an error.
*/
public static final class TsiHandshakeCompletionEvent {
private final Throwable cause;
private final TsiPeer peer;
private final Object context;
private final TsiFrameProtector protector;
/** Creates a new event that indicates a successful handshake. */
@VisibleForTesting
TsiHandshakeCompletionEvent(
TsiFrameProtector protector, TsiPeer peer, @Nullable Object peerObject) {
this.cause = null;
this.peer = checkNotNull(peer);
this.protector = checkNotNull(protector);
this.context = peerObject;
}
/** Creates a new event that indicates an unsuccessful handshake/. */
TsiHandshakeCompletionEvent(Throwable cause) {
this.cause = checkNotNull(cause);
this.peer = null;
this.protector = null;
this.context = null;
}
/** Return {@code true} if the handshake was successful. */
public boolean isSuccess() {
return cause == null;
}
/**
* Return the {@link Throwable} if {@link #isSuccess()} returns {@code false} and so the
* handshake failed.
*/
@Nullable
public Throwable cause() {
return cause;
}
@Nullable
public TsiPeer peer() {
return peer;
}
@Nullable
public Object context() {
return context;
}
@Nullable
TsiFrameProtector protector() {
return protector;
}
}
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
maybeStart(ctx);
super.handlerAdded(ctx);
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
maybeStart(ctx);
super.channelActive(ctx);
}
@Override
public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
close();
super.handlerRemoved0(ctx);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
ctx.fireUserEventTriggered(new TsiHandshakeCompletionEvent(cause));
super.exceptionCaught(ctx, cause);
}
@Override
protected void decodeLast(ChannelHandlerContext ctx, ByteBuf in, List<Object> out)
throws Exception {
// TODO: Not sure why override is needed. Investigate if it can be removed.
decode(ctx, in, out);
}
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
// Process the data. If we need to send more data, do so now.
if (handshaker.processBytesFromPeer(in) && handshaker.isInProgress()) {
sendHandshake(ctx);
}
// If the handshake is complete, transition to the framing state.
if (!handshaker.isInProgress()) {
try {
ctx.pipeline().remove(this);
ctx.fireUserEventTriggered(
new TsiHandshakeCompletionEvent(
handshaker.createFrameProtector(ctx.alloc()),
handshaker.extractPeer(),
handshaker.extractPeerObject()));
// No need to do anything with the in buffer, it will be re added to the pipeline when this
// handler is removed.
} finally {
close();
}
}
}
private void maybeStart(ChannelHandlerContext ctx) {
if (!started && ctx.channel().isActive()) {
started = true;
sendHandshake(ctx);
}
}
/** Sends as many bytes as are available from the handshaker to the remote peer. */
private void sendHandshake(ChannelHandlerContext ctx) {
boolean needToFlush = false;
// Iterate until there is nothing left to write.
while (true) {
buffer = getOrCreateBuffer(ctx.alloc());
try {
handshaker.getBytesToSendToPeer(buffer);
} catch (GeneralSecurityException e) {
throw new RuntimeException(e);
}
if (!buffer.isReadable()) {
break;
}
needToFlush = true;
@SuppressWarnings("unused") // go/futurereturn-lsc
Future<?> possiblyIgnoredError = ctx.write(buffer);
buffer = null;
}
// If something was written, flush.
if (needToFlush) {
ctx.flush();
}
}
private ByteBuf getOrCreateBuffer(ByteBufAllocator alloc) {
if (buffer == null) {
buffer = alloc.buffer(HANDSHAKE_FRAME_SIZE);
}
return buffer;
}
private void close() {
ReferenceCountUtil.safeRelease(buffer);
buffer = null;
}
}

View File

@ -0,0 +1,149 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts;
import static com.google.common.base.Preconditions.checkState;
import io.netty.channel.Channel;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.util.concurrent.EventExecutor;
import java.util.ArrayList;
import java.util.List;
/**
* Promise used when flushing the {@code pendingUnprotectedWrites} queue. It manages the many-to
* many relationship between pending unprotected messages and the individual writes. Each protected
* frame will be written using the same instance of this promise and it will accumulate the results.
* Once all frames have been successfully written (or any failed), all of the promises for the
* pending unprotected writes are notified.
*
* <p>NOTE: this code is based on code in Netty's {@code Http2CodecUtil}.
*/
final class ProtectedPromise extends DefaultChannelPromise {
private final List<ChannelPromise> unprotectedPromises;
private int expectedCount;
private int successfulCount;
private int failureCount;
private boolean doneAllocating;
ProtectedPromise(Channel channel, EventExecutor executor, int numUnprotectedPromises) {
super(channel, executor);
unprotectedPromises = new ArrayList<>(numUnprotectedPromises);
}
/**
* Adds a promise for a pending unprotected write. This will be notified after all of the writes
* complete.
*/
void addUnprotectedPromise(ChannelPromise promise) {
unprotectedPromises.add(promise);
}
/**
* Allocate a new promise for the write of a protected frame. This will be used to aggregate the
* overall success of the unprotected promises.
*
* @return {@code this} promise.
*/
ChannelPromise newPromise() {
checkState(!doneAllocating, "Done allocating. No more promises can be allocated.");
expectedCount++;
return this;
}
/**
* Signify that no more {@link #newPromise()} allocations will be made. The aggregation can not be
* successful until this method is called.
*
* @return {@code this} promise.
*/
ChannelPromise doneAllocatingPromises() {
if (!doneAllocating) {
doneAllocating = true;
if (successfulCount == expectedCount) {
trySuccessInternal(null);
return super.setSuccess(null);
}
}
return this;
}
@Override
public boolean tryFailure(Throwable cause) {
if (awaitingPromises()) {
++failureCount;
if (failureCount == 1) {
tryFailureInternal(cause);
return super.tryFailure(cause);
}
// TODO: We break the interface a bit here.
// Multiple failure events can be processed without issue because this is an aggregation.
return true;
}
return false;
}
/**
* Fail this object if it has not already been failed.
*
* <p>This method will NOT throw an {@link IllegalStateException} if called multiple times because
* that may be expected.
*/
@Override
public ChannelPromise setFailure(Throwable cause) {
tryFailure(cause);
return this;
}
private boolean awaitingPromises() {
return successfulCount + failureCount < expectedCount;
}
@Override
public ChannelPromise setSuccess(Void result) {
trySuccess(result);
return this;
}
@Override
public boolean trySuccess(Void result) {
if (awaitingPromises()) {
++successfulCount;
if (successfulCount == expectedCount && doneAllocating) {
trySuccessInternal(result);
return super.trySuccess(result);
}
// TODO: We break the interface a bit here.
// Multiple success events can be processed without issue because this is an aggregation.
return true;
}
return false;
}
private void trySuccessInternal(Void result) {
for (int i = 0; i < unprotectedPromises.size(); ++i) {
unprotectedPromises.get(i).trySuccess(result);
}
}
private void tryFailureInternal(Throwable cause) {
for (int i = 0; i < unprotectedPromises.size(); ++i) {
unprotectedPromises.get(i).tryFailure(cause);
}
}
}

View File

@ -0,0 +1,130 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.alts.TransportSecurityCommon.RpcProtocolVersions;
import io.grpc.alts.TransportSecurityCommon.RpcProtocolVersions.Version;
import javax.annotation.Nullable;
/** Utility class for Rpc Protocol Versions. */
final class RpcProtocolVersionsUtil {
private static final int MAX_RPC_VERSION_MAJOR = 2;
private static final int MAX_RPC_VERSION_MINOR = 1;
private static final int MIN_RPC_VERSION_MAJOR = 2;
private static final int MIN_RPC_VERSION_MINOR = 1;
private static final RpcProtocolVersions RPC_PROTOCOL_VERSIONS =
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(
RpcProtocolVersions.Version.newBuilder()
.setMajor(MAX_RPC_VERSION_MAJOR)
.setMinor(MAX_RPC_VERSION_MINOR)
.build())
.setMinRpcVersion(
RpcProtocolVersions.Version.newBuilder()
.setMajor(MIN_RPC_VERSION_MAJOR)
.setMinor(MIN_RPC_VERSION_MINOR)
.build())
.build();
/** Returns default Rpc Protocol Versions. */
static RpcProtocolVersions getRpcProtocolVersions() {
return RPC_PROTOCOL_VERSIONS;
}
/**
* Returns true if first Rpc Protocol Version is greater than or equal to the second one. Returns
* false otherwise.
*/
@VisibleForTesting
static boolean isGreaterThanOrEqualTo(Version first, Version second) {
if ((first.getMajor() > second.getMajor())
|| (first.getMajor() == second.getMajor() && first.getMinor() >= second.getMinor())) {
return true;
}
return false;
}
/**
* Performs check between local and peer Rpc Protocol Versions. This function returns true and the
* highest common version if there exists a common Rpc Protocol Version to use, and returns false
* and null otherwise.
*/
static RpcVersionsCheckResult checkRpcProtocolVersions(
RpcProtocolVersions localVersions, RpcProtocolVersions peerVersions) {
Version maxCommonVersion;
Version minCommonVersion;
// maxCommonVersion is MIN(local.max, peer.max)
if (isGreaterThanOrEqualTo(localVersions.getMaxRpcVersion(), peerVersions.getMaxRpcVersion())) {
maxCommonVersion = peerVersions.getMaxRpcVersion();
} else {
maxCommonVersion = localVersions.getMaxRpcVersion();
}
// minCommonVersion is MAX(local.min, peer.min)
if (isGreaterThanOrEqualTo(localVersions.getMinRpcVersion(), peerVersions.getMinRpcVersion())) {
minCommonVersion = localVersions.getMinRpcVersion();
} else {
minCommonVersion = peerVersions.getMinRpcVersion();
}
if (isGreaterThanOrEqualTo(maxCommonVersion, minCommonVersion)) {
return new RpcVersionsCheckResult.Builder()
.setResult(true)
.setHighestCommonVersion(maxCommonVersion)
.build();
}
return new RpcVersionsCheckResult.Builder().setResult(false).build();
}
/** Wrapper class that stores results of Rpc Protocol Versions check. */
static final class RpcVersionsCheckResult {
private final boolean result;
@Nullable private final Version highestCommonVersion;
private RpcVersionsCheckResult(Builder builder) {
result = builder.result;
highestCommonVersion = builder.highestCommonVersion;
}
boolean getResult() {
return result;
}
Version getHighestCommonVersion() {
return highestCommonVersion;
}
static final class Builder {
private boolean result;
@Nullable private Version highestCommonVersion = null;
public Builder setResult(boolean result) {
this.result = result;
return this;
}
public Builder setHighestCommonVersion(Version highestCommonVersion) {
this.highestCommonVersion = highestCommonVersion;
return this;
}
public RpcVersionsCheckResult build() {
return new RpcVersionsCheckResult(this);
}
}
}
}

View File

@ -0,0 +1,76 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
/**
* {@code AeadCrypter} performs authenticated encryption and decryption for a fixed key given unique
* nonces. Authenticated additional data is supported.
*/
interface AeadCrypter {
/**
* Encrypt plaintext into ciphertext buffer using the given nonce.
*
* @param ciphertext the encrypted plaintext and the tag will be written into this buffer.
* @param plaintext the input that should be encrypted.
* @param nonce the unique nonce used for the encryption.
* @throws GeneralSecurityException if ciphertext buffer is short or the nonce does not have the
* expected size.
*/
void encrypt(ByteBuffer ciphertext, ByteBuffer plaintext, byte[] nonce)
throws GeneralSecurityException;
/**
* Encrypt plaintext into ciphertext buffer using the given nonce with authenticated data.
*
* @param ciphertext the encrypted plaintext and the tag will be written into this buffer.
* @param plaintext the input that should be encrypted.
* @param aad additional data that should be authenticated, but not encrypted.
* @param nonce the unique nonce used for the encryption.
* @throws GeneralSecurityException if ciphertext buffer is short or the nonce does not have the
* expected size.
*/
void encrypt(ByteBuffer ciphertext, ByteBuffer plaintext, ByteBuffer aad, byte[] nonce)
throws GeneralSecurityException;
/**
* Decrypt ciphertext into plaintext buffer using the given nonce.
*
* @param plaintext the decrypted plaintext will be written into this buffer.
* @param ciphertext the ciphertext and tag that should be decrypted.
* @param nonce the nonce that was used for the encryption.
* @throws GeneralSecurityException if the tag is invalid or any of the inputs do not have the
* expected size.
*/
void decrypt(ByteBuffer plaintext, ByteBuffer ciphertext, byte[] nonce)
throws GeneralSecurityException;
/**
* Decrypt ciphertext into plaintext buffer using the given nonce.
*
* @param plaintext the decrypted plaintext will be written into this buffer.
* @param ciphertext the ciphertext and tag that should be decrypted.
* @param aad additional data that is checked for authenticity.
* @param nonce the nonce that was used for the encryption.
* @throws GeneralSecurityException if the tag is invalid or any of the inputs do not have the
* expected size.
*/
void decrypt(ByteBuffer plaintext, ByteBuffer ciphertext, ByteBuffer aad, byte[] nonce)
throws GeneralSecurityException;
}

View File

@ -0,0 +1,101 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import static com.google.common.base.Preconditions.checkArgument;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import javax.annotation.Nullable;
import javax.crypto.Cipher;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.SecretKeySpec;
/** AES128-GCM implementation of {@link AeadCrypter} that uses default JCE provider. */
final class AesGcmAeadCrypter implements AeadCrypter {
private static final int KEY_LENGTH = 16;
private static final int TAG_LENGTH = 16;
static final int NONCE_LENGTH = 12;
private static final String AES = "AES";
private static final String AES_GCM = AES + "/GCM/NoPadding";
private final byte[] key;
private final Cipher cipher;
AesGcmAeadCrypter(byte[] key) throws GeneralSecurityException {
checkArgument(key.length == KEY_LENGTH);
this.key = key;
cipher = Cipher.getInstance(AES_GCM);
}
private int encryptAad(
ByteBuffer ciphertext, ByteBuffer plaintext, @Nullable ByteBuffer aad, byte[] nonce)
throws GeneralSecurityException {
checkArgument(nonce.length == NONCE_LENGTH);
cipher.init(
Cipher.ENCRYPT_MODE,
new SecretKeySpec(this.key, AES),
new GCMParameterSpec(TAG_LENGTH * 8, nonce));
if (aad != null) {
cipher.updateAAD(aad);
}
return cipher.doFinal(plaintext, ciphertext);
}
private void decryptAad(
ByteBuffer plaintext, ByteBuffer ciphertext, @Nullable ByteBuffer aad, byte[] nonce)
throws GeneralSecurityException {
checkArgument(nonce.length == NONCE_LENGTH);
cipher.init(
Cipher.DECRYPT_MODE,
new SecretKeySpec(this.key, AES),
new GCMParameterSpec(TAG_LENGTH * 8, nonce));
if (aad != null) {
cipher.updateAAD(aad);
}
cipher.doFinal(ciphertext, plaintext);
}
@Override
public void encrypt(ByteBuffer ciphertext, ByteBuffer plaintext, byte[] nonce)
throws GeneralSecurityException {
encryptAad(ciphertext, plaintext, null, nonce);
}
@Override
public void encrypt(ByteBuffer ciphertext, ByteBuffer plaintext, ByteBuffer aad, byte[] nonce)
throws GeneralSecurityException {
encryptAad(ciphertext, plaintext, aad, nonce);
}
@Override
public void decrypt(ByteBuffer plaintext, ByteBuffer ciphertext, byte[] nonce)
throws GeneralSecurityException {
decryptAad(plaintext, ciphertext, null, nonce);
}
@Override
public void decrypt(ByteBuffer plaintext, ByteBuffer ciphertext, ByteBuffer aad, byte[] nonce)
throws GeneralSecurityException {
decryptAad(plaintext, ciphertext, aad, nonce);
}
static int getKeyLength() {
return KEY_LENGTH;
}
}

View File

@ -0,0 +1,127 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import static com.google.common.base.Preconditions.checkArgument;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.Arrays;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
/**
* {@link AeadCrypter} implementation based on {@link AesGcmAeadCrypter} with nonce-based rekeying
* using HKDF-expand and random nonce-mask that is XORed with the given nonce/counter. The AES-GCM
* key is computed as HKDF-expand(kdfKey, nonce[2..7]), i.e., the first 2 bytes are ignored to
* require rekeying only after 2^16 operations and the last 4 bytes (including the direction bit)
* are ignored to allow for optimizations (use same AEAD context for both directions, store counter
* as unsigned long and boolean for direction).
*/
final class AesGcmHkdfAeadCrypter implements AeadCrypter {
private static final int KDF_KEY_LENGTH = 32;
// Rekey after 2^(2*8) = 2^16 operations by ignoring the first 2 nonce bytes for key derivation.
private static final int KDF_COUNTER_OFFSET = 2;
// Use remaining bytes of 64-bit counter included in nonce for key derivation.
private static final int KDF_COUNTER_LENGTH = 6;
private static final int NONCE_LENGTH = AesGcmAeadCrypter.NONCE_LENGTH;
private static final int KEY_LENGTH = KDF_KEY_LENGTH + NONCE_LENGTH;
private final byte[] kdfKey;
private final byte[] kdfCounter = new byte[KDF_COUNTER_LENGTH];
private final byte[] nonceMask;
private final byte[] nonceBuffer = new byte[NONCE_LENGTH];
private AeadCrypter aeadCrypter;
AesGcmHkdfAeadCrypter(byte[] key) {
checkArgument(key.length == KEY_LENGTH);
this.kdfKey = Arrays.copyOf(key, KDF_KEY_LENGTH);
this.nonceMask = Arrays.copyOfRange(key, KDF_KEY_LENGTH, KDF_KEY_LENGTH + NONCE_LENGTH);
}
@Override
public void encrypt(ByteBuffer ciphertext, ByteBuffer plaintext, byte[] nonce)
throws GeneralSecurityException {
maybeRekey(nonce);
maskNonce(nonceBuffer, nonceMask, nonce);
aeadCrypter.encrypt(ciphertext, plaintext, nonceBuffer);
}
@Override
public void encrypt(ByteBuffer ciphertext, ByteBuffer plaintext, ByteBuffer aad, byte[] nonce)
throws GeneralSecurityException {
maybeRekey(nonce);
maskNonce(nonceBuffer, nonceMask, nonce);
aeadCrypter.encrypt(ciphertext, plaintext, aad, nonceBuffer);
}
@Override
public void decrypt(ByteBuffer plaintext, ByteBuffer ciphertext, byte[] nonce)
throws GeneralSecurityException {
maybeRekey(nonce);
maskNonce(nonceBuffer, nonceMask, nonce);
aeadCrypter.decrypt(plaintext, ciphertext, nonceBuffer);
}
@Override
public void decrypt(ByteBuffer plaintext, ByteBuffer ciphertext, ByteBuffer aad, byte[] nonce)
throws GeneralSecurityException {
maybeRekey(nonce);
maskNonce(nonceBuffer, nonceMask, nonce);
aeadCrypter.decrypt(plaintext, ciphertext, aad, nonceBuffer);
}
private void maybeRekey(byte[] nonce) throws GeneralSecurityException {
if (aeadCrypter != null
&& arrayEqualOn(nonce, KDF_COUNTER_OFFSET, kdfCounter, 0, KDF_COUNTER_LENGTH)) {
return;
}
System.arraycopy(nonce, KDF_COUNTER_OFFSET, kdfCounter, 0, KDF_COUNTER_LENGTH);
int aeKeyLen = AesGcmAeadCrypter.getKeyLength();
byte[] aeKey = Arrays.copyOf(hkdfExpandSha256(kdfKey, kdfCounter), aeKeyLen);
aeadCrypter = new AesGcmAeadCrypter(aeKey);
}
private static void maskNonce(byte[] nonceBuffer, byte[] nonceMask, byte[] nonce) {
checkArgument(nonce.length == NONCE_LENGTH);
for (int i = 0; i < NONCE_LENGTH; i++) {
nonceBuffer[i] = (byte) (nonceMask[i] ^ nonce[i]);
}
}
private static byte[] hkdfExpandSha256(byte[] key, byte[] info) throws GeneralSecurityException {
Mac mac = Mac.getInstance("HMACSHA256");
mac.init(new SecretKeySpec(key, mac.getAlgorithm()));
mac.update(info);
mac.update((byte) 0x01);
return mac.doFinal();
}
private static boolean arrayEqualOn(byte[] a, int aPos, byte[] b, int bPos, int length) {
for (int i = 0; i < length; i++) {
if (a[aPos + i] != b[bPos + i]) {
return false;
}
}
return true;
}
static int getKeyLength() {
return KEY_LENGTH;
}
}

View File

@ -0,0 +1,101 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.alts.Altscontext.AltsContext;
import io.grpc.alts.Handshaker.HandshakerResult;
import io.grpc.alts.TransportSecurityCommon.RpcProtocolVersions;
import io.grpc.alts.TransportSecurityCommon.SecurityLevel;
/** AltsAuthContext contains security-related context information about an ALTs connection. */
public final class AltsAuthContext {
final AltsContext context;
/** Create a new AltsAuthContext. */
public AltsAuthContext(HandshakerResult result) {
context =
AltsContext.newBuilder()
.setApplicationProtocol(result.getApplicationProtocol())
.setRecordProtocol(result.getRecordProtocol())
// TODO: Set security level based on the handshaker result.
.setSecurityLevel(SecurityLevel.INTEGRITY_AND_PRIVACY)
.setPeerServiceAccount(result.getPeerIdentity().getServiceAccount())
.setLocalServiceAccount(result.getLocalIdentity().getServiceAccount())
.setPeerRpcVersions(result.getPeerRpcVersions())
.build();
}
@VisibleForTesting
public static AltsAuthContext getDefaultInstance() {
return new AltsAuthContext(HandshakerResult.newBuilder().build());
}
/**
* Get application protocol.
*
* @return the context's application protocol.
*/
public String getApplicationProtocol() {
return context.getApplicationProtocol();
}
/**
* Get negotiated record protocol.
*
* @return the context's negotiated record protocol.
*/
public String getRecordProtocol() {
return context.getRecordProtocol();
}
/**
* Get security level.
*
* @return the context's security level.
*/
public SecurityLevel getSecurityLevel() {
return context.getSecurityLevel();
}
/**
* Get peer service account.
*
* @return the context's peer service account.
*/
public String getPeerServiceAccount() {
return context.getPeerServiceAccount();
}
/**
* Get local service account.
*
* @return the context's local service account.
*/
public String getLocalServiceAccount() {
return context.getLocalServiceAccount();
}
/**
* Get peer RPC versions.
*
* @return the context's peer RPC versions.
*/
public RpcProtocolVersions getPeerRpcVersions() {
return context.getPeerRpcVersions();
}
}

View File

@ -0,0 +1,171 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;
import com.google.common.annotations.VisibleForTesting;
import io.netty.buffer.ByteBuf;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.List;
/** Performs encryption and decryption with AES-GCM using JCE. All methods are thread-compatible. */
final class AltsChannelCrypter implements ChannelCrypterNetty {
private static final int KEY_LENGTH = AesGcmHkdfAeadCrypter.getKeyLength();
private static final int COUNTER_LENGTH = 12;
// The counter will overflow after 2^64 operations and encryption/decryption will stop working.
private static final int COUNTER_OVERFLOW_LENGTH = 8;
private static final int TAG_LENGTH = 16;
private final AeadCrypter aeadCrypter;
private final byte[] outCounter = new byte[COUNTER_LENGTH];
private final byte[] inCounter = new byte[COUNTER_LENGTH];
private final byte[] oldCounter = new byte[COUNTER_LENGTH];
AltsChannelCrypter(byte[] key, boolean isClient) {
checkArgument(key.length == KEY_LENGTH);
byte[] counter = isClient ? inCounter : outCounter;
counter[counter.length - 1] = (byte) 0x80;
this.aeadCrypter = new AesGcmHkdfAeadCrypter(key);
}
static int getKeyLength() {
return KEY_LENGTH;
}
static int getCounterLength() {
return COUNTER_LENGTH;
}
@SuppressWarnings("BetaApi") // verify is stable in Guava
@Override
public void encrypt(ByteBuf outBuf, List<ByteBuf> plainBufs) throws GeneralSecurityException {
checkArgument(outBuf.nioBufferCount() == 1);
// Copy plaintext buffers into outBuf for in-place encryption on single direct buffer.
ByteBuf plainBuf = outBuf.slice(outBuf.writerIndex(), outBuf.writableBytes());
plainBuf.writerIndex(0);
for (ByteBuf inBuf : plainBufs) {
plainBuf.writeBytes(inBuf);
}
verify(outBuf.writableBytes() == plainBuf.readableBytes() + TAG_LENGTH);
ByteBuffer out = outBuf.internalNioBuffer(outBuf.writerIndex(), outBuf.writableBytes());
ByteBuffer plain = out.duplicate();
plain.limit(out.limit() - TAG_LENGTH);
byte[] counter = incrementOutCounter();
int outPosition = out.position();
aeadCrypter.encrypt(out, plain, counter);
int bytesWritten = out.position() - outPosition;
outBuf.writerIndex(outBuf.writerIndex() + bytesWritten);
verify(!outBuf.isWritable());
}
@Override
public void decrypt(ByteBuf out, ByteBuf tag, List<ByteBuf> ciphertextBufs)
throws GeneralSecurityException {
ByteBuf cipherTextAndTag = out.slice(out.writerIndex(), out.writableBytes());
cipherTextAndTag.writerIndex(0);
for (ByteBuf inBuf : ciphertextBufs) {
cipherTextAndTag.writeBytes(inBuf);
}
cipherTextAndTag.writeBytes(tag);
decrypt(out, cipherTextAndTag);
}
@SuppressWarnings("BetaApi") // verify is stable in Guava
@Override
public void decrypt(ByteBuf out, ByteBuf ciphertextAndTag) throws GeneralSecurityException {
int bytesRead = ciphertextAndTag.readableBytes();
checkArgument(bytesRead == out.writableBytes());
checkArgument(out.nioBufferCount() == 1);
ByteBuffer outBuffer = out.internalNioBuffer(out.writerIndex(), out.writableBytes());
checkArgument(ciphertextAndTag.nioBufferCount() == 1);
ByteBuffer ciphertextAndTagBuffer =
ciphertextAndTag.nioBuffer(ciphertextAndTag.readerIndex(), bytesRead);
byte[] counter = incrementInCounter();
int outPosition = outBuffer.position();
aeadCrypter.decrypt(outBuffer, ciphertextAndTagBuffer, counter);
int bytesWritten = outBuffer.position() - outPosition;
out.writerIndex(out.writerIndex() + bytesWritten);
ciphertextAndTag.readerIndex(out.readerIndex() + bytesRead);
verify(out.writableBytes() == TAG_LENGTH);
}
@Override
public int getSuffixLength() {
return TAG_LENGTH;
}
@Override
public void destroy() {
// no destroy required
}
/** Increments {@code counter}, store the unincremented value in {@code oldCounter}. */
static void incrementCounter(byte[] counter, byte[] oldCounter) throws GeneralSecurityException {
System.arraycopy(counter, 0, oldCounter, 0, counter.length);
int i = 0;
for (; i < COUNTER_OVERFLOW_LENGTH; i++) {
counter[i]++;
if (counter[i] != (byte) 0x00) {
break;
}
}
if (i == COUNTER_OVERFLOW_LENGTH) {
// Restore old counter value to ensure that encrypt and decrypt keep failing.
System.arraycopy(oldCounter, 0, counter, 0, counter.length);
throw new GeneralSecurityException("Counter has overflowed.");
}
}
/** Increments the input counter, returning the previous (unincremented) value. */
private byte[] incrementInCounter() throws GeneralSecurityException {
incrementCounter(inCounter, oldCounter);
return oldCounter;
}
/** Increments the output counter, returning the previous (unincremented) value. */
private byte[] incrementOutCounter() throws GeneralSecurityException {
incrementCounter(outCounter, oldCounter);
return oldCounter;
}
@VisibleForTesting
void incrementInCounterForTesting(int n) throws GeneralSecurityException {
for (int i = 0; i < n; i++) {
incrementInCounter();
}
}
@VisibleForTesting
void incrementOutCounterForTesting(int n) throws GeneralSecurityException {
for (int i = 0; i < n; i++) {
incrementOutCounter();
}
}
}

View File

@ -0,0 +1,75 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import io.grpc.alts.TransportSecurityCommon.RpcProtocolVersions;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import javax.annotation.Nullable;
/** Handshaker options for creating ALTS client channel. */
public final class AltsClientOptions extends AltsHandshakerOptions {
// targetName is the server service account name for secure name checking. This field is not yet
// supported.
@Nullable private final String targetName;
// targetServiceAccounts contains a list of expected target service accounts. One of these service
// accounts should match peer service account in the handshaker result. Otherwise, the handshake
// fails.
private final List<String> targetServiceAccounts;
private AltsClientOptions(Builder builder) {
super(builder.rpcProtocolVersions);
targetName = builder.targetName;
targetServiceAccounts =
Collections.unmodifiableList(new ArrayList<String>(builder.targetServiceAccounts));
}
public String getTargetName() {
return targetName;
}
public List<String> getTargetServiceAccounts() {
return targetServiceAccounts;
}
/** Builder for AltsClientOptions. */
public static final class Builder {
@Nullable private String targetName;
@Nullable private RpcProtocolVersions rpcProtocolVersions;
private ArrayList<String> targetServiceAccounts = new ArrayList<String>();
public Builder setTargetName(String targetName) {
this.targetName = targetName;
return this;
}
public Builder setRpcProtocolVersions(RpcProtocolVersions rpcProtocolVersions) {
this.rpcProtocolVersions = rpcProtocolVersions;
return this;
}
public Builder addTargetServiceAccount(String targetServiceAccount) {
targetServiceAccounts.add(targetServiceAccount);
return this;
}
public AltsClientOptions build() {
return new AltsClientOptions(this);
}
}
}

View File

@ -0,0 +1,365 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import com.google.common.base.Preconditions;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.security.GeneralSecurityException;
/** Framing and deframing methods and classes used by handshaker. */
public final class AltsFraming {
// The size of the frame field. Must correspond to the size of int, 4 bytes.
// Left package-private for testing.
private static final int FRAME_LENGTH_HEADER_SIZE = 4;
private static final int FRAME_MESSAGE_TYPE_HEADER_SIZE = 4;
private static final int MAX_DATA_LENGTH = 1024 * 1024;
private static final int INITIAL_BUFFER_CAPACITY = 1024 * 64;
// TODO: Make this the responsibility of the caller.
private static final int MESSAGE_TYPE = 6;
private AltsFraming() {}
static int getFrameLengthHeaderSize() {
return FRAME_LENGTH_HEADER_SIZE;
}
static int getFrameMessageTypeHeaderSize() {
return FRAME_MESSAGE_TYPE_HEADER_SIZE;
}
static int getMaxDataLength() {
return MAX_DATA_LENGTH;
}
static int getFramingOverhead() {
return FRAME_LENGTH_HEADER_SIZE + FRAME_MESSAGE_TYPE_HEADER_SIZE;
}
/**
* Creates a frame of length dataSize + FRAME_HEADER_SIZE using the input bytes, if dataSize <=
* input.remaining(). Otherwise, a frame of length input.remaining() + FRAME_HEADER_SIZE is
* created.
*/
static ByteBuffer toFrame(ByteBuffer input, int dataSize) throws GeneralSecurityException {
Preconditions.checkNotNull(input);
if (dataSize > input.remaining()) {
dataSize = input.remaining();
}
Producer producer = new Producer();
ByteBuffer inputAlias = input.duplicate();
inputAlias.limit(input.position() + dataSize);
producer.readBytes(inputAlias);
producer.flush();
input.position(inputAlias.position());
ByteBuffer output = producer.getRawFrame();
return output;
}
/**
* A helper class to write a frame.
*
* <p>This class guarantees that one of the following is true:
*
* <ul>
* <li>readBytes will read from the input
* <li>writeBytes will write to the output
* </ul>
*
* <p>Sample usage:
*
* <pre>{@code
* Producer producer = new Producer();
* ByteBuffer inputBuffer = readBytesFromMyStream();
* ByteBuffer outputBuffer = writeBytesToMyStream();
* while (inputBuffer.hasRemaining() || outputBuffer.hasRemaining()) {
* producer.readBytes(inputBuffer);
* producer.writeBytes(outputBuffer);
* }
* }</pre>
*
* <p>Alternatively, this class guarantees that one of the following is true:
*
* <ul>
* <li>readBytes will read from the input
* <li>{@code isComplete()} returns true and {@code getByteBuffer()} returns the contents of a
* processed frame.
* </ul>
*
* <p>Sample usage:
*
* <pre>{@code
* Producer producer = new Producer();
* while (!producer.isComplete()) {
* ByteBuffer inputBuffer = readBytesFromMyStream();
* producer.readBytes(inputBuffer);
* }
* producer.flush();
* ByteBuffer outputBuffer = producer.getRawFrame();
* }</pre>
*/
static final class Producer {
private ByteBuffer buffer;
private boolean isComplete;
Producer(int maxFrameSize) {
buffer = ByteBuffer.allocate(maxFrameSize);
reset();
Preconditions.checkArgument(maxFrameSize > getFramePrefixLength() + getFrameSuffixLength());
}
Producer() {
this(INITIAL_BUFFER_CAPACITY);
}
/** The length of the frame prefix data, including the message length/type fields. */
int getFramePrefixLength() {
int result = FRAME_LENGTH_HEADER_SIZE + FRAME_MESSAGE_TYPE_HEADER_SIZE;
return result;
}
int getFrameSuffixLength() {
return 0;
}
/**
* Reads bytes from input, parsing them into a frame. Returns false if and only if more data is
* needed. To obtain a full frame this method must be called repeatedly until it returns true.
*/
boolean readBytes(ByteBuffer input) throws GeneralSecurityException {
Preconditions.checkNotNull(input);
if (isComplete) {
return true;
}
copy(buffer, input);
if (!buffer.hasRemaining()) {
flush();
}
return isComplete;
}
/**
* Completes the current frame, signaling that no further data is available to be passed to
* readBytes and that the client requires writeBytes to start returning data. isComplete() is
* guaranteed to return true after this call.
*/
void flush() throws GeneralSecurityException {
if (isComplete) {
return;
}
// Get the length of the complete frame.
int frameLength = buffer.position() + getFrameSuffixLength();
// Set the limit and move to the start.
buffer.flip();
// Advance the limit to allow a crypto suffix.
buffer.limit(buffer.limit() + getFrameSuffixLength());
// Write the data length and the message type.
int dataLength = frameLength - FRAME_LENGTH_HEADER_SIZE;
buffer.order(ByteOrder.LITTLE_ENDIAN);
buffer.putInt(dataLength);
buffer.putInt(MESSAGE_TYPE);
// Move the position back to 0, the frame is ready.
buffer.position(0);
isComplete = true;
}
/** Resets the state, preparing to construct a new frame. Must be called between frames. */
private void reset() {
buffer.clear();
// Save some space for framing, we'll fill that in later.
buffer.position(getFramePrefixLength());
buffer.limit(buffer.limit() - getFrameSuffixLength());
isComplete = false;
}
/**
* Returns a ByteBuffer containing a complete raw frame, if it's available. Should only be
* called when isComplete() returns true, otherwise null is returned. The returned object
* aliases the internal buffer, that is, it shares memory with the internal buffer. No further
* operations are permitted on this object until the caller has processed the data it needs from
* the returned byte buffer.
*/
ByteBuffer getRawFrame() {
if (!isComplete) {
return null;
}
ByteBuffer result = buffer.duplicate();
reset();
return result;
}
}
/**
* A helper class to read a frame.
*
* <p>This class guarantees that one of the following is true:
*
* <ul>
* <li>readBytes will read from the input
* <li>writeBytes will write to the output
* </ul>
*
* <p>Sample usage:
*
* <pre>{@code
* Parser parser = new Parser();
* ByteBuffer inputBuffer = readBytesFromMyStream();
* ByteBuffer outputBuffer = writeBytesToMyStream();
* while (inputBuffer.hasRemaining() || outputBuffer.hasRemaining()) {
* parser.readBytes(inputBuffer);
* parser.writeBytes(outputBuffer); }
* }</pre>
*
* <p>Alternatively, this class guarantees that one of the following is true:
*
* <ul>
* <li>readBytes will read from the input
* <li>{@code isComplete()} returns true and {@code getByteBuffer()} returns the contents of a
* processed frame.
* </ul>
*
* <p>Sample usage:
*
* <pre>{@code
* Parser parser = new Parser();
* while (!parser.isComplete()) {
* ByteBuffer inputBuffer = readBytesFromMyStream();
* parser.readBytes(inputBuffer);
* }
* ByteBuffer outputBuffer = parser.getRawFrame();
* }</pre>
*/
public static final class Parser {
private ByteBuffer buffer = ByteBuffer.allocate(INITIAL_BUFFER_CAPACITY);
private boolean isComplete = false;
public Parser() {
Preconditions.checkArgument(
INITIAL_BUFFER_CAPACITY > getFramePrefixLength() + getFrameSuffixLength());
}
/**
* Reads bytes from input, parsing them into a frame. Returns false if and only if more data is
* needed. To obtain a full frame this method must be called repeatedly until it returns true.
*/
public boolean readBytes(ByteBuffer input) throws GeneralSecurityException {
Preconditions.checkNotNull(input);
if (isComplete) {
return true;
}
// Read enough bytes to determine the length
while (buffer.position() < FRAME_LENGTH_HEADER_SIZE && input.hasRemaining()) {
buffer.put(input.get());
}
// If we have enough bytes to determine the length, read the length and ensure that our
// internal buffer is large enough.
if (buffer.position() == FRAME_LENGTH_HEADER_SIZE && input.hasRemaining()) {
ByteBuffer bufferAlias = buffer.duplicate();
bufferAlias.flip();
bufferAlias.order(ByteOrder.LITTLE_ENDIAN);
int dataLength = bufferAlias.getInt();
if (dataLength < FRAME_MESSAGE_TYPE_HEADER_SIZE || dataLength > MAX_DATA_LENGTH) {
throw new IllegalArgumentException("Invalid frame length " + dataLength);
}
// Maybe resize the buffer
int frameLength = dataLength + FRAME_LENGTH_HEADER_SIZE;
if (buffer.capacity() < frameLength) {
buffer = ByteBuffer.allocate(frameLength);
buffer.order(ByteOrder.LITTLE_ENDIAN);
buffer.putInt(dataLength);
}
buffer.limit(frameLength);
}
// TODO: Similarly extract and check message type.
// Read the remaining data into the internal buffer.
copy(buffer, input);
if (!buffer.hasRemaining()) {
buffer.flip();
isComplete = true;
}
return isComplete;
}
/** The length of the frame prefix data, including the message length/type fields. */
int getFramePrefixLength() {
int result = FRAME_LENGTH_HEADER_SIZE + FRAME_MESSAGE_TYPE_HEADER_SIZE;
return result;
}
int getFrameSuffixLength() {
return 0;
}
/** Returns true if we've parsed a complete frame. */
public boolean isComplete() {
return isComplete;
}
/** Resets the state, preparing to parse a new frame. Must be called between frames. */
private void reset() {
buffer.clear();
isComplete = false;
}
/**
* Returns a ByteBuffer containing a complete raw frame, if it's available. Should only be
* called when isComplete() returns true, otherwise null is returned. The returned object
* aliases the internal buffer, that is, it shares memory with the internal buffer. No further
* operations are permitted on this object until the caller has processed the data it needs from
* the returned byte buffer.
*/
public ByteBuffer getRawFrame() {
if (!isComplete) {
return null;
}
ByteBuffer result = buffer.duplicate();
reset();
return result;
}
}
/**
* Copy as much as possible to dst from src. Unlike {@link ByteBuffer#put(ByteBuffer)}, this stops
* early if there is no room left in dst.
*/
private static void copy(ByteBuffer dst, ByteBuffer src) {
if (dst.hasRemaining() && src.hasRemaining()) {
// Avoid an allocation if possible.
if (dst.remaining() >= src.remaining()) {
dst.put(src);
} else {
int count = Math.min(dst.remaining(), src.remaining());
ByteBuffer slice = src.slice();
slice.limit(count);
dst.put(slice);
src.position(src.position() + count);
}
}
}
}

View File

@ -0,0 +1,245 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.protobuf.ByteString;
import io.grpc.Status;
import io.grpc.alts.Handshaker.HandshakeProtocol;
import io.grpc.alts.Handshaker.HandshakerReq;
import io.grpc.alts.Handshaker.HandshakerResp;
import io.grpc.alts.Handshaker.HandshakerResult;
import io.grpc.alts.Handshaker.HandshakerStatus;
import io.grpc.alts.Handshaker.NextHandshakeMessageReq;
import io.grpc.alts.Handshaker.ServerHandshakeParameters;
import io.grpc.alts.Handshaker.StartClientHandshakeReq;
import io.grpc.alts.Handshaker.StartServerHandshakeReq;
import io.grpc.alts.HandshakerServiceGrpc.HandshakerServiceStub;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.logging.Level;
import java.util.logging.Logger;
/** An API for conducting handshakes via ALTS handshaker service. */
class AltsHandshakerClient {
private static final Logger logger = Logger.getLogger(AltsHandshakerClient.class.getName());
private static final String APPLICATION_PROTOCOL = "grpc";
private static final String RECORD_PROTOCOL = "ALTSRP_GCM_AES128_REKEY";
private static final int KEY_LENGTH = AltsChannelCrypter.getKeyLength();
private final AltsHandshakerStub handshakerStub;
private final AltsHandshakerOptions handshakerOptions;
private HandshakerResult result;
private HandshakerStatus status;
/** Starts a new handshake interacting with the handshaker service. */
AltsHandshakerClient(HandshakerServiceStub stub, AltsHandshakerOptions options) {
handshakerStub = new AltsHandshakerStub(stub);
handshakerOptions = options;
}
@VisibleForTesting
AltsHandshakerClient(AltsHandshakerStub handshakerStub, AltsHandshakerOptions options) {
this.handshakerStub = handshakerStub;
handshakerOptions = options;
}
static String getApplicationProtocol() {
return APPLICATION_PROTOCOL;
}
static String getRecordProtocol() {
return RECORD_PROTOCOL;
}
/** Sets the start client fields for the passed handshake request. */
private void setStartClientFields(HandshakerReq.Builder req) {
// Sets the default values.
StartClientHandshakeReq.Builder startClientReq =
StartClientHandshakeReq.newBuilder()
.setHandshakeSecurityProtocol(HandshakeProtocol.ALTS)
.addApplicationProtocols(APPLICATION_PROTOCOL)
.addRecordProtocols(RECORD_PROTOCOL);
// Sets handshaker options.
if (handshakerOptions.getRpcProtocolVersions() != null) {
startClientReq.setRpcVersions(handshakerOptions.getRpcProtocolVersions());
}
if (handshakerOptions instanceof AltsClientOptions) {
AltsClientOptions clientOptions = (AltsClientOptions) handshakerOptions;
if (!Strings.isNullOrEmpty(clientOptions.getTargetName())) {
startClientReq.setTargetName(clientOptions.getTargetName());
}
for (String serviceAccount : clientOptions.getTargetServiceAccounts()) {
startClientReq.addTargetIdentitiesBuilder().setServiceAccount(serviceAccount);
}
}
req.setClientStart(startClientReq);
}
/** Sets the start server fields for the passed handshake request. */
private void setStartServerFields(HandshakerReq.Builder req, ByteBuffer inBytes) {
ServerHandshakeParameters serverParameters =
ServerHandshakeParameters.newBuilder().addRecordProtocols(RECORD_PROTOCOL).build();
StartServerHandshakeReq.Builder startServerReq =
StartServerHandshakeReq.newBuilder()
.addApplicationProtocols(APPLICATION_PROTOCOL)
.putHandshakeParameters(HandshakeProtocol.ALTS.getNumber(), serverParameters)
.setInBytes(ByteString.copyFrom(inBytes.duplicate()));
if (handshakerOptions.getRpcProtocolVersions() != null) {
startServerReq.setRpcVersions(handshakerOptions.getRpcProtocolVersions());
}
req.setServerStart(startServerReq);
}
/** Returns true if the handshake is complete. */
public boolean isFinished() {
// If we have a HandshakeResult, we are done.
if (result != null) {
return true;
}
// If we have an error status, we are done.
if (status != null && status.getCode() != Status.Code.OK.value()) {
return true;
}
return false;
}
/** Returns the handshake status. */
public HandshakerStatus getStatus() {
return status;
}
/** Returns the result data of the handshake, if the handshake is completed. */
public HandshakerResult getResult() {
return result;
}
/**
* Returns the resulting key of the handshake, if the handshake is completed. Note that the key
* data returned from the handshake may be more than the key length required for the record
* protocol, thus we need to truncate to the right size.
*/
public byte[] getKey() {
if (result == null) {
return null;
}
if (result.getKeyData().size() < KEY_LENGTH) {
throw new IllegalStateException("Could not get enough key data from the handshake.");
}
byte[] key = new byte[KEY_LENGTH];
result.getKeyData().copyTo(key, 0, 0, KEY_LENGTH);
return key;
}
/**
* Parses a handshake response, setting the status, result, and closing the handshaker, as needed.
*/
private void handleResponse(HandshakerResp resp) throws GeneralSecurityException {
status = resp.getStatus();
if (resp.hasResult()) {
result = resp.getResult();
close();
}
if (status.getCode() != Status.Code.OK.value()) {
String error = "Handshaker service error: " + status.getDetails();
logger.log(Level.INFO, error);
close();
throw new GeneralSecurityException(error);
}
}
/**
* Starts a client handshake. A GeneralSecurityException is thrown if the handshaker service is
* interrupted or fails. Note that isFinished() must be false before this function is called.
*
* @return the frame to give to the peer.
* @throws GeneralSecurityException or IllegalStateException
*/
public ByteBuffer startClientHandshake() throws GeneralSecurityException {
Preconditions.checkState(!isFinished(), "Handshake has already finished.");
HandshakerReq.Builder req = HandshakerReq.newBuilder();
setStartClientFields(req);
HandshakerResp resp;
try {
resp = handshakerStub.send(req.build());
} catch (IOException | InterruptedException e) {
throw new GeneralSecurityException(e);
}
handleResponse(resp);
return resp.getOutFrames().asReadOnlyByteBuffer();
}
/**
* Starts a server handshake. A GeneralSecurityException is thrown if the handshaker service is
* interrupted or fails. Note that isFinished() must be false before this function is called.
*
* @param inBytes the bytes received from the peer.
* @return the frame to give to the peer.
* @throws GeneralSecurityException or IllegalStateException
*/
public ByteBuffer startServerHandshake(ByteBuffer inBytes) throws GeneralSecurityException {
Preconditions.checkState(!isFinished(), "Handshake has already finished.");
HandshakerReq.Builder req = HandshakerReq.newBuilder();
setStartServerFields(req, inBytes);
HandshakerResp resp;
try {
resp = handshakerStub.send(req.build());
} catch (IOException | InterruptedException e) {
throw new GeneralSecurityException(e);
}
handleResponse(resp);
inBytes.position(inBytes.position() + resp.getBytesConsumed());
return resp.getOutFrames().asReadOnlyByteBuffer();
}
/**
* Processes the next bytes in a handshake. A GeneralSecurityException is thrown if the handshaker
* service is interrupted or fails. Note that isFinished() must be false before this function is
* called.
*
* @param inBytes the bytes received from the peer.
* @return the frame to give to the peer.
* @throws GeneralSecurityException or IllegalStateException
*/
public ByteBuffer next(ByteBuffer inBytes) throws GeneralSecurityException {
Preconditions.checkState(!isFinished(), "Handshake has already finished.");
HandshakerReq.Builder req =
HandshakerReq.newBuilder()
.setNext(
NextHandshakeMessageReq.newBuilder()
.setInBytes(ByteString.copyFrom(inBytes.duplicate()))
.build());
HandshakerResp resp;
try {
resp = handshakerStub.send(req.build());
} catch (IOException | InterruptedException e) {
throw new GeneralSecurityException(e);
}
handleResponse(resp);
inBytes.position(inBytes.position() + resp.getBytesConsumed());
return resp.getOutFrames().asReadOnlyByteBuffer();
}
/** Closes the connection. */
public void close() {
handshakerStub.close();
}
}

View File

@ -0,0 +1,33 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import io.grpc.alts.TransportSecurityCommon.RpcProtocolVersions;
import javax.annotation.Nullable;
/** Handshaker options for creating ALTS channel. */
public class AltsHandshakerOptions {
@Nullable private final RpcProtocolVersions rpcProtocolVersions;
public AltsHandshakerOptions(RpcProtocolVersions rpcProtocolVersions) {
this.rpcProtocolVersions = rpcProtocolVersions;
}
public RpcProtocolVersions getRpcProtocolVersions() {
return rpcProtocolVersions;
}
}

View File

@ -0,0 +1,114 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Optional;
import io.grpc.alts.Handshaker.HandshakerReq;
import io.grpc.alts.Handshaker.HandshakerResp;
import io.grpc.alts.HandshakerServiceGrpc.HandshakerServiceStub;
import io.grpc.stub.StreamObserver;
import java.io.IOException;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.atomic.AtomicReference;
/** An interface to the ALTS handshaker service. */
class AltsHandshakerStub {
private final StreamObserver<HandshakerResp> reader = new Reader();
private final StreamObserver<HandshakerReq> writer;
private final ArrayBlockingQueue<Optional<HandshakerResp>> responseQueue =
new ArrayBlockingQueue<Optional<HandshakerResp>>(1);
private final AtomicReference<String> exceptionMessage = new AtomicReference<>();
AltsHandshakerStub(HandshakerServiceStub serviceStub) {
this.writer = serviceStub.doHandshake(this.reader);
}
@VisibleForTesting
AltsHandshakerStub() {
writer = null;
}
@VisibleForTesting
AltsHandshakerStub(StreamObserver<HandshakerReq> writer) {
this.writer = writer;
}
@VisibleForTesting
StreamObserver<HandshakerResp> getReaderForTest() {
return reader;
}
/** Send a handshaker request and return the handshaker response. */
public HandshakerResp send(HandshakerReq req) throws InterruptedException, IOException {
maybeThrowIoException();
if (!responseQueue.isEmpty()) {
throw new IOException("Received an unexpected response.");
}
writer.onNext(req);
Optional<HandshakerResp> result = responseQueue.take();
if (!result.isPresent()) {
maybeThrowIoException();
}
return result.get();
}
/** Throw exception if there is an outstanding exception. */
private void maybeThrowIoException() throws IOException {
if (exceptionMessage.get() != null) {
throw new IOException(exceptionMessage.get());
}
}
/** Close the connection. */
public void close() {
writer.onCompleted();
}
private class Reader implements StreamObserver<HandshakerResp> {
/** Receive a handshaker response from the server. */
@Override
public void onNext(HandshakerResp resp) {
try {
AltsHandshakerStub.this.responseQueue.add(Optional.of(resp));
} catch (IllegalStateException e) {
AltsHandshakerStub.this.exceptionMessage.compareAndSet(
null, "Received an unexpected response.");
AltsHandshakerStub.this.close();
}
}
/** Receive an error from the server. */
@Override
public void onError(Throwable t) {
AltsHandshakerStub.this.exceptionMessage.compareAndSet(
null, "Received a terminating error: " + t.toString());
// Trigger the release of any blocked send.
Optional<HandshakerResp> result = Optional.absent();
AltsHandshakerStub.this.responseQueue.offer(result);
}
/** Receive the closing message from the server. */
@Override
public void onCompleted() {
AltsHandshakerStub.this.exceptionMessage.compareAndSet(null, "Response stream closed.");
// Trigger the release of any blocked send.
Optional<HandshakerResp> result = Optional.absent();
AltsHandshakerStub.this.responseQueue.offer(result);
}
}
}

View File

@ -0,0 +1,404 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
/** Frame protector that uses the ALTS framing. */
public final class AltsTsiFrameProtector implements TsiFrameProtector {
private static final int HEADER_LEN_FIELD_BYTES = 4;
private static final int HEADER_TYPE_FIELD_BYTES = 4;
private static final int HEADER_BYTES = HEADER_LEN_FIELD_BYTES + HEADER_TYPE_FIELD_BYTES;
private static final int HEADER_TYPE_DEFAULT = 6;
// Total frame size including full header and tag.
private static final int MAX_ALLOWED_FRAME_BYTES = 16 * 1024;
private static final int LIMIT_MAX_ALLOWED_FRAME_BYTES = 1024 * 1024;
private final Protector protector;
private final Unprotector unprotector;
/** Create a new AltsTsiFrameProtector. */
public AltsTsiFrameProtector(
int maxProtectedFrameBytes, ChannelCrypterNetty crypter, ByteBufAllocator alloc) {
checkArgument(maxProtectedFrameBytes > HEADER_BYTES + crypter.getSuffixLength());
maxProtectedFrameBytes = Math.min(LIMIT_MAX_ALLOWED_FRAME_BYTES, maxProtectedFrameBytes);
protector = new Protector(maxProtectedFrameBytes, crypter);
unprotector = new Unprotector(crypter, alloc);
}
static int getHeaderLenFieldBytes() {
return HEADER_LEN_FIELD_BYTES;
}
static int getHeaderTypeFieldBytes() {
return HEADER_TYPE_FIELD_BYTES;
}
public static int getHeaderBytes() {
return HEADER_BYTES;
}
static int getHeaderTypeDefault() {
return HEADER_TYPE_DEFAULT;
}
public static int getMaxAllowedFrameBytes() {
return MAX_ALLOWED_FRAME_BYTES;
}
static int getLimitMaxAllowedFrameBytes() {
return LIMIT_MAX_ALLOWED_FRAME_BYTES;
}
@Override
public void protectFlush(
List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc)
throws GeneralSecurityException {
protector.protectFlush(unprotectedBufs, ctxWrite, alloc);
}
@Override
public void unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)
throws GeneralSecurityException {
unprotector.unprotect(in, out, alloc);
}
@Override
public void destroy() {
try {
unprotector.destroy();
} finally {
protector.destroy();
}
}
static final class Protector {
private final int maxUnprotectedBytesPerFrame;
private final int suffixBytes;
private ChannelCrypterNetty crypter;
Protector(int maxProtectedFrameBytes, ChannelCrypterNetty crypter) {
this.suffixBytes = crypter.getSuffixLength();
this.maxUnprotectedBytesPerFrame = maxProtectedFrameBytes - HEADER_BYTES - suffixBytes;
this.crypter = crypter;
}
void destroy() {
// Shared with Unprotector and destroyed there.
crypter = null;
}
void protectFlush(
List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc)
throws GeneralSecurityException {
checkState(crypter != null, "Cannot protectFlush after destroy.");
ByteBuf protectedBuf;
try {
protectedBuf = handleUnprotected(unprotectedBufs, alloc);
} finally {
for (ByteBuf buf : unprotectedBufs) {
buf.release();
}
}
if (protectedBuf != null) {
ctxWrite.accept(protectedBuf);
}
}
@SuppressWarnings("BetaApi") // verify is stable in Guava
private ByteBuf handleUnprotected(List<ByteBuf> unprotectedBufs, ByteBufAllocator alloc)
throws GeneralSecurityException {
long unprotectedBytes = 0;
for (ByteBuf buf : unprotectedBufs) {
unprotectedBytes += buf.readableBytes();
}
// Empty plaintext not allowed since this should be handled as no-op in layer above.
checkArgument(unprotectedBytes > 0);
// Compute number of frames and allocate a single buffer for all frames.
long frameNum = unprotectedBytes / maxUnprotectedBytesPerFrame + 1;
int lastFrameUnprotectedBytes = (int) (unprotectedBytes % maxUnprotectedBytesPerFrame);
if (lastFrameUnprotectedBytes == 0) {
frameNum--;
lastFrameUnprotectedBytes = maxUnprotectedBytesPerFrame;
}
long protectedBytes = frameNum * (HEADER_BYTES + suffixBytes) + unprotectedBytes;
ByteBuf protectedBuf = alloc.directBuffer(Math.toIntExact(protectedBytes));
try {
int bufferIdx = 0;
for (int frameIdx = 0; frameIdx < frameNum; ++frameIdx) {
int unprotectedBytesLeft =
(frameIdx == frameNum - 1) ? lastFrameUnprotectedBytes : maxUnprotectedBytesPerFrame;
// Write header (at most LIMIT_MAX_ALLOWED_FRAME_BYTES).
protectedBuf.writeIntLE(unprotectedBytesLeft + HEADER_TYPE_FIELD_BYTES + suffixBytes);
protectedBuf.writeIntLE(HEADER_TYPE_DEFAULT);
// Ownership of the backing buffer remains with protectedBuf.
ByteBuf frameOut = writeSlice(protectedBuf, unprotectedBytesLeft + suffixBytes);
List<ByteBuf> framePlain = new ArrayList<>();
while (unprotectedBytesLeft > 0) {
// Ownership of the buffer backing in remains with unprotectedBufs.
ByteBuf in = unprotectedBufs.get(bufferIdx);
if (in.readableBytes() <= unprotectedBytesLeft) {
// The complete buffer belongs to this frame.
framePlain.add(in);
unprotectedBytesLeft -= in.readableBytes();
bufferIdx++;
} else {
// The remainder of in will be part of the next frame.
framePlain.add(in.readSlice(unprotectedBytesLeft));
unprotectedBytesLeft = 0;
}
}
crypter.encrypt(frameOut, framePlain);
verify(!frameOut.isWritable());
}
protectedBuf.readerIndex(0);
protectedBuf.writerIndex(protectedBuf.capacity());
return protectedBuf.retain();
} finally {
protectedBuf.release();
}
}
}
static final class Unprotector {
private final int suffixBytes;
private final ChannelCrypterNetty crypter;
private DeframerState state = DeframerState.READ_HEADER;
private int requiredProtectedBytes;
private ByteBuf header;
private ByteBuf firstFrameTag;
private int unhandledIdx = 0;
private long unhandledBytes = 0;
private List<ByteBuf> unhandledBufs = new ArrayList<>(16);
Unprotector(ChannelCrypterNetty crypter, ByteBufAllocator alloc) {
this.crypter = crypter;
this.suffixBytes = crypter.getSuffixLength();
this.header = alloc.directBuffer(HEADER_BYTES);
this.firstFrameTag = alloc.directBuffer(suffixBytes);
}
private void addUnhandled(ByteBuf in) {
if (in.isReadable()) {
ByteBuf buf = in.readRetainedSlice(in.readableBytes());
unhandledBufs.add(buf);
unhandledBytes += buf.readableBytes();
}
}
void unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)
throws GeneralSecurityException {
checkState(header != null, "Cannot unprotect after destroy.");
addUnhandled(in);
decodeFrame(alloc, out);
}
@SuppressWarnings("fallthrough")
private void decodeFrame(ByteBufAllocator alloc, List<Object> out)
throws GeneralSecurityException {
switch (state) {
case READ_HEADER:
if (unhandledBytes < HEADER_BYTES) {
return;
}
handleHeader();
// fall through
case READ_PROTECTED_PAYLOAD:
if (unhandledBytes < requiredProtectedBytes) {
return;
}
ByteBuf unprotectedBuf;
try {
unprotectedBuf = handlePayload(alloc);
} finally {
clearState();
}
if (unprotectedBuf != null) {
out.add(unprotectedBuf);
}
break;
default:
throw new AssertionError("impossible enum value");
}
}
private void handleHeader() {
while (header.isWritable()) {
ByteBuf in = unhandledBufs.get(unhandledIdx);
int headerBytesToRead = Math.min(in.readableBytes(), header.writableBytes());
header.writeBytes(in, headerBytesToRead);
unhandledBytes -= headerBytesToRead;
if (!in.isReadable()) {
unhandledIdx++;
}
}
requiredProtectedBytes = header.readIntLE() - HEADER_TYPE_FIELD_BYTES;
checkArgument(
requiredProtectedBytes >= suffixBytes, "Invalid header field: frame size too small");
checkArgument(
requiredProtectedBytes <= LIMIT_MAX_ALLOWED_FRAME_BYTES - HEADER_BYTES,
"Invalid header field: frame size too large");
int frameType = header.readIntLE();
checkArgument(frameType == HEADER_TYPE_DEFAULT, "Invalid header field: frame type");
state = DeframerState.READ_PROTECTED_PAYLOAD;
}
@SuppressWarnings("BetaApi") // verify is stable in Guava
private ByteBuf handlePayload(ByteBufAllocator alloc) throws GeneralSecurityException {
int requiredCiphertextBytes = requiredProtectedBytes - suffixBytes;
int firstFrameUnprotectedLen = requiredCiphertextBytes;
// We get the ciphertexts of the first frame and copy over the tag into a single buffer.
List<ByteBuf> firstFrameCiphertext = new ArrayList<>();
while (requiredCiphertextBytes > 0) {
ByteBuf buf = unhandledBufs.get(unhandledIdx);
if (buf.readableBytes() <= requiredCiphertextBytes) {
// We use the whole buffer.
firstFrameCiphertext.add(buf);
requiredCiphertextBytes -= buf.readableBytes();
unhandledIdx++;
} else {
firstFrameCiphertext.add(buf.readSlice(requiredCiphertextBytes));
requiredCiphertextBytes = 0;
}
}
int requiredSuffixBytes = suffixBytes;
while (true) {
ByteBuf buf = unhandledBufs.get(unhandledIdx);
if (buf.readableBytes() <= requiredSuffixBytes) {
// We use the whole buffer.
requiredSuffixBytes -= buf.readableBytes();
firstFrameTag.writeBytes(buf);
if (requiredSuffixBytes == 0) {
break;
}
unhandledIdx++;
} else {
firstFrameTag.writeBytes(buf, requiredSuffixBytes);
break;
}
}
verify(unhandledIdx == unhandledBufs.size() - 1);
ByteBuf lastBuf = unhandledBufs.get(unhandledIdx);
// We get the remaining ciphertexts and tags contained in the last buffer.
List<ByteBuf> ciphertextsAndTags = new ArrayList<>();
List<Integer> unprotectedLens = new ArrayList<>();
long requiredUnprotectedBytesCompleteFrames = firstFrameUnprotectedLen;
while (lastBuf.readableBytes() >= HEADER_BYTES + suffixBytes) {
// Read frame size.
int frameSize = lastBuf.readIntLE();
int payloadSize = frameSize - HEADER_TYPE_FIELD_BYTES - suffixBytes;
// Break and undo read if we don't have the complete frame yet.
if (lastBuf.readableBytes() < frameSize) {
lastBuf.readerIndex(lastBuf.readerIndex() - HEADER_LEN_FIELD_BYTES);
break;
}
// Check the type header.
checkArgument(lastBuf.readIntLE() == 6);
// Create a new frame (except for out buffer).
ciphertextsAndTags.add(lastBuf.readSlice(payloadSize + suffixBytes));
// Update sizes for frame.
requiredUnprotectedBytesCompleteFrames += payloadSize;
unprotectedLens.add(payloadSize);
}
// We leave space for suffixBytes to allow for in-place encryption. This allows for calling
// doFinal in the JCE implementation which can be optimized better than update and doFinal.
ByteBuf unprotectedBuf =
alloc.directBuffer(Math.toIntExact(requiredUnprotectedBytesCompleteFrames + suffixBytes));
try {
ByteBuf out = writeSlice(unprotectedBuf, firstFrameUnprotectedLen + suffixBytes);
crypter.decrypt(out, firstFrameTag, firstFrameCiphertext);
verify(out.writableBytes() == suffixBytes);
unprotectedBuf.writerIndex(unprotectedBuf.writerIndex() - suffixBytes);
for (int frameIdx = 0; frameIdx < ciphertextsAndTags.size(); ++frameIdx) {
out = writeSlice(unprotectedBuf, unprotectedLens.get(frameIdx) + suffixBytes);
crypter.decrypt(out, ciphertextsAndTags.get(frameIdx));
verify(out.writableBytes() == suffixBytes);
unprotectedBuf.writerIndex(unprotectedBuf.writerIndex() - suffixBytes);
}
return unprotectedBuf.retain();
} finally {
unprotectedBuf.release();
}
}
private void clearState() {
int bufsSize = unhandledBufs.size();
ByteBuf lastBuf = unhandledBufs.get(bufsSize - 1);
boolean keepLast = lastBuf.isReadable();
for (int bufIdx = 0; bufIdx < (keepLast ? bufsSize - 1 : bufsSize); ++bufIdx) {
unhandledBufs.get(bufIdx).release();
}
unhandledBufs.clear();
unhandledBytes = 0;
unhandledIdx = 0;
if (keepLast) {
unhandledBufs.add(lastBuf);
unhandledBytes = lastBuf.readableBytes();
}
state = DeframerState.READ_HEADER;
requiredProtectedBytes = 0;
header.clear();
firstFrameTag.clear();
}
void destroy() {
for (ByteBuf unhandledBuf : unhandledBufs) {
unhandledBuf.release();
}
unhandledBufs.clear();
if (header != null) {
header.release();
header = null;
}
if (firstFrameTag != null) {
firstFrameTag.release();
firstFrameTag = null;
}
crypter.destroy();
}
}
private enum DeframerState {
READ_HEADER,
READ_PROTECTED_PAYLOAD
}
private static ByteBuf writeSlice(ByteBuf in, int len) {
checkArgument(len <= in.writableBytes());
ByteBuf out = in.slice(in.writerIndex(), len);
in.writerIndex(in.writerIndex() + len);
return out.writerIndex(0);
}
}

View File

@ -0,0 +1,195 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import static com.google.common.base.Preconditions.checkState;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.grpc.alts.HandshakerServiceGrpc.HandshakerServiceStub;
import io.netty.buffer.ByteBufAllocator;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.List;
/**
* Negotiates a grpc channel key to be used by the TsiFrameProtector, using ALTs handshaker service.
*/
public final class AltsTsiHandshaker implements TsiHandshaker {
public static final String TSI_SERVICE_ACCOUNT_PEER_PROPERTY = "service_account";
private final boolean isClient;
private final AltsHandshakerClient handshaker;
private ByteBuffer outputFrame;
/** Starts a new TSI handshaker with client options. */
private AltsTsiHandshaker(
boolean isClient, HandshakerServiceStub stub, AltsHandshakerOptions options) {
this.isClient = isClient;
handshaker = new AltsHandshakerClient(stub, options);
}
@VisibleForTesting
AltsTsiHandshaker(boolean isClient, AltsHandshakerClient handshaker) {
this.isClient = isClient;
this.handshaker = handshaker;
}
/**
* Process the bytes received from the peer.
*
* @param bytes The buffer containing the handshake bytes from the peer.
* @return true, if the handshake has all the data it needs to process and false, if the method
* must be called again to complete processing.
*/
@Override
public boolean processBytesFromPeer(ByteBuffer bytes) throws GeneralSecurityException {
// If we're the client and we haven't given an output frame, we shouldn't be processing any
// bytes.
if (outputFrame == null && isClient) {
return true;
}
// If we already have bytes to write, just return.
if (outputFrame != null && outputFrame.hasRemaining()) {
return true;
}
int remaining = bytes.remaining();
// Call handshaker service to proceess the bytes.
if (outputFrame == null) {
checkState(!isClient, "Client handshaker should not process any frame at the beginning.");
outputFrame = handshaker.startServerHandshake(bytes);
} else {
outputFrame = handshaker.next(bytes);
}
// If handshake has finished or we already have bytes to write, just return true.
if (handshaker.isFinished() || outputFrame.hasRemaining()) {
return true;
}
// We have done processing input bytes, but no bytes to write. Thus we need more data.
if (!bytes.hasRemaining()) {
return false;
}
// There are still remaining bytes. Thus we need to continue processing the bytes.
// Prevent infinite loop by checking some bytes are consumed by handshaker.
checkState(bytes.remaining() < remaining, "Handshaker did not consume any bytes.");
return processBytesFromPeer(bytes);
}
/**
* Returns the peer extracted from a completed handshake.
*
* @return the extracted peer.
*/
@Override
public TsiPeer extractPeer() throws GeneralSecurityException {
Preconditions.checkState(!isInProgress(), "Handshake is not complete.");
List<TsiPeer.Property<?>> peerProperties = new ArrayList<TsiPeer.Property<?>>();
peerProperties.add(
new TsiPeer.StringProperty(
TSI_SERVICE_ACCOUNT_PEER_PROPERTY,
handshaker.getResult().getPeerIdentity().getServiceAccount()));
return new TsiPeer(peerProperties);
}
/**
* Returns the peer extracted from a completed handshake.
*
* @return the extracted peer.
*/
@Override
public Object extractPeerObject() throws GeneralSecurityException {
Preconditions.checkState(!isInProgress(), "Handshake is not complete.");
return new AltsAuthContext(handshaker.getResult());
}
/** Creates a new TsiHandshaker for use by the client. */
public static TsiHandshaker newClient(HandshakerServiceStub stub, AltsHandshakerOptions options) {
return new AltsTsiHandshaker(true, stub, options);
}
/** Creates a new TsiHandshaker for use by the server. */
public static TsiHandshaker newServer(HandshakerServiceStub stub, AltsHandshakerOptions options) {
return new AltsTsiHandshaker(false, stub, options);
}
/**
* Gets bytes that need to be sent to the peer.
*
* @param bytes The buffer to put handshake bytes.
*/
@Override
public void getBytesToSendToPeer(ByteBuffer bytes) throws GeneralSecurityException {
if (outputFrame == null) { // A null outputFrame indicates we haven't started the handshake.
if (isClient) {
outputFrame = handshaker.startClientHandshake();
} else {
// The server needs bytes to process before it can start the handshake.
return;
}
}
// Write as many bytes as we are able.
ByteBuffer outputFrameAlias = outputFrame;
if (outputFrame.remaining() > bytes.remaining()) {
outputFrameAlias = outputFrame.duplicate();
outputFrameAlias.limit(outputFrameAlias.position() + bytes.remaining());
}
bytes.put(outputFrameAlias);
outputFrame.position(outputFrameAlias.position());
}
/**
* Returns true if and only if the handshake is still in progress
*
* @return true, if the handshake is still in progress, false otherwise.
*/
@Override
public boolean isInProgress() {
return !handshaker.isFinished() || outputFrame.hasRemaining();
}
/**
* Creates a frame protector from a completed handshake. No other methods may be called after the
* frame protector is created.
*
* @param maxFrameSize the requested max frame size, the callee is free to ignore.
* @param alloc used for allocating ByteBufs.
* @return a new TsiFrameProtector.
*/
@Override
public TsiFrameProtector createFrameProtector(int maxFrameSize, ByteBufAllocator alloc) {
Preconditions.checkState(!isInProgress(), "Handshake is not complete.");
byte[] key = handshaker.getKey();
Preconditions.checkState(key.length == AltsChannelCrypter.getKeyLength(), "Bad key length.");
return new AltsTsiFrameProtector(maxFrameSize, new AltsChannelCrypter(key, isClient), alloc);
}
/**
* Creates a frame protector from a completed handshake. No other methods may be called after the
* frame protector is created.
*
* @param alloc used for allocating ByteBufs.
* @return a new TsiFrameProtector.
*/
@Override
public TsiFrameProtector createFrameProtector(ByteBufAllocator alloc) {
return createFrameProtector(AltsTsiFrameProtector.getMaxAllowedFrameBytes(), alloc);
}
}

View File

@ -0,0 +1,76 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import io.netty.buffer.ByteBuf;
import java.security.GeneralSecurityException;
import java.util.List;
/**
* A @{code ChannelCrypterNetty} performs stateful encryption and decryption of independent input
* and output streams. Both decrypt and encrypt gather their input from a list of Netty @{link
* ByteBuf} instances.
*
* <p>Note that we provide implementations of this interface that provide integrity only and
* implementations that provide privacy and integrity. All methods should be thread-compatible.
*/
public interface ChannelCrypterNetty {
/**
* Encrypt plaintext into output buffer.
*
* @param out the protected input will be written into this buffer. The buffer must be direct and
* have enough space to hold all input buffers and the tag. Encrypt does not take ownership of
* this buffer.
* @param plain the input buffers that should be protected. Encrypt does not modify or take
* ownership of these buffers.
*/
void encrypt(ByteBuf out, List<ByteBuf> plain) throws GeneralSecurityException;
/**
* Decrypt ciphertext into the given output buffer and check tag.
*
* @param out the unprotected input will be written into this buffer. The buffer must be direct
* and have enough space to hold all ciphertext buffers and the tag, i.e., it must have
* additional space for the tag, even though this space will be unused in the final result.
* Decrypt does not take ownership of this buffer.
* @param tag the tag appended to the ciphertext. Decrypt does not modify or take ownership of
* this buffer.
* @param ciphertext the buffers that should be unprotected (excluding the tag). Decrypt does not
* modify or take ownership of these buffers.
*/
void decrypt(ByteBuf out, ByteBuf tag, List<ByteBuf> ciphertext) throws GeneralSecurityException;
/**
* Decrypt ciphertext into the given output buffer and check tag.
*
* @param out the unprotected input will be written into this buffer. The buffer must be direct
* and have enough space to hold all ciphertext buffers and the tag, i.e., it must have
* additional space for the tag, even though this space will be unused in the final result.
* Decrypt does not take ownership of this buffer.
* @param ciphertextAndTag single buffer containing ciphertext and tag that should be unprotected.
* The buffer must be direct and either completely overlap with {@code out} or not overlap at
* all.
*/
void decrypt(ByteBuf out, ByteBuf ciphertextAndTag) throws GeneralSecurityException;
/** Returns the length of the tag in bytes. */
int getSuffixLength();
/** Must be called to release all associated resources (instance cannot be used afterwards). */
void destroy();
}

View File

@ -0,0 +1,56 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import java.security.GeneralSecurityException;
import java.util.List;
import java.util.function.Consumer;
/**
* This object protects and unprotects netty buffers once the handshake is done.
*
* <p>Implementations of this object must be thread compatible.
*/
public interface TsiFrameProtector {
/**
* Protects the buffers by performing framing and encrypting/appending MACs.
*
* @param unprotectedBufs contain the payload that will be protected
* @param ctxWrite is called with buffers containing protected frames and must release the given
* buffers
* @param alloc is used to allocate new buffers for the protected frames
*/
void protectFlush(
List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc)
throws GeneralSecurityException;
/**
* Unprotects the buffers by removing the framing and decrypting/checking MACs.
*
* @param in contains (partial) protected frames
* @param out is only used to append unprotected payload buffers
* @param alloc is used to allocate new buffers for the unprotected frames
*/
void unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)
throws GeneralSecurityException;
/** Must be called to release all associated resources (instance cannot be used afterwards). */
void destroy();
}

View File

@ -0,0 +1,109 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import io.netty.buffer.ByteBufAllocator;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
/**
* This object protects and unprotects buffers once the handshake is done.
*
* <p>A typical usage of this object would be:
*
* <pre>{@code
* ByteBuffer buffer = allocateDirect(ALLOCATE_SIZE);
* while (true) {
* while (true) {
* tsiHandshaker.getBytesToSendToPeer(buffer.clear());
* if (!buffer.hasRemaining()) break;
* yourTransportSendMethod(buffer.flip());
* assert(!buffer.hasRemaining()); // Guaranteed by yourTransportReceiveMethod(...)
* }
* if (!tsiHandshaker.isInProgress()) break;
* while (true) {
* assert(!buffer.hasRemaining());
* yourTransportReceiveMethod(buffer.clear());
* if (tsiHandshaker.processBytesFromPeer(buffer.flip())) break;
* }
* if (!tsiHandshaker.isInProgress()) break;
* assert(!buffer.hasRemaining());
* }
* yourCheckPeerMethod(tsiHandshaker.extractPeer());
* TsiFrameProtector tsiFrameProtector = tsiHandshaker.createFrameProtector(MAX_FRAME_SIZE);
* if (buffer.hasRemaining()) tsiFrameProtector.unprotect(buffer, messageBuffer);
* }</pre>
*
* <p>Implementations of this object must be thread compatible.
*/
public interface TsiHandshaker {
/**
* Gets bytes that need to be sent to the peer.
*
* @param bytes The buffer to put handshake bytes.
*/
void getBytesToSendToPeer(ByteBuffer bytes) throws GeneralSecurityException;
/**
* Process the bytes received from the peer.
*
* @param bytes The buffer containing the handshake bytes from the peer.
* @return true, if the handshake has all the data it needs to process and false, if the method
* must be called again to complete processing.
*/
boolean processBytesFromPeer(ByteBuffer bytes) throws GeneralSecurityException;
/**
* Returns true if and only if the handshake is still in progress
*
* @return true, if the handshake is still in progress, false otherwise.
*/
boolean isInProgress();
/**
* Returns the peer extracted from a completed handshake.
*
* @return the extracted peer.
*/
TsiPeer extractPeer() throws GeneralSecurityException;
/**
* Returns the peer extracted from a completed handshake.
*
* @return the extracted peer.
*/
public Object extractPeerObject() throws GeneralSecurityException;
/**
* Creates a frame protector from a completed handshake. No other methods may be called after the
* frame protector is created.
*
* @param maxFrameSize the requested max frame size, the callee is free to ignore.
* @param alloc used for allocating ByteBufs.
* @return a new TsiFrameProtector.
*/
TsiFrameProtector createFrameProtector(int maxFrameSize, ByteBufAllocator alloc);
/**
* Creates a frame protector from a completed handshake. No other methods may be called after the
* frame protector is created.
*
* @param alloc used for allocating ByteBufs.
* @return a new TsiFrameProtector.
*/
TsiFrameProtector createFrameProtector(ByteBufAllocator alloc);
}

View File

@ -0,0 +1,24 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
/** Factory that manufactures instances of {@link TsiHandshaker}. */
public interface TsiHandshakerFactory {
/** Creates a new handshaker. */
TsiHandshaker newHandshaker();
}

View File

@ -0,0 +1,110 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import javax.annotation.Nonnull;
/** A set of peer properties. */
public final class TsiPeer {
private final List<Property<?>> properties;
public TsiPeer(List<Property<?>> properties) {
this.properties = Collections.unmodifiableList(properties);
}
public List<Property<?>> getProperties() {
return properties;
}
/** Get peer property. */
public Property<?> getProperty(String name) {
for (Property<?> property : properties) {
if (property.getName().equals(name)) {
return property;
}
}
return null;
}
@Override
public String toString() {
return new ArrayList<>(properties).toString();
}
/** A peer property. */
public abstract static class Property<T> {
private final String name;
private final T value;
public Property(@Nonnull String name, @Nonnull T value) {
this.name = name;
this.value = value;
}
public final T getValue() {
return value;
}
public final String getName() {
return name;
}
@Override
public String toString() {
return String.format("%s=%s", name, value);
}
}
/** A peer property corresponding to a signed 64-bit integer. */
public static final class SignedInt64Property extends Property<Long> {
public SignedInt64Property(@Nonnull String name, @Nonnull Long value) {
super(name, value);
}
}
/** A peer property corresponding to an unsigned 64-bit integer. */
public static final class UnsignedInt64Property extends Property<BigInteger> {
public UnsignedInt64Property(@Nonnull String name, @Nonnull BigInteger value) {
super(name, value);
}
}
/** A peer property corresponding to a double. */
public static final class DoubleProperty extends Property<Double> {
public DoubleProperty(@Nonnull String name, @Nonnull Double value) {
super(name, value);
}
}
/** A peer property corresponding to a string. */
public static final class StringProperty extends Property<String> {
public StringProperty(@Nonnull String name, @Nonnull String value) {
super(name, value);
}
}
/** A peer property corresponding to a list of peer properties. */
public static final class PropertyList extends Property<List<Property<?>>> {
public PropertyList(@Nonnull String name, @Nonnull List<Property<?>> value) {
super(name, value);
}
}
}

View File

@ -0,0 +1,27 @@
syntax = "proto3";
import "transport_security_common.proto";
package grpc.gcp;
option java_package = "io.grpc.alts";
message AltsContext {
// The application protocol negotiated for this connection.
string application_protocol = 1;
// The record protocol negotiated for this connection.
string record_protocol = 2;
// The security level of the created secure channel.
SecurityLevel security_level = 3;
// The peer service account.
string peer_service_account = 4;
// The local service account.
string local_service_account = 5;
// The RPC protocol versions supported by the peer.
RpcProtocolVersions peer_rpc_versions = 6;
}

View File

@ -0,0 +1,206 @@
syntax = "proto3";
import "transport_security_common.proto";
package grpc.gcp;
option java_package = "io.grpc.alts";
enum HandshakeProtocol {
// Default value.
HANDSHAKE_PROTOCOL_UNSPECIFIED = 0;
// TLS handshake protocol.
TLS = 1;
// Application Layer Transport Security handshake protocol.
ALTS = 2;
}
enum NetworkProtocol {
NETWORK_PROTOCOL_UNSPECIFIED = 0;
TCP = 1;
UDP = 2;
}
message Endpoint {
// IP address. It should contain an IPv4 or IPv6 string literal, e.g.
// "192.168.0.1" or "2001:db8::1".
string ip_address = 1;
// Port number.
int32 port = 2;
// Network protocol (e.g., TCP, UDP) associated with this endpoint.
NetworkProtocol protocol = 3;
}
message Identity {
oneof identity_oneof {
// Service account of a connection endpoint.
string service_account = 1;
// Hostname of a connection endpoint.
string hostname = 2;
}
}
message StartClientHandshakeReq {
// Handshake security protocol requested by the client.
HandshakeProtocol handshake_security_protocol = 1;
// The application protocols supported by the client, e.g., "h2" (for http2),
// "grpc".
repeated string application_protocols = 2;
// The record protocols supported by the client, e.g.,
// "ALTSRP_GCM_AES128".
repeated string record_protocols = 3;
// (Optional) Describes which server identities are acceptable by the client.
// If target identities are provided and none of them matches the peer
// identity of the server, handshake will fail.
repeated Identity target_identities = 4;
// (Optional) Application may specify a local identity. Otherwise, the
// handshaker chooses a default local identity.
Identity local_identity = 5;
// (Optional) Local endpoint information of the connection to the server,
// such as local IP address, port number, and network protocol.
Endpoint local_endpoint = 6;
// (Optional) Endpoint information of the remote server, such as IP address,
// port number, and network protocool.
Endpoint remote_endpoint = 7;
// (Optional) If target name is provided, a secure naming check is performed
// to verify that the peer authenticated identity is indeed authorized to run
// the target name.
string target_name = 8;
// (Optional) RPC protocol versions supported by the client.
RpcProtocolVersions rpc_versions = 9;
}
message ServerHandshakeParameters {
// The record protocols supported by the server, e.g.,
// "ALTSRP_GCM_AES128".
repeated string record_protocols = 1;
// (Optional) A list of local identities supported by the server, if
// specified. Otherwise, the handshaker chooses a default local identity.
repeated Identity local_identities = 2;
}
message StartServerHandshakeReq {
// The application protocols supported by the server, e.g., "h2" (for http2),
// "grpc".
repeated string application_protocols = 1;
// Handshake parameters (record protocols and local identities supported by
// the server) mapped by the handshake protocol. Each handshake security
// protocol (e.g., TLS or ALTS) has its own set of record protocols and local
// identities. Since protobuf does not support enum as key to the map, the key
// to handshake_parameters is the integer value of HandshakeProtocol enum.
map<int32, ServerHandshakeParameters> handshake_parameters = 2;
// Bytes in out_frames returned from the peer's HandshakerResp. It is possible
// that the peer's out_frames are split into multiple HandshakReq messages.
bytes in_bytes = 3;
// (Optional) Local endpoint information of the connection to the client,
// such as local IP address, port number, and network protocol.
Endpoint local_endpoint = 4;
// (Optional) Endpoint information of the remote client, such as IP address,
// port number, and network protocool.
Endpoint remote_endpoint = 5;
// (Optional) RPC protocol versions supported by the server.
RpcProtocolVersions rpc_versions = 6;
}
message NextHandshakeMessageReq {
// Bytes in out_frames returned from the peer's HandshakerResp. It is possible
// that the peer's out_frames are split into multiple NextHandshakerMessageReq
// messages.
bytes in_bytes = 1;
}
message HandshakerReq {
oneof req_oneof {
// The start client handshake request message.
StartClientHandshakeReq client_start = 1;
// The start server handshake request message.
StartServerHandshakeReq server_start = 2;
// The next handshake request message.
NextHandshakeMessageReq next = 3;
}
}
message HandshakerResult {
// The application protocol negotiated for this connection.
string application_protocol = 1;
// The record protocol negotiated for this connection.
string record_protocol = 2;
// Cryptographic key data. The key data may be more than the key length
// required for the record protocol, thus the client of the handshaker
// service needs to truncate the key data into the right key length.
bytes key_data = 3;
// The authenticated identity of the peer.
Identity peer_identity = 4;
// The local identity used in the handshake.
Identity local_identity = 5;
// Indicate whether the handshaker service client should keep the channel
// between the handshaker service open, e.g., in order to handle
// post-handshake messages in the future.
bool keep_channel_open = 6;
// The RPC protocol versions supported by the peer.
RpcProtocolVersions peer_rpc_versions = 7;
}
message HandshakerStatus {
// The status code. This could be the gRPC status code.
uint32 code = 1;
// The status details.
string details = 2;
}
message HandshakerResp {
// Frames to be given to the peer for the NextHandshakeMessageReq. May be
// empty if no out_frames have to be sent to the peer or if in_bytes in the
// HandshakerReq are incomplete. All the non-empty out frames must be sent to
// the peer even if the handshaker status is not OK as these frames may
// contain the alert frames.
bytes out_frames = 1;
// Number of bytes in the in_bytes consumed by the handshaker. It is possible
// that part of in_bytes in HandshakerReq was unrelated to the handshake
// process.
uint32 bytes_consumed = 2;
// This is set iff the handshake was successful. out_frames may still be set
// to frames that needs to be forwarded to the peer.
HandshakerResult result = 3;
// Status of the handshaker.
HandshakerStatus status = 4;
}
service HandshakerService {
// Accepts a stream of handshaker request, returning a stream of handshaker
// response.
rpc DoHandshake(stream HandshakerReq)
returns (stream HandshakerResp) {
}
}

View File

@ -0,0 +1,26 @@
syntax = "proto3";
package grpc.gcp;
option java_package = "io.grpc.alts";
// The security level of the created channel. The list is sorted in increasing
// level of security. This order must always be maintained.
enum SecurityLevel {
SECURITY_NONE = 0;
INTEGRITY_ONLY = 1;
INTEGRITY_AND_PRIVACY = 2;
}
// Max and min supported RPC protocol versions.
message RpcProtocolVersions {
// RPC version contains a major version and a minor version.
message Version {
uint32 major = 1;
uint32 minor = 2;
}
// Maximum supported RPC version.
Version max_rpc_version = 1;
// Minimum supported RPC version.
Version min_rpc_version = 2;
}

View File

@ -0,0 +1,94 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts;
import static com.google.common.truth.Truth.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import com.google.common.base.Defaults;
import io.grpc.ManagedChannel;
import io.grpc.alts.AltsChannelBuilder.AltsChannel;
import io.grpc.alts.TransportSecurityCommon.RpcProtocolVersions;
import io.grpc.alts.transportsecurity.AltsClientOptions;
import io.grpc.netty.InternalNettyChannelBuilder.TransportCreationParamsFilterFactory;
import io.grpc.netty.ProtocolNegotiator;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.net.InetSocketAddress;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@RunWith(JUnit4.class)
public final class AltsChannelBuilderTest {
@Test
public void buildsNettyChannel() throws Exception {
AltsChannelBuilder builder =
AltsChannelBuilder.forTarget("localhost:8080").enableUntrustedAltsForTesting();
TransportCreationParamsFilterFactory tcpfFactory = builder.getTcpfFactoryForTest();
AltsClientOptions altsClientOptions = builder.getAltsClientOptionsForTest();
assertThat(tcpfFactory).isNull();
assertThat(altsClientOptions).isNull();
ManagedChannel channel = builder.build();
assertThat(channel).isInstanceOf(AltsChannel.class);
tcpfFactory = builder.getTcpfFactoryForTest();
altsClientOptions = builder.getAltsClientOptionsForTest();
assertThat(tcpfFactory).isNotNull();
ProtocolNegotiator protocolNegotiator =
tcpfFactory
.create(new InetSocketAddress(8080), "fakeAuthority", "fakeUserAgent", null)
.getProtocolNegotiator();
assertThat(protocolNegotiator).isInstanceOf(AltsProtocolNegotiator.class);
assertThat(altsClientOptions).isNotNull();
RpcProtocolVersions expectedVersions =
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(
RpcProtocolVersions.Version.newBuilder().setMajor(2).setMinor(1).build())
.setMinRpcVersion(
RpcProtocolVersions.Version.newBuilder().setMajor(2).setMinor(1).build())
.build();
assertThat(altsClientOptions.getRpcProtocolVersions()).isEqualTo(expectedVersions);
}
@Test
public void allAltsChannelMethodsForward() throws Exception {
ManagedChannel mockDelegate = mock(ManagedChannel.class);
AltsChannel altsChannel = new AltsChannel(mockDelegate);
for (Method method : ManagedChannel.class.getDeclaredMethods()) {
if (Modifier.isStatic(method.getModifiers()) || Modifier.isPrivate(method.getModifiers())) {
continue;
}
Class<?>[] argTypes = method.getParameterTypes();
Object[] args = new Object[argTypes.length];
for (int i = 0; i < argTypes.length; i++) {
args[i] = Defaults.defaultValue(argTypes[i]);
}
method.invoke(altsChannel, args);
method.invoke(verify(mockDelegate), args);
}
}
}

View File

@ -0,0 +1,494 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts;
import static com.google.common.truth.Truth.assertThat;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import io.grpc.Attributes;
import io.grpc.Grpc;
import io.grpc.alts.Handshaker.HandshakerResult;
import io.grpc.alts.transportsecurity.AltsAuthContext;
import io.grpc.alts.transportsecurity.FakeTsiHandshaker;
import io.grpc.alts.transportsecurity.TsiFrameProtector;
import io.grpc.alts.transportsecurity.TsiHandshaker;
import io.grpc.alts.transportsecurity.TsiHandshakerFactory;
import io.grpc.alts.transportsecurity.TsiPeer;
import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http2.DefaultHttp2Connection;
import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder;
import io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder;
import io.netty.handler.codec.http2.DefaultHttp2FrameReader;
import io.netty.handler.codec.http2.DefaultHttp2FrameWriter;
import io.netty.handler.codec.http2.Http2Connection;
import io.netty.handler.codec.http2.Http2ConnectionDecoder;
import io.netty.handler.codec.http2.Http2ConnectionEncoder;
import io.netty.handler.codec.http2.Http2FrameReader;
import io.netty.handler.codec.http2.Http2FrameWriter;
import io.netty.handler.codec.http2.Http2Settings;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Tests for {@link AltsProtocolNegotiator}. */
@RunWith(JUnit4.class)
public class AltsProtocolNegotiatorTest {
private final CapturingGrpcHttp2ConnectionHandler grpcHandler = capturingGrpcHandler();
private final List<ReferenceCounted> references = new ArrayList<>();
private final LinkedBlockingQueue<InterceptingProtector> protectors = new LinkedBlockingQueue<>();
private EmbeddedChannel channel;
private Throwable caughtException;
private volatile InternalTsiHandshakeHandler.TsiHandshakeCompletionEvent tsiEvent;
private ChannelHandler handler;
private TsiPeer mockedTsiPeer = new TsiPeer(Collections.emptyList());
private AltsAuthContext mockedAltsContext =
new AltsAuthContext(
HandshakerResult.newBuilder()
.setPeerRpcVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions())
.build());
private final TsiHandshaker mockHandshaker =
new DelegatingTsiHandshaker(FakeTsiHandshaker.newFakeHandshakerServer()) {
@Override
public TsiPeer extractPeer() throws GeneralSecurityException {
return mockedTsiPeer;
}
@Override
public Object extractPeerObject() throws GeneralSecurityException {
return mockedAltsContext;
}
};
private final InternalNettyTsiHandshaker serverHandshaker =
new InternalNettyTsiHandshaker(mockHandshaker);
@Before
public void setup() throws Exception {
ChannelHandler userEventHandler =
new ChannelDuplexHandler() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof InternalTsiHandshakeHandler.TsiHandshakeCompletionEvent) {
tsiEvent = (InternalTsiHandshakeHandler.TsiHandshakeCompletionEvent) evt;
} else {
super.userEventTriggered(ctx, evt);
}
}
};
ChannelHandler uncaughtExceptionHandler =
new ChannelDuplexHandler() {
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
caughtException = cause;
super.exceptionCaught(ctx, cause);
}
};
TsiHandshakerFactory handshakerFactory =
new DelegatingTsiHandshakerFactory(FakeTsiHandshaker.clientHandshakerFactory()) {
@Override
public TsiHandshaker newHandshaker() {
return new DelegatingTsiHandshaker(super.newHandshaker()) {
@Override
public TsiPeer extractPeer() throws GeneralSecurityException {
return mockedTsiPeer;
}
@Override
public Object extractPeerObject() throws GeneralSecurityException {
return mockedAltsContext;
}
};
}
};
handler = AltsProtocolNegotiator.create(handshakerFactory).newHandler(grpcHandler);
channel = new EmbeddedChannel(uncaughtExceptionHandler, handler, userEventHandler);
}
@After
public void teardown() throws Exception {
if (channel != null) {
@SuppressWarnings("unused") // go/futurereturn-lsc
Future<?> possiblyIgnoredError = channel.close();
}
for (ReferenceCounted reference : references) {
ReferenceCountUtil.safeRelease(reference);
}
}
@Test
public void handshakeShouldBeSuccessful() throws Exception {
doHandshake();
}
@Test
@SuppressWarnings("unchecked") // List cast
public void protectShouldRoundtrip() throws Exception {
// Write the message 1 character at a time. The message should be buffered
// and not interfere with the handshake.
final AtomicInteger writeCount = new AtomicInteger();
String message = "hello";
for (int ix = 0; ix < message.length(); ++ix) {
ByteBuf in = Unpooled.copiedBuffer(message, ix, 1, UTF_8);
@SuppressWarnings("unused") // go/futurereturn-lsc
Future<?> possiblyIgnoredError =
channel
.write(in)
.addListener(
new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {
writeCount.incrementAndGet();
}
}
});
}
channel.flush();
// Now do the handshake. The buffered message will automatically be protected
// and sent.
doHandshake();
// Capture the protected data written to the wire.
assertEquals(1, channel.outboundMessages().size());
ByteBuf protectedData = channel.<ByteBuf>readOutbound();
assertEquals(message.length(), writeCount.get());
// Read the protected message at the server and verify it matches the original message.
TsiFrameProtector serverProtector = serverHandshaker.createFrameProtector(channel.alloc());
List<ByteBuf> unprotected = new ArrayList<>();
serverProtector.unprotect(protectedData, (List<Object>) (List<?>) unprotected, channel.alloc());
// We try our best to remove the HTTP2 handler as soon as possible, but just by constructing it
// a settings frame is written (and an HTTP2 preface). This is hard coded into Netty, so we
// have to remove it here. See {@code Http2ConnectionHandler.PrefaceDecode.sendPreface}.
int settingsFrameLength = 9;
CompositeByteBuf unprotectedAll =
new CompositeByteBuf(channel.alloc(), false, unprotected.size() + 1, unprotected);
ByteBuf unprotectedData = unprotectedAll.slice(settingsFrameLength, message.length());
assertEquals(message, unprotectedData.toString(UTF_8));
// Protect the same message at the server.
AtomicReference<ByteBuf> newlyProtectedData = new AtomicReference<>();
serverProtector.protectFlush(
Collections.singletonList(unprotectedData),
b -> newlyProtectedData.set(b),
channel.alloc());
// Read the protected message at the client and verify that it matches the original message.
channel.writeInbound(newlyProtectedData.get());
assertEquals(1, channel.inboundMessages().size());
assertEquals(message, channel.<ByteBuf>readInbound().toString(UTF_8));
}
@Test
public void unprotectLargeIncomingFrame() throws Exception {
// We use a server frameprotector with twice the standard frame size.
int serverFrameSize = 4096 * 2;
// This should fit into one frame.
byte[] unprotectedBytes = new byte[serverFrameSize - 500];
Arrays.fill(unprotectedBytes, (byte) 7);
ByteBuf unprotectedData = Unpooled.wrappedBuffer(unprotectedBytes);
unprotectedData.writerIndex(unprotectedBytes.length);
// Perform handshake.
doHandshake();
// Protect the message on the server.
TsiFrameProtector serverProtector =
serverHandshaker.createFrameProtector(serverFrameSize, channel.alloc());
serverProtector.protectFlush(
Collections.singletonList(unprotectedData), b -> channel.writeInbound(b), channel.alloc());
channel.flushInbound();
// Read the protected message at the client and verify that it matches the original message.
assertEquals(1, channel.inboundMessages().size());
ByteBuf receivedData1 = channel.<ByteBuf>readInbound();
int receivedLen1 = receivedData1.readableBytes();
byte[] receivedBytes = new byte[receivedLen1];
receivedData1.readBytes(receivedBytes, 0, receivedLen1);
assertThat(unprotectedBytes.length).isEqualTo(receivedBytes.length);
assertThat(unprotectedBytes).isEqualTo(receivedBytes);
}
@Test
public void flushShouldFailAllPromises() throws Exception {
doHandshake();
channel
.pipeline()
.addFirst(
new ChannelDuplexHandler() {
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
throws Exception {
throw new Exception("Fake exception");
}
});
// Write the message 1 character at a time.
String message = "hello";
final AtomicInteger failures = new AtomicInteger();
for (int ix = 0; ix < message.length(); ++ix) {
ByteBuf in = Unpooled.copiedBuffer(message, ix, 1, UTF_8);
@SuppressWarnings("unused") // go/futurereturn-lsc
Future<?> possiblyIgnoredError =
channel
.write(in)
.addListener(
new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
failures.incrementAndGet();
}
}
});
}
channel.flush();
// Verify that the promises fail.
assertEquals(message.length(), failures.get());
}
@Test
public void doNotFlushEmptyBuffer() throws Exception {
doHandshake();
assertEquals(1, protectors.size());
InterceptingProtector protector = protectors.poll();
String message = "hello";
ByteBuf in = Unpooled.copiedBuffer(message, UTF_8);
assertEquals(0, protector.flushes.get());
Future<?> done = channel.write(in);
channel.flush();
done.get(5, TimeUnit.SECONDS);
assertEquals(1, protector.flushes.get());
done = channel.write(Unpooled.EMPTY_BUFFER);
channel.flush();
done.get(5, TimeUnit.SECONDS);
assertEquals(1, protector.flushes.get());
}
@Test
public void peerPropagated() throws Exception {
doHandshake();
assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.getTsiPeerAttributeKey()))
.isEqualTo(mockedTsiPeer);
assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.getAltsAuthContextAttributeKey()))
.isEqualTo(mockedAltsContext);
assertThat(grpcHandler.attrs.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR).toString())
.isEqualTo("embedded");
}
private void doHandshake() throws Exception {
// Capture the client frame and add to the server.
assertEquals(1, channel.outboundMessages().size());
ByteBuf clientFrame = channel.<ByteBuf>readOutbound();
assertTrue(serverHandshaker.processBytesFromPeer(clientFrame));
// Get the server response handshake frames.
ByteBuf serverFrame = channel.alloc().buffer();
serverHandshaker.getBytesToSendToPeer(serverFrame);
channel.writeInbound(serverFrame);
// Capture the next client frame and add to the server.
assertEquals(1, channel.outboundMessages().size());
clientFrame = channel.<ByteBuf>readOutbound();
assertTrue(serverHandshaker.processBytesFromPeer(clientFrame));
// Get the server response handshake frames.
serverFrame = channel.alloc().buffer();
serverHandshaker.getBytesToSendToPeer(serverFrame);
channel.writeInbound(serverFrame);
// Ensure that both sides have confirmed that the handshake has completed.
assertFalse(serverHandshaker.isInProgress());
if (caughtException != null) {
throw new RuntimeException(caughtException);
}
assertNotNull(tsiEvent);
}
private CapturingGrpcHttp2ConnectionHandler capturingGrpcHandler() {
// Netty Boilerplate. We don't really need any of this, but there is a tight coupling
// between a Http2ConnectionHandler and its dependencies.
Http2Connection connection = new DefaultHttp2Connection(true);
Http2FrameWriter frameWriter = new DefaultHttp2FrameWriter();
Http2FrameReader frameReader = new DefaultHttp2FrameReader(false);
DefaultHttp2ConnectionEncoder encoder =
new DefaultHttp2ConnectionEncoder(connection, frameWriter);
DefaultHttp2ConnectionDecoder decoder =
new DefaultHttp2ConnectionDecoder(connection, encoder, frameReader);
return new CapturingGrpcHttp2ConnectionHandler(decoder, encoder, new Http2Settings());
}
private final class CapturingGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler {
private Attributes attrs;
private CapturingGrpcHttp2ConnectionHandler(
Http2ConnectionDecoder decoder,
Http2ConnectionEncoder encoder,
Http2Settings initialSettings) {
super(null, decoder, encoder, initialSettings);
}
@Override
public void handleProtocolNegotiationCompleted(Attributes attrs) {
// If we are added to the pipeline, we need to remove ourselves. The HTTP2 handler
channel.pipeline().remove(this);
this.attrs = attrs;
}
}
private static class DelegatingTsiHandshakerFactory implements TsiHandshakerFactory {
private TsiHandshakerFactory delegate;
DelegatingTsiHandshakerFactory(TsiHandshakerFactory delegate) {
this.delegate = delegate;
}
@Override
public TsiHandshaker newHandshaker() {
return delegate.newHandshaker();
}
}
private class DelegatingTsiHandshaker implements TsiHandshaker {
private final TsiHandshaker delegate;
DelegatingTsiHandshaker(TsiHandshaker delegate) {
this.delegate = delegate;
}
@Override
public void getBytesToSendToPeer(ByteBuffer bytes) throws GeneralSecurityException {
delegate.getBytesToSendToPeer(bytes);
}
@Override
public boolean processBytesFromPeer(ByteBuffer bytes) throws GeneralSecurityException {
return delegate.processBytesFromPeer(bytes);
}
@Override
public boolean isInProgress() {
return delegate.isInProgress();
}
@Override
public TsiPeer extractPeer() throws GeneralSecurityException {
return delegate.extractPeer();
}
@Override
public Object extractPeerObject() throws GeneralSecurityException {
return delegate.extractPeerObject();
}
@Override
public TsiFrameProtector createFrameProtector(ByteBufAllocator alloc) {
InterceptingProtector protector =
new InterceptingProtector(delegate.createFrameProtector(alloc));
protectors.add(protector);
return protector;
}
@Override
public TsiFrameProtector createFrameProtector(int maxFrameSize, ByteBufAllocator alloc) {
InterceptingProtector protector =
new InterceptingProtector(delegate.createFrameProtector(maxFrameSize, alloc));
protectors.add(protector);
return protector;
}
}
private static class InterceptingProtector implements TsiFrameProtector {
private final TsiFrameProtector delegate;
final AtomicInteger flushes = new AtomicInteger();
InterceptingProtector(TsiFrameProtector delegate) {
this.delegate = delegate;
}
@Override
public void protectFlush(
List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc)
throws GeneralSecurityException {
flushes.incrementAndGet();
delegate.protectFlush(unprotectedBufs, ctxWrite, alloc);
}
@Override
public void unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)
throws GeneralSecurityException {
delegate.unprotect(in, out, alloc);
}
@Override
public void destroy() {
delegate.destroy();
}
}
}

View File

@ -0,0 +1,30 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@RunWith(JUnit4.class)
public final class AltsServerBuilderTest {
@Test
public void buildsNettyServer() throws Exception {
AltsServerBuilder.forPort(1234).enableUntrustedAltsForTesting().build();
}
}

View File

@ -0,0 +1,137 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts;
import static org.junit.Assert.assertEquals;
import com.google.common.truth.Truth;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.UnpooledByteBufAllocator;
import java.nio.ByteBuffer;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@RunWith(JUnit4.class)
public class BufUnwrapperTest {
private final ByteBufAllocator alloc = UnpooledByteBufAllocator.DEFAULT;
@Test
public void closeEmptiesBuffers() {
BufUnwrapper unwrapper = new BufUnwrapper();
ByteBuf buf = alloc.buffer();
try {
ByteBuffer[] readableBufs = unwrapper.readableNioBuffers(buf);
ByteBuffer[] writableBufs = unwrapper.writableNioBuffers(buf);
Truth.assertThat(readableBufs).hasLength(1);
Truth.assertThat(readableBufs[0]).isNotNull();
Truth.assertThat(writableBufs).hasLength(1);
Truth.assertThat(writableBufs[0]).isNotNull();
unwrapper.close();
Truth.assertThat(readableBufs[0]).isNull();
Truth.assertThat(writableBufs[0]).isNull();
} finally {
buf.release();
}
}
@Test
public void readableNioBuffers_worksWithNormal() {
ByteBuf buf = alloc.buffer(1).writeByte('a');
try (BufUnwrapper unwrapper = new BufUnwrapper()) {
ByteBuffer[] internalBufs = unwrapper.readableNioBuffers(buf);
Truth.assertThat(internalBufs).hasLength(1);
assertEquals('a', internalBufs[0].get(0));
} finally {
buf.release();
}
}
@Test
public void readableNioBuffers_worksWithComposite() {
CompositeByteBuf buf = alloc.compositeBuffer();
buf.addComponent(true, alloc.buffer(1).writeByte('a'));
try (BufUnwrapper unwrapper = new BufUnwrapper()) {
ByteBuffer[] internalBufs = unwrapper.readableNioBuffers(buf);
Truth.assertThat(internalBufs).hasLength(1);
assertEquals('a', internalBufs[0].get(0));
} finally {
buf.release();
}
}
@Test
public void writableNioBuffers_indexesPreserved() {
ByteBuf buf = alloc.buffer(1);
int ridx = buf.readerIndex();
int widx = buf.writerIndex();
int cap = buf.capacity();
try (BufUnwrapper unwrapper = new BufUnwrapper()) {
ByteBuffer[] internalBufs = unwrapper.writableNioBuffers(buf);
Truth.assertThat(internalBufs).hasLength(1);
internalBufs[0].put((byte) 'a');
assertEquals(ridx, buf.readerIndex());
assertEquals(widx, buf.writerIndex());
assertEquals(cap, buf.capacity());
} finally {
buf.release();
}
}
@Test
public void writableNioBuffers_worksWithNormal() {
ByteBuf buf = alloc.buffer(1);
try (BufUnwrapper unwrapper = new BufUnwrapper()) {
ByteBuffer[] internalBufs = unwrapper.writableNioBuffers(buf);
Truth.assertThat(internalBufs).hasLength(1);
internalBufs[0].put((byte) 'a');
buf.writerIndex(1);
assertEquals('a', buf.readByte());
} finally {
buf.release();
}
}
@Test
public void writableNioBuffers_worksWithComposite() {
CompositeByteBuf buf = alloc.compositeBuffer();
buf.addComponent(alloc.buffer(1));
buf.capacity(1);
try (BufUnwrapper unwrapper = new BufUnwrapper()) {
ByteBuffer[] internalBufs = unwrapper.writableNioBuffers(buf);
Truth.assertThat(internalBufs).hasLength(1);
internalBufs[0].put((byte) 'a');
buf.writerIndex(1);
assertEquals('a', buf.readByte());
} finally {
buf.release();
}
}
}

View File

@ -0,0 +1,62 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import java.io.BufferedReader;
import java.io.StringReader;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@RunWith(JUnit4.class)
public final class CheckGcpEnvironmentTest {
@Test
public void checkGcpLinuxPlatformData() throws Exception {
BufferedReader reader;
reader = new BufferedReader(new StringReader("HP Z440 Workstation"));
assertFalse(CheckGcpEnvironment.checkProductNameOnLinux(reader));
reader = new BufferedReader(new StringReader("Google"));
assertTrue(CheckGcpEnvironment.checkProductNameOnLinux(reader));
reader = new BufferedReader(new StringReader("Google Compute Engine"));
assertTrue(CheckGcpEnvironment.checkProductNameOnLinux(reader));
reader = new BufferedReader(new StringReader("Google Compute Engine "));
assertTrue(CheckGcpEnvironment.checkProductNameOnLinux(reader));
}
@Test
public void checkGcpWindowsPlatformData() throws Exception {
BufferedReader reader;
reader = new BufferedReader(new StringReader("Product : Google"));
assertFalse(CheckGcpEnvironment.checkBiosDataOnWindows(reader));
reader = new BufferedReader(new StringReader("Manufacturer : LENOVO"));
assertFalse(CheckGcpEnvironment.checkBiosDataOnWindows(reader));
reader = new BufferedReader(new StringReader("Manufacturer : Google Compute Engine"));
assertFalse(CheckGcpEnvironment.checkBiosDataOnWindows(reader));
reader = new BufferedReader(new StringReader("Manufacturer : Google"));
assertTrue(CheckGcpEnvironment.checkBiosDataOnWindows(reader));
reader = new BufferedReader(new StringReader("Manufacturer:Google"));
assertTrue(CheckGcpEnvironment.checkBiosDataOnWindows(reader));
reader = new BufferedReader(new StringReader("Manufacturer : Google "));
assertTrue(CheckGcpEnvironment.checkBiosDataOnWindows(reader));
reader = new BufferedReader(new StringReader("BIOSVersion : 1.0\nManufacturer : Google\n"));
assertTrue(CheckGcpEnvironment.checkBiosDataOnWindows(reader));
}
}

View File

@ -0,0 +1,193 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import io.grpc.alts.transportsecurity.FakeTsiHandshaker;
import io.grpc.alts.transportsecurity.TsiHandshaker;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.util.ReferenceCounted;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;
import org.junit.After;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@RunWith(JUnit4.class)
public class InternalNettyTsiHandshakerTest {
private final UnpooledByteBufAllocator alloc = UnpooledByteBufAllocator.DEFAULT;
private final List<ReferenceCounted> references = new ArrayList<>();
private final InternalNettyTsiHandshaker clientHandshaker =
new InternalNettyTsiHandshaker(FakeTsiHandshaker.newFakeHandshakerClient());
private final InternalNettyTsiHandshaker serverHandshaker =
new InternalNettyTsiHandshaker(FakeTsiHandshaker.newFakeHandshakerServer());
@After
public void teardown() {
for (ReferenceCounted reference : references) {
reference.release(reference.refCnt());
}
}
@Test
public void failsOnNullHandshaker() {
try {
new InternalNettyTsiHandshaker(null);
fail("Exception expected");
} catch (NullPointerException ex) {
// Do nothing.
}
}
@Test
public void processPeerHandshakeShouldAcceptPartialFrames() throws GeneralSecurityException {
for (int i = 0; i < 1024; i++) {
ByteBuf clientData = ref(alloc.buffer(1));
clientHandshaker.getBytesToSendToPeer(clientData);
if (clientData.isReadable()) {
if (serverHandshaker.processBytesFromPeer(clientData)) {
// Done.
return;
}
}
}
fail("Failed to process the handshake frame.");
}
@Test
public void handshakeShouldSucceed() throws GeneralSecurityException {
doHandshake();
}
@Test
public void isInProgress() throws GeneralSecurityException {
assertTrue(clientHandshaker.isInProgress());
assertTrue(serverHandshaker.isInProgress());
doHandshake();
assertFalse(clientHandshaker.isInProgress());
assertFalse(serverHandshaker.isInProgress());
}
@Test
public void extractPeer_notNull() throws GeneralSecurityException {
doHandshake();
assertNotNull(serverHandshaker.extractPeer());
assertNotNull(clientHandshaker.extractPeer());
}
@Test
public void extractPeer_failsBeforeHandshake() throws GeneralSecurityException {
try {
clientHandshaker.extractPeer();
fail("Exception expected");
} catch (IllegalStateException ex) {
// Do nothing.
}
}
@Test
public void extractPeerObject_notNull() throws GeneralSecurityException {
doHandshake();
assertNotNull(serverHandshaker.extractPeerObject());
assertNotNull(clientHandshaker.extractPeerObject());
}
@Test
public void extractPeerObject_failsBeforeHandshake() throws GeneralSecurityException {
try {
clientHandshaker.extractPeerObject();
fail("Exception expected");
} catch (IllegalStateException ex) {
// Do nothing.
}
}
/**
* InternalNettyTsiHandshaker just converts {@link ByteBuffer} to {@link ByteBuf}, so check that
* the other methods are otherwise the same.
*/
@Test
public void handshakerMethodsMatch() {
List<String> expectedMethods = new ArrayList<>();
for (Method m : TsiHandshaker.class.getDeclaredMethods()) {
expectedMethods.add(m.getName());
}
List<String> actualMethods = new ArrayList<>();
for (Method m : InternalNettyTsiHandshaker.class.getDeclaredMethods()) {
actualMethods.add(m.getName());
}
assertThat(actualMethods).containsAllIn(expectedMethods);
}
static void doHandshake(
InternalNettyTsiHandshaker clientHandshaker,
InternalNettyTsiHandshaker serverHandshaker,
ByteBufAllocator alloc,
Function<ByteBuf, ByteBuf> ref)
throws GeneralSecurityException {
// Get the server response handshake frames.
for (int i = 0; i < 10; i++) {
if (!(clientHandshaker.isInProgress() || serverHandshaker.isInProgress())) {
return;
}
ByteBuf clientData = ref.apply(alloc.buffer());
clientHandshaker.getBytesToSendToPeer(clientData);
if (clientData.isReadable()) {
serverHandshaker.processBytesFromPeer(clientData);
}
ByteBuf serverData = ref.apply(alloc.buffer());
serverHandshaker.getBytesToSendToPeer(serverData);
if (serverData.isReadable()) {
clientHandshaker.processBytesFromPeer(serverData);
}
}
throw new AssertionError("Failed to complete the handshake.");
}
private void doHandshake() throws GeneralSecurityException {
doHandshake(clientHandshaker, serverHandshaker, alloc, buf -> ref(buf));
}
private ByteBuf ref(ByteBuf buf) {
if (buf != null) {
references.add(buf);
}
return buf;
}
}

View File

@ -0,0 +1,248 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import io.grpc.alts.RpcProtocolVersionsUtil.RpcVersionsCheckResult;
import io.grpc.alts.TransportSecurityCommon.RpcProtocolVersions;
import io.grpc.alts.TransportSecurityCommon.RpcProtocolVersions.Version;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link RpcProtocolVersionsUtil}. */
@RunWith(JUnit4.class)
public final class RpcProtocolVersionsUtilTest {
@Test
public void compareVersions() throws Exception {
assertTrue(
RpcProtocolVersionsUtil.isGreaterThanOrEqualTo(
Version.newBuilder().setMajor(3).setMinor(2).build(),
Version.newBuilder().setMajor(2).setMinor(1).build()));
assertTrue(
RpcProtocolVersionsUtil.isGreaterThanOrEqualTo(
Version.newBuilder().setMajor(3).setMinor(2).build(),
Version.newBuilder().setMajor(2).setMinor(1).build()));
assertTrue(
RpcProtocolVersionsUtil.isGreaterThanOrEqualTo(
Version.newBuilder().setMajor(3).setMinor(2).build(),
Version.newBuilder().setMajor(3).setMinor(2).build()));
assertFalse(
RpcProtocolVersionsUtil.isGreaterThanOrEqualTo(
Version.newBuilder().setMajor(2).setMinor(3).build(),
Version.newBuilder().setMajor(3).setMinor(2).build()));
assertFalse(
RpcProtocolVersionsUtil.isGreaterThanOrEqualTo(
Version.newBuilder().setMajor(3).setMinor(1).build(),
Version.newBuilder().setMajor(3).setMinor(2).build()));
}
@Test
public void checkRpcVersions() throws Exception {
// local.max > peer.max and local.min > peer.min
RpcVersionsCheckResult checkResult =
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder().setMajor(3).setMinor(2).build())
.setMinRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
.build(),
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
.setMinRpcVersion(Version.newBuilder().setMajor(1).setMinor(2).build())
.build());
assertTrue(checkResult.getResult());
assertEquals(
Version.newBuilder().setMajor(2).setMinor(1).build(),
checkResult.getHighestCommonVersion());
// local.max > peer.max and local.min < peer.min
checkResult =
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder().setMajor(3).setMinor(2).build())
.setMinRpcVersion(Version.newBuilder().setMajor(1).setMinor(2).build())
.build(),
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
.setMinRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
.build());
assertTrue(checkResult.getResult());
assertEquals(
Version.newBuilder().setMajor(2).setMinor(1).build(),
checkResult.getHighestCommonVersion());
// local.max > peer.max and local.min = peer.min
checkResult =
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder().setMajor(3).setMinor(2).build())
.setMinRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
.build(),
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
.setMinRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
.build());
assertTrue(checkResult.getResult());
assertEquals(
Version.newBuilder().setMajor(2).setMinor(1).build(),
checkResult.getHighestCommonVersion());
// local.max < peer.max and local.min > peer.min
checkResult =
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
.setMinRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
.build(),
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder().setMajor(3).setMinor(2).build())
.setMinRpcVersion(Version.newBuilder().setMajor(1).setMinor(2).build())
.build());
assertTrue(checkResult.getResult());
assertEquals(
Version.newBuilder().setMajor(2).setMinor(1).build(),
checkResult.getHighestCommonVersion());
// local.max = peer.max and local.min > peer.min
checkResult =
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
.setMinRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
.build(),
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
.setMinRpcVersion(Version.newBuilder().setMajor(1).setMinor(2).build())
.build());
assertTrue(checkResult.getResult());
assertEquals(
Version.newBuilder().setMajor(2).setMinor(1).build(),
checkResult.getHighestCommonVersion());
// local.max < peer.max and local.min < peer.min
checkResult =
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
.setMinRpcVersion(Version.newBuilder().setMajor(1).setMinor(2).build())
.build(),
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder().setMajor(3).setMinor(2).build())
.setMinRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
.build());
assertTrue(checkResult.getResult());
assertEquals(
Version.newBuilder().setMajor(2).setMinor(1).build(),
checkResult.getHighestCommonVersion());
// local.max < peer.max and local.min = peer.min
checkResult =
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
.setMinRpcVersion(Version.newBuilder().setMajor(1).setMinor(2).build())
.build(),
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder().setMajor(3).setMinor(2).build())
.setMinRpcVersion(Version.newBuilder().setMajor(1).setMinor(2).build())
.build());
assertTrue(checkResult.getResult());
assertEquals(
Version.newBuilder().setMajor(2).setMinor(1).build(),
checkResult.getHighestCommonVersion());
// local.max = peer.max and local.min < peer.min
checkResult =
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
.setMinRpcVersion(Version.newBuilder().setMajor(1).setMinor(2).build())
.build(),
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
.setMinRpcVersion(Version.newBuilder().setMajor(1).setMinor(2).build())
.build());
assertTrue(checkResult.getResult());
assertEquals(
Version.newBuilder().setMajor(2).setMinor(1).build(),
checkResult.getHighestCommonVersion());
// all equal
checkResult =
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
.setMinRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
.build(),
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
.setMinRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
.build());
assertTrue(checkResult.getResult());
assertEquals(
Version.newBuilder().setMajor(2).setMinor(1).build(),
checkResult.getHighestCommonVersion());
// max is smaller than min
checkResult =
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder().setMajor(1).setMinor(2).build())
.setMinRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
.build(),
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder().setMajor(1).setMinor(2).build())
.setMinRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
.build());
assertFalse(checkResult.getResult());
assertEquals(null, checkResult.getHighestCommonVersion());
// no overlap, local > peer
checkResult =
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder().setMajor(6).setMinor(5).build())
.setMinRpcVersion(Version.newBuilder().setMajor(4).setMinor(3).build())
.build(),
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
.setMinRpcVersion(Version.newBuilder().setMajor(0).setMinor(0).build())
.build());
assertFalse(checkResult.getResult());
assertEquals(null, checkResult.getHighestCommonVersion());
// no overlap, local < peer
checkResult =
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
.setMinRpcVersion(Version.newBuilder().setMajor(1).setMinor(0).build())
.build(),
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder().setMajor(6).setMinor(5).build())
.setMinRpcVersion(Version.newBuilder().setMajor(4).setMinor(3).build())
.build());
assertFalse(checkResult.getResult());
assertEquals(null, checkResult.getHighestCommonVersion());
// no overlap, max < min
checkResult =
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder().setMajor(4).setMinor(3).build())
.setMinRpcVersion(Version.newBuilder().setMajor(6).setMinor(5).build())
.build(),
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder().setMajor(1).setMinor(0).build())
.setMinRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
.build());
assertFalse(checkResult.getResult());
assertEquals(null, checkResult.getHighestCommonVersion());
}
}

View File

@ -0,0 +1,494 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import static com.google.common.truth.Truth.assertWithMessage;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.Arrays;
import javax.xml.bind.DatatypeConverter;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link AesGcmHkdfAeadCrypter}. */
@RunWith(JUnit4.class)
public final class AesGcmHkdfAeadCrypterTest {
private static class TestVector {
final String comment;
final byte[] key;
final byte[] nonce;
final byte[] aad;
final byte[] plaintext;
final byte[] ciphertext;
TestVector(TestVectorBuilder builder) {
comment = builder.comment;
key = builder.key;
nonce = builder.nonce;
aad = builder.aad;
plaintext = builder.plaintext;
ciphertext = builder.ciphertext;
}
static TestVectorBuilder builder() {
return new TestVectorBuilder();
}
}
private static class TestVectorBuilder {
String comment;
byte[] key;
byte[] nonce;
byte[] aad;
byte[] plaintext;
byte[] ciphertext;
TestVector build() {
if (comment == null
&& key == null
&& nonce == null
&& aad == null
&& plaintext == null
&& ciphertext == null) {
throw new IllegalStateException("All fields must be set before calling build().");
}
return new TestVector(this);
}
TestVectorBuilder withComment(String comment) {
this.comment = comment;
return this;
}
TestVectorBuilder withKey(String key) {
this.key = DatatypeConverter.parseHexBinary(key);
return this;
}
TestVectorBuilder withNonce(String nonce) {
this.nonce = DatatypeConverter.parseHexBinary(nonce);
return this;
}
TestVectorBuilder withAad(String aad) {
this.aad = DatatypeConverter.parseHexBinary(aad);
return this;
}
TestVectorBuilder withPlaintext(String plaintext) {
this.plaintext = DatatypeConverter.parseHexBinary(plaintext);
return this;
}
TestVectorBuilder withCiphertext(String ciphertext) {
this.ciphertext = DatatypeConverter.parseHexBinary(ciphertext);
return this;
}
}
@Test
public void testVectorEncrypt() throws GeneralSecurityException {
int i = 0;
for (TestVector testVector : testVectors) {
int bufferSize = testVector.ciphertext.length;
byte[] ciphertext = new byte[bufferSize];
ByteBuffer ciphertextBuffer = ByteBuffer.wrap(ciphertext);
AesGcmHkdfAeadCrypter aeadCrypter = new AesGcmHkdfAeadCrypter(testVector.key);
aeadCrypter.encrypt(
ciphertextBuffer,
ByteBuffer.wrap(testVector.plaintext),
ByteBuffer.wrap(testVector.aad),
testVector.nonce);
String msg = "Failure for test vector " + i;
assertWithMessage(msg)
.that(ciphertextBuffer.remaining())
.isEqualTo(bufferSize - testVector.ciphertext.length);
byte[] exactCiphertext = Arrays.copyOf(ciphertext, testVector.ciphertext.length);
assertWithMessage(msg).that(exactCiphertext).isEqualTo(testVector.ciphertext);
i++;
}
}
@Test
public void testVectorDecrypt() throws GeneralSecurityException {
int i = 0;
for (TestVector testVector : testVectors) {
// The plaintext buffer might require space for the tag to decrypt (e.g., for conscrypt).
int bufferSize = testVector.ciphertext.length;
byte[] plaintext = new byte[bufferSize];
ByteBuffer plaintextBuffer = ByteBuffer.wrap(plaintext);
AesGcmHkdfAeadCrypter aeadCrypter = new AesGcmHkdfAeadCrypter(testVector.key);
aeadCrypter.decrypt(
plaintextBuffer,
ByteBuffer.wrap(testVector.ciphertext),
ByteBuffer.wrap(testVector.aad),
testVector.nonce);
String msg = "Failure for test vector " + i;
assertWithMessage(msg)
.that(plaintextBuffer.remaining())
.isEqualTo(bufferSize - testVector.plaintext.length);
byte[] exactPlaintext = Arrays.copyOf(plaintext, testVector.plaintext.length);
assertWithMessage(msg).that(exactPlaintext).isEqualTo(testVector.plaintext);
i++;
}
}
/*
* NIST vectors from:
* http://csrc.nist.gov/groups/ST/toolkit/BCM/documents/proposedmodes/gcm/gcm-revised-spec.pdf
*
* IEEE vectors from:
* http://www.ieee802.org/1/files/public/docs2011/bn-randall-test-vectors-0511-v1.pdf
* Key expanded by setting
* expandedKey = (key||(key ^ {0x01, .., 0x01})||key ^ {0x02,..,0x02}))[0:44].
*/
private static final TestVector[] testVectors =
new TestVector[] {
TestVector.builder()
.withComment("Derived from NIST test vector 1")
.withKey(
"0000000000000000000000000000000001010101010101010101010101010101020202020202020202"
+ "020202")
.withNonce("000000000000000000000000")
.withAad("")
.withPlaintext("")
.withCiphertext("85e873e002f6ebdc4060954eb8675508")
.build(),
TestVector.builder()
.withComment("Derived from NIST test vector 2")
.withKey(
"0000000000000000000000000000000001010101010101010101010101010101020202020202020202"
+ "020202")
.withNonce("000000000000000000000000")
.withAad("")
.withPlaintext("00000000000000000000000000000000")
.withCiphertext("51e9a8cb23ca2512c8256afff8e72d681aca19a1148ac115e83df4888cc00d11")
.build(),
TestVector.builder()
.withComment("Derived from NIST test vector 3")
.withKey(
"feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f"
+ "688d96")
.withNonce("cafebabefacedbaddecaf888")
.withAad("")
.withPlaintext(
"d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532f"
+ "cf0e2449a6b525b16aedf5aa0de657ba637b391aafd255")
.withCiphertext(
"1018ed5a1402a86516d6576d70b2ffccca261b94df88b58f53b64dfba435d18b2f6e3b7869f9353d4a"
+ "c8cf09afb1663daa7b4017e6fc2c177c0c087c0df1162129952213cee1bc6e9c8495dd705e1f"
+ "3d")
.build(),
TestVector.builder()
.withComment("Derived from NIST test vector 4")
.withKey(
"feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f"
+ "688d96")
.withNonce("cafebabefacedbaddecaf888")
.withAad("feedfacedeadbeeffeedfacedeadbeefabaddad2")
.withPlaintext(
"d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532f"
+ "cf0e2449a6b525b16aedf5aa0de657ba637b39")
.withCiphertext(
"1018ed5a1402a86516d6576d70b2ffccca261b94df88b58f53b64dfba435d18b2f6e3b7869f9353d4a"
+ "c8cf09afb1663daa7b4017e6fc2c177c0c087c4764565d077e9124001ddb27fc0848c5")
.build(),
TestVector.builder()
.withComment(
"Derived from adapted NIST test vector 4"
+ " for KDF counter boundary (flip nonce bit 15)")
.withKey(
"feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f"
+ "688d96")
.withNonce("ca7ebabefacedbaddecaf888")
.withAad("feedfacedeadbeeffeedfacedeadbeefabaddad2")
.withPlaintext(
"d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532f"
+ "cf0e2449a6b525b16aedf5aa0de657ba637b39")
.withCiphertext(
"e650d3c0fb879327f2d03287fa93cd07342b136215adbca00c3bd5099ec41832b1d18e0423ed26bb12"
+ "c6cd09debb29230a94c0cee15903656f85edb6fc509b1b28216382172ecbcc31e1e9b1")
.build(),
TestVector.builder()
.withComment(
"Derived from adapted NIST test vector 4"
+ " for KDF counter boundary (flip nonce bit 16)")
.withKey(
"feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f"
+ "688d96")
.withNonce("cafebbbefacedbaddecaf888")
.withAad("feedfacedeadbeeffeedfacedeadbeefabaddad2")
.withPlaintext(
"d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532f"
+ "cf0e2449a6b525b16aedf5aa0de657ba637b39")
.withCiphertext(
"c0121e6c954d0767f96630c33450999791b2da2ad05c4190169ccad9ac86ff1c721e3d82f2ad22ab46"
+ "3bab4a0754b7dd68ca4de7ea2531b625eda01f89312b2ab957d5c7f8568dd95fcdcd1f")
.build(),
TestVector.builder()
.withComment(
"Derived from adapted NIST test vector 4"
+ " for KDF counter boundary (flip nonce bit 63)")
.withKey(
"feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f"
+ "688d96")
.withNonce("cafebabefacedb2ddecaf888")
.withAad("feedfacedeadbeeffeedfacedeadbeefabaddad2")
.withPlaintext(
"d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532f"
+ "cf0e2449a6b525b16aedf5aa0de657ba637b39")
.withCiphertext(
"8af37ea5684a4d81d4fd817261fd9743099e7e6a025eaacf8e54b124fb5743149e05cb89f4a49467fe"
+ "2e5e5965f29a19f99416b0016b54585d12553783ba59e9f782e82e097c336bf7989f08")
.build(),
TestVector.builder()
.withComment(
"Derived from adapted NIST test vector 4"
+ " for KDF counter boundary (flip nonce bit 64)")
.withKey(
"feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f"
+ "688d96")
.withNonce("cafebabefacedbaddfcaf888")
.withAad("feedfacedeadbeeffeedfacedeadbeefabaddad2")
.withPlaintext(
"d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532f"
+ "cf0e2449a6b525b16aedf5aa0de657ba637b39")
.withCiphertext(
"fbd528448d0346bfa878634864d407a35a039de9db2f1feb8e965b3ae9356ce6289441d77f8f0df294"
+ "891f37ea438b223e3bf2bdc53d4c5a74fb680bb312a8dec6f7252cbcd7f5799750ad78")
.build(),
TestVector.builder()
.withComment("Derived from IEEE 2.1.1 54-byte auth")
.withKey(
"ad7a2bd03eac835a6f620fdcb506b345ac7b2ad13fad825b6e630eddb407b244af7829d23cae81586d"
+ "600dde")
.withNonce("12153524c0895e81b2c28465")
.withAad(
"d609b1f056637a0d46df998d88e5222ab2c2846512153524c0895e8108000f10111213141516171819"
+ "1a1b1c1d1e1f202122232425262728292a2b2c2d2e2f30313233340001")
.withPlaintext("")
.withCiphertext("3ea0b584f3c85e93f9320ea591699efb")
.build(),
TestVector.builder()
.withComment("Derived from IEEE 2.1.2 54-byte auth")
.withKey(
"e3c08a8f06c6e3ad95a70557b23f75483ce33021a9c72b7025666204c69c0b72e1c2888d04c4e1af97"
+ "a50755")
.withNonce("12153524c0895e81b2c28465")
.withAad(
"d609b1f056637a0d46df998d88e5222ab2c2846512153524c0895e8108000f10111213141516171819"
+ "1a1b1c1d1e1f202122232425262728292a2b2c2d2e2f30313233340001")
.withPlaintext("")
.withCiphertext("294e028bf1fe6f14c4e8f7305c933eb5")
.build(),
TestVector.builder()
.withComment("Derived from IEEE 2.2.1 60-byte crypt")
.withKey(
"ad7a2bd03eac835a6f620fdcb506b345ac7b2ad13fad825b6e630eddb407b244af7829d23cae81586d"
+ "600dde")
.withNonce("12153524c0895e81b2c28465")
.withAad("d609b1f056637a0d46df998d88e52e00b2c2846512153524c0895e81")
.withPlaintext(
"08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435"
+ "363738393a0002")
.withCiphertext(
"db3d25719c6b0a3ca6145c159d5c6ed9aff9c6e0b79f17019ea923b8665ddf52137ad611f0d1bf417a"
+ "7ca85e45afe106ff9c7569d335d086ae6c03f00987ccd6")
.build(),
TestVector.builder()
.withComment("Derived from IEEE 2.2.2 60-byte crypt")
.withKey(
"e3c08a8f06c6e3ad95a70557b23f75483ce33021a9c72b7025666204c69c0b72e1c2888d04c4e1af97"
+ "a50755")
.withNonce("12153524c0895e81b2c28465")
.withAad("d609b1f056637a0d46df998d88e52e00b2c2846512153524c0895e81")
.withPlaintext(
"08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435"
+ "363738393a0002")
.withCiphertext(
"1641f28ec13afcc8f7903389787201051644914933e9202bb9d06aa020c2a67ef51dfe7bc00a856c55"
+ "b8f8133e77f659132502bad63f5713d57d0c11e0f871ed")
.build(),
TestVector.builder()
.withComment("Derived from IEEE 2.3.1 60-byte auth")
.withKey(
"071b113b0ca743fecccf3d051f737382061a103a0da642ffcdce3c041e727283051913390ea541fcce"
+ "cd3f07")
.withNonce("f0761e8dcd3d000176d457ed")
.withAad(
"e20106d7cd0df0761e8dcd3d88e5400076d457ed08000f101112131415161718191a1b1c1d1e1f2021"
+ "22232425262728292a2b2c2d2e2f303132333435363738393a0003")
.withPlaintext("")
.withCiphertext("58837a10562b0f1f8edbe58ca55811d3")
.build(),
TestVector.builder()
.withComment("Derived from IEEE 2.3.2 60-byte auth")
.withKey(
"691d3ee909d7f54167fd1ca0b5d769081f2bde1aee655fdbab80bd5295ae6be76b1f3ceb0bd5f74365"
+ "ff1ea2")
.withNonce("f0761e8dcd3d000176d457ed")
.withAad(
"e20106d7cd0df0761e8dcd3d88e5400076d457ed08000f101112131415161718191a1b1c1d1e1f2021"
+ "22232425262728292a2b2c2d2e2f303132333435363738393a0003")
.withPlaintext("")
.withCiphertext("c2722ff6ca29a257718a529d1f0c6a3b")
.build(),
TestVector.builder()
.withComment("Derived from IEEE 2.4.1 54-byte crypt")
.withKey(
"071b113b0ca743fecccf3d051f737382061a103a0da642ffcdce3c041e727283051913390ea541fcce"
+ "cd3f07")
.withNonce("f0761e8dcd3d000176d457ed")
.withAad("e20106d7cd0df0761e8dcd3d88e54c2a76d457ed")
.withPlaintext(
"08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333400"
+ "04")
.withCiphertext(
"fd96b715b93a13346af51e8acdf792cdc7b2686f8574c70e6b0cbf16291ded427ad73fec48cd298e05"
+ "28a1f4c644a949fc31dc9279706ddba33f")
.build(),
TestVector.builder()
.withComment("Derived from IEEE 2.4.2 54-byte crypt")
.withKey(
"691d3ee909d7f54167fd1ca0b5d769081f2bde1aee655fdbab80bd5295ae6be76b1f3ceb0bd5f74365"
+ "ff1ea2")
.withNonce("f0761e8dcd3d000176d457ed")
.withAad("e20106d7cd0df0761e8dcd3d88e54c2a76d457ed")
.withPlaintext(
"08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333400"
+ "04")
.withCiphertext(
"b68f6300c2e9ae833bdc070e24021a3477118e78ccf84e11a485d861476c300f175353d5cdf92008a4"
+ "f878e6cc3577768085c50a0e98fda6cbb8")
.build(),
TestVector.builder()
.withComment("Derived from IEEE 2.5.1 65-byte auth")
.withKey(
"013fe00b5f11be7f866d0cbbc55a7a90003ee10a5e10bf7e876c0dbac45b7b91033de2095d13bc7d84"
+ "6f0eb9")
.withNonce("7cfde9f9e33724c68932d612")
.withAad(
"84c5d513d2aaf6e5bbd2727788e523008932d6127cfde9f9e33724c608000f10111213141516171819"
+ "1a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f"
+ "0005")
.withPlaintext("")
.withCiphertext("cca20eecda6283f09bb3543dd99edb9b")
.build(),
TestVector.builder()
.withComment("Derived from IEEE 2.5.2 65-byte auth")
.withKey(
"83c093b58de7ffe1c0da926ac43fb3609ac1c80fee1b624497ef942e2f79a82381c291b78fe5fde3c2"
+ "d89068")
.withNonce("7cfde9f9e33724c68932d612")
.withAad(
"84c5d513d2aaf6e5bbd2727788e523008932d6127cfde9f9e33724c608000f10111213141516171819"
+ "1a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f"
+ "0005")
.withPlaintext("")
.withCiphertext("b232cc1da5117bf15003734fa599d271")
.build(),
TestVector.builder()
.withComment("Derived from IEEE 2.6.1 61-byte crypt")
.withKey(
"013fe00b5f11be7f866d0cbbc55a7a90003ee10a5e10bf7e876c0dbac45b7b91033de2095d13bc7d84"
+ "6f0eb9")
.withNonce("7cfde9f9e33724c68932d612")
.withAad("84c5d513d2aaf6e5bbd2727788e52f008932d6127cfde9f9e33724c6")
.withPlaintext(
"08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435"
+ "363738393a3b0006")
.withCiphertext(
"ff1910d35ad7e5657890c7c560146fd038707f204b66edbc3d161f8ace244b985921023c436e3a1c35"
+ "32ecd5d09a056d70be583f0d10829d9387d07d33d872e490")
.build(),
TestVector.builder()
.withComment("Derived from IEEE 2.6.2 61-byte crypt")
.withKey(
"83c093b58de7ffe1c0da926ac43fb3609ac1c80fee1b624497ef942e2f79a82381c291b78fe5fde3c2"
+ "d89068")
.withNonce("7cfde9f9e33724c68932d612")
.withAad("84c5d513d2aaf6e5bbd2727788e52f008932d6127cfde9f9e33724c6")
.withPlaintext(
"08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435"
+ "363738393a3b0006")
.withCiphertext(
"0db4cf956b5f97eca4eab82a6955307f9ae02a32dd7d93f83d66ad04e1cfdc5182ad12abdea5bbb619"
+ "a1bd5fb9a573590fba908e9c7a46c1f7ba0905d1b55ffda4")
.build(),
TestVector.builder()
.withComment("Derived from IEEE 2.7.1 79-byte crypt")
.withKey(
"88ee087fd95da9fbf6725aa9d757b0cd89ef097ed85ca8faf7735ba8d656b1cc8aec0a7ddb5fabf9f4"
+ "7058ab")
.withNonce("7ae8e2ca4ec500012e58495c")
.withAad(
"68f2e77696ce7ae8e2ca4ec588e541002e58495c08000f101112131415161718191a1b1c1d1e1f2021"
+ "22232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f4041424344454647"
+ "48494a4b4c4d0007")
.withPlaintext("")
.withCiphertext("813f0e630f96fb2d030f58d83f5cdfd0")
.build(),
TestVector.builder()
.withComment("Derived from IEEE 2.7.2 79-byte crypt")
.withKey(
"4c973dbc7364621674f8b5b89e5c15511fced9216490fb1c1a2caa0ffe0407e54e953fbe7166601476"
+ "fab7ba")
.withNonce("7ae8e2ca4ec500012e58495c")
.withAad(
"68f2e77696ce7ae8e2ca4ec588e541002e58495c08000f101112131415161718191a1b1c1d1e1f2021"
+ "22232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f4041424344454647"
+ "48494a4b4c4d0007")
.withPlaintext("")
.withCiphertext("77e5a44c21eb07188aacbd74d1980e97")
.build(),
TestVector.builder()
.withComment("Derived from IEEE 2.8.1 61-byte crypt")
.withKey(
"88ee087fd95da9fbf6725aa9d757b0cd89ef097ed85ca8faf7735ba8d656b1cc8aec0a7ddb5fabf9f4"
+ "7058ab")
.withNonce("7ae8e2ca4ec500012e58495c")
.withAad("68f2e77696ce7ae8e2ca4ec588e54d002e58495c")
.withPlaintext(
"08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435"
+ "363738393a3b3c3d3e3f404142434445464748490008")
.withCiphertext(
"958ec3f6d60afeda99efd888f175e5fcd4c87b9bcc5c2f5426253a8b506296c8c43309ab2adb593946"
+ "2541d95e80811e04e706b1498f2c407c7fb234f8cc01a647550ee6b557b35a7e3945381821"
+ "f4")
.build(),
TestVector.builder()
.withComment("Derived from IEEE 2.8.2 61-byte crypt")
.withKey(
"4c973dbc7364621674f8b5b89e5c15511fced9216490fb1c1a2caa0ffe0407e54e953fbe7166601476"
+ "fab7ba")
.withNonce("7ae8e2ca4ec500012e58495c")
.withAad("68f2e77696ce7ae8e2ca4ec588e54d002e58495c")
.withPlaintext(
"08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435"
+ "363738393a3b3c3d3e3f404142434445464748490008")
.withCiphertext(
"b44d072011cd36d272a9b7a98db9aa90cbc5c67b93ddce67c854503214e2e896ec7e9db649ed4bcf6f"
+ "850aac0223d0cf92c83db80795c3a17ecc1248bb00591712b1ae71e268164196252162810b"
+ "00")
.build()
};
}

View File

@ -0,0 +1,81 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import static org.junit.Assert.assertEquals;
import io.grpc.alts.Handshaker.HandshakerResult;
import io.grpc.alts.Handshaker.Identity;
import io.grpc.alts.TransportSecurityCommon.RpcProtocolVersions;
import io.grpc.alts.TransportSecurityCommon.SecurityLevel;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link AltsAuthContext}. */
@RunWith(JUnit4.class)
public final class AltsAuthContextTest {
private static final int TEST_MAX_RPC_VERSION_MAJOR = 3;
private static final int TEST_MAX_RPC_VERSION_MINOR = 5;
private static final int TEST_MIN_RPC_VERSION_MAJOR = 2;
private static final int TEST_MIN_RPC_VERSION_MINOR = 1;
private static final SecurityLevel TEST_SECURITY_LEVEL = SecurityLevel.INTEGRITY_AND_PRIVACY;
private static final String TEST_APPLICATION_PROTOCOL = "grpc";
private static final String TEST_LOCAL_SERVICE_ACCOUNT = "local@gserviceaccount.com";
private static final String TEST_PEER_SERVICE_ACCOUNT = "peer@gserviceaccount.com";
private static final String TEST_RECORD_PROTOCOL = "ALTSRP_GCM_AES128";
private HandshakerResult handshakerResult;
private RpcProtocolVersions rpcVersions;
@Before
public void setUp() {
rpcVersions =
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(
RpcProtocolVersions.Version.newBuilder()
.setMajor(TEST_MAX_RPC_VERSION_MAJOR)
.setMinor(TEST_MAX_RPC_VERSION_MINOR)
.build())
.setMinRpcVersion(
RpcProtocolVersions.Version.newBuilder()
.setMajor(TEST_MIN_RPC_VERSION_MAJOR)
.setMinor(TEST_MIN_RPC_VERSION_MINOR)
.build())
.build();
handshakerResult =
HandshakerResult.newBuilder()
.setApplicationProtocol(TEST_APPLICATION_PROTOCOL)
.setRecordProtocol(TEST_RECORD_PROTOCOL)
.setPeerIdentity(Identity.newBuilder().setServiceAccount(TEST_PEER_SERVICE_ACCOUNT))
.setLocalIdentity(Identity.newBuilder().setServiceAccount(TEST_LOCAL_SERVICE_ACCOUNT))
.setPeerRpcVersions(rpcVersions)
.build();
}
@Test
public void testAltsAuthContext() {
AltsAuthContext authContext = new AltsAuthContext(handshakerResult);
assertEquals(TEST_APPLICATION_PROTOCOL, authContext.getApplicationProtocol());
assertEquals(TEST_RECORD_PROTOCOL, authContext.getRecordProtocol());
assertEquals(TEST_SECURITY_LEVEL, authContext.getSecurityLevel());
assertEquals(TEST_PEER_SERVICE_ACCOUNT, authContext.getPeerServiceAccount());
assertEquals(TEST_LOCAL_SERVICE_ACCOUNT, authContext.getLocalServiceAccount());
assertEquals(rpcVersions, authContext.getPeerRpcVersions());
}
}

View File

@ -0,0 +1,150 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.alts.transportsecurity.AltsChannelCrypter.incrementCounter;
import static org.junit.Assert.fail;
import com.google.common.testing.GcFinalization;
import io.netty.util.ReferenceCounted;
import io.netty.util.ResourceLeakDetector;
import io.netty.util.ResourceLeakDetector.Level;
import java.security.GeneralSecurityException;
import java.util.Arrays;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link AltsChannelCrypter}. */
@RunWith(JUnit4.class)
public final class AltsChannelCrypterTest extends ChannelCrypterNettyTestBase {
@Before
public void setUp() throws GeneralSecurityException {
ResourceLeakDetector.setLevel(Level.PARANOID);
client = new AltsChannelCrypter(new byte[AltsChannelCrypter.getKeyLength()], true);
server = new AltsChannelCrypter(new byte[AltsChannelCrypter.getKeyLength()], false);
}
@After
public void tearDown() throws GeneralSecurityException {
for (ReferenceCounted reference : references) {
reference.release();
}
references.clear();
client.destroy();
server.destroy();
// Increase our chances to detect ByteBuf leaks.
GcFinalization.awaitFullGc();
}
@Test
public void encryptDecryptKdfCounterIncr() throws GeneralSecurityException {
AltsChannelCrypter client =
new AltsChannelCrypter(new byte[AltsChannelCrypter.getKeyLength()], true);
AltsChannelCrypter server =
new AltsChannelCrypter(new byte[AltsChannelCrypter.getKeyLength()], false);
String message = "Hello world";
FrameEncrypt frameEncrypt1 = createFrameEncrypt(message);
client.encrypt(frameEncrypt1.out, frameEncrypt1.plain);
FrameDecrypt frameDecrypt1 = frameDecryptOfEncrypt(frameEncrypt1);
server.decrypt(frameDecrypt1.out, frameDecrypt1.tag, frameDecrypt1.ciphertext);
assertThat(frameEncrypt1.plain.get(0).slice(0, frameDecrypt1.out.readableBytes()))
.isEqualTo(frameDecrypt1.out);
// Increase counters to get a new KDF counter value (first two bytes are skipped).
client.incrementOutCounterForTesting(1 << 17);
server.incrementInCounterForTesting(1 << 17);
FrameEncrypt frameEncrypt2 = createFrameEncrypt(message);
client.encrypt(frameEncrypt2.out, frameEncrypt2.plain);
FrameDecrypt frameDecrypt2 = frameDecryptOfEncrypt(frameEncrypt2);
server.decrypt(frameDecrypt2.out, frameDecrypt2.tag, frameDecrypt2.ciphertext);
assertThat(frameEncrypt2.plain.get(0).slice(0, frameDecrypt2.out.readableBytes()))
.isEqualTo(frameDecrypt2.out);
}
@Test
public void overflowsClient() throws GeneralSecurityException {
byte[] maxFirst =
new byte[] {
(byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF,
(byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF,
(byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00
};
byte[] maxFirstPred = Arrays.copyOf(maxFirst, maxFirst.length);
maxFirstPred[0]--;
byte[] oldCounter = new byte[AltsChannelCrypter.getCounterLength()];
byte[] counter = Arrays.copyOf(maxFirstPred, maxFirstPred.length);
incrementCounter(counter, oldCounter);
assertThat(oldCounter).isEqualTo(maxFirstPred);
assertThat(counter).isEqualTo(maxFirst);
try {
incrementCounter(counter, oldCounter);
fail("Exception expected");
} catch (GeneralSecurityException ex) {
assertThat(ex).hasMessageThat().contains("Counter has overflowed");
}
assertThat(oldCounter).isEqualTo(maxFirst);
assertThat(counter).isEqualTo(maxFirst);
}
@Test
public void overflowsServer() throws GeneralSecurityException {
byte[] maxSecond =
new byte[] {
(byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF,
(byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF,
(byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x80
};
byte[] maxSecondPred = Arrays.copyOf(maxSecond, maxSecond.length);
maxSecondPred[0]--;
byte[] oldCounter = new byte[AltsChannelCrypter.getCounterLength()];
byte[] counter = Arrays.copyOf(maxSecondPred, maxSecondPred.length);
incrementCounter(counter, oldCounter);
assertThat(oldCounter).isEqualTo(maxSecondPred);
assertThat(counter).isEqualTo(maxSecond);
try {
incrementCounter(counter, oldCounter);
fail("Exception expected");
} catch (GeneralSecurityException ex) {
assertThat(ex).hasMessageThat().contains("Counter has overflowed");
}
assertThat(oldCounter).isEqualTo(maxSecond);
assertThat(counter).isEqualTo(maxSecond);
}
}

View File

@ -0,0 +1,55 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import static com.google.common.truth.Truth.assertThat;
import io.grpc.alts.TransportSecurityCommon.RpcProtocolVersions;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link AltsClientOptions}. */
@RunWith(JUnit4.class)
public final class AltsClientOptionsTest {
@Test
public void setAndGet() throws Exception {
String targetName = "foo";
String serviceAccount1 = "bar1";
String serviceAccount2 = "bar2";
RpcProtocolVersions rpcVersions =
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(
RpcProtocolVersions.Version.newBuilder().setMajor(2).setMinor(1).build())
.setMinRpcVersion(
RpcProtocolVersions.Version.newBuilder().setMajor(2).setMinor(1).build())
.build();
AltsClientOptions options =
new AltsClientOptions.Builder()
.setTargetName(targetName)
.addTargetServiceAccount(serviceAccount1)
.addTargetServiceAccount(serviceAccount2)
.setRpcProtocolVersions(rpcVersions)
.build();
assertThat(options.getTargetName()).isEqualTo(targetName);
assertThat(options.getTargetServiceAccounts()).containsAllOf(serviceAccount1, serviceAccount2);
assertThat(options.getRpcProtocolVersions()).isEqualTo(rpcVersions);
}
}

View File

@ -0,0 +1,126 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.security.GeneralSecurityException;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link AltsFraming}. */
@RunWith(JUnit4.class)
public class AltsFramingTest {
@Test
public void parserFrameLengthNegativeFails() throws GeneralSecurityException {
AltsFraming.Parser parser = new AltsFraming.Parser();
// frame length + one remaining byte (required)
ByteBuffer buffer = ByteBuffer.allocate(AltsFraming.getFrameLengthHeaderSize() + 1);
buffer.order(ByteOrder.LITTLE_ENDIAN);
buffer.putInt(-1); // write invalid length
buffer.put((byte) 0); // write some byte
buffer.flip();
try {
parser.readBytes(buffer);
fail("Exception expected");
} catch (IllegalArgumentException ex) {
assertThat(ex).hasMessageThat().contains("Invalid frame length");
}
}
@Test
public void parserFrameLengthSmallerMessageTypeFails() throws GeneralSecurityException {
AltsFraming.Parser parser = new AltsFraming.Parser();
// frame length + one remaining byte (required)
ByteBuffer buffer = ByteBuffer.allocate(AltsFraming.getFrameLengthHeaderSize() + 1);
buffer.order(ByteOrder.LITTLE_ENDIAN);
buffer.putInt(AltsFraming.getFrameMessageTypeHeaderSize() - 1); // write invalid length
buffer.put((byte) 0); // write some byte
buffer.flip();
try {
parser.readBytes(buffer);
fail("Exception expected");
} catch (IllegalArgumentException ex) {
assertThat(ex).hasMessageThat().contains("Invalid frame length");
}
}
@Test
public void parserFrameLengthTooLargeFails() throws GeneralSecurityException {
AltsFraming.Parser parser = new AltsFraming.Parser();
// frame length + one remaining byte (required)
ByteBuffer buffer = ByteBuffer.allocate(AltsFraming.getFrameLengthHeaderSize() + 1);
buffer.order(ByteOrder.LITTLE_ENDIAN);
buffer.putInt(AltsFraming.getMaxDataLength() + 1); // write invalid length
buffer.put((byte) 0); // write some byte
buffer.flip();
try {
parser.readBytes(buffer);
fail("Exception expected");
} catch (IllegalArgumentException ex) {
assertThat(ex).hasMessageThat().contains("Invalid frame length");
}
}
@Test
public void parserFrameLengthMaxOk() throws GeneralSecurityException {
AltsFraming.Parser parser = new AltsFraming.Parser();
// length of type header + data
int dataLength = AltsFraming.getMaxDataLength();
// complete frame + 1 byte
ByteBuffer buffer =
ByteBuffer.allocate(AltsFraming.getFrameLengthHeaderSize() + dataLength + 1);
buffer.order(ByteOrder.LITTLE_ENDIAN);
buffer.putInt(dataLength); // write invalid length
buffer.putInt(6); // default message type
buffer.put(new byte[dataLength - AltsFraming.getFrameMessageTypeHeaderSize()]); // write data
buffer.put((byte) 0);
buffer.flip();
parser.readBytes(buffer);
assertThat(parser.isComplete()).isTrue();
assertThat(buffer.remaining()).isEqualTo(1);
}
@Test
public void parserFrameLengthZeroOk() throws GeneralSecurityException {
AltsFraming.Parser parser = new AltsFraming.Parser();
int dataLength = AltsFraming.getFrameMessageTypeHeaderSize();
// complete frame + 1 byte
ByteBuffer buffer =
ByteBuffer.allocate(AltsFraming.getFrameLengthHeaderSize() + dataLength + 1);
buffer.order(ByteOrder.LITTLE_ENDIAN);
buffer.putInt(dataLength); // write invalid length
buffer.putInt(6); // default message type
buffer.put((byte) 0);
buffer.flip();
parser.readBytes(buffer);
assertThat(parser.isComplete()).isTrue();
assertThat(buffer.remaining()).isEqualTo(1);
}
}

View File

@ -0,0 +1,263 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import io.grpc.alts.Handshaker.HandshakeProtocol;
import io.grpc.alts.Handshaker.HandshakerReq;
import io.grpc.alts.Handshaker.Identity;
import io.grpc.alts.Handshaker.StartClientHandshakeReq;
import io.grpc.alts.TransportSecurityCommon.RpcProtocolVersions;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Matchers;
/** Unit tests for {@link AltsHandshakerClient}. */
@RunWith(JUnit4.class)
public class AltsHandshakerClientTest {
private static final int IN_BYTES_SIZE = 100;
private static final int BYTES_CONSUMED = 30;
private static final int PREFIX_POSITION = 20;
private static final String TEST_TARGET_NAME = "target name";
private static final String TEST_TARGET_SERVICE_ACCOUNT = "peer service account";
private AltsHandshakerStub mockStub;
private AltsHandshakerClient handshaker;
private AltsClientOptions clientOptions;
@Before
public void setUp() {
mockStub = mock(AltsHandshakerStub.class);
clientOptions =
new AltsClientOptions.Builder()
.setTargetName(TEST_TARGET_NAME)
.addTargetServiceAccount(TEST_TARGET_SERVICE_ACCOUNT)
.build();
handshaker = new AltsHandshakerClient(mockStub, clientOptions);
}
@Test
public void startClientHandshakeFailure() throws Exception {
when(mockStub.send(Matchers.<HandshakerReq>any()))
.thenReturn(MockAltsHandshakerResp.getErrorResponse());
try {
handshaker.startClientHandshake();
fail("Exception expected");
} catch (GeneralSecurityException ex) {
assertThat(ex).hasMessageThat().contains(MockAltsHandshakerResp.getTestErrorDetails());
}
}
@Test
public void startClientHandshakeSuccess() throws Exception {
when(mockStub.send(Matchers.<HandshakerReq>any()))
.thenReturn(MockAltsHandshakerResp.getOkResponse(0));
ByteBuffer outFrame = handshaker.startClientHandshake();
assertEquals(ByteString.copyFrom(outFrame), MockAltsHandshakerResp.getOutFrame());
assertFalse(handshaker.isFinished());
assertNull(handshaker.getResult());
assertNull(handshaker.getKey());
}
@Test
public void startClientHandshakeWithOptions() throws Exception {
when(mockStub.send(Matchers.<HandshakerReq>any()))
.thenReturn(MockAltsHandshakerResp.getOkResponse(0));
ByteBuffer outFrame = handshaker.startClientHandshake();
assertEquals(ByteString.copyFrom(outFrame), MockAltsHandshakerResp.getOutFrame());
HandshakerReq req =
HandshakerReq.newBuilder()
.setClientStart(
StartClientHandshakeReq.newBuilder()
.setHandshakeSecurityProtocol(HandshakeProtocol.ALTS)
.addApplicationProtocols(AltsHandshakerClient.getApplicationProtocol())
.addRecordProtocols(AltsHandshakerClient.getRecordProtocol())
.setTargetName(TEST_TARGET_NAME)
.addTargetIdentities(
Identity.newBuilder().setServiceAccount(TEST_TARGET_SERVICE_ACCOUNT))
.build())
.build();
verify(mockStub).send(req);
}
@Test
public void startServerHandshakeFailure() throws Exception {
when(mockStub.send(Matchers.<HandshakerReq>any()))
.thenReturn(MockAltsHandshakerResp.getErrorResponse());
try {
ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE);
handshaker.startServerHandshake(inBytes);
fail("Exception expected");
} catch (GeneralSecurityException ex) {
assertThat(ex).hasMessageThat().contains(MockAltsHandshakerResp.getTestErrorDetails());
}
}
@Test
public void startServerHandshakeSuccess() throws Exception {
when(mockStub.send(Matchers.<HandshakerReq>any()))
.thenReturn(MockAltsHandshakerResp.getOkResponse(BYTES_CONSUMED));
ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE);
ByteBuffer outFrame = handshaker.startServerHandshake(inBytes);
assertEquals(ByteString.copyFrom(outFrame), MockAltsHandshakerResp.getOutFrame());
assertFalse(handshaker.isFinished());
assertNull(handshaker.getResult());
assertNull(handshaker.getKey());
assertEquals(IN_BYTES_SIZE - BYTES_CONSUMED, inBytes.remaining());
}
@Test
public void startServerHandshakeEmptyOutFrame() throws Exception {
when(mockStub.send(Matchers.<HandshakerReq>any()))
.thenReturn(MockAltsHandshakerResp.getEmptyOutFrameResponse(BYTES_CONSUMED));
ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE);
ByteBuffer outFrame = handshaker.startServerHandshake(inBytes);
assertEquals(0, outFrame.remaining());
assertFalse(handshaker.isFinished());
assertNull(handshaker.getResult());
assertNull(handshaker.getKey());
assertEquals(IN_BYTES_SIZE - BYTES_CONSUMED, inBytes.remaining());
}
@Test
public void startServerHandshakeWithPrefixBuffer() throws Exception {
when(mockStub.send(Matchers.<HandshakerReq>any()))
.thenReturn(MockAltsHandshakerResp.getOkResponse(BYTES_CONSUMED));
ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE);
inBytes.position(PREFIX_POSITION);
ByteBuffer outFrame = handshaker.startServerHandshake(inBytes);
assertEquals(ByteString.copyFrom(outFrame), MockAltsHandshakerResp.getOutFrame());
assertFalse(handshaker.isFinished());
assertNull(handshaker.getResult());
assertNull(handshaker.getKey());
assertEquals(PREFIX_POSITION + BYTES_CONSUMED, inBytes.position());
assertEquals(IN_BYTES_SIZE - BYTES_CONSUMED - PREFIX_POSITION, inBytes.remaining());
}
@Test
public void nextFailure() throws Exception {
when(mockStub.send(Matchers.<HandshakerReq>any()))
.thenReturn(MockAltsHandshakerResp.getErrorResponse());
try {
ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE);
handshaker.next(inBytes);
fail("Exception expected");
} catch (GeneralSecurityException ex) {
assertThat(ex).hasMessageThat().contains(MockAltsHandshakerResp.getTestErrorDetails());
}
}
@Test
public void nextSuccess() throws Exception {
when(mockStub.send(Matchers.<HandshakerReq>any()))
.thenReturn(MockAltsHandshakerResp.getOkResponse(BYTES_CONSUMED));
ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE);
ByteBuffer outFrame = handshaker.next(inBytes);
assertEquals(ByteString.copyFrom(outFrame), MockAltsHandshakerResp.getOutFrame());
assertFalse(handshaker.isFinished());
assertNull(handshaker.getResult());
assertNull(handshaker.getKey());
assertEquals(IN_BYTES_SIZE - BYTES_CONSUMED, inBytes.remaining());
}
@Test
public void nextEmptyOutFrame() throws Exception {
when(mockStub.send(Matchers.<HandshakerReq>any()))
.thenReturn(MockAltsHandshakerResp.getEmptyOutFrameResponse(BYTES_CONSUMED));
ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE);
ByteBuffer outFrame = handshaker.next(inBytes);
assertEquals(0, outFrame.remaining());
assertFalse(handshaker.isFinished());
assertNull(handshaker.getResult());
assertNull(handshaker.getKey());
assertEquals(IN_BYTES_SIZE - BYTES_CONSUMED, inBytes.remaining());
}
@Test
public void nextFinished() throws Exception {
when(mockStub.send(Matchers.<HandshakerReq>any()))
.thenReturn(MockAltsHandshakerResp.getFinishedResponse(BYTES_CONSUMED));
ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE);
ByteBuffer outFrame = handshaker.next(inBytes);
assertEquals(ByteString.copyFrom(outFrame), MockAltsHandshakerResp.getOutFrame());
assertTrue(handshaker.isFinished());
assertArrayEquals(handshaker.getKey(), MockAltsHandshakerResp.getTestKeyData());
assertEquals(IN_BYTES_SIZE - BYTES_CONSUMED, inBytes.remaining());
}
@Test
public void setRpcVersions() throws Exception {
when(mockStub.send(Matchers.<HandshakerReq>any()))
.thenReturn(MockAltsHandshakerResp.getOkResponse(0));
RpcProtocolVersions rpcVersions =
RpcProtocolVersions.newBuilder()
.setMinRpcVersion(
RpcProtocolVersions.Version.newBuilder().setMajor(3).setMinor(4).build())
.setMaxRpcVersion(
RpcProtocolVersions.Version.newBuilder().setMajor(5).setMinor(6).build())
.build();
clientOptions =
new AltsClientOptions.Builder()
.setTargetName(TEST_TARGET_NAME)
.addTargetServiceAccount(TEST_TARGET_SERVICE_ACCOUNT)
.setRpcProtocolVersions(rpcVersions)
.build();
handshaker = new AltsHandshakerClient(mockStub, clientOptions);
handshaker.startClientHandshake();
ArgumentCaptor<HandshakerReq> reqCaptor = ArgumentCaptor.forClass(HandshakerReq.class);
verify(mockStub).send(reqCaptor.capture());
assertEquals(rpcVersions, reqCaptor.getValue().getClientStart().getRpcVersions());
}
}

View File

@ -0,0 +1,43 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import static com.google.common.truth.Truth.assertThat;
import io.grpc.alts.TransportSecurityCommon.RpcProtocolVersions;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link AltsHandshakerOptions}. */
@RunWith(JUnit4.class)
public final class AltsHandshakerOptionsTest {
@Test
public void setAndGet() throws Exception {
RpcProtocolVersions rpcVersions =
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(
RpcProtocolVersions.Version.newBuilder().setMajor(2).setMinor(1).build())
.setMinRpcVersion(
RpcProtocolVersions.Version.newBuilder().setMajor(2).setMinor(1).build())
.build();
AltsHandshakerOptions options = new AltsHandshakerOptions(rpcVersions);
assertThat(options.getRpcProtocolVersions()).isEqualTo(rpcVersions);
}
}

View File

@ -0,0 +1,199 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import com.google.protobuf.ByteString;
import io.grpc.alts.Handshaker.HandshakerReq;
import io.grpc.alts.Handshaker.HandshakerResp;
import io.grpc.alts.Handshaker.NextHandshakeMessageReq;
import io.grpc.stub.StreamObserver;
import java.io.IOException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link AltsHandshakerStub}. */
@RunWith(JUnit4.class)
public class AltsHandshakerStubTest {
/** Mock status of handshaker service. */
private static enum Status {
OK,
ERROR,
COMPLETE
}
private AltsHandshakerStub stub;
private MockWriter writer;
private ExecutorService executor;
@Before
public void setUp() {
executor = Executors.newSingleThreadExecutor();
writer = new MockWriter();
stub = new AltsHandshakerStub(writer);
writer.setReader(stub.getReaderForTest());
}
@After
public void tearDown() {
executor.shutdown();
}
/** Send a message as in_bytes and expect same message as out_frames echo back. */
private void sendSuccessfulMessage() throws Exception {
String message = "hello world";
HandshakerReq.Builder req =
HandshakerReq.newBuilder()
.setNext(
NextHandshakeMessageReq.newBuilder()
.setInBytes(ByteString.copyFromUtf8(message))
.build());
HandshakerResp resp = stub.send(req.build());
assertEquals(resp.getOutFrames().toStringUtf8(), message);
}
/** Send a message and expect an IOException on error. */
private void sendAndExpectError() throws InterruptedException {
try {
stub.send(HandshakerReq.newBuilder().build());
fail("Exception expected");
} catch (IOException ex) {
assertThat(ex).hasMessageThat().contains("Received a terminating error");
}
}
/** Send a message and expect an IOException on closing. */
private void sendAndExpectComplete() throws InterruptedException {
try {
stub.send(HandshakerReq.newBuilder().build());
fail("Exception expected");
} catch (IOException ex) {
assertThat(ex).hasMessageThat().contains("Response stream closed");
}
}
/** Send a message and expect an IOException on unexpected message. */
private void sendAndExpectUnexpectedMessage() throws InterruptedException {
try {
stub.send(HandshakerReq.newBuilder().build());
fail("Exception expected");
} catch (IOException ex) {
assertThat(ex).hasMessageThat().contains("Received an unexpected response");
}
}
@Test
public void sendSuccessfulMessageTest() throws Exception {
writer.setServiceStatus(Status.OK);
sendSuccessfulMessage();
stub.close();
}
@Test
public void getServiceErrorTest() throws InterruptedException {
writer.setServiceStatus(Status.ERROR);
sendAndExpectError();
stub.close();
}
@Test
public void getServiceCompleteTest() throws Exception {
writer.setServiceStatus(Status.COMPLETE);
sendAndExpectComplete();
stub.close();
}
@Test
public void getUnexpectedMessageTest() throws Exception {
writer.setServiceStatus(Status.OK);
writer.sendUnexpectedResponse();
sendAndExpectUnexpectedMessage();
stub.close();
}
@Test
public void closeEarlyTest() throws InterruptedException {
stub.close();
sendAndExpectComplete();
}
private class MockWriter implements StreamObserver<HandshakerReq> {
private StreamObserver<HandshakerResp> reader;
private Status status = Status.OK;
private void setReader(StreamObserver<HandshakerResp> reader) {
this.reader = reader;
}
private void setServiceStatus(Status status) {
this.status = status;
}
/** Send a handshaker response to reader. */
private void sendUnexpectedResponse() {
reader.onNext(HandshakerResp.newBuilder().build());
}
/** Mock writer onNext. Will respond based on the server status. */
@Override
public void onNext(final HandshakerReq req) {
executor.execute(
new Runnable() {
@Override
public void run() {
switch (status) {
case OK:
HandshakerResp.Builder resp = HandshakerResp.newBuilder();
reader.onNext(resp.setOutFrames(req.getNext().getInBytes()).build());
break;
case ERROR:
reader.onError(new RuntimeException());
break;
case COMPLETE:
reader.onCompleted();
break;
default:
return;
}
}
});
}
@Override
public void onError(Throwable t) {}
/** Mock writer onComplete. */
@Override
public void onCompleted() {
executor.execute(
new Runnable() {
@Override
public void run() {
reader.onCompleted();
}
});
}
}
}

View File

@ -0,0 +1,480 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.alts.transportsecurity.ByteBufTestUtils.getDirectBuffer;
import static io.grpc.alts.transportsecurity.ByteBufTestUtils.getRandom;
import static io.grpc.alts.transportsecurity.ByteBufTestUtils.writeSlice;
import static org.junit.Assert.fail;
import com.google.common.testing.GcFinalization;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.util.ReferenceCounted;
import io.netty.util.ResourceLeakDetector;
import io.netty.util.ResourceLeakDetector.Level;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link AltsTsiFrameProtector}. */
@RunWith(JUnit4.class)
public class AltsTsiFrameProtectorTest {
private static final int FRAME_MIN_SIZE =
AltsTsiFrameProtector.getHeaderTypeFieldBytes() + FakeChannelCrypter.getTagBytes();
private final List<ReferenceCounted> references = new ArrayList<ReferenceCounted>();
@Before
public void setUp() {
ResourceLeakDetector.setLevel(Level.PARANOID);
}
@After
public void teardown() {
for (ReferenceCounted reference : references) {
reference.release();
}
references.clear();
// Increase our chances to detect ByteBuf leaks.
GcFinalization.awaitFullGc();
}
@Test
public void parserHeader_frameLengthNegativeFails() throws GeneralSecurityException {
ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
List<Object> out = new ArrayList<>();
FakeChannelCrypter crypter = new FakeChannelCrypter();
AltsTsiFrameProtector.Unprotector unprotector =
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf in = getDirectBuffer(AltsTsiFrameProtector.getHeaderBytes(), this::ref);
in.writeIntLE(-1);
in.writeIntLE(6);
try {
unprotector.unprotect(in, out, alloc);
fail("Exception expected");
} catch (IllegalArgumentException ex) {
assertThat(ex).hasMessageThat().contains("Invalid header field: frame size too small");
}
unprotector.destroy();
}
@Test
public void parserHeader_frameTooSmall() throws GeneralSecurityException {
ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
List<Object> out = new ArrayList<>();
FakeChannelCrypter crypter = new FakeChannelCrypter();
AltsTsiFrameProtector.Unprotector unprotector =
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf in =
getDirectBuffer(
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
in.writeIntLE(FRAME_MIN_SIZE - 1);
in.writeIntLE(6);
try {
unprotector.unprotect(in, out, alloc);
fail("Exception expected");
} catch (IllegalArgumentException ex) {
assertThat(ex).hasMessageThat().contains("Invalid header field: frame size too small");
}
unprotector.destroy();
}
@Test
public void parserHeader_frameTooLarge() throws GeneralSecurityException {
ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
List<Object> out = new ArrayList<>();
FakeChannelCrypter crypter = new FakeChannelCrypter();
AltsTsiFrameProtector.Unprotector unprotector =
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf in =
getDirectBuffer(
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
in.writeIntLE(
AltsTsiFrameProtector.getLimitMaxAllowedFrameBytes()
- AltsTsiFrameProtector.getHeaderLenFieldBytes()
+ 1);
in.writeIntLE(6);
try {
unprotector.unprotect(in, out, alloc);
fail("Exception expected");
} catch (IllegalArgumentException ex) {
assertThat(ex).hasMessageThat().contains("Invalid header field: frame size too large");
}
unprotector.destroy();
}
@Test
public void parserHeader_frameTypeInvalid() throws GeneralSecurityException {
ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
List<Object> out = new ArrayList<>();
FakeChannelCrypter crypter = new FakeChannelCrypter();
AltsTsiFrameProtector.Unprotector unprotector =
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf in =
getDirectBuffer(
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
in.writeIntLE(FRAME_MIN_SIZE);
in.writeIntLE(5);
try {
unprotector.unprotect(in, out, alloc);
fail("Exception expected");
} catch (IllegalArgumentException ex) {
assertThat(ex).hasMessageThat().contains("Invalid header field: frame type");
}
unprotector.destroy();
}
@Test
public void parserHeader_frameZeroOk() throws GeneralSecurityException {
ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
List<Object> out = new ArrayList<>();
FakeChannelCrypter crypter = new FakeChannelCrypter();
AltsTsiFrameProtector.Unprotector unprotector =
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf in =
getDirectBuffer(
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
in.writeIntLE(FRAME_MIN_SIZE);
in.writeIntLE(6);
unprotector.unprotect(in, out, alloc);
assertThat(in.readableBytes()).isEqualTo(0);
unprotector.destroy();
}
@Test
public void parserHeader_EmptyUnprotectNoRetain() throws GeneralSecurityException {
ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
List<Object> out = new ArrayList<>();
FakeChannelCrypter crypter = new FakeChannelCrypter();
AltsTsiFrameProtector.Unprotector unprotector =
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf emptyBuf = getDirectBuffer(0, this::ref);
unprotector.unprotect(emptyBuf, out, alloc);
assertThat(emptyBuf.refCnt()).isEqualTo(1);
unprotector.destroy();
}
@Test
public void parserHeader_frameMaxOk() throws GeneralSecurityException {
ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
List<Object> out = new ArrayList<>();
FakeChannelCrypter crypter = new FakeChannelCrypter();
AltsTsiFrameProtector.Unprotector unprotector =
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf in =
getDirectBuffer(
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
in.writeIntLE(
AltsTsiFrameProtector.getLimitMaxAllowedFrameBytes()
- AltsTsiFrameProtector.getHeaderLenFieldBytes());
in.writeIntLE(6);
unprotector.unprotect(in, out, alloc);
assertThat(in.readableBytes()).isEqualTo(0);
unprotector.destroy();
}
@Test
public void parserHeader_frameOkFragment() throws GeneralSecurityException {
ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
List<Object> out = new ArrayList<>();
FakeChannelCrypter crypter = new FakeChannelCrypter();
AltsTsiFrameProtector.Unprotector unprotector =
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf in =
getDirectBuffer(
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
in.writeIntLE(FRAME_MIN_SIZE);
in.writeIntLE(6);
ByteBuf in1 = in.readSlice(AltsTsiFrameProtector.getHeaderBytes() - 1);
ByteBuf in2 = in.readSlice(1);
unprotector.unprotect(in1, out, alloc);
assertThat(in1.readableBytes()).isEqualTo(0);
unprotector.unprotect(in2, out, alloc);
assertThat(in2.readableBytes()).isEqualTo(0);
unprotector.destroy();
}
@Test
public void parseHeader_frameFailFragment() throws GeneralSecurityException {
ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
List<Object> out = new ArrayList<>();
FakeChannelCrypter crypter = new FakeChannelCrypter();
AltsTsiFrameProtector.Unprotector unprotector =
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf in =
getDirectBuffer(
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
in.writeIntLE(FRAME_MIN_SIZE - 1);
in.writeIntLE(6);
ByteBuf in1 = in.readSlice(AltsTsiFrameProtector.getHeaderBytes() - 1);
ByteBuf in2 = in.readSlice(1);
unprotector.unprotect(in1, out, alloc);
assertThat(in1.readableBytes()).isEqualTo(0);
try {
unprotector.unprotect(in2, out, alloc);
fail("Exception expected");
} catch (IllegalArgumentException ex) {
assertThat(ex).hasMessageThat().contains("Invalid header field: frame size too small");
}
assertThat(in2.readableBytes()).isEqualTo(0);
unprotector.destroy();
}
@Test
public void parseFrame_oneFrameNoFragment() throws GeneralSecurityException {
int payloadBytes = 1024;
ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
List<Object> out = new ArrayList<>();
FakeChannelCrypter crypter = new FakeChannelCrypter();
AltsTsiFrameProtector.Unprotector unprotector =
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf plain = getRandom(payloadBytes, this::ref);
ByteBuf outFrame =
getDirectBuffer(
AltsTsiFrameProtector.getHeaderBytes()
+ payloadBytes
+ FakeChannelCrypter.getTagBytes(),
this::ref);
outFrame.writeIntLE(
AltsTsiFrameProtector.getHeaderTypeFieldBytes()
+ payloadBytes
+ FakeChannelCrypter.getTagBytes());
outFrame.writeIntLE(6);
List<ByteBuf> framePlain = Collections.singletonList(plain);
ByteBuf frameOut = writeSlice(outFrame, payloadBytes + FakeChannelCrypter.getTagBytes());
crypter.encrypt(frameOut, framePlain);
plain.readerIndex(0);
unprotector.unprotect(outFrame, out, alloc);
assertThat(outFrame.readableBytes()).isEqualTo(0);
assertThat(out.size()).isEqualTo(1);
ByteBuf out1 = ref((ByteBuf) out.get(0));
assertThat(out1).isEqualTo(plain);
unprotector.destroy();
}
@Test
public void parseFrame_twoFramesNoFragment() throws GeneralSecurityException {
int payloadBytes = 1536;
int payloadBytes1 = 1024;
int payloadBytes2 = payloadBytes - payloadBytes1;
ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
List<Object> out = new ArrayList<>();
FakeChannelCrypter crypter = new FakeChannelCrypter();
AltsTsiFrameProtector.Unprotector unprotector =
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf plain = getRandom(payloadBytes, this::ref);
ByteBuf outFrame =
getDirectBuffer(
2 * (AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes())
+ payloadBytes,
this::ref);
outFrame.writeIntLE(
AltsTsiFrameProtector.getHeaderTypeFieldBytes()
+ payloadBytes1
+ FakeChannelCrypter.getTagBytes());
outFrame.writeIntLE(6);
List<ByteBuf> framePlain1 = Collections.singletonList(plain.readSlice(payloadBytes1));
ByteBuf frameOut1 = writeSlice(outFrame, payloadBytes1 + FakeChannelCrypter.getTagBytes());
outFrame.writeIntLE(
AltsTsiFrameProtector.getHeaderTypeFieldBytes()
+ payloadBytes2
+ FakeChannelCrypter.getTagBytes());
outFrame.writeIntLE(6);
List<ByteBuf> framePlain2 = Collections.singletonList(plain);
ByteBuf frameOut2 = writeSlice(outFrame, payloadBytes2 + FakeChannelCrypter.getTagBytes());
crypter.encrypt(frameOut1, framePlain1);
crypter.encrypt(frameOut2, framePlain2);
plain.readerIndex(0);
unprotector.unprotect(outFrame, out, alloc);
assertThat(out.size()).isEqualTo(1);
ByteBuf out1 = ref((ByteBuf) out.get(0));
assertThat(out1).isEqualTo(plain);
assertThat(outFrame.refCnt()).isEqualTo(1);
assertThat(outFrame.readableBytes()).isEqualTo(0);
unprotector.destroy();
}
@Test
public void parseFrame_twoFramesNoFragment_Leftover() throws GeneralSecurityException {
int payloadBytes = 1536;
int payloadBytes1 = 1024;
int payloadBytes2 = payloadBytes - payloadBytes1;
ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
List<Object> out = new ArrayList<>();
FakeChannelCrypter crypter = new FakeChannelCrypter();
AltsTsiFrameProtector.Unprotector unprotector =
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf plain = getRandom(payloadBytes, this::ref);
ByteBuf protectedBuf =
getDirectBuffer(
2 * (AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes())
+ payloadBytes
+ AltsTsiFrameProtector.getHeaderBytes(),
this::ref);
protectedBuf.writeIntLE(
AltsTsiFrameProtector.getHeaderTypeFieldBytes()
+ payloadBytes1
+ FakeChannelCrypter.getTagBytes());
protectedBuf.writeIntLE(6);
List<ByteBuf> framePlain1 = Collections.singletonList(plain.readSlice(payloadBytes1));
ByteBuf frameOut1 = writeSlice(protectedBuf, payloadBytes1 + FakeChannelCrypter.getTagBytes());
protectedBuf.writeIntLE(
AltsTsiFrameProtector.getHeaderTypeFieldBytes()
+ payloadBytes2
+ FakeChannelCrypter.getTagBytes());
protectedBuf.writeIntLE(6);
List<ByteBuf> framePlain2 = Collections.singletonList(plain);
ByteBuf frameOut2 = writeSlice(protectedBuf, payloadBytes2 + FakeChannelCrypter.getTagBytes());
// This is an invalid header length field, make sure it triggers an error
// when the remainder of the header is given.
protectedBuf.writeIntLE((byte) -1);
crypter.encrypt(frameOut1, framePlain1);
crypter.encrypt(frameOut2, framePlain2);
plain.readerIndex(0);
unprotector.unprotect(protectedBuf, out, alloc);
assertThat(out.size()).isEqualTo(1);
ByteBuf out1 = ref((ByteBuf) out.get(0));
assertThat(out1).isEqualTo(plain);
// The protectedBuf is buffered inside the unprotector.
assertThat(protectedBuf.readableBytes()).isEqualTo(0);
assertThat(protectedBuf.refCnt()).isEqualTo(2);
protectedBuf.writeIntLE(6);
try {
unprotector.unprotect(protectedBuf, out, alloc);
fail("Exception expected");
} catch (IllegalArgumentException ex) {
assertThat(ex).hasMessageThat().contains("Invalid header field: frame size too small");
}
unprotector.destroy();
// Make sure that unprotector does not hold onto buffered ByteBuf instance after destroy.
assertThat(protectedBuf.refCnt()).isEqualTo(1);
// Make sure that destroying twice does not throw.
unprotector.destroy();
}
@Test
public void parseFrame_twoFramesFragmentSecond() throws GeneralSecurityException {
int payloadBytes = 1536;
int payloadBytes1 = 1024;
int payloadBytes2 = payloadBytes - payloadBytes1;
ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
List<Object> out = new ArrayList<>();
FakeChannelCrypter crypter = new FakeChannelCrypter();
AltsTsiFrameProtector.Unprotector unprotector =
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
ByteBuf plain = getRandom(payloadBytes, this::ref);
ByteBuf protectedBuf =
getDirectBuffer(
2 * (AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes())
+ payloadBytes
+ AltsTsiFrameProtector.getHeaderBytes(),
this::ref);
protectedBuf.writeIntLE(
AltsTsiFrameProtector.getHeaderTypeFieldBytes()
+ payloadBytes1
+ FakeChannelCrypter.getTagBytes());
protectedBuf.writeIntLE(6);
List<ByteBuf> framePlain1 = Collections.singletonList(plain.readSlice(payloadBytes1));
ByteBuf frameOut1 = writeSlice(protectedBuf, payloadBytes1 + FakeChannelCrypter.getTagBytes());
protectedBuf.writeIntLE(
AltsTsiFrameProtector.getHeaderTypeFieldBytes()
+ payloadBytes2
+ FakeChannelCrypter.getTagBytes());
protectedBuf.writeIntLE(6);
List<ByteBuf> framePlain2 = Collections.singletonList(plain);
ByteBuf frameOut2 = writeSlice(protectedBuf, payloadBytes2 + FakeChannelCrypter.getTagBytes());
crypter.encrypt(frameOut1, framePlain1);
crypter.encrypt(frameOut2, framePlain2);
plain.readerIndex(0);
unprotector.unprotect(
protectedBuf.readSlice(
payloadBytes
+ AltsTsiFrameProtector.getHeaderBytes()
+ FakeChannelCrypter.getTagBytes()
+ AltsTsiFrameProtector.getHeaderBytes()),
out,
alloc);
assertThat(out.size()).isEqualTo(1);
ByteBuf out1 = ref((ByteBuf) out.get(0));
assertThat(out1).isEqualTo(plain.readSlice(payloadBytes1));
assertThat(protectedBuf.refCnt()).isEqualTo(2);
unprotector.unprotect(protectedBuf, out, alloc);
assertThat(out.size()).isEqualTo(2);
ByteBuf out2 = ref((ByteBuf) out.get(1));
assertThat(out2).isEqualTo(plain);
assertThat(protectedBuf.refCnt()).isEqualTo(1);
unprotector.destroy();
}
private ByteBuf ref(ByteBuf buf) {
if (buf != null) {
references.add(buf);
}
return buf;
}
}

View File

@ -0,0 +1,269 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import io.grpc.alts.Handshaker.HandshakerResult;
import io.grpc.alts.Handshaker.Identity;
import io.grpc.alts.TransportSecurityCommon.RpcProtocolVersions;
import java.nio.ByteBuffer;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Matchers;
/** Unit tests for {@link AltsTsiHandshaker}. */
@RunWith(JUnit4.class)
public class AltsTsiHandshakerTest {
private static final String TEST_KEY_DATA = "super secret 123";
private static final String TEST_APPLICATION_PROTOCOL = "grpc";
private static final String TEST_RECORD_PROTOCOL = "ALTSRP_GCM_AES128";
private static final String TEST_CLIENT_SERVICE_ACCOUNT = "client@developer.gserviceaccount.com";
private static final String TEST_SERVER_SERVICE_ACCOUNT = "server@developer.gserviceaccount.com";
private static final int OUT_FRAME_SIZE = 100;
private static final int TRANSPORT_BUFFER_SIZE = 200;
private static final int TEST_MAX_RPC_VERSION_MAJOR = 3;
private static final int TEST_MAX_RPC_VERSION_MINOR = 2;
private static final int TEST_MIN_RPC_VERSION_MAJOR = 2;
private static final int TEST_MIN_RPC_VERSION_MINOR = 1;
private static final RpcProtocolVersions TEST_RPC_PROTOCOL_VERSIONS =
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(
RpcProtocolVersions.Version.newBuilder()
.setMajor(TEST_MAX_RPC_VERSION_MAJOR)
.setMinor(TEST_MAX_RPC_VERSION_MINOR)
.build())
.setMinRpcVersion(
RpcProtocolVersions.Version.newBuilder()
.setMajor(TEST_MIN_RPC_VERSION_MAJOR)
.setMinor(TEST_MIN_RPC_VERSION_MINOR)
.build())
.build();
private AltsHandshakerClient mockClient;
private AltsHandshakerClient mockServer;
private AltsTsiHandshaker handshakerClient;
private AltsTsiHandshaker handshakerServer;
@Before
public void setUp() throws Exception {
mockClient = mock(AltsHandshakerClient.class);
mockServer = mock(AltsHandshakerClient.class);
handshakerClient = new AltsTsiHandshaker(true, mockClient);
handshakerServer = new AltsTsiHandshaker(false, mockServer);
}
private HandshakerResult getHandshakerResult(boolean isClient) {
HandshakerResult.Builder builder =
HandshakerResult.newBuilder()
.setApplicationProtocol(TEST_APPLICATION_PROTOCOL)
.setRecordProtocol(TEST_RECORD_PROTOCOL)
.setKeyData(ByteString.copyFromUtf8(TEST_KEY_DATA))
.setPeerRpcVersions(TEST_RPC_PROTOCOL_VERSIONS);
if (isClient) {
builder.setPeerIdentity(
Identity.newBuilder().setServiceAccount(TEST_SERVER_SERVICE_ACCOUNT).build());
builder.setLocalIdentity(
Identity.newBuilder().setServiceAccount(TEST_CLIENT_SERVICE_ACCOUNT).build());
} else {
builder.setPeerIdentity(
Identity.newBuilder().setServiceAccount(TEST_CLIENT_SERVICE_ACCOUNT).build());
builder.setLocalIdentity(
Identity.newBuilder().setServiceAccount(TEST_SERVER_SERVICE_ACCOUNT).build());
}
return builder.build();
}
@Test
public void processBytesFromPeerFalseStart() throws Exception {
verify(mockClient, never()).startClientHandshake();
verify(mockClient, never()).startServerHandshake(Matchers.<ByteBuffer>any());
verify(mockClient, never()).next(Matchers.<ByteBuffer>any());
ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
assertTrue(handshakerClient.processBytesFromPeer(transportBuffer));
}
@Test
public void processBytesFromPeerStartServer() throws Exception {
ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE);
verify(mockServer, never()).startClientHandshake();
verify(mockServer, never()).next(Matchers.<ByteBuffer>any());
// Mock transport buffer all consumed by processBytesFromPeer and there is an output frame.
transportBuffer.position(transportBuffer.limit());
when(mockServer.startServerHandshake(transportBuffer)).thenReturn(outputFrame);
when(mockServer.isFinished()).thenReturn(false);
assertTrue(handshakerServer.processBytesFromPeer(transportBuffer));
}
@Test
public void processBytesFromPeerStartServerEmptyOutput() throws Exception {
ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
ByteBuffer emptyOutputFrame = ByteBuffer.allocate(0);
verify(mockServer, never()).startClientHandshake();
verify(mockServer, never()).next(Matchers.<ByteBuffer>any());
// Mock transport buffer all consumed by processBytesFromPeer and output frame is empty.
// Expect processBytesFromPeer return False, because more data are needed from the peer.
transportBuffer.position(transportBuffer.limit());
when(mockServer.startServerHandshake(transportBuffer)).thenReturn(emptyOutputFrame);
when(mockServer.isFinished()).thenReturn(false);
assertFalse(handshakerServer.processBytesFromPeer(transportBuffer));
}
@Test
public void processBytesFromPeerStartServerFinished() throws Exception {
ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE);
verify(mockServer, never()).startClientHandshake();
verify(mockServer, never()).next(Matchers.<ByteBuffer>any());
// Mock handshake complete after processBytesFromPeer.
when(mockServer.startServerHandshake(transportBuffer)).thenReturn(outputFrame);
when(mockServer.isFinished()).thenReturn(true);
assertTrue(handshakerServer.processBytesFromPeer(transportBuffer));
}
@Test
public void processBytesFromPeerNoBytesConsumed() throws Exception {
ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
ByteBuffer emptyOutputFrame = ByteBuffer.allocate(0);
verify(mockServer, never()).startClientHandshake();
verify(mockServer, never()).next(Matchers.<ByteBuffer>any());
when(mockServer.startServerHandshake(transportBuffer)).thenReturn(emptyOutputFrame);
when(mockServer.isFinished()).thenReturn(false);
try {
assertTrue(handshakerServer.processBytesFromPeer(transportBuffer));
fail("Expected IllegalStateException");
} catch (IllegalStateException expected) {
assertEquals("Handshaker did not consume any bytes.", expected.getMessage());
}
}
@Test
public void processBytesFromPeerClientNext() throws Exception {
ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE);
verify(mockClient, never()).startServerHandshake(Matchers.<ByteBuffer>any());
when(mockClient.startClientHandshake()).thenReturn(outputFrame);
when(mockClient.next(transportBuffer)).thenReturn(outputFrame);
when(mockClient.isFinished()).thenReturn(false);
handshakerClient.getBytesToSendToPeer(transportBuffer);
transportBuffer.position(transportBuffer.limit());
assertFalse(handshakerClient.processBytesFromPeer(transportBuffer));
}
@Test
public void processBytesFromPeerClientNextFinished() throws Exception {
ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE);
verify(mockClient, never()).startServerHandshake(Matchers.<ByteBuffer>any());
when(mockClient.startClientHandshake()).thenReturn(outputFrame);
when(mockClient.next(transportBuffer)).thenReturn(outputFrame);
when(mockClient.isFinished()).thenReturn(true);
handshakerClient.getBytesToSendToPeer(transportBuffer);
assertTrue(handshakerClient.processBytesFromPeer(transportBuffer));
}
@Test
public void extractPeerFailure() throws Exception {
when(mockClient.isFinished()).thenReturn(false);
try {
handshakerClient.extractPeer();
fail("Expected IllegalStateException");
} catch (IllegalStateException expected) {
assertEquals("Handshake is not complete.", expected.getMessage());
}
}
@Test
public void extractPeerObjectFailure() throws Exception {
when(mockClient.isFinished()).thenReturn(false);
try {
handshakerClient.extractPeerObject();
fail("Expected IllegalStateException");
} catch (IllegalStateException expected) {
assertEquals("Handshake is not complete.", expected.getMessage());
}
}
@Test
public void extractClientPeerSuccess() throws Exception {
ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE);
ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
when(mockClient.startClientHandshake()).thenReturn(outputFrame);
when(mockClient.isFinished()).thenReturn(true);
when(mockClient.getResult()).thenReturn(getHandshakerResult(/* isClient = */ true));
handshakerClient.getBytesToSendToPeer(transportBuffer);
TsiPeer clientPeer = handshakerClient.extractPeer();
assertEquals(1, clientPeer.getProperties().size());
assertEquals(
TEST_SERVER_SERVICE_ACCOUNT,
clientPeer.getProperty(AltsTsiHandshaker.TSI_SERVICE_ACCOUNT_PEER_PROPERTY).getValue());
AltsAuthContext clientContext = (AltsAuthContext) handshakerClient.extractPeerObject();
assertEquals(TEST_APPLICATION_PROTOCOL, clientContext.getApplicationProtocol());
assertEquals(TEST_RECORD_PROTOCOL, clientContext.getRecordProtocol());
assertEquals(TEST_SERVER_SERVICE_ACCOUNT, clientContext.getPeerServiceAccount());
assertEquals(TEST_CLIENT_SERVICE_ACCOUNT, clientContext.getLocalServiceAccount());
assertEquals(TEST_RPC_PROTOCOL_VERSIONS, clientContext.getPeerRpcVersions());
}
@Test
public void extractServerPeerSuccess() throws Exception {
ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE);
ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
when(mockServer.startServerHandshake(Matchers.<ByteBuffer>any())).thenReturn(outputFrame);
when(mockServer.isFinished()).thenReturn(true);
when(mockServer.getResult()).thenReturn(getHandshakerResult(/* isClient = */ false));
handshakerServer.processBytesFromPeer(transportBuffer);
handshakerServer.getBytesToSendToPeer(transportBuffer);
TsiPeer serverPeer = handshakerServer.extractPeer();
assertEquals(1, serverPeer.getProperties().size());
assertEquals(
TEST_CLIENT_SERVICE_ACCOUNT,
serverPeer.getProperty(AltsTsiHandshaker.TSI_SERVICE_ACCOUNT_PEER_PROPERTY).getValue());
AltsAuthContext serverContext = (AltsAuthContext) handshakerServer.extractPeerObject();
assertEquals(TEST_APPLICATION_PROTOCOL, serverContext.getApplicationProtocol());
assertEquals(TEST_RECORD_PROTOCOL, serverContext.getRecordProtocol());
assertEquals(TEST_CLIENT_SERVICE_ACCOUNT, serverContext.getPeerServiceAccount());
assertEquals(TEST_SERVER_SERVICE_ACCOUNT, serverContext.getLocalServiceAccount());
assertEquals(TEST_RPC_PROTOCOL_VERSIONS, serverContext.getPeerRpcVersions());
}
}

View File

@ -0,0 +1,194 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import static org.junit.Assert.assertEquals;
import com.google.common.testing.GcFinalization;
import io.grpc.alts.Handshaker.HandshakeProtocol;
import io.grpc.alts.Handshaker.HandshakerReq;
import io.grpc.alts.Handshaker.HandshakerResp;
import io.grpc.alts.transportsecurity.TsiTest.Handshakers;
import io.netty.buffer.ByteBuf;
import io.netty.util.ReferenceCounted;
import io.netty.util.ResourceLeakDetector;
import io.netty.util.ResourceLeakDetector.Level;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.List;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link AltsTsiHandshaker}. */
@RunWith(JUnit4.class)
public class AltsTsiTest {
private static final int OVERHEAD =
FakeChannelCrypter.getTagBytes() + AltsTsiFrameProtector.getHeaderBytes();
private final List<ReferenceCounted> references = new ArrayList<>();
private AltsHandshakerClient client;
private AltsHandshakerClient server;
@Before
public void setUp() throws Exception {
ResourceLeakDetector.setLevel(Level.PARANOID);
// Use MockAltsHandshakerStub for all the tests.
AltsHandshakerOptions handshakerOptions = new AltsHandshakerOptions(null);
MockAltsHandshakerStub clientStub = new MockAltsHandshakerStub();
MockAltsHandshakerStub serverStub = new MockAltsHandshakerStub();
client = new AltsHandshakerClient(clientStub, handshakerOptions);
server = new AltsHandshakerClient(serverStub, handshakerOptions);
}
@After
public void tearDown() {
for (ReferenceCounted reference : references) {
reference.release();
}
references.clear();
// Increase our chances to detect ByteBuf leaks.
GcFinalization.awaitFullGc();
}
private Handshakers newHandshakers() {
TsiHandshaker clientHandshaker = new AltsTsiHandshaker(true, client);
TsiHandshaker serverHandshaker = new AltsTsiHandshaker(false, server);
return new Handshakers(clientHandshaker, serverHandshaker);
}
@Test
public void verifyHandshakePeer() throws Exception {
Handshakers handshakers = newHandshakers();
TsiTest.performHandshake(TsiTest.getDefaultTransportBufferSize(), handshakers);
TsiPeer clientPeer = handshakers.getClient().extractPeer();
assertEquals(1, clientPeer.getProperties().size());
assertEquals(
MockAltsHandshakerResp.getTestPeerAccount(),
clientPeer.getProperty("service_account").getValue());
TsiPeer serverPeer = handshakers.getServer().extractPeer();
assertEquals(1, serverPeer.getProperties().size());
assertEquals(
MockAltsHandshakerResp.getTestPeerAccount(),
serverPeer.getProperty("service_account").getValue());
}
@Test
public void handshake() throws GeneralSecurityException {
TsiTest.handshakeTest(newHandshakers());
}
@Test
public void handshakeSmallBuffer() throws GeneralSecurityException {
TsiTest.handshakeSmallBufferTest(newHandshakers());
}
@Test
public void pingPong() throws GeneralSecurityException {
TsiTest.pingPongTest(newHandshakers(), this::ref);
}
@Test
public void pingPongExactFrameSize() throws GeneralSecurityException {
TsiTest.pingPongExactFrameSizeTest(newHandshakers(), this::ref);
}
@Test
public void pingPongSmallBuffer() throws GeneralSecurityException {
TsiTest.pingPongSmallBufferTest(newHandshakers(), this::ref);
}
@Test
public void pingPongSmallFrame() throws GeneralSecurityException {
TsiTest.pingPongSmallFrameTest(OVERHEAD, newHandshakers(), this::ref);
}
@Test
public void pingPongSmallFrameSmallBuffer() throws GeneralSecurityException {
TsiTest.pingPongSmallFrameSmallBufferTest(OVERHEAD, newHandshakers(), this::ref);
}
@Test
public void corruptedCounter() throws GeneralSecurityException {
TsiTest.corruptedCounterTest(newHandshakers(), this::ref);
}
@Test
public void corruptedCiphertext() throws GeneralSecurityException {
TsiTest.corruptedCiphertextTest(newHandshakers(), this::ref);
}
@Test
public void corruptedTag() throws GeneralSecurityException {
TsiTest.corruptedTagTest(newHandshakers(), this::ref);
}
@Test
public void reflectedCiphertext() throws GeneralSecurityException {
TsiTest.reflectedCiphertextTest(newHandshakers(), this::ref);
}
private static class MockAltsHandshakerStub extends AltsHandshakerStub {
private boolean started = false;
@Override
public HandshakerResp send(HandshakerReq req) {
if (started) {
// Expect handshake next message.
if (req.getReqOneofCase().getNumber() != 3) {
return MockAltsHandshakerResp.getErrorResponse();
}
return MockAltsHandshakerResp.getFinishedResponse(req.getNext().getInBytes().size());
} else {
List<String> recordProtocols;
int bytesConsumed = 0;
switch (req.getReqOneofCase().getNumber()) {
case 1:
recordProtocols = req.getClientStart().getRecordProtocolsList();
break;
case 2:
recordProtocols =
req.getServerStart()
.getHandshakeParametersMap()
.get(HandshakeProtocol.ALTS.getNumber())
.getRecordProtocolsList();
bytesConsumed = req.getServerStart().getInBytes().size();
break;
default:
return MockAltsHandshakerResp.getErrorResponse();
}
if (recordProtocols.isEmpty()) {
return MockAltsHandshakerResp.getErrorResponse();
}
started = true;
return MockAltsHandshakerResp.getOkResponse(bytesConsumed);
}
}
@Override
public void close() {}
}
private ByteBuf ref(ByteBuf buf) {
if (buf != null) {
references.add(buf);
}
return buf;
}
}

View File

@ -0,0 +1,71 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import com.google.common.base.Preconditions;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
public final class ByteBufTestUtils {
public interface RegisterRef {
ByteBuf register(ByteBuf buf);
}
private static final Random random = new SecureRandom();
// The {@code ref} argument can be used to register the buffer for {@code release}.
// TODO: allow the allocator to be passed in.
public static ByteBuf getDirectBuffer(int len, RegisterRef ref) {
return ref.register(Unpooled.directBuffer(len));
}
/** Get random bytes. */
public static ByteBuf getRandom(int len, RegisterRef ref) {
ByteBuf buf = getDirectBuffer(len, ref);
byte[] bytes = new byte[len];
random.nextBytes(bytes);
buf.writeBytes(bytes);
return buf;
}
/** Fragment byte buffer into multiple pieces. */
public static List<ByteBuf> fragmentByteBuf(ByteBuf in, int num, RegisterRef ref) {
ByteBuf buf = in.slice();
Preconditions.checkArgument(num > 0);
List<ByteBuf> fragmentedBufs = new ArrayList<>(num);
int fragmentSize = buf.readableBytes() / num;
while (buf.isReadable()) {
int readBytes = num == 0 ? buf.readableBytes() : fragmentSize;
ByteBuf tmpBuf = getDirectBuffer(readBytes, ref);
tmpBuf.writeBytes(buf, readBytes);
fragmentedBufs.add(tmpBuf);
num--;
}
return fragmentedBufs;
}
static ByteBuf writeSlice(ByteBuf in, int len) {
Preconditions.checkArgument(len <= in.writableBytes());
ByteBuf out = in.slice(in.writerIndex(), len);
in.writerIndex(in.writerIndex() + len);
return out.writerIndex(0);
}
}

View File

@ -0,0 +1,222 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.alts.transportsecurity.ByteBufTestUtils.getDirectBuffer;
import static io.grpc.alts.transportsecurity.ByteBufTestUtils.getRandom;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.junit.Assert.fail;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.util.ReferenceCounted;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import javax.crypto.AEADBadTagException;
import org.junit.Test;
/** Abstract class for unit tests of {@link ChannelCrypterNetty}. */
public abstract class ChannelCrypterNettyTestBase {
private static final String DECRYPTION_FAILURE_MESSAGE = "Tag mismatch";
protected final List<ReferenceCounted> references = new ArrayList<>();
public ChannelCrypterNetty client;
public ChannelCrypterNetty server;
static final class FrameEncrypt {
List<ByteBuf> plain;
ByteBuf out;
}
static final class FrameDecrypt {
List<ByteBuf> ciphertext;
ByteBuf out;
ByteBuf tag;
}
FrameEncrypt createFrameEncrypt(String message) {
byte[] messageBytes = message.getBytes(UTF_8);
FrameEncrypt frame = new FrameEncrypt();
ByteBuf plain = getDirectBuffer(messageBytes.length, this::ref);
plain.writeBytes(messageBytes);
frame.plain = Collections.singletonList(plain);
frame.out = getDirectBuffer(messageBytes.length + client.getSuffixLength(), this::ref);
return frame;
}
FrameDecrypt frameDecryptOfEncrypt(FrameEncrypt frameEncrypt) {
int tagLen = client.getSuffixLength();
FrameDecrypt frameDecrypt = new FrameDecrypt();
ByteBuf out = frameEncrypt.out;
frameDecrypt.ciphertext =
Collections.singletonList(out.slice(out.readerIndex(), out.readableBytes() - tagLen));
frameDecrypt.tag = out.slice(out.readerIndex() + out.readableBytes() - tagLen, tagLen);
frameDecrypt.out = getDirectBuffer(out.readableBytes(), this::ref);
return frameDecrypt;
}
@Test
public void encryptDecrypt() throws GeneralSecurityException {
String message = "Hello world";
FrameEncrypt frameEncrypt = createFrameEncrypt(message);
client.encrypt(frameEncrypt.out, frameEncrypt.plain);
FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt);
server.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext);
assertThat(frameEncrypt.plain.get(0).slice(0, frameDecrypt.out.readableBytes()))
.isEqualTo(frameDecrypt.out);
}
@Test
public void encryptDecryptLarge() throws GeneralSecurityException {
FrameEncrypt frameEncrypt = new FrameEncrypt();
ByteBuf plain = getRandom(17 * 1024, this::ref);
frameEncrypt.plain = Collections.singletonList(plain);
frameEncrypt.out = getDirectBuffer(plain.readableBytes() + client.getSuffixLength(), this::ref);
client.encrypt(frameEncrypt.out, frameEncrypt.plain);
FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt);
// Call decrypt overload that takes ciphertext and tag.
server.decrypt(frameDecrypt.out, frameEncrypt.out);
assertThat(frameEncrypt.plain.get(0).slice(0, frameDecrypt.out.readableBytes()))
.isEqualTo(frameDecrypt.out);
}
@Test
public void encryptDecryptMultiple() throws GeneralSecurityException {
String message = "Hello world";
for (int i = 0; i < 512; ++i) {
FrameEncrypt frameEncrypt = createFrameEncrypt(message);
client.encrypt(frameEncrypt.out, frameEncrypt.plain);
FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt);
server.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext);
assertThat(frameEncrypt.plain.get(0).slice(0, frameDecrypt.out.readableBytes()))
.isEqualTo(frameDecrypt.out);
}
}
@Test
public void encryptDecryptComposite() throws GeneralSecurityException {
String message = "Hello world";
int lastLen = 2;
byte[] messageBytes = message.getBytes(UTF_8);
FrameEncrypt frameEncrypt = new FrameEncrypt();
ByteBuf plain1 = getDirectBuffer(messageBytes.length - lastLen, this::ref);
ByteBuf plain2 = getDirectBuffer(lastLen, this::ref);
plain1.writeBytes(messageBytes, 0, messageBytes.length - lastLen);
plain2.writeBytes(messageBytes, messageBytes.length - lastLen, lastLen);
ByteBuf plain = Unpooled.wrappedBuffer(plain1, plain2);
frameEncrypt.plain = Collections.singletonList(plain);
frameEncrypt.out = getDirectBuffer(messageBytes.length + client.getSuffixLength(), this::ref);
client.encrypt(frameEncrypt.out, frameEncrypt.plain);
int tagLen = client.getSuffixLength();
FrameDecrypt frameDecrypt = new FrameDecrypt();
ByteBuf out = frameEncrypt.out;
int outLen = out.readableBytes();
ByteBuf cipher1 = getDirectBuffer(outLen - lastLen - tagLen, this::ref);
ByteBuf cipher2 = getDirectBuffer(lastLen, this::ref);
cipher1.writeBytes(out, 0, outLen - lastLen - tagLen);
cipher2.writeBytes(out, outLen - tagLen - lastLen, lastLen);
ByteBuf cipher = Unpooled.wrappedBuffer(cipher1, cipher2);
frameDecrypt.ciphertext = Collections.singletonList(cipher);
frameDecrypt.tag = out.slice(out.readerIndex() + out.readableBytes() - tagLen, tagLen);
frameDecrypt.out = getDirectBuffer(out.readableBytes(), this::ref);
server.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext);
assertThat(frameEncrypt.plain.get(0).slice(0, frameDecrypt.out.readableBytes()))
.isEqualTo(frameDecrypt.out);
}
@Test
public void reflection() throws GeneralSecurityException {
String message = "Hello world";
FrameEncrypt frameEncrypt = createFrameEncrypt(message);
client.encrypt(frameEncrypt.out, frameEncrypt.plain);
FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt);
try {
client.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext);
fail("Exception expected");
} catch (AEADBadTagException ex) {
assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_MESSAGE);
}
}
@Test
public void skipMessage() throws GeneralSecurityException {
String message = "Hello world";
FrameEncrypt frameEncrypt1 = createFrameEncrypt(message);
client.encrypt(frameEncrypt1.out, frameEncrypt1.plain);
FrameEncrypt frameEncrypt2 = createFrameEncrypt(message);
client.encrypt(frameEncrypt2.out, frameEncrypt2.plain);
FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt2);
try {
client.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext);
fail("Exception expected");
} catch (AEADBadTagException ex) {
assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_MESSAGE);
}
}
@Test
public void corruptMessage() throws GeneralSecurityException {
String message = "Hello world";
FrameEncrypt frameEncrypt = createFrameEncrypt(message);
client.encrypt(frameEncrypt.out, frameEncrypt.plain);
FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt);
frameEncrypt.out.setByte(3, frameEncrypt.out.getByte(3) + 1);
try {
client.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext);
fail("Exception expected");
} catch (AEADBadTagException ex) {
assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_MESSAGE);
}
}
@Test
public void replayMessage() throws GeneralSecurityException {
String message = "Hello world";
FrameEncrypt frameEncrypt = createFrameEncrypt(message);
client.encrypt(frameEncrypt.out, frameEncrypt.plain);
FrameDecrypt frameDecrypt1 = frameDecryptOfEncrypt(frameEncrypt);
FrameDecrypt frameDecrypt2 = frameDecryptOfEncrypt(frameEncrypt);
server.decrypt(frameDecrypt1.out, frameDecrypt1.tag, frameDecrypt1.ciphertext);
try {
server.decrypt(frameDecrypt2.out, frameDecrypt2.tag, frameDecrypt2.ciphertext);
fail("Exception expected");
} catch (AEADBadTagException ex) {
assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_MESSAGE);
}
}
private ByteBuf ref(ByteBuf buf) {
if (buf != null) {
references.add(buf);
}
return buf;
}
}

View File

@ -0,0 +1,77 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import static com.google.common.base.Preconditions.checkState;
import io.netty.buffer.ByteBuf;
import java.security.GeneralSecurityException;
import java.util.Collections;
import java.util.List;
import javax.crypto.AEADBadTagException;
public final class FakeChannelCrypter implements ChannelCrypterNetty {
private static final int TAG_BYTES = 16;
private static final byte TAG_BYTE = (byte) 0xa1;
private boolean destroyCalled = false;
public static int getTagBytes() {
return TAG_BYTES;
}
@Override
public void encrypt(ByteBuf out, List<ByteBuf> plain) throws GeneralSecurityException {
checkState(!destroyCalled);
for (ByteBuf buf : plain) {
out.writeBytes(buf);
for (int i = 0; i < TAG_BYTES; ++i) {
out.writeByte(TAG_BYTE);
}
}
}
@Override
public void decrypt(ByteBuf out, ByteBuf tag, List<ByteBuf> ciphertext)
throws GeneralSecurityException {
checkState(!destroyCalled);
for (ByteBuf buf : ciphertext) {
out.writeBytes(buf);
}
boolean tagValid = tag.forEachByte((byte value) -> value == TAG_BYTE) == -1;
if (!tagValid) {
throw new AEADBadTagException("Tag mismatch!");
}
}
@Override
public void decrypt(ByteBuf out, ByteBuf ciphertextAndTag) throws GeneralSecurityException {
checkState(!destroyCalled);
ByteBuf ciphertext = ciphertextAndTag.readSlice(ciphertextAndTag.readableBytes() - TAG_BYTES);
decrypt(out, /*tag=*/ ciphertextAndTag, Collections.singletonList(ciphertext));
}
@Override
public int getSuffixLength() {
return TAG_BYTES;
}
@Override
public void destroy() {
destroyCalled = true;
}
}

View File

@ -0,0 +1,227 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.common.base.Preconditions;
import io.netty.buffer.ByteBufAllocator;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.Collections;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
* A fake handshaker compatible with security/transport_security/fake_transport_security.h See
* {@link TsiHandshaker} for documentation.
*/
public class FakeTsiHandshaker implements TsiHandshaker {
private static final Logger logger = Logger.getLogger(FakeTsiHandshaker.class.getName());
private static final TsiHandshakerFactory clientHandshakerFactory =
new TsiHandshakerFactory() {
@Override
public TsiHandshaker newHandshaker() {
return new FakeTsiHandshaker(true);
}
};
private static final TsiHandshakerFactory serverHandshakerFactory =
new TsiHandshakerFactory() {
@Override
public TsiHandshaker newHandshaker() {
return new FakeTsiHandshaker(false);
}
};
private boolean isClient;
private ByteBuffer sendBuffer = null;
private AltsFraming.Parser frameParser = new AltsFraming.Parser();
private State sendState;
private State receiveState;
enum State {
CLIENT_NONE,
SERVER_NONE,
CLIENT_INIT,
SERVER_INIT,
CLIENT_FINISHED,
SERVER_FINISHED;
// Returns the next State. In order to advance to sendState=N, receiveState must be N-1.
public State next() {
if (ordinal() + 1 < values().length) {
return values()[ordinal() + 1];
}
throw new UnsupportedOperationException("Can't call next() on last element: " + this);
}
}
public static TsiHandshakerFactory clientHandshakerFactory() {
return clientHandshakerFactory;
}
public static TsiHandshakerFactory serverHandshakerFactory() {
return serverHandshakerFactory;
}
public static TsiHandshaker newFakeHandshakerClient() {
return clientHandshakerFactory.newHandshaker();
}
public static TsiHandshaker newFakeHandshakerServer() {
return serverHandshakerFactory.newHandshaker();
}
protected FakeTsiHandshaker(boolean isClient) {
this.isClient = isClient;
if (isClient) {
sendState = State.CLIENT_NONE;
receiveState = State.SERVER_NONE;
} else {
sendState = State.SERVER_NONE;
receiveState = State.CLIENT_NONE;
}
}
private State getNextState(State state) {
switch (state) {
case CLIENT_NONE:
return State.CLIENT_INIT;
case SERVER_NONE:
return State.SERVER_INIT;
case CLIENT_INIT:
return State.CLIENT_FINISHED;
case SERVER_INIT:
return State.SERVER_FINISHED;
default:
return null;
}
}
private String getNextMessage() {
State result = getNextState(sendState);
return result == null ? "BAD STATE" : result.toString();
}
private String getExpectedMessage() {
State result = getNextState(receiveState);
return result == null ? "BAD STATE" : result.toString();
}
private void incrementSendState() {
sendState = getNextState(sendState);
}
private void incrementReceiveState() {
receiveState = getNextState(receiveState);
}
@Override
public void getBytesToSendToPeer(ByteBuffer bytes) throws GeneralSecurityException {
Preconditions.checkNotNull(bytes);
// If we're done, return nothing.
if (sendState == State.CLIENT_FINISHED || sendState == State.SERVER_FINISHED) {
return;
}
// Prepare the next message, if neeeded.
if (sendBuffer == null) {
if (sendState.next() != receiveState) {
// We're still waiting for bytes from the peer, so bail.
return;
}
ByteBuffer payload = ByteBuffer.wrap(getNextMessage().getBytes(UTF_8));
sendBuffer = AltsFraming.toFrame(payload, payload.remaining());
logger.log(Level.FINE, "Buffered message: {0}", getNextMessage());
}
while (bytes.hasRemaining() && sendBuffer.hasRemaining()) {
bytes.put(sendBuffer.get());
}
if (!sendBuffer.hasRemaining()) {
// Get ready to send the next message.
sendBuffer = null;
incrementSendState();
}
}
@Override
public boolean processBytesFromPeer(ByteBuffer bytes) throws GeneralSecurityException {
Preconditions.checkNotNull(bytes);
frameParser.readBytes(bytes);
if (frameParser.isComplete()) {
ByteBuffer messageBytes = frameParser.getRawFrame();
int offset = AltsFraming.getFramingOverhead();
int length = messageBytes.limit() - offset;
String message = new String(messageBytes.array(), offset, length, UTF_8);
logger.log(Level.FINE, "Read message: {0}", message);
if (!message.equals(getExpectedMessage())) {
throw new IllegalArgumentException(
"Bad handshake message. Got "
+ message
+ " (length = "
+ message.length()
+ ") expected "
+ getExpectedMessage()
+ " (length = "
+ getExpectedMessage().length()
+ ")");
}
incrementReceiveState();
return true;
}
return false;
}
@Override
public boolean isInProgress() {
boolean finishedReceiving =
receiveState == State.CLIENT_FINISHED || receiveState == State.SERVER_FINISHED;
boolean finishedSending =
sendState == State.CLIENT_FINISHED || sendState == State.SERVER_FINISHED;
return !finishedSending || !finishedReceiving;
}
@Override
public TsiPeer extractPeer() {
return new TsiPeer(Collections.emptyList());
}
@Override
public Object extractPeerObject() {
return AltsAuthContext.getDefaultInstance();
}
@Override
public TsiFrameProtector createFrameProtector(int maxFrameSize, ByteBufAllocator alloc) {
Preconditions.checkState(!isInProgress(), "Handshake is not complete.");
// We use an all-zero key, since this is the fake handshaker.
byte[] key = new byte[AltsChannelCrypter.getKeyLength()];
return new AltsTsiFrameProtector(maxFrameSize, new AltsChannelCrypter(key, isClient), alloc);
}
@Override
public TsiFrameProtector createFrameProtector(ByteBufAllocator alloc) {
return createFrameProtector(AltsTsiFrameProtector.getMaxAllowedFrameBytes(), alloc);
}
}

View File

@ -0,0 +1,209 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import com.google.common.testing.GcFinalization;
import io.grpc.alts.transportsecurity.TsiTest.Handshakers;
import io.netty.buffer.ByteBuf;
import io.netty.util.ReferenceCounted;
import io.netty.util.ResourceLeakDetector;
import io.netty.util.ResourceLeakDetector.Level;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.List;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link TsiHandshaker}. */
@RunWith(JUnit4.class)
public class FakeTsiTest {
private static final int OVERHEAD =
FakeChannelCrypter.getTagBytes() + AltsTsiFrameProtector.getHeaderBytes();
private final List<ReferenceCounted> references = new ArrayList<>();
private static Handshakers newHandshakers() {
TsiHandshaker clientHandshaker = FakeTsiHandshaker.newFakeHandshakerClient();
TsiHandshaker serverHandshaker = FakeTsiHandshaker.newFakeHandshakerServer();
return new Handshakers(clientHandshaker, serverHandshaker);
}
@Before
public void setUp() {
ResourceLeakDetector.setLevel(Level.PARANOID);
}
@After
public void tearDown() {
for (ReferenceCounted reference : references) {
reference.release();
}
references.clear();
// Increase our chances to detect ByteBuf leaks.
GcFinalization.awaitFullGc();
}
@Test
public void handshakeStateOrderTest() {
try {
Handshakers handshakers = newHandshakers();
TsiHandshaker clientHandshaker = handshakers.getClient();
TsiHandshaker serverHandshaker = handshakers.getServer();
byte[] transportBufferBytes = new byte[TsiTest.getDefaultTransportBufferSize()];
ByteBuffer transportBuffer = ByteBuffer.wrap(transportBufferBytes);
transportBuffer.limit(0); // Start off with an empty buffer
transportBuffer.clear();
clientHandshaker.getBytesToSendToPeer(transportBuffer);
transportBuffer.flip();
assertEquals(
FakeTsiHandshaker.State.CLIENT_INIT.toString().trim(),
new String(transportBufferBytes, 4, transportBuffer.remaining(), UTF_8).trim());
serverHandshaker.processBytesFromPeer(transportBuffer);
assertFalse(transportBuffer.hasRemaining());
// client shouldn't offer any more bytes
transportBuffer.clear();
clientHandshaker.getBytesToSendToPeer(transportBuffer);
transportBuffer.flip();
assertFalse(transportBuffer.hasRemaining());
transportBuffer.clear();
serverHandshaker.getBytesToSendToPeer(transportBuffer);
transportBuffer.flip();
assertEquals(
FakeTsiHandshaker.State.SERVER_INIT.toString().trim(),
new String(transportBufferBytes, 4, transportBuffer.remaining(), UTF_8).trim());
clientHandshaker.processBytesFromPeer(transportBuffer);
assertFalse(transportBuffer.hasRemaining());
// server shouldn't offer any more bytes
transportBuffer.clear();
serverHandshaker.getBytesToSendToPeer(transportBuffer);
transportBuffer.flip();
assertFalse(transportBuffer.hasRemaining());
transportBuffer.clear();
clientHandshaker.getBytesToSendToPeer(transportBuffer);
transportBuffer.flip();
assertEquals(
FakeTsiHandshaker.State.CLIENT_FINISHED.toString().trim(),
new String(transportBufferBytes, 4, transportBuffer.remaining(), UTF_8).trim());
serverHandshaker.processBytesFromPeer(transportBuffer);
assertFalse(transportBuffer.hasRemaining());
// client shouldn't offer any more bytes
transportBuffer.clear();
clientHandshaker.getBytesToSendToPeer(transportBuffer);
transportBuffer.flip();
assertFalse(transportBuffer.hasRemaining());
transportBuffer.clear();
serverHandshaker.getBytesToSendToPeer(transportBuffer);
transportBuffer.flip();
assertEquals(
FakeTsiHandshaker.State.SERVER_FINISHED.toString().trim(),
new String(transportBufferBytes, 4, transportBuffer.remaining(), UTF_8).trim());
clientHandshaker.processBytesFromPeer(transportBuffer);
assertFalse(transportBuffer.hasRemaining());
// server shouldn't offer any more bytes
transportBuffer.clear();
serverHandshaker.getBytesToSendToPeer(transportBuffer);
transportBuffer.flip();
assertFalse(transportBuffer.hasRemaining());
} catch (GeneralSecurityException e) {
throw new AssertionError(e);
}
}
@Test
public void handshake() throws GeneralSecurityException {
TsiTest.handshakeTest(newHandshakers());
}
@Test
public void handshakeSmallBuffer() throws GeneralSecurityException {
TsiTest.handshakeSmallBufferTest(newHandshakers());
}
@Test
public void pingPong() throws GeneralSecurityException {
TsiTest.pingPongTest(newHandshakers(), this::ref);
}
@Test
public void pingPongExactFrameSize() throws GeneralSecurityException {
TsiTest.pingPongExactFrameSizeTest(newHandshakers(), this::ref);
}
@Test
public void pingPongSmallBuffer() throws GeneralSecurityException {
TsiTest.pingPongSmallBufferTest(newHandshakers(), this::ref);
}
@Test
public void pingPongSmallFrame() throws GeneralSecurityException {
TsiTest.pingPongSmallFrameTest(OVERHEAD, newHandshakers(), this::ref);
}
@Test
public void pingPongSmallFrameSmallBuffer() throws GeneralSecurityException {
TsiTest.pingPongSmallFrameSmallBufferTest(OVERHEAD, newHandshakers(), this::ref);
}
@Test
public void corruptedCounter() throws GeneralSecurityException {
TsiTest.corruptedCounterTest(newHandshakers(), this::ref);
}
@Test
public void corruptedCiphertext() throws GeneralSecurityException {
TsiTest.corruptedCiphertextTest(newHandshakers(), this::ref);
}
@Test
public void corruptedTag() throws GeneralSecurityException {
TsiTest.corruptedTagTest(newHandshakers(), this::ref);
}
@Test
public void reflectedCiphertext() throws GeneralSecurityException {
TsiTest.reflectedCiphertextTest(newHandshakers(), this::ref);
}
private ByteBuf ref(ByteBuf buf) {
if (buf != null) {
references.add(buf);
}
return buf;
}
}

View File

@ -0,0 +1,117 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.protobuf.ByteString;
import io.grpc.Status;
import io.grpc.alts.Handshaker.HandshakerResp;
import io.grpc.alts.Handshaker.HandshakerResult;
import io.grpc.alts.Handshaker.HandshakerStatus;
import io.grpc.alts.Handshaker.Identity;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.security.SecureRandom;
import java.util.Random;
/** A class for mocking ALTS Handshaker Responses. */
class MockAltsHandshakerResp {
private static final String TEST_ERROR_DETAILS = "handshake error";
private static final String TEST_APPLICATION_PROTOCOL = "grpc";
private static final String TEST_RECORD_PROTOCOL = "ALTSRP_GCM_AES128";
private static final String TEST_OUT_FRAME = "output frame";
private static final String TEST_LOCAL_ACCOUNT = "local@developer.gserviceaccount.com";
private static final String TEST_PEER_ACCOUNT = "peer@developer.gserviceaccount.com";
private static final byte[] TEST_KEY_DATA = initializeTestKeyData();
private static final int FRAME_HEADER_SIZE = 4;
static String getTestErrorDetails() {
return TEST_ERROR_DETAILS;
}
static String getTestPeerAccount() {
return TEST_PEER_ACCOUNT;
}
private static byte[] initializeTestKeyData() {
Random random = new SecureRandom();
byte[] randombytes = new byte[AltsChannelCrypter.getKeyLength()];
random.nextBytes(randombytes);
return randombytes;
}
static byte[] getTestKeyData() {
return TEST_KEY_DATA;
}
/** Returns a mock output frame. */
static ByteString getOutFrame() {
int frameSize = TEST_OUT_FRAME.length();
ByteBuffer buffer = ByteBuffer.allocate(FRAME_HEADER_SIZE + frameSize);
buffer.order(ByteOrder.LITTLE_ENDIAN);
buffer.putInt(frameSize);
buffer.put(TEST_OUT_FRAME.getBytes(UTF_8));
buffer.flip();
return ByteString.copyFrom(buffer);
}
/** Returns a mock error handshaker response. */
static HandshakerResp getErrorResponse() {
HandshakerResp.Builder resp = HandshakerResp.newBuilder();
resp.setStatus(
HandshakerStatus.newBuilder()
.setCode(Status.Code.UNKNOWN.value())
.setDetails(TEST_ERROR_DETAILS)
.build());
return resp.build();
}
/** Returns a mock normal handshaker response. */
static HandshakerResp getOkResponse(int bytesConsumed) {
HandshakerResp.Builder resp = HandshakerResp.newBuilder();
resp.setOutFrames(getOutFrame());
resp.setBytesConsumed(bytesConsumed);
resp.setStatus(HandshakerStatus.newBuilder().setCode(Status.Code.OK.value()).build());
return resp.build();
}
/** Returns a mock normal handshaker response. */
static HandshakerResp getEmptyOutFrameResponse(int bytesConsumed) {
HandshakerResp.Builder resp = HandshakerResp.newBuilder();
resp.setBytesConsumed(bytesConsumed);
resp.setStatus(HandshakerStatus.newBuilder().setCode(Status.Code.OK.value()).build());
return resp.build();
}
/** Returns a mock final handshaker response with handshake result. */
static HandshakerResp getFinishedResponse(int bytesConsumed) {
HandshakerResp.Builder resp = HandshakerResp.newBuilder();
HandshakerResult.Builder result =
HandshakerResult.newBuilder()
.setApplicationProtocol(TEST_APPLICATION_PROTOCOL)
.setRecordProtocol(TEST_RECORD_PROTOCOL)
.setPeerIdentity(Identity.newBuilder().setServiceAccount(TEST_PEER_ACCOUNT).build())
.setLocalIdentity(Identity.newBuilder().setServiceAccount(TEST_LOCAL_ACCOUNT).build())
.setKeyData(ByteString.copyFrom(TEST_KEY_DATA));
resp.setOutFrames(getOutFrame());
resp.setBytesConsumed(bytesConsumed);
resp.setStatus(HandshakerStatus.newBuilder().setCode(Status.Code.OK.value()).build());
resp.setResult(result.build());
return resp.build();
}
}

View File

@ -0,0 +1,364 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.transportsecurity;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.alts.transportsecurity.ByteBufTestUtils.getDirectBuffer;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.junit.Assert.fail;
import io.grpc.alts.transportsecurity.ByteBufTestUtils.RegisterRef;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import javax.crypto.AEADBadTagException;
/** Utility class that provides tests for implementations of @{link TsiHandshaker}. */
public final class TsiTest {
private static final String DECRYPTION_FAILURE_RE = "Tag mismatch!";
private TsiTest() {}
/** A @{code TsiHandshaker} pair for running tests. */
public static class Handshakers {
private final TsiHandshaker client;
private final TsiHandshaker server;
public Handshakers(TsiHandshaker client, TsiHandshaker server) {
this.client = client;
this.server = server;
}
public TsiHandshaker getClient() {
return client;
}
public TsiHandshaker getServer() {
return server;
}
}
private static final int DEFAULT_TRANSPORT_BUFFER_SIZE = 2048;
private static final UnpooledByteBufAllocator alloc = UnpooledByteBufAllocator.DEFAULT;
private static final String EXAMPLE_MESSAGE1 = "hello world";
private static final String EXAMPLE_MESSAGE2 = "oysteroystersoysterseateateat";
private static final int EXAMPLE_MESSAGE1_LEN = EXAMPLE_MESSAGE1.getBytes(UTF_8).length;
private static final int EXAMPLE_MESSAGE2_LEN = EXAMPLE_MESSAGE2.getBytes(UTF_8).length;
static int getDefaultTransportBufferSize() {
return DEFAULT_TRANSPORT_BUFFER_SIZE;
}
/**
* Performs a handshake between the client handshaker and server handshaker using a transport of
* length transportBufferSize.
*/
static void performHandshake(int transportBufferSize, Handshakers handshakers)
throws GeneralSecurityException {
TsiHandshaker clientHandshaker = handshakers.getClient();
TsiHandshaker serverHandshaker = handshakers.getServer();
byte[] transportBufferBytes = new byte[transportBufferSize];
ByteBuffer transportBuffer = ByteBuffer.wrap(transportBufferBytes);
transportBuffer.limit(0); // Start off with an empty buffer
while (clientHandshaker.isInProgress() || serverHandshaker.isInProgress()) {
for (TsiHandshaker handshaker : new TsiHandshaker[] {clientHandshaker, serverHandshaker}) {
if (handshaker.isInProgress()) {
// Process any bytes on the wire.
if (transportBuffer.hasRemaining()) {
handshaker.processBytesFromPeer(transportBuffer);
}
// Put new bytes on the wire, if needed.
if (handshaker.isInProgress()) {
transportBuffer.clear();
handshaker.getBytesToSendToPeer(transportBuffer);
transportBuffer.flip();
}
}
}
}
clientHandshaker.extractPeer();
serverHandshaker.extractPeer();
}
public static void handshakeTest(Handshakers handshakers) throws GeneralSecurityException {
performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
}
public static void handshakeSmallBufferTest(Handshakers handshakers)
throws GeneralSecurityException {
performHandshake(9, handshakers);
}
/** Sends a message between the sender and receiver. */
private static void sendMessage(
TsiFrameProtector sender,
TsiFrameProtector receiver,
int recvFragmentSize,
String message,
RegisterRef ref)
throws GeneralSecurityException {
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
List<ByteBuf> protectOut = new ArrayList<>();
List<Object> unprotectOut = new ArrayList<>();
sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc);
assertThat(protectOut.size()).isEqualTo(1);
ByteBuf protect = ref.register(protectOut.get(0));
while (protect.isReadable()) {
ByteBuf buf = protect;
if (recvFragmentSize > 0) {
int size = Math.min(protect.readableBytes(), recvFragmentSize);
buf = protect.readSlice(size);
}
receiver.unprotect(buf, unprotectOut, alloc);
}
ByteBuf plaintextRecvd = getDirectBuffer(message.getBytes(UTF_8).length, ref);
for (Object unprotect : unprotectOut) {
ByteBuf unprotectBuf = ref.register((ByteBuf) unprotect);
plaintextRecvd.writeBytes(unprotectBuf);
}
assertThat(plaintextRecvd).isEqualTo(Unpooled.wrappedBuffer(message.getBytes(UTF_8)));
}
/** Ping pong test. */
public static void pingPongTest(Handshakers handshakers, RegisterRef ref)
throws GeneralSecurityException {
performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
TsiFrameProtector clientProtector = handshakers.getClient().createFrameProtector(alloc);
TsiFrameProtector serverProtector = handshakers.getServer().createFrameProtector(alloc);
sendMessage(clientProtector, serverProtector, -1, EXAMPLE_MESSAGE1, ref);
sendMessage(serverProtector, clientProtector, -1, EXAMPLE_MESSAGE2, ref);
clientProtector.destroy();
serverProtector.destroy();
}
/** Ping pong test with exact frame size. */
public static void pingPongExactFrameSizeTest(Handshakers handshakers, RegisterRef ref)
throws GeneralSecurityException {
performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
int frameSize =
EXAMPLE_MESSAGE1.getBytes(UTF_8).length
+ AltsTsiFrameProtector.getHeaderBytes()
+ FakeChannelCrypter.getTagBytes();
TsiFrameProtector clientProtector =
handshakers.getClient().createFrameProtector(frameSize, alloc);
TsiFrameProtector serverProtector =
handshakers.getServer().createFrameProtector(frameSize, alloc);
sendMessage(clientProtector, serverProtector, -1, EXAMPLE_MESSAGE1, ref);
sendMessage(serverProtector, clientProtector, -1, EXAMPLE_MESSAGE1, ref);
clientProtector.destroy();
serverProtector.destroy();
}
/** Ping pong test with small buffer size. */
public static void pingPongSmallBufferTest(Handshakers handshakers, RegisterRef ref)
throws GeneralSecurityException {
performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
TsiFrameProtector clientProtector = handshakers.getClient().createFrameProtector(alloc);
TsiFrameProtector serverProtector = handshakers.getServer().createFrameProtector(alloc);
sendMessage(clientProtector, serverProtector, 1, EXAMPLE_MESSAGE1, ref);
sendMessage(serverProtector, clientProtector, 1, EXAMPLE_MESSAGE2, ref);
clientProtector.destroy();
serverProtector.destroy();
}
/** Ping pong test with small frame size. */
public static void pingPongSmallFrameTest(
int frameProtectorOverhead, Handshakers handshakers, RegisterRef ref)
throws GeneralSecurityException {
performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
// We send messages using small non-aligned buffers. We use 3 and 5, small primes.
TsiFrameProtector clientProtector =
handshakers.getClient().createFrameProtector(frameProtectorOverhead + 3, alloc);
TsiFrameProtector serverProtector =
handshakers.getServer().createFrameProtector(frameProtectorOverhead + 5, alloc);
sendMessage(clientProtector, serverProtector, EXAMPLE_MESSAGE1_LEN, EXAMPLE_MESSAGE1, ref);
sendMessage(serverProtector, clientProtector, EXAMPLE_MESSAGE2_LEN, EXAMPLE_MESSAGE2, ref);
clientProtector.destroy();
serverProtector.destroy();
}
/** Ping pong test with small frame and small buffer. */
public static void pingPongSmallFrameSmallBufferTest(
int frameProtectorOverhead, Handshakers handshakers, RegisterRef ref)
throws GeneralSecurityException {
performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
// We send messages using small non-aligned buffers. We use 3 and 5, small primes.
TsiFrameProtector clientProtector =
handshakers.getClient().createFrameProtector(frameProtectorOverhead + 3, alloc);
TsiFrameProtector serverProtector =
handshakers.getServer().createFrameProtector(frameProtectorOverhead + 5, alloc);
sendMessage(clientProtector, serverProtector, EXAMPLE_MESSAGE1_LEN, EXAMPLE_MESSAGE1, ref);
sendMessage(serverProtector, clientProtector, EXAMPLE_MESSAGE2_LEN, EXAMPLE_MESSAGE2, ref);
sendMessage(clientProtector, serverProtector, EXAMPLE_MESSAGE1_LEN, EXAMPLE_MESSAGE1, ref);
sendMessage(serverProtector, clientProtector, EXAMPLE_MESSAGE2_LEN, EXAMPLE_MESSAGE2, ref);
clientProtector.destroy();
serverProtector.destroy();
}
/** Test corrupted counter. */
public static void corruptedCounterTest(Handshakers handshakers, RegisterRef ref)
throws GeneralSecurityException {
performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
TsiFrameProtector sender = handshakers.getClient().createFrameProtector(alloc);
TsiFrameProtector receiver = handshakers.getServer().createFrameProtector(alloc);
String message = "hello world";
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
List<ByteBuf> protectOut = new ArrayList<>();
List<Object> unprotectOut = new ArrayList<>();
sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc);
assertThat(protectOut.size()).isEqualTo(1);
ByteBuf protect = ref.register(protectOut.get(0));
// Unprotect once to increase receiver counter.
receiver.unprotect(protect.slice(), unprotectOut, alloc);
assertThat(unprotectOut.size()).isEqualTo(1);
ref.register((ByteBuf) unprotectOut.get(0));
try {
receiver.unprotect(protect, unprotectOut, alloc);
fail("Exception expected");
} catch (AEADBadTagException ex) {
assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_RE);
}
sender.destroy();
receiver.destroy();
}
/** Test corrupted ciphertext. */
public static void corruptedCiphertextTest(Handshakers handshakers, RegisterRef ref)
throws GeneralSecurityException {
performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
TsiFrameProtector sender = handshakers.getClient().createFrameProtector(alloc);
TsiFrameProtector receiver = handshakers.getServer().createFrameProtector(alloc);
String message = "hello world";
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
List<ByteBuf> protectOut = new ArrayList<>();
List<Object> unprotectOut = new ArrayList<>();
sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc);
assertThat(protectOut.size()).isEqualTo(1);
ByteBuf protect = ref.register(protectOut.get(0));
int ciphertextIdx = protect.writerIndex() - FakeChannelCrypter.getTagBytes() - 2;
protect.setByte(ciphertextIdx, protect.getByte(ciphertextIdx) + 1);
try {
receiver.unprotect(protect, unprotectOut, alloc);
fail("Exception expected");
} catch (AEADBadTagException ex) {
assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_RE);
}
sender.destroy();
receiver.destroy();
}
/** Test corrupted tag. */
public static void corruptedTagTest(Handshakers handshakers, RegisterRef ref)
throws GeneralSecurityException {
performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
TsiFrameProtector sender = handshakers.getClient().createFrameProtector(alloc);
TsiFrameProtector receiver = handshakers.getServer().createFrameProtector(alloc);
String message = "hello world";
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
List<ByteBuf> protectOut = new ArrayList<>();
List<Object> unprotectOut = new ArrayList<>();
sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc);
assertThat(protectOut.size()).isEqualTo(1);
ByteBuf protect = ref.register(protectOut.get(0));
int tagIdx = protect.writerIndex() - 1;
protect.setByte(tagIdx, protect.getByte(tagIdx) + 1);
try {
receiver.unprotect(protect, unprotectOut, alloc);
fail("Exception expected");
} catch (AEADBadTagException ex) {
assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_RE);
}
sender.destroy();
receiver.destroy();
}
/** Test reflected ciphertext. */
public static void reflectedCiphertextTest(Handshakers handshakers, RegisterRef ref)
throws GeneralSecurityException {
performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
TsiFrameProtector sender = handshakers.getClient().createFrameProtector(alloc);
TsiFrameProtector receiver = handshakers.getServer().createFrameProtector(alloc);
String message = "hello world";
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
List<ByteBuf> protectOut = new ArrayList<>();
List<Object> unprotectOut = new ArrayList<>();
sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc);
assertThat(protectOut.size()).isEqualTo(1);
ByteBuf protect = ref.register(protectOut.get(0));
try {
sender.unprotect(protect.slice(), unprotectOut, alloc);
fail("Exception expected");
} catch (AEADBadTagException ex) {
assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_RE);
}
sender.destroy();
receiver.destroy();
}
}

View File

@ -208,6 +208,7 @@ subprojects {
protobuf_nano: "com.google.protobuf.nano:protobuf-javanano:${protobufNanoVersion}",
protobuf_plugin: 'com.google.protobuf:protobuf-gradle-plugin:0.8.3',
protobuf_util: "com.google.protobuf:protobuf-java-util:${protobufVersion}",
lang: "org.apache.commons:commons-lang3:3.5",
netty: "io.netty:netty-codec-http2:[${nettyVersion}]",
netty_epoll: "io.netty:netty-transport-native-epoll:${nettyVersion}" + epoll_suffix,
@ -218,6 +219,7 @@ subprojects {
junit: 'junit:junit:4.12',
mockito: 'org.mockito:mockito-core:1.9.5',
truth: 'com.google.truth:truth:0.36',
guava_testlib: 'com.google.guava:guava-testlib:19.0',
// Benchmark dependencies
hdrhistogram: 'org.hdrhistogram:HdrHistogram:2.1.10',
@ -391,6 +393,8 @@ subprojects {
// Run with: ./gradlew japicmp --continue
def baselineGrpcVersion = '1.6.1'
def publicApiSubprojects = [
// TODO: uncomment after grpc-alts artifact is published.
// ':grpc-alts',
':grpc-auth',
':grpc-context',
':grpc-core',

View File

@ -71,6 +71,7 @@ java_library(
"@com_google_guava_guava//jar",
"@com_google_protobuf//:protobuf_java",
"@com_google_protobuf//:protobuf_java_util",
"@grpc_java//alts",
"@grpc_java//core",
"@grpc_java//netty",
"@grpc_java//protobuf",
@ -97,6 +98,29 @@ java_binary(
],
)
java_binary(
name = "hello-world-alts-client",
testonly = 1,
main_class = "io.grpc.examples.alts.HelloWorldAltsClient",
runtime_deps = [
":examples",
"@grpc_java//alts",
"@grpc_java//netty",
],
)
java_binary(
name = "hello-world-alts-server",
testonly = 1,
main_class = "io.grpc.examples.alts.HelloWorldAltsServer",
runtime_deps = [
":examples",
"@grpc_java//alts",
"@grpc_java//netty",
],
)
java_binary(
name = "route-guide-client",
testonly = 1,

View File

@ -27,6 +27,7 @@ def nettyTcNativeVersion = '2.0.7.Final'
dependencies {
compile "com.google.api.grpc:proto-google-common-protos:1.0.0"
compile "io.grpc:grpc-alts:${grpcVersion}"
compile "io.grpc:grpc-netty:${grpcVersion}"
compile "io.grpc:grpc-protobuf:${grpcVersion}"
compile "io.grpc:grpc-stub:${grpcVersion}"
@ -101,6 +102,20 @@ task helloWorldClient(type: CreateStartScripts) {
classpath = jar.outputs.files + project.configurations.runtime
}
task helloWorldAltsServer(type: CreateStartScripts) {
mainClassName = 'io.grpc.examples.alts.HelloWorldAltsServer'
applicationName = 'hello-world-alts-server'
outputDir = new File(project.buildDir, 'tmp')
classpath = jar.outputs.files + project.configurations.runtime
}
task helloWorldAltsClient(type: CreateStartScripts) {
mainClassName = 'io.grpc.examples.alts.HelloWorldAltsClient'
applicationName = 'hello-world-alts-client'
outputDir = new File(project.buildDir, 'tmp')
classpath = jar.outputs.files + project.configurations.runtime
}
task helloWorldTlsServer(type: CreateStartScripts) {
mainClassName = 'io.grpc.examples.helloworldtls.HelloWorldServerTls'
applicationName = 'hello-world-tls-server'
@ -127,6 +142,8 @@ applicationDistribution.into('bin') {
from(routeGuideClient)
from(helloWorldServer)
from(helloWorldClient)
from(helloWorldAltsServer)
from(helloWorldAltsClient)
from(helloWorldTlsServer)
from(helloWorldTlsClient)
from(compressingHelloWorldClient)

View File

@ -30,6 +30,11 @@
<artifactId>grpc-stub</artifactId>
<version>${grpc.version}</version>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-alts</artifactId>
<version>${grpc.version}</version>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-testing</artifactId>

View File

@ -0,0 +1,98 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.examples.alts;
import io.grpc.alts.AltsChannelBuilder;
import io.grpc.ManagedChannel;
import io.grpc.examples.helloworld.GreeterGrpc;
import io.grpc.examples.helloworld.HelloReply;
import io.grpc.examples.helloworld.HelloRequest;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
* An example gRPC client that uses ALTS. Shows how to do a Unary RPC. This example can only be run
* on Google Cloud Platform.
*/
public final class HelloWorldAltsClient {
private static final Logger logger = Logger.getLogger(HelloWorldAltsClient.class.getName());
private String serverAddress = "localhost:10001";
public static void main(String[] args) throws InterruptedException {
new HelloWorldAltsClient().run(args);
}
private void parseArgs(String[] args) {
boolean usage = false;
for (String arg : args) {
if (!arg.startsWith("--")) {
System.err.println("All arguments must start with '--': " + arg);
usage = true;
break;
}
String[] parts = arg.substring(2).split("=", 2);
String key = parts[0];
if ("help".equals(key)) {
usage = true;
break;
}
if (parts.length != 2) {
System.err.println("All arguments must be of the form --arg=value");
usage = true;
break;
}
String value = parts[1];
if ("server".equals(key)) {
serverAddress = value;
} else {
System.err.println("Unknown argument: " + key);
usage = true;
break;
}
}
if (usage) {
HelloWorldAltsClient c = new HelloWorldAltsClient();
System.out.println(
"Usage: [ARGS...]"
+ "\n"
+ "\n --server=SERVER_ADDRESS Server address to connect to. Default "
+ c.serverAddress);
System.exit(1);
}
}
private void run(String[] args) throws InterruptedException {
parseArgs(args);
ExecutorService executor = Executors.newFixedThreadPool(1);
ManagedChannel channel = AltsChannelBuilder.forTarget(serverAddress).executor(executor).build();
try {
GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(channel);
HelloReply resp = stub.sayHello(HelloRequest.newBuilder().setName("Waldo").build());
logger.log(Level.INFO, "Got {0}", resp);
} finally {
channel.shutdown();
channel.awaitTermination(1, TimeUnit.SECONDS);
// Wait until the channel has terminated, since tasks can be queued after the channel is
// shutdown.
executor.shutdown();
}
}
}

View File

@ -0,0 +1,99 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.examples.alts;
import io.grpc.alts.AltsServerBuilder;
import io.grpc.Server;
import io.grpc.examples.helloworld.GreeterGrpc.GreeterImplBase;
import io.grpc.examples.helloworld.HelloReply;
import io.grpc.examples.helloworld.HelloRequest;
import io.grpc.stub.StreamObserver;
import java.io.IOException;
import java.util.concurrent.Executors;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
* An example gRPC server that uses ALTS. Shows how to do a Unary RPC. This example can only be run
* on Google Cloud Platform.
*/
public final class HelloWorldAltsServer extends GreeterImplBase {
private static final Logger logger = Logger.getLogger(HelloWorldAltsServer.class.getName());
private Server server;
private int port = 10001;
public static void main(String[] args) throws IOException, InterruptedException {
new HelloWorldAltsServer().start(args);
}
private void parseArgs(String[] args) {
boolean usage = false;
for (String arg : args) {
if (!arg.startsWith("--")) {
System.err.println("All arguments must start with '--': " + arg);
usage = true;
break;
}
String[] parts = arg.substring(2).split("=", 2);
String key = parts[0];
if ("help".equals(key)) {
usage = true;
break;
}
if (parts.length != 2) {
System.err.println("All arguments must be of the form --arg=value");
usage = true;
break;
}
String value = parts[1];
if ("port".equals(key)) {
port = Integer.parseInt(value);
} else {
System.err.println("Unknown argument: " + key);
usage = true;
break;
}
}
if (usage) {
HelloWorldAltsServer s = new HelloWorldAltsServer();
System.out.println(
"Usage: [ARGS...]"
+ "\n"
+ "\n --port=PORT Server port to bind to. Default "
+ s.port);
System.exit(1);
}
}
private void start(String[] args) throws IOException, InterruptedException {
parseArgs(args);
server =
AltsServerBuilder.forPort(port)
.addService(this)
.executor(Executors.newFixedThreadPool(1))
.build();
server.start();
logger.log(Level.INFO, "Started on {0}", port);
server.awaitTermination();
}
@Override
public void sayHello(HelloRequest request, StreamObserver<HelloReply> observer) {
observer.onNext(HelloReply.newBuilder().setMessage("Hello, " + request.getName()).build());
observer.onCompleted();
}
}

View File

@ -26,7 +26,8 @@ def grpc_java_repositories(
omit_io_netty_tcnative_boringssl_static=False,
omit_io_opencensus_api=False,
omit_io_opencensus_grpc_metrics=False,
omit_junit_junit=False):
omit_junit_junit=False,
omit_org_apache_commons_lang3=False):
"""Imports dependencies for grpc-java."""
if not omit_com_google_api_grpc_google_common_protos:
com_google_api_grpc_google_common_protos()
@ -80,6 +81,8 @@ def grpc_java_repositories(
io_opencensus_grpc_metrics()
if not omit_junit_junit:
junit_junit()
if not omit_org_apache_commons_lang3:
org_apache_commons_lang3()
native.bind(
name = "guava",
@ -268,3 +271,10 @@ def junit_junit():
artifact = "junit:junit:4.12",
sha1 = "2973d150c0dc1fefe998f834810d68f278ea58ec",
)
def org_apache_commons_lang3():
native.maven_jar(
name = "org_apache_commons_commons_lang3",
artifact = "org.apache.commons:commons-lang3:3.5",
sha1 = "6c6c702c89bfff3cd9e80b04d668c5e190d588c6"
)

View File

@ -16,6 +16,7 @@ include ":grpc-interop-testing"
include ":grpc-gae-interop-testing-jdk7"
include ":grpc-gae-interop-testing-jdk8"
include ":grpc-all"
include ":grpc-alts"
include ":grpc-benchmarks"
include ":grpc-services"
@ -36,6 +37,7 @@ project(':grpc-interop-testing').projectDir = "$rootDir/interop-testing" as File
project(':grpc-gae-interop-testing-jdk7').projectDir = "$rootDir/gae-interop-testing/gae-jdk7" as File
project(':grpc-gae-interop-testing-jdk8').projectDir = "$rootDir/gae-interop-testing/gae-jdk8" as File
project(':grpc-all').projectDir = "$rootDir/all" as File
project(':grpc-alts').projectDir = "$rootDir/alts" as File
project(':grpc-benchmarks').projectDir = "$rootDir/benchmarks" as File
project(':grpc-services').projectDir = "$rootDir/services" as File