mirror of https://github.com/grpc/grpc-java.git
alts: add gRPC ALTS
This commit is contained in:
parent
a45d07bcb5
commit
e7f2f1dedd
|
|
@ -0,0 +1,82 @@
|
|||
load("//:java_grpc_library.bzl", "java_grpc_library")
|
||||
|
||||
java_library(
|
||||
name = "alts_tsi",
|
||||
srcs = glob([
|
||||
"src/main/java/io/grpc/alts/transportsecurity/*.java",
|
||||
]),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//core",
|
||||
"//core:internal",
|
||||
"//stub",
|
||||
"@com_google_code_findbugs_jsr305//jar",
|
||||
"@com_google_guava_guava//jar",
|
||||
"@com_google_protobuf//:protobuf_java",
|
||||
"@com_google_protobuf//:protobuf_java_util",
|
||||
"@io_netty_netty_buffer//jar",
|
||||
"@io_netty_netty_common//jar",
|
||||
":handshaker_java_proto",
|
||||
":handshaker_java_grpc",
|
||||
],
|
||||
)
|
||||
|
||||
java_library(
|
||||
name = "alts",
|
||||
srcs = glob([
|
||||
"src/main/java/io/grpc/alts/*.java",
|
||||
]),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//core",
|
||||
"//core:internal",
|
||||
"//netty",
|
||||
"//stub",
|
||||
"@com_google_code_findbugs_jsr305//jar",
|
||||
"@com_google_guava_guava//jar",
|
||||
"@com_google_protobuf//:protobuf_java",
|
||||
"@com_google_protobuf//:protobuf_java_util",
|
||||
"@io_netty_netty_buffer//jar",
|
||||
"@io_netty_netty_codec//jar",
|
||||
"@io_netty_netty_common//jar",
|
||||
"@io_netty_netty_transport//jar",
|
||||
"@org_apache_commons_commons_lang3//jar",
|
||||
":alts_tsi",
|
||||
":handshaker_java_proto",
|
||||
":handshaker_java_grpc",
|
||||
],
|
||||
)
|
||||
|
||||
# bazel only accepts proto import with absolute path.
|
||||
genrule(
|
||||
name = "protobuf_imports",
|
||||
srcs = glob(["src/main/proto/*.proto"]),
|
||||
outs = [
|
||||
"protobuf_out/altscontext.proto",
|
||||
"protobuf_out/handshaker.proto",
|
||||
"protobuf_out/transport_security_common.proto",
|
||||
],
|
||||
cmd = "for fname in $(SRCS); do " +
|
||||
"sed 's,import \",import \"alts/protobuf_out/,g' $$fname > " +
|
||||
"$(@D)/protobuf_out/$$(basename $$fname); done",
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "handshaker_proto",
|
||||
srcs = [
|
||||
"protobuf_out/altscontext.proto",
|
||||
"protobuf_out/handshaker.proto",
|
||||
"protobuf_out/transport_security_common.proto",
|
||||
],
|
||||
)
|
||||
|
||||
java_proto_library(
|
||||
name = "handshaker_java_proto",
|
||||
deps = [":handshaker_proto"],
|
||||
)
|
||||
|
||||
java_grpc_library(
|
||||
name = "handshaker_java_grpc",
|
||||
srcs = [":handshaker_proto"],
|
||||
deps = [":handshaker_java_proto"],
|
||||
)
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
description = "gRPC: ALTS"
|
||||
|
||||
sourceCompatibility = 1.8
|
||||
targetCompatibility = 1.8
|
||||
|
||||
buildscript {
|
||||
repositories {
|
||||
mavenCentral()
|
||||
}
|
||||
dependencies {
|
||||
classpath libraries.protobuf_plugin
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
compile project(':grpc-core'),
|
||||
project(':grpc-netty'),
|
||||
project(':grpc-protobuf'),
|
||||
project(':grpc-stub'),
|
||||
libraries.lang,
|
||||
libraries.protobuf
|
||||
testCompile libraries.guava_testlib,
|
||||
libraries.junit,
|
||||
libraries.mockito,
|
||||
libraries.truth
|
||||
}
|
||||
|
||||
configureProtoCompilation()
|
||||
|
||||
[compileJava, compileTestJava].each() {
|
||||
// ALTS retuns a lot of futures that we mostly don't care about.
|
||||
// protobuf calls valueof. Will be fixed in next release (google/protobuf#4046)
|
||||
it.options.compilerArgs += ["-Xlint:-deprecation", "-Xep:FutureReturnValueIgnored:OFF"]
|
||||
}
|
||||
|
||||
idea {
|
||||
module {
|
||||
sourceDirs += file("${projectDir}/src/generated/main/grpc");
|
||||
sourceDirs += file("${projectDir}/src/generated/main/java");
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,278 @@
|
|||
package io.grpc.alts;
|
||||
|
||||
import static io.grpc.MethodDescriptor.generateFullMethodName;
|
||||
import static io.grpc.stub.ClientCalls.asyncBidiStreamingCall;
|
||||
import static io.grpc.stub.ClientCalls.asyncClientStreamingCall;
|
||||
import static io.grpc.stub.ClientCalls.asyncServerStreamingCall;
|
||||
import static io.grpc.stub.ClientCalls.asyncUnaryCall;
|
||||
import static io.grpc.stub.ClientCalls.blockingServerStreamingCall;
|
||||
import static io.grpc.stub.ClientCalls.blockingUnaryCall;
|
||||
import static io.grpc.stub.ClientCalls.futureUnaryCall;
|
||||
import static io.grpc.stub.ServerCalls.asyncBidiStreamingCall;
|
||||
import static io.grpc.stub.ServerCalls.asyncClientStreamingCall;
|
||||
import static io.grpc.stub.ServerCalls.asyncServerStreamingCall;
|
||||
import static io.grpc.stub.ServerCalls.asyncUnaryCall;
|
||||
import static io.grpc.stub.ServerCalls.asyncUnimplementedStreamingCall;
|
||||
import static io.grpc.stub.ServerCalls.asyncUnimplementedUnaryCall;
|
||||
|
||||
/**
|
||||
*/
|
||||
@javax.annotation.Generated(
|
||||
value = "by gRPC proto compiler",
|
||||
comments = "Source: handshaker.proto")
|
||||
public final class HandshakerServiceGrpc {
|
||||
|
||||
private HandshakerServiceGrpc() {}
|
||||
|
||||
public static final String SERVICE_NAME = "grpc.gcp.HandshakerService";
|
||||
|
||||
// Static method descriptors that strictly reflect the proto.
|
||||
@io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901")
|
||||
@java.lang.Deprecated // Use {@link #getDoHandshakeMethod()} instead.
|
||||
public static final io.grpc.MethodDescriptor<io.grpc.alts.Handshaker.HandshakerReq,
|
||||
io.grpc.alts.Handshaker.HandshakerResp> METHOD_DO_HANDSHAKE = getDoHandshakeMethodHelper();
|
||||
|
||||
private static volatile io.grpc.MethodDescriptor<io.grpc.alts.Handshaker.HandshakerReq,
|
||||
io.grpc.alts.Handshaker.HandshakerResp> getDoHandshakeMethod;
|
||||
|
||||
@io.grpc.ExperimentalApi("https://github.com/grpc/grpc-java/issues/1901")
|
||||
public static io.grpc.MethodDescriptor<io.grpc.alts.Handshaker.HandshakerReq,
|
||||
io.grpc.alts.Handshaker.HandshakerResp> getDoHandshakeMethod() {
|
||||
return getDoHandshakeMethodHelper();
|
||||
}
|
||||
|
||||
private static io.grpc.MethodDescriptor<io.grpc.alts.Handshaker.HandshakerReq,
|
||||
io.grpc.alts.Handshaker.HandshakerResp> getDoHandshakeMethodHelper() {
|
||||
io.grpc.MethodDescriptor<io.grpc.alts.Handshaker.HandshakerReq, io.grpc.alts.Handshaker.HandshakerResp> getDoHandshakeMethod;
|
||||
if ((getDoHandshakeMethod = HandshakerServiceGrpc.getDoHandshakeMethod) == null) {
|
||||
synchronized (HandshakerServiceGrpc.class) {
|
||||
if ((getDoHandshakeMethod = HandshakerServiceGrpc.getDoHandshakeMethod) == null) {
|
||||
HandshakerServiceGrpc.getDoHandshakeMethod = getDoHandshakeMethod =
|
||||
io.grpc.MethodDescriptor.<io.grpc.alts.Handshaker.HandshakerReq, io.grpc.alts.Handshaker.HandshakerResp>newBuilder()
|
||||
.setType(io.grpc.MethodDescriptor.MethodType.BIDI_STREAMING)
|
||||
.setFullMethodName(generateFullMethodName(
|
||||
"grpc.gcp.HandshakerService", "DoHandshake"))
|
||||
.setSampledToLocalTracing(true)
|
||||
.setRequestMarshaller(io.grpc.protobuf.ProtoUtils.marshaller(
|
||||
io.grpc.alts.Handshaker.HandshakerReq.getDefaultInstance()))
|
||||
.setResponseMarshaller(io.grpc.protobuf.ProtoUtils.marshaller(
|
||||
io.grpc.alts.Handshaker.HandshakerResp.getDefaultInstance()))
|
||||
.setSchemaDescriptor(new HandshakerServiceMethodDescriptorSupplier("DoHandshake"))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
}
|
||||
return getDoHandshakeMethod;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new async stub that supports all call types for the service
|
||||
*/
|
||||
public static HandshakerServiceStub newStub(io.grpc.Channel channel) {
|
||||
return new HandshakerServiceStub(channel);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new blocking-style stub that supports unary and streaming output calls on the service
|
||||
*/
|
||||
public static HandshakerServiceBlockingStub newBlockingStub(
|
||||
io.grpc.Channel channel) {
|
||||
return new HandshakerServiceBlockingStub(channel);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new ListenableFuture-style stub that supports unary calls on the service
|
||||
*/
|
||||
public static HandshakerServiceFutureStub newFutureStub(
|
||||
io.grpc.Channel channel) {
|
||||
return new HandshakerServiceFutureStub(channel);
|
||||
}
|
||||
|
||||
/**
|
||||
*/
|
||||
public static abstract class HandshakerServiceImplBase implements io.grpc.BindableService {
|
||||
|
||||
/**
|
||||
* <pre>
|
||||
* Accepts a stream of handshaker request, returning a stream of handshaker
|
||||
* response.
|
||||
* </pre>
|
||||
*/
|
||||
public io.grpc.stub.StreamObserver<io.grpc.alts.Handshaker.HandshakerReq> doHandshake(
|
||||
io.grpc.stub.StreamObserver<io.grpc.alts.Handshaker.HandshakerResp> responseObserver) {
|
||||
return asyncUnimplementedStreamingCall(getDoHandshakeMethodHelper(), responseObserver);
|
||||
}
|
||||
|
||||
@java.lang.Override public final io.grpc.ServerServiceDefinition bindService() {
|
||||
return io.grpc.ServerServiceDefinition.builder(getServiceDescriptor())
|
||||
.addMethod(
|
||||
getDoHandshakeMethodHelper(),
|
||||
asyncBidiStreamingCall(
|
||||
new MethodHandlers<
|
||||
io.grpc.alts.Handshaker.HandshakerReq,
|
||||
io.grpc.alts.Handshaker.HandshakerResp>(
|
||||
this, METHODID_DO_HANDSHAKE)))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
*/
|
||||
public static final class HandshakerServiceStub extends io.grpc.stub.AbstractStub<HandshakerServiceStub> {
|
||||
private HandshakerServiceStub(io.grpc.Channel channel) {
|
||||
super(channel);
|
||||
}
|
||||
|
||||
private HandshakerServiceStub(io.grpc.Channel channel,
|
||||
io.grpc.CallOptions callOptions) {
|
||||
super(channel, callOptions);
|
||||
}
|
||||
|
||||
@java.lang.Override
|
||||
protected HandshakerServiceStub build(io.grpc.Channel channel,
|
||||
io.grpc.CallOptions callOptions) {
|
||||
return new HandshakerServiceStub(channel, callOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* <pre>
|
||||
* Accepts a stream of handshaker request, returning a stream of handshaker
|
||||
* response.
|
||||
* </pre>
|
||||
*/
|
||||
public io.grpc.stub.StreamObserver<io.grpc.alts.Handshaker.HandshakerReq> doHandshake(
|
||||
io.grpc.stub.StreamObserver<io.grpc.alts.Handshaker.HandshakerResp> responseObserver) {
|
||||
return asyncBidiStreamingCall(
|
||||
getChannel().newCall(getDoHandshakeMethodHelper(), getCallOptions()), responseObserver);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
*/
|
||||
public static final class HandshakerServiceBlockingStub extends io.grpc.stub.AbstractStub<HandshakerServiceBlockingStub> {
|
||||
private HandshakerServiceBlockingStub(io.grpc.Channel channel) {
|
||||
super(channel);
|
||||
}
|
||||
|
||||
private HandshakerServiceBlockingStub(io.grpc.Channel channel,
|
||||
io.grpc.CallOptions callOptions) {
|
||||
super(channel, callOptions);
|
||||
}
|
||||
|
||||
@java.lang.Override
|
||||
protected HandshakerServiceBlockingStub build(io.grpc.Channel channel,
|
||||
io.grpc.CallOptions callOptions) {
|
||||
return new HandshakerServiceBlockingStub(channel, callOptions);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
*/
|
||||
public static final class HandshakerServiceFutureStub extends io.grpc.stub.AbstractStub<HandshakerServiceFutureStub> {
|
||||
private HandshakerServiceFutureStub(io.grpc.Channel channel) {
|
||||
super(channel);
|
||||
}
|
||||
|
||||
private HandshakerServiceFutureStub(io.grpc.Channel channel,
|
||||
io.grpc.CallOptions callOptions) {
|
||||
super(channel, callOptions);
|
||||
}
|
||||
|
||||
@java.lang.Override
|
||||
protected HandshakerServiceFutureStub build(io.grpc.Channel channel,
|
||||
io.grpc.CallOptions callOptions) {
|
||||
return new HandshakerServiceFutureStub(channel, callOptions);
|
||||
}
|
||||
}
|
||||
|
||||
private static final int METHODID_DO_HANDSHAKE = 0;
|
||||
|
||||
private static final class MethodHandlers<Req, Resp> implements
|
||||
io.grpc.stub.ServerCalls.UnaryMethod<Req, Resp>,
|
||||
io.grpc.stub.ServerCalls.ServerStreamingMethod<Req, Resp>,
|
||||
io.grpc.stub.ServerCalls.ClientStreamingMethod<Req, Resp>,
|
||||
io.grpc.stub.ServerCalls.BidiStreamingMethod<Req, Resp> {
|
||||
private final HandshakerServiceImplBase serviceImpl;
|
||||
private final int methodId;
|
||||
|
||||
MethodHandlers(HandshakerServiceImplBase serviceImpl, int methodId) {
|
||||
this.serviceImpl = serviceImpl;
|
||||
this.methodId = methodId;
|
||||
}
|
||||
|
||||
@java.lang.Override
|
||||
@java.lang.SuppressWarnings("unchecked")
|
||||
public void invoke(Req request, io.grpc.stub.StreamObserver<Resp> responseObserver) {
|
||||
switch (methodId) {
|
||||
default:
|
||||
throw new AssertionError();
|
||||
}
|
||||
}
|
||||
|
||||
@java.lang.Override
|
||||
@java.lang.SuppressWarnings("unchecked")
|
||||
public io.grpc.stub.StreamObserver<Req> invoke(
|
||||
io.grpc.stub.StreamObserver<Resp> responseObserver) {
|
||||
switch (methodId) {
|
||||
case METHODID_DO_HANDSHAKE:
|
||||
return (io.grpc.stub.StreamObserver<Req>) serviceImpl.doHandshake(
|
||||
(io.grpc.stub.StreamObserver<io.grpc.alts.Handshaker.HandshakerResp>) responseObserver);
|
||||
default:
|
||||
throw new AssertionError();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static abstract class HandshakerServiceBaseDescriptorSupplier
|
||||
implements io.grpc.protobuf.ProtoFileDescriptorSupplier, io.grpc.protobuf.ProtoServiceDescriptorSupplier {
|
||||
HandshakerServiceBaseDescriptorSupplier() {}
|
||||
|
||||
@java.lang.Override
|
||||
public com.google.protobuf.Descriptors.FileDescriptor getFileDescriptor() {
|
||||
return io.grpc.alts.Handshaker.getDescriptor();
|
||||
}
|
||||
|
||||
@java.lang.Override
|
||||
public com.google.protobuf.Descriptors.ServiceDescriptor getServiceDescriptor() {
|
||||
return getFileDescriptor().findServiceByName("HandshakerService");
|
||||
}
|
||||
}
|
||||
|
||||
private static final class HandshakerServiceFileDescriptorSupplier
|
||||
extends HandshakerServiceBaseDescriptorSupplier {
|
||||
HandshakerServiceFileDescriptorSupplier() {}
|
||||
}
|
||||
|
||||
private static final class HandshakerServiceMethodDescriptorSupplier
|
||||
extends HandshakerServiceBaseDescriptorSupplier
|
||||
implements io.grpc.protobuf.ProtoMethodDescriptorSupplier {
|
||||
private final String methodName;
|
||||
|
||||
HandshakerServiceMethodDescriptorSupplier(String methodName) {
|
||||
this.methodName = methodName;
|
||||
}
|
||||
|
||||
@java.lang.Override
|
||||
public com.google.protobuf.Descriptors.MethodDescriptor getMethodDescriptor() {
|
||||
return getServiceDescriptor().findMethodByName(methodName);
|
||||
}
|
||||
}
|
||||
|
||||
private static volatile io.grpc.ServiceDescriptor serviceDescriptor;
|
||||
|
||||
public static io.grpc.ServiceDescriptor getServiceDescriptor() {
|
||||
io.grpc.ServiceDescriptor result = serviceDescriptor;
|
||||
if (result == null) {
|
||||
synchronized (HandshakerServiceGrpc.class) {
|
||||
result = serviceDescriptor;
|
||||
if (result == null) {
|
||||
serviceDescriptor = result = io.grpc.ServiceDescriptor.newBuilder(SERVICE_NAME)
|
||||
.setSchemaDescriptor(new HandshakerServiceFileDescriptorSupplier())
|
||||
.addMethod(getDoHandshakeMethodHelper())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,252 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts;
|
||||
|
||||
import static com.google.common.base.Preconditions.checkArgument;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.grpc.CallOptions;
|
||||
import io.grpc.ClientCall;
|
||||
import io.grpc.ConnectivityState;
|
||||
import io.grpc.ForwardingChannelBuilder;
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.ManagedChannelBuilder;
|
||||
import io.grpc.MethodDescriptor;
|
||||
import io.grpc.alts.transportsecurity.AltsClientOptions;
|
||||
import io.grpc.alts.transportsecurity.AltsTsiHandshaker;
|
||||
import io.grpc.alts.transportsecurity.TsiHandshaker;
|
||||
import io.grpc.alts.transportsecurity.TsiHandshakerFactory;
|
||||
import io.grpc.internal.GrpcUtil;
|
||||
import io.grpc.internal.ProxyParameters;
|
||||
import io.grpc.netty.InternalNettyChannelBuilder;
|
||||
import io.grpc.netty.InternalNettyChannelBuilder.TransportCreationParamsFilter;
|
||||
import io.grpc.netty.InternalNettyChannelBuilder.TransportCreationParamsFilterFactory;
|
||||
import io.grpc.netty.NettyChannelBuilder;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.net.SocketAddress;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
/**
|
||||
* ALTS version of {@code ManagedChannelBuilder}. This class sets up a secure and authenticated
|
||||
* commmunication between two cloud VMs using ALTS.
|
||||
*/
|
||||
public final class AltsChannelBuilder extends ForwardingChannelBuilder<AltsChannelBuilder> {
|
||||
|
||||
private final NettyChannelBuilder delegate;
|
||||
private final AltsClientOptions.Builder handshakerOptionsBuilder =
|
||||
new AltsClientOptions.Builder();
|
||||
private TcpfFactory tcpfFactoryForTest;
|
||||
private boolean enableUntrustedAlts;
|
||||
|
||||
/** "Overrides" the static method in {@link ManagedChannelBuilder}. */
|
||||
public static final AltsChannelBuilder forTarget(String target) {
|
||||
return new AltsChannelBuilder(target);
|
||||
}
|
||||
|
||||
/** "Overrides" the static method in {@link ManagedChannelBuilder}. */
|
||||
public static AltsChannelBuilder forAddress(String name, int port) {
|
||||
return forTarget(GrpcUtil.authorityFromHostAndPort(name, port));
|
||||
}
|
||||
|
||||
private AltsChannelBuilder(String target) {
|
||||
delegate =
|
||||
NettyChannelBuilder.forTarget(target)
|
||||
.keepAliveTime(20, TimeUnit.SECONDS)
|
||||
.keepAliveTimeout(10, TimeUnit.SECONDS)
|
||||
.keepAliveWithoutCalls(true);
|
||||
handshakerOptionsBuilder.setRpcProtocolVersions(
|
||||
RpcProtocolVersionsUtil.getRpcProtocolVersions());
|
||||
}
|
||||
|
||||
/** The server service account name for secure name checking. */
|
||||
public AltsChannelBuilder withSecureNamingTarget(String targetName) {
|
||||
handshakerOptionsBuilder.setTargetName(targetName);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds an expected target service accounts. One of the added service accounts should match peer
|
||||
* service account in the handshaker result. Otherwise, the handshake fails.
|
||||
*/
|
||||
public AltsChannelBuilder addTargetServiceAccount(String targetServiceAccount) {
|
||||
handshakerOptionsBuilder.addTargetServiceAccount(targetServiceAccount);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Enables untrusted ALTS for testing. If this function is called, we will not check whether ALTS
|
||||
* is running on Google Cloud Platform.
|
||||
*/
|
||||
public AltsChannelBuilder enableUntrustedAltsForTesting() {
|
||||
enableUntrustedAlts = true;
|
||||
return this;
|
||||
}
|
||||
|
||||
/** Sets a new handshaker service address for testing. */
|
||||
public AltsChannelBuilder setHandshakerAddressForTesting(String handshakerAddress) {
|
||||
HandshakerServiceChannel.setHandshakerAddressForTesting(handshakerAddress);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NettyChannelBuilder delegate() {
|
||||
return delegate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ManagedChannel build() {
|
||||
CheckGcpEnvironment.check(enableUntrustedAlts);
|
||||
TcpfFactory tcpfFactory = new TcpfFactory();
|
||||
InternalNettyChannelBuilder.setDynamicTransportParamsFactory(delegate(), tcpfFactory);
|
||||
|
||||
tcpfFactoryForTest = tcpfFactory;
|
||||
|
||||
return new AltsChannel(delegate().build());
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
@Nullable
|
||||
TransportCreationParamsFilterFactory getTcpfFactoryForTest() {
|
||||
return tcpfFactoryForTest;
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
@Nullable
|
||||
AltsClientOptions getAltsClientOptionsForTest() {
|
||||
if (tcpfFactoryForTest == null) {
|
||||
return null;
|
||||
}
|
||||
return tcpfFactoryForTest.handshakerOptions;
|
||||
}
|
||||
|
||||
private final class TcpfFactory implements TransportCreationParamsFilterFactory {
|
||||
final AltsClientOptions handshakerOptions = handshakerOptionsBuilder.build();
|
||||
|
||||
private final TsiHandshakerFactory altsHandshakerFactory =
|
||||
new TsiHandshakerFactory() {
|
||||
@Override
|
||||
public TsiHandshaker newHandshaker() {
|
||||
// Used the shared grpc channel to connecting to the ALTS handshaker service.
|
||||
ManagedChannel channel = HandshakerServiceChannel.get();
|
||||
return AltsTsiHandshaker.newClient(
|
||||
HandshakerServiceGrpc.newStub(channel), handshakerOptions);
|
||||
}
|
||||
};
|
||||
|
||||
@Override
|
||||
public TransportCreationParamsFilter create(
|
||||
SocketAddress serverAddress,
|
||||
final String authority,
|
||||
final String userAgent,
|
||||
final ProxyParameters proxy) {
|
||||
checkArgument(
|
||||
serverAddress instanceof InetSocketAddress,
|
||||
"%s must be a InetSocketAddress",
|
||||
serverAddress);
|
||||
final AltsProtocolNegotiator negotiator =
|
||||
AltsProtocolNegotiator.create(altsHandshakerFactory);
|
||||
return new TransportCreationParamsFilter() {
|
||||
@Override
|
||||
public SocketAddress getTargetServerAddress() {
|
||||
return serverAddress;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getAuthority() {
|
||||
return authority;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getUserAgent() {
|
||||
return userAgent;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AltsProtocolNegotiator getProtocolNegotiator() {
|
||||
return negotiator;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
static final class AltsChannel extends ManagedChannel {
|
||||
private final ManagedChannel delegate;
|
||||
|
||||
AltsChannel(ManagedChannel delegate) {
|
||||
this.delegate = delegate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ConnectivityState getState(boolean requestConnection) {
|
||||
return delegate.getState(requestConnection);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void notifyWhenStateChanged(ConnectivityState source, Runnable callback) {
|
||||
delegate.notifyWhenStateChanged(source, callback);
|
||||
}
|
||||
|
||||
@Override
|
||||
public AltsChannel shutdown() {
|
||||
delegate.shutdown();
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isShutdown() {
|
||||
return delegate.isShutdown();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isTerminated() {
|
||||
return delegate.isTerminated();
|
||||
}
|
||||
|
||||
@Override
|
||||
public AltsChannel shutdownNow() {
|
||||
delegate.shutdownNow();
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
|
||||
return delegate.awaitTermination(timeout, unit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(
|
||||
MethodDescriptor<RequestT, ResponseT> methodDescriptor, CallOptions callOptions) {
|
||||
return delegate.newCall(methodDescriptor, callOptions);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String authority() {
|
||||
return delegate.authority();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void resetConnectBackoff() {
|
||||
delegate.resetConnectBackoff();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void prepareToLoseNetwork() {
|
||||
delegate.prepareToLoseNetwork();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,142 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.grpc.Attributes;
|
||||
import io.grpc.Grpc;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.alts.InternalTsiHandshakeHandler.TsiHandshakeCompletionEvent;
|
||||
import io.grpc.alts.RpcProtocolVersionsUtil.RpcVersionsCheckResult;
|
||||
import io.grpc.alts.transportsecurity.AltsAuthContext;
|
||||
import io.grpc.alts.transportsecurity.TsiHandshakerFactory;
|
||||
import io.grpc.alts.transportsecurity.TsiPeer;
|
||||
import io.grpc.netty.GrpcHttp2ConnectionHandler;
|
||||
import io.grpc.netty.ProtocolNegotiator;
|
||||
import io.grpc.netty.ProtocolNegotiators.AbstractBufferingHandler;
|
||||
import io.netty.channel.ChannelHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.util.AsciiString;
|
||||
|
||||
/**
|
||||
* A client-side GRPC {@link ProtocolNegotiator} for ALTS. This class creates a Netty handler that
|
||||
* provides ALTS security on the wire, similar to Netty's {@code SslHandler}.
|
||||
*/
|
||||
public abstract class AltsProtocolNegotiator implements ProtocolNegotiator {
|
||||
|
||||
private static final Attributes.Key<TsiPeer> TSI_PEER_KEY = Attributes.Key.of("TSI_PEER");
|
||||
private static final Attributes.Key<AltsAuthContext> ALTS_CONTEXT_KEY =
|
||||
Attributes.Key.of("ALTS_CONTEXT_KEY");
|
||||
private static final AsciiString scheme = AsciiString.of("https");
|
||||
|
||||
public static Attributes.Key<TsiPeer> getTsiPeerAttributeKey() {
|
||||
return TSI_PEER_KEY;
|
||||
}
|
||||
|
||||
public static Attributes.Key<AltsAuthContext> getAltsAuthContextAttributeKey() {
|
||||
return ALTS_CONTEXT_KEY;
|
||||
}
|
||||
|
||||
/** Creates a negotiator used for ALTS. */
|
||||
public static AltsProtocolNegotiator create(TsiHandshakerFactory handshakerFactory) {
|
||||
return new AltsProtocolNegotiator() {
|
||||
@Override
|
||||
public Handler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
|
||||
return new BufferUntilAltsNegotiatedHandler(
|
||||
grpcHandler,
|
||||
new InternalTsiHandshakeHandler(
|
||||
new InternalNettyTsiHandshaker(handshakerFactory.newHandshaker())),
|
||||
new InternalTsiFrameHandler());
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/** Buffers all writes until the ALTS handshake is complete. */
|
||||
@VisibleForTesting
|
||||
static class BufferUntilAltsNegotiatedHandler extends AbstractBufferingHandler
|
||||
implements ProtocolNegotiator.Handler {
|
||||
|
||||
private final GrpcHttp2ConnectionHandler grpcHandler;
|
||||
|
||||
BufferUntilAltsNegotiatedHandler(
|
||||
GrpcHttp2ConnectionHandler grpcHandler, ChannelHandler... negotiationhandlers) {
|
||||
super(negotiationhandlers);
|
||||
// Save the gRPC handler. The ALTS handler doesn't support buffering before the handshake
|
||||
// completes, so we wait until the handshake was successful before adding the grpc handler.
|
||||
this.grpcHandler = grpcHandler;
|
||||
}
|
||||
|
||||
// TODO: Remove this once https://github.com/grpc/grpc-java/pull/3715 is in.
|
||||
@Override
|
||||
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
|
||||
fail(ctx, cause);
|
||||
ctx.fireExceptionCaught(cause);
|
||||
}
|
||||
|
||||
@Override
|
||||
public AsciiString scheme() {
|
||||
return scheme;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
|
||||
if (evt instanceof TsiHandshakeCompletionEvent) {
|
||||
TsiHandshakeCompletionEvent altsEvt = (TsiHandshakeCompletionEvent) evt;
|
||||
if (altsEvt.isSuccess()) {
|
||||
// Add the gRPC handler just before this handler. We only allow the grpcHandler to be
|
||||
// null to support testing. In production, a grpc handler will always be provided.
|
||||
if (grpcHandler != null) {
|
||||
ctx.pipeline().addBefore(ctx.name(), null, grpcHandler);
|
||||
AltsAuthContext altsContext = (AltsAuthContext) altsEvt.context();
|
||||
Preconditions.checkNotNull(altsContext);
|
||||
// Checks peer Rpc Protocol Versions in the ALTS auth context. Fails the connection if
|
||||
// Rpc Protocol Versions mismatch.
|
||||
RpcVersionsCheckResult checkResult =
|
||||
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
|
||||
RpcProtocolVersionsUtil.getRpcProtocolVersions(),
|
||||
altsContext.getPeerRpcVersions());
|
||||
if (!checkResult.getResult()) {
|
||||
String errorMessage =
|
||||
"Local Rpc Protocol Versions "
|
||||
+ RpcProtocolVersionsUtil.getRpcProtocolVersions().toString()
|
||||
+ "are not compatible with peer Rpc Protocol Versions "
|
||||
+ altsContext.getPeerRpcVersions().toString();
|
||||
fail(ctx, Status.UNAVAILABLE.withDescription(errorMessage).asRuntimeException());
|
||||
}
|
||||
grpcHandler.handleProtocolNegotiationCompleted(
|
||||
Attributes.newBuilder()
|
||||
.set(TSI_PEER_KEY, altsEvt.peer())
|
||||
.set(ALTS_CONTEXT_KEY, altsContext)
|
||||
.set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, ctx.channel().remoteAddress())
|
||||
.build());
|
||||
}
|
||||
|
||||
// Now write any buffered data and remove this handler.
|
||||
writeBufferedAndRemove(ctx);
|
||||
} else {
|
||||
fail(ctx, unavailableException("ALTS handshake failed", altsEvt.cause()));
|
||||
}
|
||||
}
|
||||
super.userEventTriggered(ctx, evt);
|
||||
}
|
||||
|
||||
private static RuntimeException unavailableException(String msg, Throwable cause) {
|
||||
return Status.UNAVAILABLE.withCause(cause).withDescription(msg).asRuntimeException();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,247 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts;
|
||||
|
||||
import io.grpc.BindableService;
|
||||
import io.grpc.CompressorRegistry;
|
||||
import io.grpc.DecompressorRegistry;
|
||||
import io.grpc.HandlerRegistry;
|
||||
import io.grpc.Server;
|
||||
import io.grpc.ServerBuilder;
|
||||
import io.grpc.ServerInterceptor;
|
||||
import io.grpc.ServerServiceDefinition;
|
||||
import io.grpc.ServerStreamTracer.Factory;
|
||||
import io.grpc.ServerTransportFilter;
|
||||
import io.grpc.alts.transportsecurity.AltsHandshakerOptions;
|
||||
import io.grpc.alts.transportsecurity.AltsTsiHandshaker;
|
||||
import io.grpc.alts.transportsecurity.TsiHandshaker;
|
||||
import io.grpc.alts.transportsecurity.TsiHandshakerFactory;
|
||||
import io.grpc.netty.NettyServerBuilder;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Executor;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* gRPC secure server builder used for ALTS. This class adds on the necessary ALTS support to create
|
||||
* a production server on Google Cloud Platform.
|
||||
*/
|
||||
public final class AltsServerBuilder extends ServerBuilder<AltsServerBuilder> {
|
||||
|
||||
final NettyServerBuilder delegate;
|
||||
private boolean enableUntrustedAlts;
|
||||
|
||||
private AltsServerBuilder(NettyServerBuilder nettyDelegate) {
|
||||
this.delegate = nettyDelegate;
|
||||
}
|
||||
|
||||
/** Creates a gRPC server builder for the given port. */
|
||||
public static AltsServerBuilder forPort(int port) {
|
||||
NettyServerBuilder nettyDelegate =
|
||||
NettyServerBuilder.forAddress(new InetSocketAddress(port))
|
||||
.maxConnectionIdle(1, TimeUnit.HOURS)
|
||||
.keepAliveTime(270, TimeUnit.SECONDS)
|
||||
.keepAliveTimeout(20, TimeUnit.SECONDS)
|
||||
.permitKeepAliveTime(10, TimeUnit.SECONDS)
|
||||
.permitKeepAliveWithoutCalls(true);
|
||||
return new AltsServerBuilder(nettyDelegate);
|
||||
}
|
||||
|
||||
/**
|
||||
* Enables untrusted ALTS for testing. If this function is called, we will not check whether ALTS
|
||||
* is running on Google Cloud Platform.
|
||||
*/
|
||||
public AltsServerBuilder enableUntrustedAltsForTesting() {
|
||||
enableUntrustedAlts = true;
|
||||
return this;
|
||||
}
|
||||
|
||||
/** Sets a new handshaker service address for testing. */
|
||||
public AltsServerBuilder setHandshakerAddressForTesting(String handshakerAddress) {
|
||||
HandshakerServiceChannel.setHandshakerAddressForTesting(handshakerAddress);
|
||||
return this;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public AltsServerBuilder handshakeTimeout(long timeout, TimeUnit unit) {
|
||||
delegate.handshakeTimeout(timeout, unit);
|
||||
return this;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public AltsServerBuilder directExecutor() {
|
||||
delegate.directExecutor();
|
||||
return this;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public AltsServerBuilder addStreamTracerFactory(Factory factory) {
|
||||
delegate.addStreamTracerFactory(factory);
|
||||
return this;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public AltsServerBuilder addTransportFilter(ServerTransportFilter filter) {
|
||||
delegate.addTransportFilter(filter);
|
||||
return this;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public AltsServerBuilder executor(Executor executor) {
|
||||
delegate.executor(executor);
|
||||
return this;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public AltsServerBuilder addService(ServerServiceDefinition service) {
|
||||
delegate.addService(service);
|
||||
return this;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public AltsServerBuilder addService(BindableService bindableService) {
|
||||
delegate.addService(bindableService);
|
||||
return this;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public AltsServerBuilder fallbackHandlerRegistry(HandlerRegistry fallbackRegistry) {
|
||||
delegate.fallbackHandlerRegistry(fallbackRegistry);
|
||||
return this;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public AltsServerBuilder useTransportSecurity(File certChain, File privateKey) {
|
||||
throw new UnsupportedOperationException("Can't set TLS settings for ALTS");
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public AltsServerBuilder decompressorRegistry(DecompressorRegistry registry) {
|
||||
delegate.decompressorRegistry(registry);
|
||||
return this;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public AltsServerBuilder compressorRegistry(CompressorRegistry registry) {
|
||||
delegate.compressorRegistry(registry);
|
||||
return this;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public AltsServerBuilder intercept(ServerInterceptor interceptor) {
|
||||
delegate.intercept(interceptor);
|
||||
return this;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public Server build() {
|
||||
CheckGcpEnvironment.check(enableUntrustedAlts);
|
||||
delegate.protocolNegotiator(
|
||||
AltsProtocolNegotiator.create(
|
||||
new TsiHandshakerFactory() {
|
||||
@Override
|
||||
public TsiHandshaker newHandshaker() {
|
||||
// Used the shared grpc channel to connecting to the ALTS handshaker service.
|
||||
return AltsTsiHandshaker.newServer(
|
||||
HandshakerServiceGrpc.newStub(HandshakerServiceChannel.get()),
|
||||
new AltsHandshakerOptions(RpcProtocolVersionsUtil.getRpcProtocolVersions()));
|
||||
}
|
||||
}));
|
||||
return new AltsServer(delegate.build());
|
||||
}
|
||||
|
||||
static final class AltsServer extends io.grpc.Server {
|
||||
private final Server delegate;
|
||||
|
||||
AltsServer(Server delegate) {
|
||||
this.delegate = delegate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ServerServiceDefinition> getImmutableServices() {
|
||||
return delegate.getImmutableServices();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ServerServiceDefinition> getMutableServices() {
|
||||
return delegate.getMutableServices();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getPort() {
|
||||
return delegate.getPort();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ServerServiceDefinition> getServices() {
|
||||
return delegate.getServices();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Server start() throws IOException {
|
||||
delegate.start();
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Server shutdown() {
|
||||
delegate.shutdown();
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Server shutdownNow() {
|
||||
delegate.shutdownNow();
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isShutdown() {
|
||||
return delegate.isShutdown();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isTerminated() {
|
||||
return delegate.isTerminated();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
|
||||
return delegate.awaitTermination(timeout, unit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void awaitTermination() throws InterruptedException {
|
||||
delegate.awaitTermination();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,84 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.CompositeByteBuf;
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
/** Unwraps {@link ByteBuf}s into {@link ByteBuffer}s. */
|
||||
final class BufUnwrapper implements AutoCloseable {
|
||||
|
||||
private final ByteBuffer[] singleReadBuffer = new ByteBuffer[1];
|
||||
private final ByteBuffer[] singleWriteBuffer = new ByteBuffer[1];
|
||||
|
||||
/**
|
||||
* Called to get access to the underlying NIO buffers for a {@link ByteBuf} that will be used for
|
||||
* writing.
|
||||
*/
|
||||
ByteBuffer[] writableNioBuffers(ByteBuf buf) {
|
||||
// Set the writer index to the capacity to guarantee that the returned NIO buffers will have
|
||||
// the capacity available.
|
||||
int readerIndex = buf.readerIndex();
|
||||
int writerIndex = buf.writerIndex();
|
||||
buf.readerIndex(writerIndex);
|
||||
buf.writerIndex(buf.capacity());
|
||||
|
||||
try {
|
||||
return nioBuffers(buf, singleWriteBuffer);
|
||||
} finally {
|
||||
// Restore the writer index before returning.
|
||||
buf.readerIndex(readerIndex);
|
||||
buf.writerIndex(writerIndex);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Called to get access to the underlying NIO buffers for a {@link ByteBuf} that will be used for
|
||||
* reading.
|
||||
*/
|
||||
ByteBuffer[] readableNioBuffers(ByteBuf buf) {
|
||||
return nioBuffers(buf, singleReadBuffer);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
singleReadBuffer[0] = null;
|
||||
singleWriteBuffer[0] = null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Optimized accessor for obtaining the underlying NIO buffers for a Netty {@link ByteBuf}. Based
|
||||
* on code from Netty's {@code SslHandler}. This method returns NIO buffers that span the readable
|
||||
* region of the {@link ByteBuf}.
|
||||
*/
|
||||
private static ByteBuffer[] nioBuffers(ByteBuf buf, ByteBuffer[] singleBuffer) {
|
||||
// As CompositeByteBuf.nioBufferCount() can be expensive (as it needs to check all composed
|
||||
// ByteBuf to calculate the count) we will just assume a CompositeByteBuf contains more than 1
|
||||
// ByteBuf. The worst that can happen is that we allocate an extra ByteBuffer[] in
|
||||
// CompositeByteBuf.nioBuffers() which is better than walking the composed ByteBuf in most
|
||||
// cases.
|
||||
if (!(buf instanceof CompositeByteBuf) && buf.nioBufferCount() == 1) {
|
||||
// We know its only backed by 1 ByteBuffer so use internalNioBuffer to keep object
|
||||
// allocation to a minimum.
|
||||
singleBuffer[0] = buf.internalNioBuffer(buf.readerIndex(), buf.readableBytes());
|
||||
return singleBuffer;
|
||||
}
|
||||
|
||||
return buf.nioBuffers();
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts;
|
||||
|
||||
import static java.nio.charset.StandardCharsets.UTF_8;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStreamReader;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.logging.Level;
|
||||
import java.util.logging.Logger;
|
||||
import org.apache.commons.lang3.SystemUtils;
|
||||
|
||||
/** Class for checking if the system is running on Google Cloud Platform (GCP). */
|
||||
final class CheckGcpEnvironment {
|
||||
private static final Logger logger = Logger.getLogger(CheckGcpEnvironment.class.getName());
|
||||
private static final String DMI_PRODUCT_NAME = "/sys/class/dmi/id/product_name";
|
||||
private static final String WINDOWS_COMMAND = "powershell.exe";
|
||||
private static Boolean cachedResult = null;
|
||||
|
||||
// Construct me not!
|
||||
private CheckGcpEnvironment() {}
|
||||
|
||||
public static void check(boolean enableUntrustedAlts) {
|
||||
if (enableUntrustedAlts) {
|
||||
logger.log(
|
||||
Level.WARNING,
|
||||
"Untrusted ALTS mode is enabled and we cannot guarantee the trustworthiness of the ALTS "
|
||||
+ "handshaker service.");
|
||||
} else if (!isOnGcp()) {
|
||||
throw new RuntimeException("ALTS is only allowed to run on Google Cloud Platform.");
|
||||
}
|
||||
}
|
||||
|
||||
private static synchronized boolean isOnGcp() {
|
||||
if (cachedResult == null) {
|
||||
cachedResult = isRunningOnGcp();
|
||||
}
|
||||
return cachedResult;
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
static boolean checkProductNameOnLinux(BufferedReader reader) throws IOException {
|
||||
String name = reader.readLine().trim();
|
||||
return name.equals("Google") || name.equals("Google Compute Engine");
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
static boolean checkBiosDataOnWindows(BufferedReader reader) throws IOException {
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
if (line.startsWith("Manufacturer")) {
|
||||
String name = line.substring(line.indexOf(':') + 1).trim();
|
||||
return name.equals("Google");
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private static boolean isRunningOnGcp() {
|
||||
try {
|
||||
if (SystemUtils.IS_OS_LINUX) {
|
||||
// Checks GCE residency on Linux platform.
|
||||
return checkProductNameOnLinux(Files.newBufferedReader(Paths.get(DMI_PRODUCT_NAME), UTF_8));
|
||||
} else if (SystemUtils.IS_OS_WINDOWS) {
|
||||
// Checks GCE residency on Windows platform.
|
||||
Process p =
|
||||
new ProcessBuilder()
|
||||
.command(WINDOWS_COMMAND, "Get-WmiObject", "-Class", "Win32_BIOS")
|
||||
.start();
|
||||
return checkBiosDataOnWindows(
|
||||
new BufferedReader(new InputStreamReader(p.getInputStream(), UTF_8)));
|
||||
}
|
||||
} catch (IOException e) {
|
||||
logger.log(Level.WARNING, "Fail to read platform information: ", e);
|
||||
return false;
|
||||
}
|
||||
// Platforms other than Linux and Windows are not supported.
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.netty.NettyChannelBuilder;
|
||||
import io.netty.channel.nio.NioEventLoopGroup;
|
||||
import io.netty.util.concurrent.DefaultThreadFactory;
|
||||
import java.util.concurrent.ThreadFactory;
|
||||
|
||||
/**
|
||||
* Class for creating a single shared grpc channel to the ALTS Handshaker Service. The channel to
|
||||
* the handshaker service is local and is over plaintext. Each application will have at most one
|
||||
* connection to the handshaker service.
|
||||
*
|
||||
* <p>TODO: Release the channel if it is not used.
|
||||
*/
|
||||
final class HandshakerServiceChannel {
|
||||
// Default handshaker service address.
|
||||
private static String handshakerAddress = "metadata.google.internal:8080";
|
||||
// Shared channel to ALTS handshaker service.
|
||||
private static ManagedChannel channel = null;
|
||||
|
||||
// Construct me not!
|
||||
private HandshakerServiceChannel() {}
|
||||
|
||||
// Sets handshaker service address for testing and creates the channel to the handshaker service.
|
||||
public static synchronized void setHandshakerAddressForTesting(String handshakerAddress) {
|
||||
Preconditions.checkState(
|
||||
channel == null || HandshakerServiceChannel.handshakerAddress.equals(handshakerAddress),
|
||||
"HandshakerServiceChannel already created with a different handshakerAddress");
|
||||
HandshakerServiceChannel.handshakerAddress = handshakerAddress;
|
||||
if (channel == null) {
|
||||
channel = createChannel();
|
||||
}
|
||||
}
|
||||
|
||||
/** Create a new channel to ALTS handshaker service, if it has not been created yet. */
|
||||
private static ManagedChannel createChannel() {
|
||||
/* Use its own event loop thread pool to avoid blocking. */
|
||||
ThreadFactory clientThreadFactory = new DefaultThreadFactory("handshaker pool", true);
|
||||
ManagedChannel channel =
|
||||
NettyChannelBuilder.forTarget(handshakerAddress)
|
||||
.directExecutor()
|
||||
.eventLoopGroup(new NioEventLoopGroup(1, clientThreadFactory))
|
||||
.usePlaintext(true)
|
||||
.build();
|
||||
return channel;
|
||||
}
|
||||
|
||||
public static synchronized ManagedChannel get() {
|
||||
if (channel == null) {
|
||||
channel = createChannel();
|
||||
}
|
||||
return channel;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,155 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts;
|
||||
|
||||
import static com.google.common.base.Preconditions.checkNotNull;
|
||||
import static com.google.common.base.Preconditions.checkState;
|
||||
|
||||
import io.grpc.Internal;
|
||||
import io.grpc.alts.transportsecurity.TsiFrameProtector;
|
||||
import io.grpc.alts.transportsecurity.TsiHandshaker;
|
||||
import io.grpc.alts.transportsecurity.TsiPeer;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufAllocator;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.security.GeneralSecurityException;
|
||||
|
||||
/** A wrapper for a {@link TsiHandshaker} that accepts netty {@link ByteBuf}s. */
|
||||
@Internal
|
||||
public final class InternalNettyTsiHandshaker {
|
||||
|
||||
private BufUnwrapper unwrapper = new BufUnwrapper();
|
||||
private final TsiHandshaker internalHandshaker;
|
||||
|
||||
public InternalNettyTsiHandshaker(TsiHandshaker handshaker) {
|
||||
internalHandshaker = checkNotNull(handshaker);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets data that is ready to be sent to the to the remote peer. This should be called in a loop
|
||||
* until no bytes are written to the output buffer.
|
||||
*
|
||||
* @param out the buffer to receive the bytes.
|
||||
*/
|
||||
void getBytesToSendToPeer(ByteBuf out) throws GeneralSecurityException {
|
||||
checkState(unwrapper != null, "protector already created");
|
||||
try (BufUnwrapper unwrapper = this.unwrapper) {
|
||||
// Write as many bytes as possible into the buffer.
|
||||
int bytesWritten = 0;
|
||||
for (ByteBuffer nioBuffer : unwrapper.writableNioBuffers(out)) {
|
||||
if (!nioBuffer.hasRemaining()) {
|
||||
// This buffer doesn't have any more space to write, go to the next buffer.
|
||||
continue;
|
||||
}
|
||||
|
||||
int prevPos = nioBuffer.position();
|
||||
internalHandshaker.getBytesToSendToPeer(nioBuffer);
|
||||
bytesWritten += nioBuffer.position() - prevPos;
|
||||
|
||||
// If the buffer position was not changed, the frame has been completely read into the
|
||||
// buffers.
|
||||
if (nioBuffer.position() == prevPos) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
out.writerIndex(out.writerIndex() + bytesWritten);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process handshake data received from the remote peer.
|
||||
*
|
||||
* @return {@code true}, if the handshake has all the data it needs to process and {@code false},
|
||||
* if the method must be called again to complete processing.
|
||||
*/
|
||||
boolean processBytesFromPeer(ByteBuf data) throws GeneralSecurityException {
|
||||
checkState(unwrapper != null, "protector already created");
|
||||
try (BufUnwrapper unwrapper = this.unwrapper) {
|
||||
int bytesRead = 0;
|
||||
boolean done = false;
|
||||
for (ByteBuffer nioBuffer : unwrapper.readableNioBuffers(data)) {
|
||||
if (!nioBuffer.hasRemaining()) {
|
||||
// This buffer has been fully read, continue to the next buffer.
|
||||
continue;
|
||||
}
|
||||
|
||||
int prevPos = nioBuffer.position();
|
||||
done = internalHandshaker.processBytesFromPeer(nioBuffer);
|
||||
bytesRead += nioBuffer.position() - prevPos;
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
data.readerIndex(data.readerIndex() + bytesRead);
|
||||
return done;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if and only if the handshake is still in progress
|
||||
*
|
||||
* @return true, if the handshake is still in progress, false otherwise.
|
||||
*/
|
||||
boolean isInProgress() {
|
||||
return internalHandshaker.isInProgress();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the peer extracted from a completed handshake.
|
||||
*
|
||||
* @return the extracted peer.
|
||||
*/
|
||||
TsiPeer extractPeer() throws GeneralSecurityException {
|
||||
checkState(!internalHandshaker.isInProgress());
|
||||
return internalHandshaker.extractPeer();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the peer extracted from a completed handshake.
|
||||
*
|
||||
* @return the extracted peer.
|
||||
*/
|
||||
Object extractPeerObject() throws GeneralSecurityException {
|
||||
checkState(!internalHandshaker.isInProgress());
|
||||
return internalHandshaker.extractPeerObject();
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a frame protector from a completed handshake. No other methods may be called after the
|
||||
* frame protector is created.
|
||||
*
|
||||
* @param maxFrameSize the requested max frame size, the callee is free to ignore.
|
||||
* @return a new {@link TsiFrameProtector}.
|
||||
*/
|
||||
TsiFrameProtector createFrameProtector(int maxFrameSize, ByteBufAllocator alloc) {
|
||||
unwrapper = null;
|
||||
return internalHandshaker.createFrameProtector(maxFrameSize, alloc);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a frame protector from a completed handshake. No other methods may be called after the
|
||||
* frame protector is created.
|
||||
*
|
||||
* @return a new {@link TsiFrameProtector}.
|
||||
*/
|
||||
TsiFrameProtector createFrameProtector(ByteBufAllocator alloc) {
|
||||
unwrapper = null;
|
||||
return internalHandshaker.createFrameProtector(alloc);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,177 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts;
|
||||
|
||||
import static com.google.common.base.Preconditions.checkNotNull;
|
||||
import static com.google.common.base.Preconditions.checkState;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.grpc.Internal;
|
||||
import io.grpc.alts.InternalTsiHandshakeHandler.TsiHandshakeCompletionEvent;
|
||||
import io.grpc.alts.transportsecurity.TsiFrameProtector;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.channel.ChannelException;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelOutboundHandler;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.channel.PendingWriteQueue;
|
||||
import io.netty.handler.codec.ByteToMessageDecoder;
|
||||
import java.net.SocketAddress;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Future;
|
||||
|
||||
/**
|
||||
* Encrypts and decrypts TSI Frames. Writes are buffered here until {@link #flush} is called. Writes
|
||||
* must not be made before the TSI handshake is complete.
|
||||
*/
|
||||
@Internal
|
||||
public final class InternalTsiFrameHandler extends ByteToMessageDecoder
|
||||
implements ChannelOutboundHandler {
|
||||
|
||||
private TsiFrameProtector protector;
|
||||
private PendingWriteQueue pendingUnprotectedWrites;
|
||||
|
||||
public InternalTsiFrameHandler() {}
|
||||
|
||||
@Override
|
||||
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
|
||||
super.handlerAdded(ctx);
|
||||
assert pendingUnprotectedWrites == null;
|
||||
pendingUnprotectedWrites = new PendingWriteQueue(checkNotNull(ctx));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void userEventTriggered(ChannelHandlerContext ctx, Object event) throws Exception {
|
||||
if (event instanceof TsiHandshakeCompletionEvent) {
|
||||
TsiHandshakeCompletionEvent tsiEvent = (TsiHandshakeCompletionEvent) event;
|
||||
if (tsiEvent.isSuccess()) {
|
||||
setProtector(tsiEvent.protector());
|
||||
}
|
||||
// Ignore errors. Another handler in the pipeline must handle TSI Errors.
|
||||
}
|
||||
// Keep propagating the message, as others may want to read it.
|
||||
super.userEventTriggered(ctx, event);
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
void setProtector(TsiFrameProtector protector) {
|
||||
checkState(this.protector == null);
|
||||
this.protector = checkNotNull(protector);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
|
||||
checkState(protector != null, "Cannot read frames while the TSI handshake is in progress");
|
||||
protector.unprotect(in, out, ctx.alloc());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(ChannelHandlerContext ctx, Object message, ChannelPromise promise)
|
||||
throws Exception {
|
||||
checkState(protector != null, "Cannot write frames while the TSI handshake is in progress");
|
||||
ByteBuf msg = (ByteBuf) message;
|
||||
if (!msg.isReadable()) {
|
||||
// Nothing to encode.
|
||||
@SuppressWarnings("unused") // go/futurereturn-lsc
|
||||
Future<?> possiblyIgnoredError = promise.setSuccess();
|
||||
return;
|
||||
}
|
||||
|
||||
// Just add the message to the pending queue. We'll write it on the next flush.
|
||||
pendingUnprotectedWrites.add(msg, promise);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
|
||||
if (!pendingUnprotectedWrites.isEmpty()) {
|
||||
pendingUnprotectedWrites.removeAndFailAll(
|
||||
new ChannelException("Pending write on removal of TSI handler"));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
|
||||
pendingUnprotectedWrites.removeAndFailAll(cause);
|
||||
super.exceptionCaught(ctx, cause);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) {
|
||||
ctx.bind(localAddress, promise);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void connect(
|
||||
ChannelHandlerContext ctx,
|
||||
SocketAddress remoteAddress,
|
||||
SocketAddress localAddress,
|
||||
ChannelPromise promise) {
|
||||
ctx.connect(remoteAddress, localAddress, promise);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) {
|
||||
ctx.disconnect(promise);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close(ChannelHandlerContext ctx, ChannelPromise promise) {
|
||||
ctx.close(promise);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) {
|
||||
ctx.deregister(promise);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void read(ChannelHandlerContext ctx) {
|
||||
ctx.read();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void flush(ChannelHandlerContext ctx) throws GeneralSecurityException {
|
||||
checkState(protector != null, "Cannot write frames while the TSI handshake is in progress");
|
||||
ProtectedPromise aggregatePromise =
|
||||
new ProtectedPromise(ctx.channel(), ctx.executor(), pendingUnprotectedWrites.size());
|
||||
|
||||
List<ByteBuf> bufs = new ArrayList<>(pendingUnprotectedWrites.size());
|
||||
|
||||
if (pendingUnprotectedWrites.isEmpty()) {
|
||||
// Return early if there's nothing to write. Otherwise protector.protectFlush() below may
|
||||
// not check for "no-data" and go on writing the 0-byte "data" to the socket with the
|
||||
// protection framing.
|
||||
return;
|
||||
}
|
||||
// Drain the unprotected writes.
|
||||
while (!pendingUnprotectedWrites.isEmpty()) {
|
||||
ByteBuf in = (ByteBuf) pendingUnprotectedWrites.current();
|
||||
bufs.add(in.retain());
|
||||
// Remove and release the buffer and add its promise to the aggregate.
|
||||
aggregatePromise.addUnprotectedPromise(pendingUnprotectedWrites.remove());
|
||||
}
|
||||
|
||||
protector.protectFlush(
|
||||
bufs, b -> ctx.writeAndFlush(b, aggregatePromise.newPromise()), ctx.alloc());
|
||||
|
||||
// We're done writing, start the flow of promise events.
|
||||
@SuppressWarnings("unused") // go/futurereturn-lsc
|
||||
Future<?> possiblyIgnoredError = aggregatePromise.doneAllocatingPromises();
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,216 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts;
|
||||
|
||||
import static com.google.common.base.Preconditions.checkNotNull;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.grpc.Internal;
|
||||
import io.grpc.alts.transportsecurity.TsiFrameProtector;
|
||||
import io.grpc.alts.transportsecurity.TsiPeer;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufAllocator;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.handler.codec.ByteToMessageDecoder;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Future;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
/**
|
||||
* Performs The TSI Handshake. When the handshake is complete, it fires a user event with a {@link
|
||||
* TsiHandshakeCompletionEvent} indicating the result of the handshake.
|
||||
*/
|
||||
@Internal
|
||||
public final class InternalTsiHandshakeHandler extends ByteToMessageDecoder {
|
||||
private static final int HANDSHAKE_FRAME_SIZE = 1024;
|
||||
|
||||
private final InternalNettyTsiHandshaker handshaker;
|
||||
private boolean started;
|
||||
|
||||
/**
|
||||
* This buffer doesn't store any state. We just hold onto it in case we end up allocating a buffer
|
||||
* that ends up being unused.
|
||||
*/
|
||||
private ByteBuf buffer;
|
||||
|
||||
public InternalTsiHandshakeHandler(InternalNettyTsiHandshaker handshaker) {
|
||||
this.handshaker = checkNotNull(handshaker);
|
||||
}
|
||||
|
||||
/**
|
||||
* Event that is fired once the TSI handshake is complete, which may be because it was successful
|
||||
* or there was an error.
|
||||
*/
|
||||
public static final class TsiHandshakeCompletionEvent {
|
||||
|
||||
private final Throwable cause;
|
||||
private final TsiPeer peer;
|
||||
private final Object context;
|
||||
private final TsiFrameProtector protector;
|
||||
|
||||
/** Creates a new event that indicates a successful handshake. */
|
||||
@VisibleForTesting
|
||||
TsiHandshakeCompletionEvent(
|
||||
TsiFrameProtector protector, TsiPeer peer, @Nullable Object peerObject) {
|
||||
this.cause = null;
|
||||
this.peer = checkNotNull(peer);
|
||||
this.protector = checkNotNull(protector);
|
||||
this.context = peerObject;
|
||||
}
|
||||
|
||||
/** Creates a new event that indicates an unsuccessful handshake/. */
|
||||
TsiHandshakeCompletionEvent(Throwable cause) {
|
||||
this.cause = checkNotNull(cause);
|
||||
this.peer = null;
|
||||
this.protector = null;
|
||||
this.context = null;
|
||||
}
|
||||
|
||||
/** Return {@code true} if the handshake was successful. */
|
||||
public boolean isSuccess() {
|
||||
return cause == null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the {@link Throwable} if {@link #isSuccess()} returns {@code false} and so the
|
||||
* handshake failed.
|
||||
*/
|
||||
@Nullable
|
||||
public Throwable cause() {
|
||||
return cause;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public TsiPeer peer() {
|
||||
return peer;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public Object context() {
|
||||
return context;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
TsiFrameProtector protector() {
|
||||
return protector;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
|
||||
maybeStart(ctx);
|
||||
super.handlerAdded(ctx);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelActive(ChannelHandlerContext ctx) throws Exception {
|
||||
maybeStart(ctx);
|
||||
super.channelActive(ctx);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
|
||||
close();
|
||||
super.handlerRemoved0(ctx);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
|
||||
ctx.fireUserEventTriggered(new TsiHandshakeCompletionEvent(cause));
|
||||
super.exceptionCaught(ctx, cause);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void decodeLast(ChannelHandlerContext ctx, ByteBuf in, List<Object> out)
|
||||
throws Exception {
|
||||
// TODO: Not sure why override is needed. Investigate if it can be removed.
|
||||
decode(ctx, in, out);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
|
||||
// Process the data. If we need to send more data, do so now.
|
||||
if (handshaker.processBytesFromPeer(in) && handshaker.isInProgress()) {
|
||||
sendHandshake(ctx);
|
||||
}
|
||||
|
||||
// If the handshake is complete, transition to the framing state.
|
||||
if (!handshaker.isInProgress()) {
|
||||
try {
|
||||
ctx.pipeline().remove(this);
|
||||
ctx.fireUserEventTriggered(
|
||||
new TsiHandshakeCompletionEvent(
|
||||
handshaker.createFrameProtector(ctx.alloc()),
|
||||
handshaker.extractPeer(),
|
||||
handshaker.extractPeerObject()));
|
||||
// No need to do anything with the in buffer, it will be re added to the pipeline when this
|
||||
// handler is removed.
|
||||
} finally {
|
||||
close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void maybeStart(ChannelHandlerContext ctx) {
|
||||
if (!started && ctx.channel().isActive()) {
|
||||
started = true;
|
||||
sendHandshake(ctx);
|
||||
}
|
||||
}
|
||||
|
||||
/** Sends as many bytes as are available from the handshaker to the remote peer. */
|
||||
private void sendHandshake(ChannelHandlerContext ctx) {
|
||||
boolean needToFlush = false;
|
||||
|
||||
// Iterate until there is nothing left to write.
|
||||
while (true) {
|
||||
buffer = getOrCreateBuffer(ctx.alloc());
|
||||
try {
|
||||
handshaker.getBytesToSendToPeer(buffer);
|
||||
} catch (GeneralSecurityException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
if (!buffer.isReadable()) {
|
||||
break;
|
||||
}
|
||||
|
||||
needToFlush = true;
|
||||
@SuppressWarnings("unused") // go/futurereturn-lsc
|
||||
Future<?> possiblyIgnoredError = ctx.write(buffer);
|
||||
buffer = null;
|
||||
}
|
||||
|
||||
// If something was written, flush.
|
||||
if (needToFlush) {
|
||||
ctx.flush();
|
||||
}
|
||||
}
|
||||
|
||||
private ByteBuf getOrCreateBuffer(ByteBufAllocator alloc) {
|
||||
if (buffer == null) {
|
||||
buffer = alloc.buffer(HANDSHAKE_FRAME_SIZE);
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
|
||||
private void close() {
|
||||
ReferenceCountUtil.safeRelease(buffer);
|
||||
buffer = null;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,149 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts;
|
||||
|
||||
import static com.google.common.base.Preconditions.checkState;
|
||||
|
||||
import io.netty.channel.Channel;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.channel.DefaultChannelPromise;
|
||||
import io.netty.util.concurrent.EventExecutor;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Promise used when flushing the {@code pendingUnprotectedWrites} queue. It manages the many-to
|
||||
* many relationship between pending unprotected messages and the individual writes. Each protected
|
||||
* frame will be written using the same instance of this promise and it will accumulate the results.
|
||||
* Once all frames have been successfully written (or any failed), all of the promises for the
|
||||
* pending unprotected writes are notified.
|
||||
*
|
||||
* <p>NOTE: this code is based on code in Netty's {@code Http2CodecUtil}.
|
||||
*/
|
||||
final class ProtectedPromise extends DefaultChannelPromise {
|
||||
private final List<ChannelPromise> unprotectedPromises;
|
||||
private int expectedCount;
|
||||
private int successfulCount;
|
||||
private int failureCount;
|
||||
private boolean doneAllocating;
|
||||
|
||||
ProtectedPromise(Channel channel, EventExecutor executor, int numUnprotectedPromises) {
|
||||
super(channel, executor);
|
||||
unprotectedPromises = new ArrayList<>(numUnprotectedPromises);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a promise for a pending unprotected write. This will be notified after all of the writes
|
||||
* complete.
|
||||
*/
|
||||
void addUnprotectedPromise(ChannelPromise promise) {
|
||||
unprotectedPromises.add(promise);
|
||||
}
|
||||
|
||||
/**
|
||||
* Allocate a new promise for the write of a protected frame. This will be used to aggregate the
|
||||
* overall success of the unprotected promises.
|
||||
*
|
||||
* @return {@code this} promise.
|
||||
*/
|
||||
ChannelPromise newPromise() {
|
||||
checkState(!doneAllocating, "Done allocating. No more promises can be allocated.");
|
||||
expectedCount++;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Signify that no more {@link #newPromise()} allocations will be made. The aggregation can not be
|
||||
* successful until this method is called.
|
||||
*
|
||||
* @return {@code this} promise.
|
||||
*/
|
||||
ChannelPromise doneAllocatingPromises() {
|
||||
if (!doneAllocating) {
|
||||
doneAllocating = true;
|
||||
if (successfulCount == expectedCount) {
|
||||
trySuccessInternal(null);
|
||||
return super.setSuccess(null);
|
||||
}
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean tryFailure(Throwable cause) {
|
||||
if (awaitingPromises()) {
|
||||
++failureCount;
|
||||
if (failureCount == 1) {
|
||||
tryFailureInternal(cause);
|
||||
return super.tryFailure(cause);
|
||||
}
|
||||
// TODO: We break the interface a bit here.
|
||||
// Multiple failure events can be processed without issue because this is an aggregation.
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Fail this object if it has not already been failed.
|
||||
*
|
||||
* <p>This method will NOT throw an {@link IllegalStateException} if called multiple times because
|
||||
* that may be expected.
|
||||
*/
|
||||
@Override
|
||||
public ChannelPromise setFailure(Throwable cause) {
|
||||
tryFailure(cause);
|
||||
return this;
|
||||
}
|
||||
|
||||
private boolean awaitingPromises() {
|
||||
return successfulCount + failureCount < expectedCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChannelPromise setSuccess(Void result) {
|
||||
trySuccess(result);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean trySuccess(Void result) {
|
||||
if (awaitingPromises()) {
|
||||
++successfulCount;
|
||||
if (successfulCount == expectedCount && doneAllocating) {
|
||||
trySuccessInternal(result);
|
||||
return super.trySuccess(result);
|
||||
}
|
||||
// TODO: We break the interface a bit here.
|
||||
// Multiple success events can be processed without issue because this is an aggregation.
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private void trySuccessInternal(Void result) {
|
||||
for (int i = 0; i < unprotectedPromises.size(); ++i) {
|
||||
unprotectedPromises.get(i).trySuccess(result);
|
||||
}
|
||||
}
|
||||
|
||||
private void tryFailureInternal(Throwable cause) {
|
||||
for (int i = 0; i < unprotectedPromises.size(); ++i) {
|
||||
unprotectedPromises.get(i).tryFailure(cause);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,130 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.grpc.alts.TransportSecurityCommon.RpcProtocolVersions;
|
||||
import io.grpc.alts.TransportSecurityCommon.RpcProtocolVersions.Version;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
/** Utility class for Rpc Protocol Versions. */
|
||||
final class RpcProtocolVersionsUtil {
|
||||
|
||||
private static final int MAX_RPC_VERSION_MAJOR = 2;
|
||||
private static final int MAX_RPC_VERSION_MINOR = 1;
|
||||
private static final int MIN_RPC_VERSION_MAJOR = 2;
|
||||
private static final int MIN_RPC_VERSION_MINOR = 1;
|
||||
private static final RpcProtocolVersions RPC_PROTOCOL_VERSIONS =
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(
|
||||
RpcProtocolVersions.Version.newBuilder()
|
||||
.setMajor(MAX_RPC_VERSION_MAJOR)
|
||||
.setMinor(MAX_RPC_VERSION_MINOR)
|
||||
.build())
|
||||
.setMinRpcVersion(
|
||||
RpcProtocolVersions.Version.newBuilder()
|
||||
.setMajor(MIN_RPC_VERSION_MAJOR)
|
||||
.setMinor(MIN_RPC_VERSION_MINOR)
|
||||
.build())
|
||||
.build();
|
||||
|
||||
/** Returns default Rpc Protocol Versions. */
|
||||
static RpcProtocolVersions getRpcProtocolVersions() {
|
||||
return RPC_PROTOCOL_VERSIONS;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if first Rpc Protocol Version is greater than or equal to the second one. Returns
|
||||
* false otherwise.
|
||||
*/
|
||||
@VisibleForTesting
|
||||
static boolean isGreaterThanOrEqualTo(Version first, Version second) {
|
||||
if ((first.getMajor() > second.getMajor())
|
||||
|| (first.getMajor() == second.getMajor() && first.getMinor() >= second.getMinor())) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs check between local and peer Rpc Protocol Versions. This function returns true and the
|
||||
* highest common version if there exists a common Rpc Protocol Version to use, and returns false
|
||||
* and null otherwise.
|
||||
*/
|
||||
static RpcVersionsCheckResult checkRpcProtocolVersions(
|
||||
RpcProtocolVersions localVersions, RpcProtocolVersions peerVersions) {
|
||||
Version maxCommonVersion;
|
||||
Version minCommonVersion;
|
||||
// maxCommonVersion is MIN(local.max, peer.max)
|
||||
if (isGreaterThanOrEqualTo(localVersions.getMaxRpcVersion(), peerVersions.getMaxRpcVersion())) {
|
||||
maxCommonVersion = peerVersions.getMaxRpcVersion();
|
||||
} else {
|
||||
maxCommonVersion = localVersions.getMaxRpcVersion();
|
||||
}
|
||||
// minCommonVersion is MAX(local.min, peer.min)
|
||||
if (isGreaterThanOrEqualTo(localVersions.getMinRpcVersion(), peerVersions.getMinRpcVersion())) {
|
||||
minCommonVersion = localVersions.getMinRpcVersion();
|
||||
} else {
|
||||
minCommonVersion = peerVersions.getMinRpcVersion();
|
||||
}
|
||||
if (isGreaterThanOrEqualTo(maxCommonVersion, minCommonVersion)) {
|
||||
return new RpcVersionsCheckResult.Builder()
|
||||
.setResult(true)
|
||||
.setHighestCommonVersion(maxCommonVersion)
|
||||
.build();
|
||||
}
|
||||
return new RpcVersionsCheckResult.Builder().setResult(false).build();
|
||||
}
|
||||
|
||||
/** Wrapper class that stores results of Rpc Protocol Versions check. */
|
||||
static final class RpcVersionsCheckResult {
|
||||
private final boolean result;
|
||||
@Nullable private final Version highestCommonVersion;
|
||||
|
||||
private RpcVersionsCheckResult(Builder builder) {
|
||||
result = builder.result;
|
||||
highestCommonVersion = builder.highestCommonVersion;
|
||||
}
|
||||
|
||||
boolean getResult() {
|
||||
return result;
|
||||
}
|
||||
|
||||
Version getHighestCommonVersion() {
|
||||
return highestCommonVersion;
|
||||
}
|
||||
|
||||
static final class Builder {
|
||||
private boolean result;
|
||||
@Nullable private Version highestCommonVersion = null;
|
||||
|
||||
public Builder setResult(boolean result) {
|
||||
this.result = result;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setHighestCommonVersion(Version highestCommonVersion) {
|
||||
this.highestCommonVersion = highestCommonVersion;
|
||||
return this;
|
||||
}
|
||||
|
||||
public RpcVersionsCheckResult build() {
|
||||
return new RpcVersionsCheckResult(this);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.security.GeneralSecurityException;
|
||||
|
||||
/**
|
||||
* {@code AeadCrypter} performs authenticated encryption and decryption for a fixed key given unique
|
||||
* nonces. Authenticated additional data is supported.
|
||||
*/
|
||||
interface AeadCrypter {
|
||||
/**
|
||||
* Encrypt plaintext into ciphertext buffer using the given nonce.
|
||||
*
|
||||
* @param ciphertext the encrypted plaintext and the tag will be written into this buffer.
|
||||
* @param plaintext the input that should be encrypted.
|
||||
* @param nonce the unique nonce used for the encryption.
|
||||
* @throws GeneralSecurityException if ciphertext buffer is short or the nonce does not have the
|
||||
* expected size.
|
||||
*/
|
||||
void encrypt(ByteBuffer ciphertext, ByteBuffer plaintext, byte[] nonce)
|
||||
throws GeneralSecurityException;
|
||||
|
||||
/**
|
||||
* Encrypt plaintext into ciphertext buffer using the given nonce with authenticated data.
|
||||
*
|
||||
* @param ciphertext the encrypted plaintext and the tag will be written into this buffer.
|
||||
* @param plaintext the input that should be encrypted.
|
||||
* @param aad additional data that should be authenticated, but not encrypted.
|
||||
* @param nonce the unique nonce used for the encryption.
|
||||
* @throws GeneralSecurityException if ciphertext buffer is short or the nonce does not have the
|
||||
* expected size.
|
||||
*/
|
||||
void encrypt(ByteBuffer ciphertext, ByteBuffer plaintext, ByteBuffer aad, byte[] nonce)
|
||||
throws GeneralSecurityException;
|
||||
|
||||
/**
|
||||
* Decrypt ciphertext into plaintext buffer using the given nonce.
|
||||
*
|
||||
* @param plaintext the decrypted plaintext will be written into this buffer.
|
||||
* @param ciphertext the ciphertext and tag that should be decrypted.
|
||||
* @param nonce the nonce that was used for the encryption.
|
||||
* @throws GeneralSecurityException if the tag is invalid or any of the inputs do not have the
|
||||
* expected size.
|
||||
*/
|
||||
void decrypt(ByteBuffer plaintext, ByteBuffer ciphertext, byte[] nonce)
|
||||
throws GeneralSecurityException;
|
||||
|
||||
/**
|
||||
* Decrypt ciphertext into plaintext buffer using the given nonce.
|
||||
*
|
||||
* @param plaintext the decrypted plaintext will be written into this buffer.
|
||||
* @param ciphertext the ciphertext and tag that should be decrypted.
|
||||
* @param aad additional data that is checked for authenticity.
|
||||
* @param nonce the nonce that was used for the encryption.
|
||||
* @throws GeneralSecurityException if the tag is invalid or any of the inputs do not have the
|
||||
* expected size.
|
||||
*/
|
||||
void decrypt(ByteBuffer plaintext, ByteBuffer ciphertext, ByteBuffer aad, byte[] nonce)
|
||||
throws GeneralSecurityException;
|
||||
}
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import static com.google.common.base.Preconditions.checkArgument;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.security.GeneralSecurityException;
|
||||
import javax.annotation.Nullable;
|
||||
import javax.crypto.Cipher;
|
||||
import javax.crypto.spec.GCMParameterSpec;
|
||||
import javax.crypto.spec.SecretKeySpec;
|
||||
|
||||
/** AES128-GCM implementation of {@link AeadCrypter} that uses default JCE provider. */
|
||||
final class AesGcmAeadCrypter implements AeadCrypter {
|
||||
private static final int KEY_LENGTH = 16;
|
||||
private static final int TAG_LENGTH = 16;
|
||||
static final int NONCE_LENGTH = 12;
|
||||
|
||||
private static final String AES = "AES";
|
||||
private static final String AES_GCM = AES + "/GCM/NoPadding";
|
||||
|
||||
private final byte[] key;
|
||||
private final Cipher cipher;
|
||||
|
||||
AesGcmAeadCrypter(byte[] key) throws GeneralSecurityException {
|
||||
checkArgument(key.length == KEY_LENGTH);
|
||||
this.key = key;
|
||||
cipher = Cipher.getInstance(AES_GCM);
|
||||
}
|
||||
|
||||
private int encryptAad(
|
||||
ByteBuffer ciphertext, ByteBuffer plaintext, @Nullable ByteBuffer aad, byte[] nonce)
|
||||
throws GeneralSecurityException {
|
||||
checkArgument(nonce.length == NONCE_LENGTH);
|
||||
cipher.init(
|
||||
Cipher.ENCRYPT_MODE,
|
||||
new SecretKeySpec(this.key, AES),
|
||||
new GCMParameterSpec(TAG_LENGTH * 8, nonce));
|
||||
if (aad != null) {
|
||||
cipher.updateAAD(aad);
|
||||
}
|
||||
return cipher.doFinal(plaintext, ciphertext);
|
||||
}
|
||||
|
||||
private void decryptAad(
|
||||
ByteBuffer plaintext, ByteBuffer ciphertext, @Nullable ByteBuffer aad, byte[] nonce)
|
||||
throws GeneralSecurityException {
|
||||
checkArgument(nonce.length == NONCE_LENGTH);
|
||||
cipher.init(
|
||||
Cipher.DECRYPT_MODE,
|
||||
new SecretKeySpec(this.key, AES),
|
||||
new GCMParameterSpec(TAG_LENGTH * 8, nonce));
|
||||
if (aad != null) {
|
||||
cipher.updateAAD(aad);
|
||||
}
|
||||
cipher.doFinal(ciphertext, plaintext);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void encrypt(ByteBuffer ciphertext, ByteBuffer plaintext, byte[] nonce)
|
||||
throws GeneralSecurityException {
|
||||
encryptAad(ciphertext, plaintext, null, nonce);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void encrypt(ByteBuffer ciphertext, ByteBuffer plaintext, ByteBuffer aad, byte[] nonce)
|
||||
throws GeneralSecurityException {
|
||||
encryptAad(ciphertext, plaintext, aad, nonce);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void decrypt(ByteBuffer plaintext, ByteBuffer ciphertext, byte[] nonce)
|
||||
throws GeneralSecurityException {
|
||||
decryptAad(plaintext, ciphertext, null, nonce);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void decrypt(ByteBuffer plaintext, ByteBuffer ciphertext, ByteBuffer aad, byte[] nonce)
|
||||
throws GeneralSecurityException {
|
||||
decryptAad(plaintext, ciphertext, aad, nonce);
|
||||
}
|
||||
|
||||
static int getKeyLength() {
|
||||
return KEY_LENGTH;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,127 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import static com.google.common.base.Preconditions.checkArgument;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.util.Arrays;
|
||||
import javax.crypto.Mac;
|
||||
import javax.crypto.spec.SecretKeySpec;
|
||||
|
||||
/**
|
||||
* {@link AeadCrypter} implementation based on {@link AesGcmAeadCrypter} with nonce-based rekeying
|
||||
* using HKDF-expand and random nonce-mask that is XORed with the given nonce/counter. The AES-GCM
|
||||
* key is computed as HKDF-expand(kdfKey, nonce[2..7]), i.e., the first 2 bytes are ignored to
|
||||
* require rekeying only after 2^16 operations and the last 4 bytes (including the direction bit)
|
||||
* are ignored to allow for optimizations (use same AEAD context for both directions, store counter
|
||||
* as unsigned long and boolean for direction).
|
||||
*/
|
||||
final class AesGcmHkdfAeadCrypter implements AeadCrypter {
|
||||
private static final int KDF_KEY_LENGTH = 32;
|
||||
// Rekey after 2^(2*8) = 2^16 operations by ignoring the first 2 nonce bytes for key derivation.
|
||||
private static final int KDF_COUNTER_OFFSET = 2;
|
||||
// Use remaining bytes of 64-bit counter included in nonce for key derivation.
|
||||
private static final int KDF_COUNTER_LENGTH = 6;
|
||||
private static final int NONCE_LENGTH = AesGcmAeadCrypter.NONCE_LENGTH;
|
||||
private static final int KEY_LENGTH = KDF_KEY_LENGTH + NONCE_LENGTH;
|
||||
|
||||
private final byte[] kdfKey;
|
||||
private final byte[] kdfCounter = new byte[KDF_COUNTER_LENGTH];
|
||||
private final byte[] nonceMask;
|
||||
private final byte[] nonceBuffer = new byte[NONCE_LENGTH];
|
||||
|
||||
private AeadCrypter aeadCrypter;
|
||||
|
||||
AesGcmHkdfAeadCrypter(byte[] key) {
|
||||
checkArgument(key.length == KEY_LENGTH);
|
||||
this.kdfKey = Arrays.copyOf(key, KDF_KEY_LENGTH);
|
||||
this.nonceMask = Arrays.copyOfRange(key, KDF_KEY_LENGTH, KDF_KEY_LENGTH + NONCE_LENGTH);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void encrypt(ByteBuffer ciphertext, ByteBuffer plaintext, byte[] nonce)
|
||||
throws GeneralSecurityException {
|
||||
maybeRekey(nonce);
|
||||
maskNonce(nonceBuffer, nonceMask, nonce);
|
||||
aeadCrypter.encrypt(ciphertext, plaintext, nonceBuffer);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void encrypt(ByteBuffer ciphertext, ByteBuffer plaintext, ByteBuffer aad, byte[] nonce)
|
||||
throws GeneralSecurityException {
|
||||
maybeRekey(nonce);
|
||||
maskNonce(nonceBuffer, nonceMask, nonce);
|
||||
aeadCrypter.encrypt(ciphertext, plaintext, aad, nonceBuffer);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void decrypt(ByteBuffer plaintext, ByteBuffer ciphertext, byte[] nonce)
|
||||
throws GeneralSecurityException {
|
||||
maybeRekey(nonce);
|
||||
maskNonce(nonceBuffer, nonceMask, nonce);
|
||||
aeadCrypter.decrypt(plaintext, ciphertext, nonceBuffer);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void decrypt(ByteBuffer plaintext, ByteBuffer ciphertext, ByteBuffer aad, byte[] nonce)
|
||||
throws GeneralSecurityException {
|
||||
maybeRekey(nonce);
|
||||
maskNonce(nonceBuffer, nonceMask, nonce);
|
||||
aeadCrypter.decrypt(plaintext, ciphertext, aad, nonceBuffer);
|
||||
}
|
||||
|
||||
private void maybeRekey(byte[] nonce) throws GeneralSecurityException {
|
||||
if (aeadCrypter != null
|
||||
&& arrayEqualOn(nonce, KDF_COUNTER_OFFSET, kdfCounter, 0, KDF_COUNTER_LENGTH)) {
|
||||
return;
|
||||
}
|
||||
System.arraycopy(nonce, KDF_COUNTER_OFFSET, kdfCounter, 0, KDF_COUNTER_LENGTH);
|
||||
int aeKeyLen = AesGcmAeadCrypter.getKeyLength();
|
||||
byte[] aeKey = Arrays.copyOf(hkdfExpandSha256(kdfKey, kdfCounter), aeKeyLen);
|
||||
aeadCrypter = new AesGcmAeadCrypter(aeKey);
|
||||
}
|
||||
|
||||
private static void maskNonce(byte[] nonceBuffer, byte[] nonceMask, byte[] nonce) {
|
||||
checkArgument(nonce.length == NONCE_LENGTH);
|
||||
for (int i = 0; i < NONCE_LENGTH; i++) {
|
||||
nonceBuffer[i] = (byte) (nonceMask[i] ^ nonce[i]);
|
||||
}
|
||||
}
|
||||
|
||||
private static byte[] hkdfExpandSha256(byte[] key, byte[] info) throws GeneralSecurityException {
|
||||
Mac mac = Mac.getInstance("HMACSHA256");
|
||||
mac.init(new SecretKeySpec(key, mac.getAlgorithm()));
|
||||
mac.update(info);
|
||||
mac.update((byte) 0x01);
|
||||
return mac.doFinal();
|
||||
}
|
||||
|
||||
private static boolean arrayEqualOn(byte[] a, int aPos, byte[] b, int bPos, int length) {
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (a[aPos + i] != b[bPos + i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static int getKeyLength() {
|
||||
return KEY_LENGTH;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.grpc.alts.Altscontext.AltsContext;
|
||||
import io.grpc.alts.Handshaker.HandshakerResult;
|
||||
import io.grpc.alts.TransportSecurityCommon.RpcProtocolVersions;
|
||||
import io.grpc.alts.TransportSecurityCommon.SecurityLevel;
|
||||
|
||||
/** AltsAuthContext contains security-related context information about an ALTs connection. */
|
||||
public final class AltsAuthContext {
|
||||
final AltsContext context;
|
||||
|
||||
/** Create a new AltsAuthContext. */
|
||||
public AltsAuthContext(HandshakerResult result) {
|
||||
context =
|
||||
AltsContext.newBuilder()
|
||||
.setApplicationProtocol(result.getApplicationProtocol())
|
||||
.setRecordProtocol(result.getRecordProtocol())
|
||||
// TODO: Set security level based on the handshaker result.
|
||||
.setSecurityLevel(SecurityLevel.INTEGRITY_AND_PRIVACY)
|
||||
.setPeerServiceAccount(result.getPeerIdentity().getServiceAccount())
|
||||
.setLocalServiceAccount(result.getLocalIdentity().getServiceAccount())
|
||||
.setPeerRpcVersions(result.getPeerRpcVersions())
|
||||
.build();
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
public static AltsAuthContext getDefaultInstance() {
|
||||
return new AltsAuthContext(HandshakerResult.newBuilder().build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Get application protocol.
|
||||
*
|
||||
* @return the context's application protocol.
|
||||
*/
|
||||
public String getApplicationProtocol() {
|
||||
return context.getApplicationProtocol();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get negotiated record protocol.
|
||||
*
|
||||
* @return the context's negotiated record protocol.
|
||||
*/
|
||||
public String getRecordProtocol() {
|
||||
return context.getRecordProtocol();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get security level.
|
||||
*
|
||||
* @return the context's security level.
|
||||
*/
|
||||
public SecurityLevel getSecurityLevel() {
|
||||
return context.getSecurityLevel();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get peer service account.
|
||||
*
|
||||
* @return the context's peer service account.
|
||||
*/
|
||||
public String getPeerServiceAccount() {
|
||||
return context.getPeerServiceAccount();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get local service account.
|
||||
*
|
||||
* @return the context's local service account.
|
||||
*/
|
||||
public String getLocalServiceAccount() {
|
||||
return context.getLocalServiceAccount();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get peer RPC versions.
|
||||
*
|
||||
* @return the context's peer RPC versions.
|
||||
*/
|
||||
public RpcProtocolVersions getPeerRpcVersions() {
|
||||
return context.getPeerRpcVersions();
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,171 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import static com.google.common.base.Preconditions.checkArgument;
|
||||
import static com.google.common.base.Verify.verify;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.util.List;
|
||||
|
||||
/** Performs encryption and decryption with AES-GCM using JCE. All methods are thread-compatible. */
|
||||
final class AltsChannelCrypter implements ChannelCrypterNetty {
|
||||
private static final int KEY_LENGTH = AesGcmHkdfAeadCrypter.getKeyLength();
|
||||
private static final int COUNTER_LENGTH = 12;
|
||||
// The counter will overflow after 2^64 operations and encryption/decryption will stop working.
|
||||
private static final int COUNTER_OVERFLOW_LENGTH = 8;
|
||||
private static final int TAG_LENGTH = 16;
|
||||
|
||||
private final AeadCrypter aeadCrypter;
|
||||
|
||||
private final byte[] outCounter = new byte[COUNTER_LENGTH];
|
||||
private final byte[] inCounter = new byte[COUNTER_LENGTH];
|
||||
private final byte[] oldCounter = new byte[COUNTER_LENGTH];
|
||||
|
||||
AltsChannelCrypter(byte[] key, boolean isClient) {
|
||||
checkArgument(key.length == KEY_LENGTH);
|
||||
byte[] counter = isClient ? inCounter : outCounter;
|
||||
counter[counter.length - 1] = (byte) 0x80;
|
||||
this.aeadCrypter = new AesGcmHkdfAeadCrypter(key);
|
||||
}
|
||||
|
||||
static int getKeyLength() {
|
||||
return KEY_LENGTH;
|
||||
}
|
||||
|
||||
static int getCounterLength() {
|
||||
return COUNTER_LENGTH;
|
||||
}
|
||||
|
||||
@SuppressWarnings("BetaApi") // verify is stable in Guava
|
||||
@Override
|
||||
public void encrypt(ByteBuf outBuf, List<ByteBuf> plainBufs) throws GeneralSecurityException {
|
||||
checkArgument(outBuf.nioBufferCount() == 1);
|
||||
// Copy plaintext buffers into outBuf for in-place encryption on single direct buffer.
|
||||
ByteBuf plainBuf = outBuf.slice(outBuf.writerIndex(), outBuf.writableBytes());
|
||||
plainBuf.writerIndex(0);
|
||||
for (ByteBuf inBuf : plainBufs) {
|
||||
plainBuf.writeBytes(inBuf);
|
||||
}
|
||||
|
||||
verify(outBuf.writableBytes() == plainBuf.readableBytes() + TAG_LENGTH);
|
||||
ByteBuffer out = outBuf.internalNioBuffer(outBuf.writerIndex(), outBuf.writableBytes());
|
||||
ByteBuffer plain = out.duplicate();
|
||||
plain.limit(out.limit() - TAG_LENGTH);
|
||||
|
||||
byte[] counter = incrementOutCounter();
|
||||
int outPosition = out.position();
|
||||
aeadCrypter.encrypt(out, plain, counter);
|
||||
int bytesWritten = out.position() - outPosition;
|
||||
outBuf.writerIndex(outBuf.writerIndex() + bytesWritten);
|
||||
verify(!outBuf.isWritable());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void decrypt(ByteBuf out, ByteBuf tag, List<ByteBuf> ciphertextBufs)
|
||||
throws GeneralSecurityException {
|
||||
|
||||
ByteBuf cipherTextAndTag = out.slice(out.writerIndex(), out.writableBytes());
|
||||
cipherTextAndTag.writerIndex(0);
|
||||
|
||||
for (ByteBuf inBuf : ciphertextBufs) {
|
||||
cipherTextAndTag.writeBytes(inBuf);
|
||||
}
|
||||
cipherTextAndTag.writeBytes(tag);
|
||||
|
||||
decrypt(out, cipherTextAndTag);
|
||||
}
|
||||
|
||||
@SuppressWarnings("BetaApi") // verify is stable in Guava
|
||||
@Override
|
||||
public void decrypt(ByteBuf out, ByteBuf ciphertextAndTag) throws GeneralSecurityException {
|
||||
int bytesRead = ciphertextAndTag.readableBytes();
|
||||
checkArgument(bytesRead == out.writableBytes());
|
||||
|
||||
checkArgument(out.nioBufferCount() == 1);
|
||||
ByteBuffer outBuffer = out.internalNioBuffer(out.writerIndex(), out.writableBytes());
|
||||
|
||||
checkArgument(ciphertextAndTag.nioBufferCount() == 1);
|
||||
ByteBuffer ciphertextAndTagBuffer =
|
||||
ciphertextAndTag.nioBuffer(ciphertextAndTag.readerIndex(), bytesRead);
|
||||
|
||||
byte[] counter = incrementInCounter();
|
||||
int outPosition = outBuffer.position();
|
||||
aeadCrypter.decrypt(outBuffer, ciphertextAndTagBuffer, counter);
|
||||
int bytesWritten = outBuffer.position() - outPosition;
|
||||
out.writerIndex(out.writerIndex() + bytesWritten);
|
||||
ciphertextAndTag.readerIndex(out.readerIndex() + bytesRead);
|
||||
verify(out.writableBytes() == TAG_LENGTH);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getSuffixLength() {
|
||||
return TAG_LENGTH;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void destroy() {
|
||||
// no destroy required
|
||||
}
|
||||
|
||||
/** Increments {@code counter}, store the unincremented value in {@code oldCounter}. */
|
||||
static void incrementCounter(byte[] counter, byte[] oldCounter) throws GeneralSecurityException {
|
||||
System.arraycopy(counter, 0, oldCounter, 0, counter.length);
|
||||
int i = 0;
|
||||
for (; i < COUNTER_OVERFLOW_LENGTH; i++) {
|
||||
counter[i]++;
|
||||
if (counter[i] != (byte) 0x00) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (i == COUNTER_OVERFLOW_LENGTH) {
|
||||
// Restore old counter value to ensure that encrypt and decrypt keep failing.
|
||||
System.arraycopy(oldCounter, 0, counter, 0, counter.length);
|
||||
throw new GeneralSecurityException("Counter has overflowed.");
|
||||
}
|
||||
}
|
||||
|
||||
/** Increments the input counter, returning the previous (unincremented) value. */
|
||||
private byte[] incrementInCounter() throws GeneralSecurityException {
|
||||
incrementCounter(inCounter, oldCounter);
|
||||
return oldCounter;
|
||||
}
|
||||
|
||||
/** Increments the output counter, returning the previous (unincremented) value. */
|
||||
private byte[] incrementOutCounter() throws GeneralSecurityException {
|
||||
incrementCounter(outCounter, oldCounter);
|
||||
return oldCounter;
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
void incrementInCounterForTesting(int n) throws GeneralSecurityException {
|
||||
for (int i = 0; i < n; i++) {
|
||||
incrementInCounter();
|
||||
}
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
void incrementOutCounterForTesting(int n) throws GeneralSecurityException {
|
||||
for (int i = 0; i < n; i++) {
|
||||
incrementOutCounter();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import io.grpc.alts.TransportSecurityCommon.RpcProtocolVersions;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
/** Handshaker options for creating ALTS client channel. */
|
||||
public final class AltsClientOptions extends AltsHandshakerOptions {
|
||||
// targetName is the server service account name for secure name checking. This field is not yet
|
||||
// supported.
|
||||
@Nullable private final String targetName;
|
||||
// targetServiceAccounts contains a list of expected target service accounts. One of these service
|
||||
// accounts should match peer service account in the handshaker result. Otherwise, the handshake
|
||||
// fails.
|
||||
private final List<String> targetServiceAccounts;
|
||||
|
||||
private AltsClientOptions(Builder builder) {
|
||||
super(builder.rpcProtocolVersions);
|
||||
targetName = builder.targetName;
|
||||
targetServiceAccounts =
|
||||
Collections.unmodifiableList(new ArrayList<String>(builder.targetServiceAccounts));
|
||||
}
|
||||
|
||||
public String getTargetName() {
|
||||
return targetName;
|
||||
}
|
||||
|
||||
public List<String> getTargetServiceAccounts() {
|
||||
return targetServiceAccounts;
|
||||
}
|
||||
|
||||
/** Builder for AltsClientOptions. */
|
||||
public static final class Builder {
|
||||
@Nullable private String targetName;
|
||||
@Nullable private RpcProtocolVersions rpcProtocolVersions;
|
||||
private ArrayList<String> targetServiceAccounts = new ArrayList<String>();
|
||||
|
||||
public Builder setTargetName(String targetName) {
|
||||
this.targetName = targetName;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setRpcProtocolVersions(RpcProtocolVersions rpcProtocolVersions) {
|
||||
this.rpcProtocolVersions = rpcProtocolVersions;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder addTargetServiceAccount(String targetServiceAccount) {
|
||||
targetServiceAccounts.add(targetServiceAccount);
|
||||
return this;
|
||||
}
|
||||
|
||||
public AltsClientOptions build() {
|
||||
return new AltsClientOptions(this);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,365 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.security.GeneralSecurityException;
|
||||
|
||||
/** Framing and deframing methods and classes used by handshaker. */
|
||||
public final class AltsFraming {
|
||||
// The size of the frame field. Must correspond to the size of int, 4 bytes.
|
||||
// Left package-private for testing.
|
||||
private static final int FRAME_LENGTH_HEADER_SIZE = 4;
|
||||
private static final int FRAME_MESSAGE_TYPE_HEADER_SIZE = 4;
|
||||
private static final int MAX_DATA_LENGTH = 1024 * 1024;
|
||||
private static final int INITIAL_BUFFER_CAPACITY = 1024 * 64;
|
||||
|
||||
// TODO: Make this the responsibility of the caller.
|
||||
private static final int MESSAGE_TYPE = 6;
|
||||
|
||||
private AltsFraming() {}
|
||||
|
||||
static int getFrameLengthHeaderSize() {
|
||||
return FRAME_LENGTH_HEADER_SIZE;
|
||||
}
|
||||
|
||||
static int getFrameMessageTypeHeaderSize() {
|
||||
return FRAME_MESSAGE_TYPE_HEADER_SIZE;
|
||||
}
|
||||
|
||||
static int getMaxDataLength() {
|
||||
return MAX_DATA_LENGTH;
|
||||
}
|
||||
|
||||
static int getFramingOverhead() {
|
||||
return FRAME_LENGTH_HEADER_SIZE + FRAME_MESSAGE_TYPE_HEADER_SIZE;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a frame of length dataSize + FRAME_HEADER_SIZE using the input bytes, if dataSize <=
|
||||
* input.remaining(). Otherwise, a frame of length input.remaining() + FRAME_HEADER_SIZE is
|
||||
* created.
|
||||
*/
|
||||
static ByteBuffer toFrame(ByteBuffer input, int dataSize) throws GeneralSecurityException {
|
||||
Preconditions.checkNotNull(input);
|
||||
if (dataSize > input.remaining()) {
|
||||
dataSize = input.remaining();
|
||||
}
|
||||
Producer producer = new Producer();
|
||||
ByteBuffer inputAlias = input.duplicate();
|
||||
inputAlias.limit(input.position() + dataSize);
|
||||
producer.readBytes(inputAlias);
|
||||
producer.flush();
|
||||
input.position(inputAlias.position());
|
||||
ByteBuffer output = producer.getRawFrame();
|
||||
return output;
|
||||
}
|
||||
|
||||
/**
|
||||
* A helper class to write a frame.
|
||||
*
|
||||
* <p>This class guarantees that one of the following is true:
|
||||
*
|
||||
* <ul>
|
||||
* <li>readBytes will read from the input
|
||||
* <li>writeBytes will write to the output
|
||||
* </ul>
|
||||
*
|
||||
* <p>Sample usage:
|
||||
*
|
||||
* <pre>{@code
|
||||
* Producer producer = new Producer();
|
||||
* ByteBuffer inputBuffer = readBytesFromMyStream();
|
||||
* ByteBuffer outputBuffer = writeBytesToMyStream();
|
||||
* while (inputBuffer.hasRemaining() || outputBuffer.hasRemaining()) {
|
||||
* producer.readBytes(inputBuffer);
|
||||
* producer.writeBytes(outputBuffer);
|
||||
* }
|
||||
* }</pre>
|
||||
*
|
||||
* <p>Alternatively, this class guarantees that one of the following is true:
|
||||
*
|
||||
* <ul>
|
||||
* <li>readBytes will read from the input
|
||||
* <li>{@code isComplete()} returns true and {@code getByteBuffer()} returns the contents of a
|
||||
* processed frame.
|
||||
* </ul>
|
||||
*
|
||||
* <p>Sample usage:
|
||||
*
|
||||
* <pre>{@code
|
||||
* Producer producer = new Producer();
|
||||
* while (!producer.isComplete()) {
|
||||
* ByteBuffer inputBuffer = readBytesFromMyStream();
|
||||
* producer.readBytes(inputBuffer);
|
||||
* }
|
||||
* producer.flush();
|
||||
* ByteBuffer outputBuffer = producer.getRawFrame();
|
||||
* }</pre>
|
||||
*/
|
||||
static final class Producer {
|
||||
private ByteBuffer buffer;
|
||||
private boolean isComplete;
|
||||
|
||||
Producer(int maxFrameSize) {
|
||||
buffer = ByteBuffer.allocate(maxFrameSize);
|
||||
reset();
|
||||
Preconditions.checkArgument(maxFrameSize > getFramePrefixLength() + getFrameSuffixLength());
|
||||
}
|
||||
|
||||
Producer() {
|
||||
this(INITIAL_BUFFER_CAPACITY);
|
||||
}
|
||||
|
||||
/** The length of the frame prefix data, including the message length/type fields. */
|
||||
int getFramePrefixLength() {
|
||||
int result = FRAME_LENGTH_HEADER_SIZE + FRAME_MESSAGE_TYPE_HEADER_SIZE;
|
||||
return result;
|
||||
}
|
||||
|
||||
int getFrameSuffixLength() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads bytes from input, parsing them into a frame. Returns false if and only if more data is
|
||||
* needed. To obtain a full frame this method must be called repeatedly until it returns true.
|
||||
*/
|
||||
boolean readBytes(ByteBuffer input) throws GeneralSecurityException {
|
||||
Preconditions.checkNotNull(input);
|
||||
if (isComplete) {
|
||||
return true;
|
||||
}
|
||||
copy(buffer, input);
|
||||
if (!buffer.hasRemaining()) {
|
||||
flush();
|
||||
}
|
||||
return isComplete;
|
||||
}
|
||||
|
||||
/**
|
||||
* Completes the current frame, signaling that no further data is available to be passed to
|
||||
* readBytes and that the client requires writeBytes to start returning data. isComplete() is
|
||||
* guaranteed to return true after this call.
|
||||
*/
|
||||
void flush() throws GeneralSecurityException {
|
||||
if (isComplete) {
|
||||
return;
|
||||
}
|
||||
// Get the length of the complete frame.
|
||||
int frameLength = buffer.position() + getFrameSuffixLength();
|
||||
|
||||
// Set the limit and move to the start.
|
||||
buffer.flip();
|
||||
|
||||
// Advance the limit to allow a crypto suffix.
|
||||
buffer.limit(buffer.limit() + getFrameSuffixLength());
|
||||
|
||||
// Write the data length and the message type.
|
||||
int dataLength = frameLength - FRAME_LENGTH_HEADER_SIZE;
|
||||
buffer.order(ByteOrder.LITTLE_ENDIAN);
|
||||
buffer.putInt(dataLength);
|
||||
buffer.putInt(MESSAGE_TYPE);
|
||||
|
||||
// Move the position back to 0, the frame is ready.
|
||||
buffer.position(0);
|
||||
isComplete = true;
|
||||
}
|
||||
|
||||
/** Resets the state, preparing to construct a new frame. Must be called between frames. */
|
||||
private void reset() {
|
||||
buffer.clear();
|
||||
|
||||
// Save some space for framing, we'll fill that in later.
|
||||
buffer.position(getFramePrefixLength());
|
||||
buffer.limit(buffer.limit() - getFrameSuffixLength());
|
||||
|
||||
isComplete = false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a ByteBuffer containing a complete raw frame, if it's available. Should only be
|
||||
* called when isComplete() returns true, otherwise null is returned. The returned object
|
||||
* aliases the internal buffer, that is, it shares memory with the internal buffer. No further
|
||||
* operations are permitted on this object until the caller has processed the data it needs from
|
||||
* the returned byte buffer.
|
||||
*/
|
||||
ByteBuffer getRawFrame() {
|
||||
if (!isComplete) {
|
||||
return null;
|
||||
}
|
||||
ByteBuffer result = buffer.duplicate();
|
||||
reset();
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A helper class to read a frame.
|
||||
*
|
||||
* <p>This class guarantees that one of the following is true:
|
||||
*
|
||||
* <ul>
|
||||
* <li>readBytes will read from the input
|
||||
* <li>writeBytes will write to the output
|
||||
* </ul>
|
||||
*
|
||||
* <p>Sample usage:
|
||||
*
|
||||
* <pre>{@code
|
||||
* Parser parser = new Parser();
|
||||
* ByteBuffer inputBuffer = readBytesFromMyStream();
|
||||
* ByteBuffer outputBuffer = writeBytesToMyStream();
|
||||
* while (inputBuffer.hasRemaining() || outputBuffer.hasRemaining()) {
|
||||
* parser.readBytes(inputBuffer);
|
||||
* parser.writeBytes(outputBuffer); }
|
||||
* }</pre>
|
||||
*
|
||||
* <p>Alternatively, this class guarantees that one of the following is true:
|
||||
*
|
||||
* <ul>
|
||||
* <li>readBytes will read from the input
|
||||
* <li>{@code isComplete()} returns true and {@code getByteBuffer()} returns the contents of a
|
||||
* processed frame.
|
||||
* </ul>
|
||||
*
|
||||
* <p>Sample usage:
|
||||
*
|
||||
* <pre>{@code
|
||||
* Parser parser = new Parser();
|
||||
* while (!parser.isComplete()) {
|
||||
* ByteBuffer inputBuffer = readBytesFromMyStream();
|
||||
* parser.readBytes(inputBuffer);
|
||||
* }
|
||||
* ByteBuffer outputBuffer = parser.getRawFrame();
|
||||
* }</pre>
|
||||
*/
|
||||
public static final class Parser {
|
||||
private ByteBuffer buffer = ByteBuffer.allocate(INITIAL_BUFFER_CAPACITY);
|
||||
private boolean isComplete = false;
|
||||
|
||||
public Parser() {
|
||||
Preconditions.checkArgument(
|
||||
INITIAL_BUFFER_CAPACITY > getFramePrefixLength() + getFrameSuffixLength());
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads bytes from input, parsing them into a frame. Returns false if and only if more data is
|
||||
* needed. To obtain a full frame this method must be called repeatedly until it returns true.
|
||||
*/
|
||||
public boolean readBytes(ByteBuffer input) throws GeneralSecurityException {
|
||||
Preconditions.checkNotNull(input);
|
||||
|
||||
if (isComplete) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Read enough bytes to determine the length
|
||||
while (buffer.position() < FRAME_LENGTH_HEADER_SIZE && input.hasRemaining()) {
|
||||
buffer.put(input.get());
|
||||
}
|
||||
|
||||
// If we have enough bytes to determine the length, read the length and ensure that our
|
||||
// internal buffer is large enough.
|
||||
if (buffer.position() == FRAME_LENGTH_HEADER_SIZE && input.hasRemaining()) {
|
||||
ByteBuffer bufferAlias = buffer.duplicate();
|
||||
bufferAlias.flip();
|
||||
bufferAlias.order(ByteOrder.LITTLE_ENDIAN);
|
||||
int dataLength = bufferAlias.getInt();
|
||||
if (dataLength < FRAME_MESSAGE_TYPE_HEADER_SIZE || dataLength > MAX_DATA_LENGTH) {
|
||||
throw new IllegalArgumentException("Invalid frame length " + dataLength);
|
||||
}
|
||||
// Maybe resize the buffer
|
||||
int frameLength = dataLength + FRAME_LENGTH_HEADER_SIZE;
|
||||
if (buffer.capacity() < frameLength) {
|
||||
buffer = ByteBuffer.allocate(frameLength);
|
||||
buffer.order(ByteOrder.LITTLE_ENDIAN);
|
||||
buffer.putInt(dataLength);
|
||||
}
|
||||
buffer.limit(frameLength);
|
||||
}
|
||||
|
||||
// TODO: Similarly extract and check message type.
|
||||
|
||||
// Read the remaining data into the internal buffer.
|
||||
copy(buffer, input);
|
||||
if (!buffer.hasRemaining()) {
|
||||
buffer.flip();
|
||||
isComplete = true;
|
||||
}
|
||||
return isComplete;
|
||||
}
|
||||
|
||||
/** The length of the frame prefix data, including the message length/type fields. */
|
||||
int getFramePrefixLength() {
|
||||
int result = FRAME_LENGTH_HEADER_SIZE + FRAME_MESSAGE_TYPE_HEADER_SIZE;
|
||||
return result;
|
||||
}
|
||||
|
||||
int getFrameSuffixLength() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
/** Returns true if we've parsed a complete frame. */
|
||||
public boolean isComplete() {
|
||||
return isComplete;
|
||||
}
|
||||
|
||||
/** Resets the state, preparing to parse a new frame. Must be called between frames. */
|
||||
private void reset() {
|
||||
buffer.clear();
|
||||
isComplete = false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a ByteBuffer containing a complete raw frame, if it's available. Should only be
|
||||
* called when isComplete() returns true, otherwise null is returned. The returned object
|
||||
* aliases the internal buffer, that is, it shares memory with the internal buffer. No further
|
||||
* operations are permitted on this object until the caller has processed the data it needs from
|
||||
* the returned byte buffer.
|
||||
*/
|
||||
public ByteBuffer getRawFrame() {
|
||||
if (!isComplete) {
|
||||
return null;
|
||||
}
|
||||
ByteBuffer result = buffer.duplicate();
|
||||
reset();
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Copy as much as possible to dst from src. Unlike {@link ByteBuffer#put(ByteBuffer)}, this stops
|
||||
* early if there is no room left in dst.
|
||||
*/
|
||||
private static void copy(ByteBuffer dst, ByteBuffer src) {
|
||||
if (dst.hasRemaining() && src.hasRemaining()) {
|
||||
// Avoid an allocation if possible.
|
||||
if (dst.remaining() >= src.remaining()) {
|
||||
dst.put(src);
|
||||
} else {
|
||||
int count = Math.min(dst.remaining(), src.remaining());
|
||||
ByteBuffer slice = src.slice();
|
||||
slice.limit(count);
|
||||
dst.put(slice);
|
||||
src.position(src.position() + count);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,245 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.base.Strings;
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.alts.Handshaker.HandshakeProtocol;
|
||||
import io.grpc.alts.Handshaker.HandshakerReq;
|
||||
import io.grpc.alts.Handshaker.HandshakerResp;
|
||||
import io.grpc.alts.Handshaker.HandshakerResult;
|
||||
import io.grpc.alts.Handshaker.HandshakerStatus;
|
||||
import io.grpc.alts.Handshaker.NextHandshakeMessageReq;
|
||||
import io.grpc.alts.Handshaker.ServerHandshakeParameters;
|
||||
import io.grpc.alts.Handshaker.StartClientHandshakeReq;
|
||||
import io.grpc.alts.Handshaker.StartServerHandshakeReq;
|
||||
import io.grpc.alts.HandshakerServiceGrpc.HandshakerServiceStub;
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.util.logging.Level;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/** An API for conducting handshakes via ALTS handshaker service. */
|
||||
class AltsHandshakerClient {
|
||||
private static final Logger logger = Logger.getLogger(AltsHandshakerClient.class.getName());
|
||||
|
||||
private static final String APPLICATION_PROTOCOL = "grpc";
|
||||
private static final String RECORD_PROTOCOL = "ALTSRP_GCM_AES128_REKEY";
|
||||
private static final int KEY_LENGTH = AltsChannelCrypter.getKeyLength();
|
||||
|
||||
private final AltsHandshakerStub handshakerStub;
|
||||
private final AltsHandshakerOptions handshakerOptions;
|
||||
private HandshakerResult result;
|
||||
private HandshakerStatus status;
|
||||
|
||||
/** Starts a new handshake interacting with the handshaker service. */
|
||||
AltsHandshakerClient(HandshakerServiceStub stub, AltsHandshakerOptions options) {
|
||||
handshakerStub = new AltsHandshakerStub(stub);
|
||||
handshakerOptions = options;
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
AltsHandshakerClient(AltsHandshakerStub handshakerStub, AltsHandshakerOptions options) {
|
||||
this.handshakerStub = handshakerStub;
|
||||
handshakerOptions = options;
|
||||
}
|
||||
|
||||
static String getApplicationProtocol() {
|
||||
return APPLICATION_PROTOCOL;
|
||||
}
|
||||
|
||||
static String getRecordProtocol() {
|
||||
return RECORD_PROTOCOL;
|
||||
}
|
||||
|
||||
/** Sets the start client fields for the passed handshake request. */
|
||||
private void setStartClientFields(HandshakerReq.Builder req) {
|
||||
// Sets the default values.
|
||||
StartClientHandshakeReq.Builder startClientReq =
|
||||
StartClientHandshakeReq.newBuilder()
|
||||
.setHandshakeSecurityProtocol(HandshakeProtocol.ALTS)
|
||||
.addApplicationProtocols(APPLICATION_PROTOCOL)
|
||||
.addRecordProtocols(RECORD_PROTOCOL);
|
||||
// Sets handshaker options.
|
||||
if (handshakerOptions.getRpcProtocolVersions() != null) {
|
||||
startClientReq.setRpcVersions(handshakerOptions.getRpcProtocolVersions());
|
||||
}
|
||||
if (handshakerOptions instanceof AltsClientOptions) {
|
||||
AltsClientOptions clientOptions = (AltsClientOptions) handshakerOptions;
|
||||
if (!Strings.isNullOrEmpty(clientOptions.getTargetName())) {
|
||||
startClientReq.setTargetName(clientOptions.getTargetName());
|
||||
}
|
||||
for (String serviceAccount : clientOptions.getTargetServiceAccounts()) {
|
||||
startClientReq.addTargetIdentitiesBuilder().setServiceAccount(serviceAccount);
|
||||
}
|
||||
}
|
||||
req.setClientStart(startClientReq);
|
||||
}
|
||||
|
||||
/** Sets the start server fields for the passed handshake request. */
|
||||
private void setStartServerFields(HandshakerReq.Builder req, ByteBuffer inBytes) {
|
||||
ServerHandshakeParameters serverParameters =
|
||||
ServerHandshakeParameters.newBuilder().addRecordProtocols(RECORD_PROTOCOL).build();
|
||||
StartServerHandshakeReq.Builder startServerReq =
|
||||
StartServerHandshakeReq.newBuilder()
|
||||
.addApplicationProtocols(APPLICATION_PROTOCOL)
|
||||
.putHandshakeParameters(HandshakeProtocol.ALTS.getNumber(), serverParameters)
|
||||
.setInBytes(ByteString.copyFrom(inBytes.duplicate()));
|
||||
if (handshakerOptions.getRpcProtocolVersions() != null) {
|
||||
startServerReq.setRpcVersions(handshakerOptions.getRpcProtocolVersions());
|
||||
}
|
||||
req.setServerStart(startServerReq);
|
||||
}
|
||||
|
||||
/** Returns true if the handshake is complete. */
|
||||
public boolean isFinished() {
|
||||
// If we have a HandshakeResult, we are done.
|
||||
if (result != null) {
|
||||
return true;
|
||||
}
|
||||
// If we have an error status, we are done.
|
||||
if (status != null && status.getCode() != Status.Code.OK.value()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/** Returns the handshake status. */
|
||||
public HandshakerStatus getStatus() {
|
||||
return status;
|
||||
}
|
||||
|
||||
/** Returns the result data of the handshake, if the handshake is completed. */
|
||||
public HandshakerResult getResult() {
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the resulting key of the handshake, if the handshake is completed. Note that the key
|
||||
* data returned from the handshake may be more than the key length required for the record
|
||||
* protocol, thus we need to truncate to the right size.
|
||||
*/
|
||||
public byte[] getKey() {
|
||||
if (result == null) {
|
||||
return null;
|
||||
}
|
||||
if (result.getKeyData().size() < KEY_LENGTH) {
|
||||
throw new IllegalStateException("Could not get enough key data from the handshake.");
|
||||
}
|
||||
byte[] key = new byte[KEY_LENGTH];
|
||||
result.getKeyData().copyTo(key, 0, 0, KEY_LENGTH);
|
||||
return key;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses a handshake response, setting the status, result, and closing the handshaker, as needed.
|
||||
*/
|
||||
private void handleResponse(HandshakerResp resp) throws GeneralSecurityException {
|
||||
status = resp.getStatus();
|
||||
if (resp.hasResult()) {
|
||||
result = resp.getResult();
|
||||
close();
|
||||
}
|
||||
if (status.getCode() != Status.Code.OK.value()) {
|
||||
String error = "Handshaker service error: " + status.getDetails();
|
||||
logger.log(Level.INFO, error);
|
||||
close();
|
||||
throw new GeneralSecurityException(error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Starts a client handshake. A GeneralSecurityException is thrown if the handshaker service is
|
||||
* interrupted or fails. Note that isFinished() must be false before this function is called.
|
||||
*
|
||||
* @return the frame to give to the peer.
|
||||
* @throws GeneralSecurityException or IllegalStateException
|
||||
*/
|
||||
public ByteBuffer startClientHandshake() throws GeneralSecurityException {
|
||||
Preconditions.checkState(!isFinished(), "Handshake has already finished.");
|
||||
HandshakerReq.Builder req = HandshakerReq.newBuilder();
|
||||
setStartClientFields(req);
|
||||
HandshakerResp resp;
|
||||
try {
|
||||
resp = handshakerStub.send(req.build());
|
||||
} catch (IOException | InterruptedException e) {
|
||||
throw new GeneralSecurityException(e);
|
||||
}
|
||||
handleResponse(resp);
|
||||
return resp.getOutFrames().asReadOnlyByteBuffer();
|
||||
}
|
||||
|
||||
/**
|
||||
* Starts a server handshake. A GeneralSecurityException is thrown if the handshaker service is
|
||||
* interrupted or fails. Note that isFinished() must be false before this function is called.
|
||||
*
|
||||
* @param inBytes the bytes received from the peer.
|
||||
* @return the frame to give to the peer.
|
||||
* @throws GeneralSecurityException or IllegalStateException
|
||||
*/
|
||||
public ByteBuffer startServerHandshake(ByteBuffer inBytes) throws GeneralSecurityException {
|
||||
Preconditions.checkState(!isFinished(), "Handshake has already finished.");
|
||||
HandshakerReq.Builder req = HandshakerReq.newBuilder();
|
||||
setStartServerFields(req, inBytes);
|
||||
HandshakerResp resp;
|
||||
try {
|
||||
resp = handshakerStub.send(req.build());
|
||||
} catch (IOException | InterruptedException e) {
|
||||
throw new GeneralSecurityException(e);
|
||||
}
|
||||
handleResponse(resp);
|
||||
inBytes.position(inBytes.position() + resp.getBytesConsumed());
|
||||
return resp.getOutFrames().asReadOnlyByteBuffer();
|
||||
}
|
||||
|
||||
/**
|
||||
* Processes the next bytes in a handshake. A GeneralSecurityException is thrown if the handshaker
|
||||
* service is interrupted or fails. Note that isFinished() must be false before this function is
|
||||
* called.
|
||||
*
|
||||
* @param inBytes the bytes received from the peer.
|
||||
* @return the frame to give to the peer.
|
||||
* @throws GeneralSecurityException or IllegalStateException
|
||||
*/
|
||||
public ByteBuffer next(ByteBuffer inBytes) throws GeneralSecurityException {
|
||||
Preconditions.checkState(!isFinished(), "Handshake has already finished.");
|
||||
HandshakerReq.Builder req =
|
||||
HandshakerReq.newBuilder()
|
||||
.setNext(
|
||||
NextHandshakeMessageReq.newBuilder()
|
||||
.setInBytes(ByteString.copyFrom(inBytes.duplicate()))
|
||||
.build());
|
||||
HandshakerResp resp;
|
||||
try {
|
||||
resp = handshakerStub.send(req.build());
|
||||
} catch (IOException | InterruptedException e) {
|
||||
throw new GeneralSecurityException(e);
|
||||
}
|
||||
handleResponse(resp);
|
||||
inBytes.position(inBytes.position() + resp.getBytesConsumed());
|
||||
return resp.getOutFrames().asReadOnlyByteBuffer();
|
||||
}
|
||||
|
||||
/** Closes the connection. */
|
||||
public void close() {
|
||||
handshakerStub.close();
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import io.grpc.alts.TransportSecurityCommon.RpcProtocolVersions;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
/** Handshaker options for creating ALTS channel. */
|
||||
public class AltsHandshakerOptions {
|
||||
@Nullable private final RpcProtocolVersions rpcProtocolVersions;
|
||||
|
||||
public AltsHandshakerOptions(RpcProtocolVersions rpcProtocolVersions) {
|
||||
this.rpcProtocolVersions = rpcProtocolVersions;
|
||||
}
|
||||
|
||||
public RpcProtocolVersions getRpcProtocolVersions() {
|
||||
return rpcProtocolVersions;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,114 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.base.Optional;
|
||||
import io.grpc.alts.Handshaker.HandshakerReq;
|
||||
import io.grpc.alts.Handshaker.HandshakerResp;
|
||||
import io.grpc.alts.HandshakerServiceGrpc.HandshakerServiceStub;
|
||||
import io.grpc.stub.StreamObserver;
|
||||
import java.io.IOException;
|
||||
import java.util.concurrent.ArrayBlockingQueue;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
/** An interface to the ALTS handshaker service. */
|
||||
class AltsHandshakerStub {
|
||||
private final StreamObserver<HandshakerResp> reader = new Reader();
|
||||
private final StreamObserver<HandshakerReq> writer;
|
||||
private final ArrayBlockingQueue<Optional<HandshakerResp>> responseQueue =
|
||||
new ArrayBlockingQueue<Optional<HandshakerResp>>(1);
|
||||
private final AtomicReference<String> exceptionMessage = new AtomicReference<>();
|
||||
|
||||
AltsHandshakerStub(HandshakerServiceStub serviceStub) {
|
||||
this.writer = serviceStub.doHandshake(this.reader);
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
AltsHandshakerStub() {
|
||||
writer = null;
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
AltsHandshakerStub(StreamObserver<HandshakerReq> writer) {
|
||||
this.writer = writer;
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
StreamObserver<HandshakerResp> getReaderForTest() {
|
||||
return reader;
|
||||
}
|
||||
|
||||
/** Send a handshaker request and return the handshaker response. */
|
||||
public HandshakerResp send(HandshakerReq req) throws InterruptedException, IOException {
|
||||
maybeThrowIoException();
|
||||
if (!responseQueue.isEmpty()) {
|
||||
throw new IOException("Received an unexpected response.");
|
||||
}
|
||||
writer.onNext(req);
|
||||
Optional<HandshakerResp> result = responseQueue.take();
|
||||
if (!result.isPresent()) {
|
||||
maybeThrowIoException();
|
||||
}
|
||||
return result.get();
|
||||
}
|
||||
|
||||
/** Throw exception if there is an outstanding exception. */
|
||||
private void maybeThrowIoException() throws IOException {
|
||||
if (exceptionMessage.get() != null) {
|
||||
throw new IOException(exceptionMessage.get());
|
||||
}
|
||||
}
|
||||
|
||||
/** Close the connection. */
|
||||
public void close() {
|
||||
writer.onCompleted();
|
||||
}
|
||||
|
||||
private class Reader implements StreamObserver<HandshakerResp> {
|
||||
/** Receive a handshaker response from the server. */
|
||||
@Override
|
||||
public void onNext(HandshakerResp resp) {
|
||||
try {
|
||||
AltsHandshakerStub.this.responseQueue.add(Optional.of(resp));
|
||||
} catch (IllegalStateException e) {
|
||||
AltsHandshakerStub.this.exceptionMessage.compareAndSet(
|
||||
null, "Received an unexpected response.");
|
||||
AltsHandshakerStub.this.close();
|
||||
}
|
||||
}
|
||||
|
||||
/** Receive an error from the server. */
|
||||
@Override
|
||||
public void onError(Throwable t) {
|
||||
AltsHandshakerStub.this.exceptionMessage.compareAndSet(
|
||||
null, "Received a terminating error: " + t.toString());
|
||||
// Trigger the release of any blocked send.
|
||||
Optional<HandshakerResp> result = Optional.absent();
|
||||
AltsHandshakerStub.this.responseQueue.offer(result);
|
||||
}
|
||||
|
||||
/** Receive the closing message from the server. */
|
||||
@Override
|
||||
public void onCompleted() {
|
||||
AltsHandshakerStub.this.exceptionMessage.compareAndSet(null, "Response stream closed.");
|
||||
// Trigger the release of any blocked send.
|
||||
Optional<HandshakerResp> result = Optional.absent();
|
||||
AltsHandshakerStub.this.responseQueue.offer(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,404 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import static com.google.common.base.Preconditions.checkArgument;
|
||||
import static com.google.common.base.Preconditions.checkState;
|
||||
import static com.google.common.base.Verify.verify;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufAllocator;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
/** Frame protector that uses the ALTS framing. */
|
||||
public final class AltsTsiFrameProtector implements TsiFrameProtector {
|
||||
private static final int HEADER_LEN_FIELD_BYTES = 4;
|
||||
private static final int HEADER_TYPE_FIELD_BYTES = 4;
|
||||
private static final int HEADER_BYTES = HEADER_LEN_FIELD_BYTES + HEADER_TYPE_FIELD_BYTES;
|
||||
private static final int HEADER_TYPE_DEFAULT = 6;
|
||||
// Total frame size including full header and tag.
|
||||
private static final int MAX_ALLOWED_FRAME_BYTES = 16 * 1024;
|
||||
private static final int LIMIT_MAX_ALLOWED_FRAME_BYTES = 1024 * 1024;
|
||||
|
||||
private final Protector protector;
|
||||
private final Unprotector unprotector;
|
||||
|
||||
/** Create a new AltsTsiFrameProtector. */
|
||||
public AltsTsiFrameProtector(
|
||||
int maxProtectedFrameBytes, ChannelCrypterNetty crypter, ByteBufAllocator alloc) {
|
||||
checkArgument(maxProtectedFrameBytes > HEADER_BYTES + crypter.getSuffixLength());
|
||||
maxProtectedFrameBytes = Math.min(LIMIT_MAX_ALLOWED_FRAME_BYTES, maxProtectedFrameBytes);
|
||||
protector = new Protector(maxProtectedFrameBytes, crypter);
|
||||
unprotector = new Unprotector(crypter, alloc);
|
||||
}
|
||||
|
||||
static int getHeaderLenFieldBytes() {
|
||||
return HEADER_LEN_FIELD_BYTES;
|
||||
}
|
||||
|
||||
static int getHeaderTypeFieldBytes() {
|
||||
return HEADER_TYPE_FIELD_BYTES;
|
||||
}
|
||||
|
||||
public static int getHeaderBytes() {
|
||||
return HEADER_BYTES;
|
||||
}
|
||||
|
||||
static int getHeaderTypeDefault() {
|
||||
return HEADER_TYPE_DEFAULT;
|
||||
}
|
||||
|
||||
public static int getMaxAllowedFrameBytes() {
|
||||
return MAX_ALLOWED_FRAME_BYTES;
|
||||
}
|
||||
|
||||
static int getLimitMaxAllowedFrameBytes() {
|
||||
return LIMIT_MAX_ALLOWED_FRAME_BYTES;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void protectFlush(
|
||||
List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc)
|
||||
throws GeneralSecurityException {
|
||||
protector.protectFlush(unprotectedBufs, ctxWrite, alloc);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)
|
||||
throws GeneralSecurityException {
|
||||
unprotector.unprotect(in, out, alloc);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void destroy() {
|
||||
try {
|
||||
unprotector.destroy();
|
||||
} finally {
|
||||
protector.destroy();
|
||||
}
|
||||
}
|
||||
|
||||
static final class Protector {
|
||||
private final int maxUnprotectedBytesPerFrame;
|
||||
private final int suffixBytes;
|
||||
private ChannelCrypterNetty crypter;
|
||||
|
||||
Protector(int maxProtectedFrameBytes, ChannelCrypterNetty crypter) {
|
||||
this.suffixBytes = crypter.getSuffixLength();
|
||||
this.maxUnprotectedBytesPerFrame = maxProtectedFrameBytes - HEADER_BYTES - suffixBytes;
|
||||
this.crypter = crypter;
|
||||
}
|
||||
|
||||
void destroy() {
|
||||
// Shared with Unprotector and destroyed there.
|
||||
crypter = null;
|
||||
}
|
||||
|
||||
void protectFlush(
|
||||
List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc)
|
||||
throws GeneralSecurityException {
|
||||
checkState(crypter != null, "Cannot protectFlush after destroy.");
|
||||
ByteBuf protectedBuf;
|
||||
try {
|
||||
protectedBuf = handleUnprotected(unprotectedBufs, alloc);
|
||||
} finally {
|
||||
for (ByteBuf buf : unprotectedBufs) {
|
||||
buf.release();
|
||||
}
|
||||
}
|
||||
if (protectedBuf != null) {
|
||||
ctxWrite.accept(protectedBuf);
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings("BetaApi") // verify is stable in Guava
|
||||
private ByteBuf handleUnprotected(List<ByteBuf> unprotectedBufs, ByteBufAllocator alloc)
|
||||
throws GeneralSecurityException {
|
||||
long unprotectedBytes = 0;
|
||||
for (ByteBuf buf : unprotectedBufs) {
|
||||
unprotectedBytes += buf.readableBytes();
|
||||
}
|
||||
// Empty plaintext not allowed since this should be handled as no-op in layer above.
|
||||
checkArgument(unprotectedBytes > 0);
|
||||
|
||||
// Compute number of frames and allocate a single buffer for all frames.
|
||||
long frameNum = unprotectedBytes / maxUnprotectedBytesPerFrame + 1;
|
||||
int lastFrameUnprotectedBytes = (int) (unprotectedBytes % maxUnprotectedBytesPerFrame);
|
||||
if (lastFrameUnprotectedBytes == 0) {
|
||||
frameNum--;
|
||||
lastFrameUnprotectedBytes = maxUnprotectedBytesPerFrame;
|
||||
}
|
||||
long protectedBytes = frameNum * (HEADER_BYTES + suffixBytes) + unprotectedBytes;
|
||||
|
||||
ByteBuf protectedBuf = alloc.directBuffer(Math.toIntExact(protectedBytes));
|
||||
try {
|
||||
int bufferIdx = 0;
|
||||
for (int frameIdx = 0; frameIdx < frameNum; ++frameIdx) {
|
||||
int unprotectedBytesLeft =
|
||||
(frameIdx == frameNum - 1) ? lastFrameUnprotectedBytes : maxUnprotectedBytesPerFrame;
|
||||
// Write header (at most LIMIT_MAX_ALLOWED_FRAME_BYTES).
|
||||
protectedBuf.writeIntLE(unprotectedBytesLeft + HEADER_TYPE_FIELD_BYTES + suffixBytes);
|
||||
protectedBuf.writeIntLE(HEADER_TYPE_DEFAULT);
|
||||
|
||||
// Ownership of the backing buffer remains with protectedBuf.
|
||||
ByteBuf frameOut = writeSlice(protectedBuf, unprotectedBytesLeft + suffixBytes);
|
||||
List<ByteBuf> framePlain = new ArrayList<>();
|
||||
while (unprotectedBytesLeft > 0) {
|
||||
// Ownership of the buffer backing in remains with unprotectedBufs.
|
||||
ByteBuf in = unprotectedBufs.get(bufferIdx);
|
||||
if (in.readableBytes() <= unprotectedBytesLeft) {
|
||||
// The complete buffer belongs to this frame.
|
||||
framePlain.add(in);
|
||||
unprotectedBytesLeft -= in.readableBytes();
|
||||
bufferIdx++;
|
||||
} else {
|
||||
// The remainder of in will be part of the next frame.
|
||||
framePlain.add(in.readSlice(unprotectedBytesLeft));
|
||||
unprotectedBytesLeft = 0;
|
||||
}
|
||||
}
|
||||
crypter.encrypt(frameOut, framePlain);
|
||||
verify(!frameOut.isWritable());
|
||||
}
|
||||
protectedBuf.readerIndex(0);
|
||||
protectedBuf.writerIndex(protectedBuf.capacity());
|
||||
return protectedBuf.retain();
|
||||
} finally {
|
||||
protectedBuf.release();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static final class Unprotector {
|
||||
private final int suffixBytes;
|
||||
private final ChannelCrypterNetty crypter;
|
||||
|
||||
private DeframerState state = DeframerState.READ_HEADER;
|
||||
private int requiredProtectedBytes;
|
||||
private ByteBuf header;
|
||||
private ByteBuf firstFrameTag;
|
||||
private int unhandledIdx = 0;
|
||||
private long unhandledBytes = 0;
|
||||
private List<ByteBuf> unhandledBufs = new ArrayList<>(16);
|
||||
|
||||
Unprotector(ChannelCrypterNetty crypter, ByteBufAllocator alloc) {
|
||||
this.crypter = crypter;
|
||||
this.suffixBytes = crypter.getSuffixLength();
|
||||
this.header = alloc.directBuffer(HEADER_BYTES);
|
||||
this.firstFrameTag = alloc.directBuffer(suffixBytes);
|
||||
}
|
||||
|
||||
private void addUnhandled(ByteBuf in) {
|
||||
if (in.isReadable()) {
|
||||
ByteBuf buf = in.readRetainedSlice(in.readableBytes());
|
||||
unhandledBufs.add(buf);
|
||||
unhandledBytes += buf.readableBytes();
|
||||
}
|
||||
}
|
||||
|
||||
void unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)
|
||||
throws GeneralSecurityException {
|
||||
checkState(header != null, "Cannot unprotect after destroy.");
|
||||
addUnhandled(in);
|
||||
decodeFrame(alloc, out);
|
||||
}
|
||||
|
||||
@SuppressWarnings("fallthrough")
|
||||
private void decodeFrame(ByteBufAllocator alloc, List<Object> out)
|
||||
throws GeneralSecurityException {
|
||||
switch (state) {
|
||||
case READ_HEADER:
|
||||
if (unhandledBytes < HEADER_BYTES) {
|
||||
return;
|
||||
}
|
||||
handleHeader();
|
||||
// fall through
|
||||
case READ_PROTECTED_PAYLOAD:
|
||||
if (unhandledBytes < requiredProtectedBytes) {
|
||||
return;
|
||||
}
|
||||
ByteBuf unprotectedBuf;
|
||||
try {
|
||||
unprotectedBuf = handlePayload(alloc);
|
||||
} finally {
|
||||
clearState();
|
||||
}
|
||||
if (unprotectedBuf != null) {
|
||||
out.add(unprotectedBuf);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
throw new AssertionError("impossible enum value");
|
||||
}
|
||||
}
|
||||
|
||||
private void handleHeader() {
|
||||
while (header.isWritable()) {
|
||||
ByteBuf in = unhandledBufs.get(unhandledIdx);
|
||||
int headerBytesToRead = Math.min(in.readableBytes(), header.writableBytes());
|
||||
header.writeBytes(in, headerBytesToRead);
|
||||
unhandledBytes -= headerBytesToRead;
|
||||
if (!in.isReadable()) {
|
||||
unhandledIdx++;
|
||||
}
|
||||
}
|
||||
requiredProtectedBytes = header.readIntLE() - HEADER_TYPE_FIELD_BYTES;
|
||||
checkArgument(
|
||||
requiredProtectedBytes >= suffixBytes, "Invalid header field: frame size too small");
|
||||
checkArgument(
|
||||
requiredProtectedBytes <= LIMIT_MAX_ALLOWED_FRAME_BYTES - HEADER_BYTES,
|
||||
"Invalid header field: frame size too large");
|
||||
int frameType = header.readIntLE();
|
||||
checkArgument(frameType == HEADER_TYPE_DEFAULT, "Invalid header field: frame type");
|
||||
state = DeframerState.READ_PROTECTED_PAYLOAD;
|
||||
}
|
||||
|
||||
@SuppressWarnings("BetaApi") // verify is stable in Guava
|
||||
private ByteBuf handlePayload(ByteBufAllocator alloc) throws GeneralSecurityException {
|
||||
int requiredCiphertextBytes = requiredProtectedBytes - suffixBytes;
|
||||
int firstFrameUnprotectedLen = requiredCiphertextBytes;
|
||||
|
||||
// We get the ciphertexts of the first frame and copy over the tag into a single buffer.
|
||||
List<ByteBuf> firstFrameCiphertext = new ArrayList<>();
|
||||
while (requiredCiphertextBytes > 0) {
|
||||
ByteBuf buf = unhandledBufs.get(unhandledIdx);
|
||||
if (buf.readableBytes() <= requiredCiphertextBytes) {
|
||||
// We use the whole buffer.
|
||||
firstFrameCiphertext.add(buf);
|
||||
requiredCiphertextBytes -= buf.readableBytes();
|
||||
unhandledIdx++;
|
||||
} else {
|
||||
firstFrameCiphertext.add(buf.readSlice(requiredCiphertextBytes));
|
||||
requiredCiphertextBytes = 0;
|
||||
}
|
||||
}
|
||||
int requiredSuffixBytes = suffixBytes;
|
||||
while (true) {
|
||||
ByteBuf buf = unhandledBufs.get(unhandledIdx);
|
||||
if (buf.readableBytes() <= requiredSuffixBytes) {
|
||||
// We use the whole buffer.
|
||||
requiredSuffixBytes -= buf.readableBytes();
|
||||
firstFrameTag.writeBytes(buf);
|
||||
if (requiredSuffixBytes == 0) {
|
||||
break;
|
||||
}
|
||||
unhandledIdx++;
|
||||
} else {
|
||||
firstFrameTag.writeBytes(buf, requiredSuffixBytes);
|
||||
break;
|
||||
}
|
||||
}
|
||||
verify(unhandledIdx == unhandledBufs.size() - 1);
|
||||
ByteBuf lastBuf = unhandledBufs.get(unhandledIdx);
|
||||
|
||||
// We get the remaining ciphertexts and tags contained in the last buffer.
|
||||
List<ByteBuf> ciphertextsAndTags = new ArrayList<>();
|
||||
List<Integer> unprotectedLens = new ArrayList<>();
|
||||
long requiredUnprotectedBytesCompleteFrames = firstFrameUnprotectedLen;
|
||||
while (lastBuf.readableBytes() >= HEADER_BYTES + suffixBytes) {
|
||||
// Read frame size.
|
||||
int frameSize = lastBuf.readIntLE();
|
||||
int payloadSize = frameSize - HEADER_TYPE_FIELD_BYTES - suffixBytes;
|
||||
// Break and undo read if we don't have the complete frame yet.
|
||||
if (lastBuf.readableBytes() < frameSize) {
|
||||
lastBuf.readerIndex(lastBuf.readerIndex() - HEADER_LEN_FIELD_BYTES);
|
||||
break;
|
||||
}
|
||||
// Check the type header.
|
||||
checkArgument(lastBuf.readIntLE() == 6);
|
||||
// Create a new frame (except for out buffer).
|
||||
ciphertextsAndTags.add(lastBuf.readSlice(payloadSize + suffixBytes));
|
||||
// Update sizes for frame.
|
||||
requiredUnprotectedBytesCompleteFrames += payloadSize;
|
||||
unprotectedLens.add(payloadSize);
|
||||
}
|
||||
|
||||
// We leave space for suffixBytes to allow for in-place encryption. This allows for calling
|
||||
// doFinal in the JCE implementation which can be optimized better than update and doFinal.
|
||||
ByteBuf unprotectedBuf =
|
||||
alloc.directBuffer(Math.toIntExact(requiredUnprotectedBytesCompleteFrames + suffixBytes));
|
||||
try {
|
||||
|
||||
ByteBuf out = writeSlice(unprotectedBuf, firstFrameUnprotectedLen + suffixBytes);
|
||||
crypter.decrypt(out, firstFrameTag, firstFrameCiphertext);
|
||||
verify(out.writableBytes() == suffixBytes);
|
||||
unprotectedBuf.writerIndex(unprotectedBuf.writerIndex() - suffixBytes);
|
||||
|
||||
for (int frameIdx = 0; frameIdx < ciphertextsAndTags.size(); ++frameIdx) {
|
||||
out = writeSlice(unprotectedBuf, unprotectedLens.get(frameIdx) + suffixBytes);
|
||||
crypter.decrypt(out, ciphertextsAndTags.get(frameIdx));
|
||||
verify(out.writableBytes() == suffixBytes);
|
||||
unprotectedBuf.writerIndex(unprotectedBuf.writerIndex() - suffixBytes);
|
||||
}
|
||||
return unprotectedBuf.retain();
|
||||
} finally {
|
||||
unprotectedBuf.release();
|
||||
}
|
||||
}
|
||||
|
||||
private void clearState() {
|
||||
int bufsSize = unhandledBufs.size();
|
||||
ByteBuf lastBuf = unhandledBufs.get(bufsSize - 1);
|
||||
boolean keepLast = lastBuf.isReadable();
|
||||
for (int bufIdx = 0; bufIdx < (keepLast ? bufsSize - 1 : bufsSize); ++bufIdx) {
|
||||
unhandledBufs.get(bufIdx).release();
|
||||
}
|
||||
unhandledBufs.clear();
|
||||
unhandledBytes = 0;
|
||||
unhandledIdx = 0;
|
||||
if (keepLast) {
|
||||
unhandledBufs.add(lastBuf);
|
||||
unhandledBytes = lastBuf.readableBytes();
|
||||
}
|
||||
state = DeframerState.READ_HEADER;
|
||||
requiredProtectedBytes = 0;
|
||||
header.clear();
|
||||
firstFrameTag.clear();
|
||||
}
|
||||
|
||||
void destroy() {
|
||||
for (ByteBuf unhandledBuf : unhandledBufs) {
|
||||
unhandledBuf.release();
|
||||
}
|
||||
unhandledBufs.clear();
|
||||
if (header != null) {
|
||||
header.release();
|
||||
header = null;
|
||||
}
|
||||
if (firstFrameTag != null) {
|
||||
firstFrameTag.release();
|
||||
firstFrameTag = null;
|
||||
}
|
||||
crypter.destroy();
|
||||
}
|
||||
}
|
||||
|
||||
private enum DeframerState {
|
||||
READ_HEADER,
|
||||
READ_PROTECTED_PAYLOAD
|
||||
}
|
||||
|
||||
private static ByteBuf writeSlice(ByteBuf in, int len) {
|
||||
checkArgument(len <= in.writableBytes());
|
||||
ByteBuf out = in.slice(in.writerIndex(), len);
|
||||
in.writerIndex(in.writerIndex() + len);
|
||||
return out.writerIndex(0);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,195 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import static com.google.common.base.Preconditions.checkState;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.grpc.alts.HandshakerServiceGrpc.HandshakerServiceStub;
|
||||
import io.netty.buffer.ByteBufAllocator;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Negotiates a grpc channel key to be used by the TsiFrameProtector, using ALTs handshaker service.
|
||||
*/
|
||||
public final class AltsTsiHandshaker implements TsiHandshaker {
|
||||
public static final String TSI_SERVICE_ACCOUNT_PEER_PROPERTY = "service_account";
|
||||
|
||||
private final boolean isClient;
|
||||
private final AltsHandshakerClient handshaker;
|
||||
|
||||
private ByteBuffer outputFrame;
|
||||
|
||||
/** Starts a new TSI handshaker with client options. */
|
||||
private AltsTsiHandshaker(
|
||||
boolean isClient, HandshakerServiceStub stub, AltsHandshakerOptions options) {
|
||||
this.isClient = isClient;
|
||||
handshaker = new AltsHandshakerClient(stub, options);
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
AltsTsiHandshaker(boolean isClient, AltsHandshakerClient handshaker) {
|
||||
this.isClient = isClient;
|
||||
this.handshaker = handshaker;
|
||||
}
|
||||
|
||||
/**
|
||||
* Process the bytes received from the peer.
|
||||
*
|
||||
* @param bytes The buffer containing the handshake bytes from the peer.
|
||||
* @return true, if the handshake has all the data it needs to process and false, if the method
|
||||
* must be called again to complete processing.
|
||||
*/
|
||||
@Override
|
||||
public boolean processBytesFromPeer(ByteBuffer bytes) throws GeneralSecurityException {
|
||||
// If we're the client and we haven't given an output frame, we shouldn't be processing any
|
||||
// bytes.
|
||||
if (outputFrame == null && isClient) {
|
||||
return true;
|
||||
}
|
||||
// If we already have bytes to write, just return.
|
||||
if (outputFrame != null && outputFrame.hasRemaining()) {
|
||||
return true;
|
||||
}
|
||||
int remaining = bytes.remaining();
|
||||
// Call handshaker service to proceess the bytes.
|
||||
if (outputFrame == null) {
|
||||
checkState(!isClient, "Client handshaker should not process any frame at the beginning.");
|
||||
outputFrame = handshaker.startServerHandshake(bytes);
|
||||
} else {
|
||||
outputFrame = handshaker.next(bytes);
|
||||
}
|
||||
// If handshake has finished or we already have bytes to write, just return true.
|
||||
if (handshaker.isFinished() || outputFrame.hasRemaining()) {
|
||||
return true;
|
||||
}
|
||||
// We have done processing input bytes, but no bytes to write. Thus we need more data.
|
||||
if (!bytes.hasRemaining()) {
|
||||
return false;
|
||||
}
|
||||
// There are still remaining bytes. Thus we need to continue processing the bytes.
|
||||
// Prevent infinite loop by checking some bytes are consumed by handshaker.
|
||||
checkState(bytes.remaining() < remaining, "Handshaker did not consume any bytes.");
|
||||
return processBytesFromPeer(bytes);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the peer extracted from a completed handshake.
|
||||
*
|
||||
* @return the extracted peer.
|
||||
*/
|
||||
@Override
|
||||
public TsiPeer extractPeer() throws GeneralSecurityException {
|
||||
Preconditions.checkState(!isInProgress(), "Handshake is not complete.");
|
||||
List<TsiPeer.Property<?>> peerProperties = new ArrayList<TsiPeer.Property<?>>();
|
||||
peerProperties.add(
|
||||
new TsiPeer.StringProperty(
|
||||
TSI_SERVICE_ACCOUNT_PEER_PROPERTY,
|
||||
handshaker.getResult().getPeerIdentity().getServiceAccount()));
|
||||
return new TsiPeer(peerProperties);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the peer extracted from a completed handshake.
|
||||
*
|
||||
* @return the extracted peer.
|
||||
*/
|
||||
@Override
|
||||
public Object extractPeerObject() throws GeneralSecurityException {
|
||||
Preconditions.checkState(!isInProgress(), "Handshake is not complete.");
|
||||
return new AltsAuthContext(handshaker.getResult());
|
||||
}
|
||||
|
||||
/** Creates a new TsiHandshaker for use by the client. */
|
||||
public static TsiHandshaker newClient(HandshakerServiceStub stub, AltsHandshakerOptions options) {
|
||||
return new AltsTsiHandshaker(true, stub, options);
|
||||
}
|
||||
|
||||
/** Creates a new TsiHandshaker for use by the server. */
|
||||
public static TsiHandshaker newServer(HandshakerServiceStub stub, AltsHandshakerOptions options) {
|
||||
return new AltsTsiHandshaker(false, stub, options);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets bytes that need to be sent to the peer.
|
||||
*
|
||||
* @param bytes The buffer to put handshake bytes.
|
||||
*/
|
||||
@Override
|
||||
public void getBytesToSendToPeer(ByteBuffer bytes) throws GeneralSecurityException {
|
||||
if (outputFrame == null) { // A null outputFrame indicates we haven't started the handshake.
|
||||
if (isClient) {
|
||||
outputFrame = handshaker.startClientHandshake();
|
||||
} else {
|
||||
// The server needs bytes to process before it can start the handshake.
|
||||
return;
|
||||
}
|
||||
}
|
||||
// Write as many bytes as we are able.
|
||||
ByteBuffer outputFrameAlias = outputFrame;
|
||||
if (outputFrame.remaining() > bytes.remaining()) {
|
||||
outputFrameAlias = outputFrame.duplicate();
|
||||
outputFrameAlias.limit(outputFrameAlias.position() + bytes.remaining());
|
||||
}
|
||||
bytes.put(outputFrameAlias);
|
||||
outputFrame.position(outputFrameAlias.position());
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if and only if the handshake is still in progress
|
||||
*
|
||||
* @return true, if the handshake is still in progress, false otherwise.
|
||||
*/
|
||||
@Override
|
||||
public boolean isInProgress() {
|
||||
return !handshaker.isFinished() || outputFrame.hasRemaining();
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a frame protector from a completed handshake. No other methods may be called after the
|
||||
* frame protector is created.
|
||||
*
|
||||
* @param maxFrameSize the requested max frame size, the callee is free to ignore.
|
||||
* @param alloc used for allocating ByteBufs.
|
||||
* @return a new TsiFrameProtector.
|
||||
*/
|
||||
@Override
|
||||
public TsiFrameProtector createFrameProtector(int maxFrameSize, ByteBufAllocator alloc) {
|
||||
Preconditions.checkState(!isInProgress(), "Handshake is not complete.");
|
||||
|
||||
byte[] key = handshaker.getKey();
|
||||
Preconditions.checkState(key.length == AltsChannelCrypter.getKeyLength(), "Bad key length.");
|
||||
|
||||
return new AltsTsiFrameProtector(maxFrameSize, new AltsChannelCrypter(key, isClient), alloc);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a frame protector from a completed handshake. No other methods may be called after the
|
||||
* frame protector is created.
|
||||
*
|
||||
* @param alloc used for allocating ByteBufs.
|
||||
* @return a new TsiFrameProtector.
|
||||
*/
|
||||
@Override
|
||||
public TsiFrameProtector createFrameProtector(ByteBufAllocator alloc) {
|
||||
return createFrameProtector(AltsTsiFrameProtector.getMaxAllowedFrameBytes(), alloc);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A @{code ChannelCrypterNetty} performs stateful encryption and decryption of independent input
|
||||
* and output streams. Both decrypt and encrypt gather their input from a list of Netty @{link
|
||||
* ByteBuf} instances.
|
||||
*
|
||||
* <p>Note that we provide implementations of this interface that provide integrity only and
|
||||
* implementations that provide privacy and integrity. All methods should be thread-compatible.
|
||||
*/
|
||||
public interface ChannelCrypterNetty {
|
||||
|
||||
/**
|
||||
* Encrypt plaintext into output buffer.
|
||||
*
|
||||
* @param out the protected input will be written into this buffer. The buffer must be direct and
|
||||
* have enough space to hold all input buffers and the tag. Encrypt does not take ownership of
|
||||
* this buffer.
|
||||
* @param plain the input buffers that should be protected. Encrypt does not modify or take
|
||||
* ownership of these buffers.
|
||||
*/
|
||||
void encrypt(ByteBuf out, List<ByteBuf> plain) throws GeneralSecurityException;
|
||||
|
||||
/**
|
||||
* Decrypt ciphertext into the given output buffer and check tag.
|
||||
*
|
||||
* @param out the unprotected input will be written into this buffer. The buffer must be direct
|
||||
* and have enough space to hold all ciphertext buffers and the tag, i.e., it must have
|
||||
* additional space for the tag, even though this space will be unused in the final result.
|
||||
* Decrypt does not take ownership of this buffer.
|
||||
* @param tag the tag appended to the ciphertext. Decrypt does not modify or take ownership of
|
||||
* this buffer.
|
||||
* @param ciphertext the buffers that should be unprotected (excluding the tag). Decrypt does not
|
||||
* modify or take ownership of these buffers.
|
||||
*/
|
||||
void decrypt(ByteBuf out, ByteBuf tag, List<ByteBuf> ciphertext) throws GeneralSecurityException;
|
||||
|
||||
/**
|
||||
* Decrypt ciphertext into the given output buffer and check tag.
|
||||
*
|
||||
* @param out the unprotected input will be written into this buffer. The buffer must be direct
|
||||
* and have enough space to hold all ciphertext buffers and the tag, i.e., it must have
|
||||
* additional space for the tag, even though this space will be unused in the final result.
|
||||
* Decrypt does not take ownership of this buffer.
|
||||
* @param ciphertextAndTag single buffer containing ciphertext and tag that should be unprotected.
|
||||
* The buffer must be direct and either completely overlap with {@code out} or not overlap at
|
||||
* all.
|
||||
*/
|
||||
void decrypt(ByteBuf out, ByteBuf ciphertextAndTag) throws GeneralSecurityException;
|
||||
|
||||
/** Returns the length of the tag in bytes. */
|
||||
int getSuffixLength();
|
||||
|
||||
/** Must be called to release all associated resources (instance cannot be used afterwards). */
|
||||
void destroy();
|
||||
}
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufAllocator;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.util.List;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
/**
|
||||
* This object protects and unprotects netty buffers once the handshake is done.
|
||||
*
|
||||
* <p>Implementations of this object must be thread compatible.
|
||||
*/
|
||||
public interface TsiFrameProtector {
|
||||
|
||||
/**
|
||||
* Protects the buffers by performing framing and encrypting/appending MACs.
|
||||
*
|
||||
* @param unprotectedBufs contain the payload that will be protected
|
||||
* @param ctxWrite is called with buffers containing protected frames and must release the given
|
||||
* buffers
|
||||
* @param alloc is used to allocate new buffers for the protected frames
|
||||
*/
|
||||
void protectFlush(
|
||||
List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc)
|
||||
throws GeneralSecurityException;
|
||||
|
||||
/**
|
||||
* Unprotects the buffers by removing the framing and decrypting/checking MACs.
|
||||
*
|
||||
* @param in contains (partial) protected frames
|
||||
* @param out is only used to append unprotected payload buffers
|
||||
* @param alloc is used to allocate new buffers for the unprotected frames
|
||||
*/
|
||||
void unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)
|
||||
throws GeneralSecurityException;
|
||||
|
||||
/** Must be called to release all associated resources (instance cannot be used afterwards). */
|
||||
void destroy();
|
||||
}
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import io.netty.buffer.ByteBufAllocator;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.security.GeneralSecurityException;
|
||||
|
||||
/**
|
||||
* This object protects and unprotects buffers once the handshake is done.
|
||||
*
|
||||
* <p>A typical usage of this object would be:
|
||||
*
|
||||
* <pre>{@code
|
||||
* ByteBuffer buffer = allocateDirect(ALLOCATE_SIZE);
|
||||
* while (true) {
|
||||
* while (true) {
|
||||
* tsiHandshaker.getBytesToSendToPeer(buffer.clear());
|
||||
* if (!buffer.hasRemaining()) break;
|
||||
* yourTransportSendMethod(buffer.flip());
|
||||
* assert(!buffer.hasRemaining()); // Guaranteed by yourTransportReceiveMethod(...)
|
||||
* }
|
||||
* if (!tsiHandshaker.isInProgress()) break;
|
||||
* while (true) {
|
||||
* assert(!buffer.hasRemaining());
|
||||
* yourTransportReceiveMethod(buffer.clear());
|
||||
* if (tsiHandshaker.processBytesFromPeer(buffer.flip())) break;
|
||||
* }
|
||||
* if (!tsiHandshaker.isInProgress()) break;
|
||||
* assert(!buffer.hasRemaining());
|
||||
* }
|
||||
* yourCheckPeerMethod(tsiHandshaker.extractPeer());
|
||||
* TsiFrameProtector tsiFrameProtector = tsiHandshaker.createFrameProtector(MAX_FRAME_SIZE);
|
||||
* if (buffer.hasRemaining()) tsiFrameProtector.unprotect(buffer, messageBuffer);
|
||||
* }</pre>
|
||||
*
|
||||
* <p>Implementations of this object must be thread compatible.
|
||||
*/
|
||||
public interface TsiHandshaker {
|
||||
/**
|
||||
* Gets bytes that need to be sent to the peer.
|
||||
*
|
||||
* @param bytes The buffer to put handshake bytes.
|
||||
*/
|
||||
void getBytesToSendToPeer(ByteBuffer bytes) throws GeneralSecurityException;
|
||||
|
||||
/**
|
||||
* Process the bytes received from the peer.
|
||||
*
|
||||
* @param bytes The buffer containing the handshake bytes from the peer.
|
||||
* @return true, if the handshake has all the data it needs to process and false, if the method
|
||||
* must be called again to complete processing.
|
||||
*/
|
||||
boolean processBytesFromPeer(ByteBuffer bytes) throws GeneralSecurityException;
|
||||
|
||||
/**
|
||||
* Returns true if and only if the handshake is still in progress
|
||||
*
|
||||
* @return true, if the handshake is still in progress, false otherwise.
|
||||
*/
|
||||
boolean isInProgress();
|
||||
|
||||
/**
|
||||
* Returns the peer extracted from a completed handshake.
|
||||
*
|
||||
* @return the extracted peer.
|
||||
*/
|
||||
TsiPeer extractPeer() throws GeneralSecurityException;
|
||||
|
||||
/**
|
||||
* Returns the peer extracted from a completed handshake.
|
||||
*
|
||||
* @return the extracted peer.
|
||||
*/
|
||||
public Object extractPeerObject() throws GeneralSecurityException;
|
||||
|
||||
/**
|
||||
* Creates a frame protector from a completed handshake. No other methods may be called after the
|
||||
* frame protector is created.
|
||||
*
|
||||
* @param maxFrameSize the requested max frame size, the callee is free to ignore.
|
||||
* @param alloc used for allocating ByteBufs.
|
||||
* @return a new TsiFrameProtector.
|
||||
*/
|
||||
TsiFrameProtector createFrameProtector(int maxFrameSize, ByteBufAllocator alloc);
|
||||
|
||||
/**
|
||||
* Creates a frame protector from a completed handshake. No other methods may be called after the
|
||||
* frame protector is created.
|
||||
*
|
||||
* @param alloc used for allocating ByteBufs.
|
||||
* @return a new TsiFrameProtector.
|
||||
*/
|
||||
TsiFrameProtector createFrameProtector(ByteBufAllocator alloc);
|
||||
}
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
/** Factory that manufactures instances of {@link TsiHandshaker}. */
|
||||
public interface TsiHandshakerFactory {
|
||||
|
||||
/** Creates a new handshaker. */
|
||||
TsiHandshaker newHandshaker();
|
||||
}
|
||||
|
|
@ -0,0 +1,110 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import javax.annotation.Nonnull;
|
||||
|
||||
/** A set of peer properties. */
|
||||
public final class TsiPeer {
|
||||
private final List<Property<?>> properties;
|
||||
|
||||
public TsiPeer(List<Property<?>> properties) {
|
||||
this.properties = Collections.unmodifiableList(properties);
|
||||
}
|
||||
|
||||
public List<Property<?>> getProperties() {
|
||||
return properties;
|
||||
}
|
||||
|
||||
/** Get peer property. */
|
||||
public Property<?> getProperty(String name) {
|
||||
for (Property<?> property : properties) {
|
||||
if (property.getName().equals(name)) {
|
||||
return property;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return new ArrayList<>(properties).toString();
|
||||
}
|
||||
|
||||
/** A peer property. */
|
||||
public abstract static class Property<T> {
|
||||
private final String name;
|
||||
private final T value;
|
||||
|
||||
public Property(@Nonnull String name, @Nonnull T value) {
|
||||
this.name = name;
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
public final T getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
public final String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String.format("%s=%s", name, value);
|
||||
}
|
||||
}
|
||||
|
||||
/** A peer property corresponding to a signed 64-bit integer. */
|
||||
public static final class SignedInt64Property extends Property<Long> {
|
||||
public SignedInt64Property(@Nonnull String name, @Nonnull Long value) {
|
||||
super(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
/** A peer property corresponding to an unsigned 64-bit integer. */
|
||||
public static final class UnsignedInt64Property extends Property<BigInteger> {
|
||||
public UnsignedInt64Property(@Nonnull String name, @Nonnull BigInteger value) {
|
||||
super(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
/** A peer property corresponding to a double. */
|
||||
public static final class DoubleProperty extends Property<Double> {
|
||||
public DoubleProperty(@Nonnull String name, @Nonnull Double value) {
|
||||
super(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
/** A peer property corresponding to a string. */
|
||||
public static final class StringProperty extends Property<String> {
|
||||
public StringProperty(@Nonnull String name, @Nonnull String value) {
|
||||
super(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
/** A peer property corresponding to a list of peer properties. */
|
||||
public static final class PropertyList extends Property<List<Property<?>>> {
|
||||
public PropertyList(@Nonnull String name, @Nonnull List<Property<?>> value) {
|
||||
super(name, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
syntax = "proto3";
|
||||
|
||||
import "transport_security_common.proto";
|
||||
|
||||
package grpc.gcp;
|
||||
|
||||
option java_package = "io.grpc.alts";
|
||||
|
||||
message AltsContext {
|
||||
// The application protocol negotiated for this connection.
|
||||
string application_protocol = 1;
|
||||
|
||||
// The record protocol negotiated for this connection.
|
||||
string record_protocol = 2;
|
||||
|
||||
// The security level of the created secure channel.
|
||||
SecurityLevel security_level = 3;
|
||||
|
||||
// The peer service account.
|
||||
string peer_service_account = 4;
|
||||
|
||||
// The local service account.
|
||||
string local_service_account = 5;
|
||||
|
||||
// The RPC protocol versions supported by the peer.
|
||||
RpcProtocolVersions peer_rpc_versions = 6;
|
||||
}
|
||||
|
|
@ -0,0 +1,206 @@
|
|||
syntax = "proto3";
|
||||
|
||||
import "transport_security_common.proto";
|
||||
|
||||
package grpc.gcp;
|
||||
|
||||
option java_package = "io.grpc.alts";
|
||||
|
||||
enum HandshakeProtocol {
|
||||
// Default value.
|
||||
HANDSHAKE_PROTOCOL_UNSPECIFIED = 0;
|
||||
|
||||
// TLS handshake protocol.
|
||||
TLS = 1;
|
||||
|
||||
// Application Layer Transport Security handshake protocol.
|
||||
ALTS = 2;
|
||||
}
|
||||
|
||||
enum NetworkProtocol {
|
||||
NETWORK_PROTOCOL_UNSPECIFIED = 0;
|
||||
TCP = 1;
|
||||
UDP = 2;
|
||||
}
|
||||
|
||||
message Endpoint {
|
||||
// IP address. It should contain an IPv4 or IPv6 string literal, e.g.
|
||||
// "192.168.0.1" or "2001:db8::1".
|
||||
string ip_address = 1;
|
||||
|
||||
// Port number.
|
||||
int32 port = 2;
|
||||
|
||||
// Network protocol (e.g., TCP, UDP) associated with this endpoint.
|
||||
NetworkProtocol protocol = 3;
|
||||
}
|
||||
|
||||
message Identity {
|
||||
oneof identity_oneof {
|
||||
// Service account of a connection endpoint.
|
||||
string service_account = 1;
|
||||
|
||||
// Hostname of a connection endpoint.
|
||||
string hostname = 2;
|
||||
}
|
||||
}
|
||||
|
||||
message StartClientHandshakeReq {
|
||||
// Handshake security protocol requested by the client.
|
||||
HandshakeProtocol handshake_security_protocol = 1;
|
||||
|
||||
// The application protocols supported by the client, e.g., "h2" (for http2),
|
||||
// "grpc".
|
||||
repeated string application_protocols = 2;
|
||||
|
||||
// The record protocols supported by the client, e.g.,
|
||||
// "ALTSRP_GCM_AES128".
|
||||
repeated string record_protocols = 3;
|
||||
|
||||
// (Optional) Describes which server identities are acceptable by the client.
|
||||
// If target identities are provided and none of them matches the peer
|
||||
// identity of the server, handshake will fail.
|
||||
repeated Identity target_identities = 4;
|
||||
|
||||
// (Optional) Application may specify a local identity. Otherwise, the
|
||||
// handshaker chooses a default local identity.
|
||||
Identity local_identity = 5;
|
||||
|
||||
// (Optional) Local endpoint information of the connection to the server,
|
||||
// such as local IP address, port number, and network protocol.
|
||||
Endpoint local_endpoint = 6;
|
||||
|
||||
// (Optional) Endpoint information of the remote server, such as IP address,
|
||||
// port number, and network protocool.
|
||||
Endpoint remote_endpoint = 7;
|
||||
|
||||
// (Optional) If target name is provided, a secure naming check is performed
|
||||
// to verify that the peer authenticated identity is indeed authorized to run
|
||||
// the target name.
|
||||
string target_name = 8;
|
||||
|
||||
// (Optional) RPC protocol versions supported by the client.
|
||||
RpcProtocolVersions rpc_versions = 9;
|
||||
}
|
||||
|
||||
message ServerHandshakeParameters {
|
||||
// The record protocols supported by the server, e.g.,
|
||||
// "ALTSRP_GCM_AES128".
|
||||
repeated string record_protocols = 1;
|
||||
|
||||
// (Optional) A list of local identities supported by the server, if
|
||||
// specified. Otherwise, the handshaker chooses a default local identity.
|
||||
repeated Identity local_identities = 2;
|
||||
}
|
||||
|
||||
message StartServerHandshakeReq {
|
||||
// The application protocols supported by the server, e.g., "h2" (for http2),
|
||||
// "grpc".
|
||||
repeated string application_protocols = 1;
|
||||
|
||||
// Handshake parameters (record protocols and local identities supported by
|
||||
// the server) mapped by the handshake protocol. Each handshake security
|
||||
// protocol (e.g., TLS or ALTS) has its own set of record protocols and local
|
||||
// identities. Since protobuf does not support enum as key to the map, the key
|
||||
// to handshake_parameters is the integer value of HandshakeProtocol enum.
|
||||
map<int32, ServerHandshakeParameters> handshake_parameters = 2;
|
||||
|
||||
// Bytes in out_frames returned from the peer's HandshakerResp. It is possible
|
||||
// that the peer's out_frames are split into multiple HandshakReq messages.
|
||||
bytes in_bytes = 3;
|
||||
|
||||
// (Optional) Local endpoint information of the connection to the client,
|
||||
// such as local IP address, port number, and network protocol.
|
||||
Endpoint local_endpoint = 4;
|
||||
|
||||
// (Optional) Endpoint information of the remote client, such as IP address,
|
||||
// port number, and network protocool.
|
||||
Endpoint remote_endpoint = 5;
|
||||
|
||||
// (Optional) RPC protocol versions supported by the server.
|
||||
RpcProtocolVersions rpc_versions = 6;
|
||||
}
|
||||
|
||||
message NextHandshakeMessageReq {
|
||||
// Bytes in out_frames returned from the peer's HandshakerResp. It is possible
|
||||
// that the peer's out_frames are split into multiple NextHandshakerMessageReq
|
||||
// messages.
|
||||
bytes in_bytes = 1;
|
||||
}
|
||||
|
||||
message HandshakerReq {
|
||||
oneof req_oneof {
|
||||
// The start client handshake request message.
|
||||
StartClientHandshakeReq client_start = 1;
|
||||
|
||||
// The start server handshake request message.
|
||||
StartServerHandshakeReq server_start = 2;
|
||||
|
||||
// The next handshake request message.
|
||||
NextHandshakeMessageReq next = 3;
|
||||
}
|
||||
}
|
||||
|
||||
message HandshakerResult {
|
||||
// The application protocol negotiated for this connection.
|
||||
string application_protocol = 1;
|
||||
|
||||
// The record protocol negotiated for this connection.
|
||||
string record_protocol = 2;
|
||||
|
||||
// Cryptographic key data. The key data may be more than the key length
|
||||
// required for the record protocol, thus the client of the handshaker
|
||||
// service needs to truncate the key data into the right key length.
|
||||
bytes key_data = 3;
|
||||
|
||||
// The authenticated identity of the peer.
|
||||
Identity peer_identity = 4;
|
||||
|
||||
// The local identity used in the handshake.
|
||||
Identity local_identity = 5;
|
||||
|
||||
// Indicate whether the handshaker service client should keep the channel
|
||||
// between the handshaker service open, e.g., in order to handle
|
||||
// post-handshake messages in the future.
|
||||
bool keep_channel_open = 6;
|
||||
|
||||
// The RPC protocol versions supported by the peer.
|
||||
RpcProtocolVersions peer_rpc_versions = 7;
|
||||
}
|
||||
|
||||
message HandshakerStatus {
|
||||
// The status code. This could be the gRPC status code.
|
||||
uint32 code = 1;
|
||||
|
||||
// The status details.
|
||||
string details = 2;
|
||||
}
|
||||
|
||||
message HandshakerResp {
|
||||
// Frames to be given to the peer for the NextHandshakeMessageReq. May be
|
||||
// empty if no out_frames have to be sent to the peer or if in_bytes in the
|
||||
// HandshakerReq are incomplete. All the non-empty out frames must be sent to
|
||||
// the peer even if the handshaker status is not OK as these frames may
|
||||
// contain the alert frames.
|
||||
bytes out_frames = 1;
|
||||
|
||||
// Number of bytes in the in_bytes consumed by the handshaker. It is possible
|
||||
// that part of in_bytes in HandshakerReq was unrelated to the handshake
|
||||
// process.
|
||||
uint32 bytes_consumed = 2;
|
||||
|
||||
// This is set iff the handshake was successful. out_frames may still be set
|
||||
// to frames that needs to be forwarded to the peer.
|
||||
HandshakerResult result = 3;
|
||||
|
||||
// Status of the handshaker.
|
||||
HandshakerStatus status = 4;
|
||||
}
|
||||
|
||||
service HandshakerService {
|
||||
// Accepts a stream of handshaker request, returning a stream of handshaker
|
||||
// response.
|
||||
rpc DoHandshake(stream HandshakerReq)
|
||||
returns (stream HandshakerResp) {
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package grpc.gcp;
|
||||
|
||||
option java_package = "io.grpc.alts";
|
||||
|
||||
// The security level of the created channel. The list is sorted in increasing
|
||||
// level of security. This order must always be maintained.
|
||||
enum SecurityLevel {
|
||||
SECURITY_NONE = 0;
|
||||
INTEGRITY_ONLY = 1;
|
||||
INTEGRITY_AND_PRIVACY = 2;
|
||||
}
|
||||
|
||||
// Max and min supported RPC protocol versions.
|
||||
message RpcProtocolVersions {
|
||||
// RPC version contains a major version and a minor version.
|
||||
message Version {
|
||||
uint32 major = 1;
|
||||
uint32 minor = 2;
|
||||
}
|
||||
// Maximum supported RPC version.
|
||||
Version max_rpc_version = 1;
|
||||
// Minimum supported RPC version.
|
||||
Version min_rpc_version = 2;
|
||||
}
|
||||
|
|
@ -0,0 +1,94 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
||||
import com.google.common.base.Defaults;
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.alts.AltsChannelBuilder.AltsChannel;
|
||||
import io.grpc.alts.TransportSecurityCommon.RpcProtocolVersions;
|
||||
import io.grpc.alts.transportsecurity.AltsClientOptions;
|
||||
import io.grpc.netty.InternalNettyChannelBuilder.TransportCreationParamsFilterFactory;
|
||||
import io.grpc.netty.ProtocolNegotiator;
|
||||
import java.lang.reflect.Method;
|
||||
import java.lang.reflect.Modifier;
|
||||
import java.net.InetSocketAddress;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
||||
@RunWith(JUnit4.class)
|
||||
public final class AltsChannelBuilderTest {
|
||||
|
||||
@Test
|
||||
public void buildsNettyChannel() throws Exception {
|
||||
AltsChannelBuilder builder =
|
||||
AltsChannelBuilder.forTarget("localhost:8080").enableUntrustedAltsForTesting();
|
||||
|
||||
TransportCreationParamsFilterFactory tcpfFactory = builder.getTcpfFactoryForTest();
|
||||
AltsClientOptions altsClientOptions = builder.getAltsClientOptionsForTest();
|
||||
|
||||
assertThat(tcpfFactory).isNull();
|
||||
assertThat(altsClientOptions).isNull();
|
||||
|
||||
ManagedChannel channel = builder.build();
|
||||
assertThat(channel).isInstanceOf(AltsChannel.class);
|
||||
|
||||
tcpfFactory = builder.getTcpfFactoryForTest();
|
||||
altsClientOptions = builder.getAltsClientOptionsForTest();
|
||||
|
||||
assertThat(tcpfFactory).isNotNull();
|
||||
ProtocolNegotiator protocolNegotiator =
|
||||
tcpfFactory
|
||||
.create(new InetSocketAddress(8080), "fakeAuthority", "fakeUserAgent", null)
|
||||
.getProtocolNegotiator();
|
||||
assertThat(protocolNegotiator).isInstanceOf(AltsProtocolNegotiator.class);
|
||||
|
||||
assertThat(altsClientOptions).isNotNull();
|
||||
RpcProtocolVersions expectedVersions =
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(
|
||||
RpcProtocolVersions.Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.setMinRpcVersion(
|
||||
RpcProtocolVersions.Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.build();
|
||||
assertThat(altsClientOptions.getRpcProtocolVersions()).isEqualTo(expectedVersions);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void allAltsChannelMethodsForward() throws Exception {
|
||||
ManagedChannel mockDelegate = mock(ManagedChannel.class);
|
||||
AltsChannel altsChannel = new AltsChannel(mockDelegate);
|
||||
|
||||
for (Method method : ManagedChannel.class.getDeclaredMethods()) {
|
||||
if (Modifier.isStatic(method.getModifiers()) || Modifier.isPrivate(method.getModifiers())) {
|
||||
continue;
|
||||
}
|
||||
Class<?>[] argTypes = method.getParameterTypes();
|
||||
Object[] args = new Object[argTypes.length];
|
||||
for (int i = 0; i < argTypes.length; i++) {
|
||||
args[i] = Defaults.defaultValue(argTypes[i]);
|
||||
}
|
||||
|
||||
method.invoke(altsChannel, args);
|
||||
method.invoke(verify(mockDelegate), args);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,494 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static java.nio.charset.StandardCharsets.UTF_8;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import io.grpc.Attributes;
|
||||
import io.grpc.Grpc;
|
||||
import io.grpc.alts.Handshaker.HandshakerResult;
|
||||
import io.grpc.alts.transportsecurity.AltsAuthContext;
|
||||
import io.grpc.alts.transportsecurity.FakeTsiHandshaker;
|
||||
import io.grpc.alts.transportsecurity.TsiFrameProtector;
|
||||
import io.grpc.alts.transportsecurity.TsiHandshaker;
|
||||
import io.grpc.alts.transportsecurity.TsiHandshakerFactory;
|
||||
import io.grpc.alts.transportsecurity.TsiPeer;
|
||||
import io.grpc.netty.GrpcHttp2ConnectionHandler;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufAllocator;
|
||||
import io.netty.buffer.CompositeByteBuf;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelDuplexHandler;
|
||||
import io.netty.channel.ChannelFuture;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandler;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.handler.codec.http2.DefaultHttp2Connection;
|
||||
import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder;
|
||||
import io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder;
|
||||
import io.netty.handler.codec.http2.DefaultHttp2FrameReader;
|
||||
import io.netty.handler.codec.http2.DefaultHttp2FrameWriter;
|
||||
import io.netty.handler.codec.http2.Http2Connection;
|
||||
import io.netty.handler.codec.http2.Http2ConnectionDecoder;
|
||||
import io.netty.handler.codec.http2.Http2ConnectionEncoder;
|
||||
import io.netty.handler.codec.http2.Http2FrameReader;
|
||||
import io.netty.handler.codec.http2.Http2FrameWriter;
|
||||
import io.netty.handler.codec.http2.Http2Settings;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import io.netty.util.ReferenceCounted;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.concurrent.LinkedBlockingQueue;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.function.Consumer;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
||||
/** Tests for {@link AltsProtocolNegotiator}. */
|
||||
@RunWith(JUnit4.class)
|
||||
public class AltsProtocolNegotiatorTest {
|
||||
private final CapturingGrpcHttp2ConnectionHandler grpcHandler = capturingGrpcHandler();
|
||||
|
||||
private final List<ReferenceCounted> references = new ArrayList<>();
|
||||
private final LinkedBlockingQueue<InterceptingProtector> protectors = new LinkedBlockingQueue<>();
|
||||
|
||||
private EmbeddedChannel channel;
|
||||
private Throwable caughtException;
|
||||
|
||||
private volatile InternalTsiHandshakeHandler.TsiHandshakeCompletionEvent tsiEvent;
|
||||
private ChannelHandler handler;
|
||||
|
||||
private TsiPeer mockedTsiPeer = new TsiPeer(Collections.emptyList());
|
||||
private AltsAuthContext mockedAltsContext =
|
||||
new AltsAuthContext(
|
||||
HandshakerResult.newBuilder()
|
||||
.setPeerRpcVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions())
|
||||
.build());
|
||||
private final TsiHandshaker mockHandshaker =
|
||||
new DelegatingTsiHandshaker(FakeTsiHandshaker.newFakeHandshakerServer()) {
|
||||
@Override
|
||||
public TsiPeer extractPeer() throws GeneralSecurityException {
|
||||
return mockedTsiPeer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object extractPeerObject() throws GeneralSecurityException {
|
||||
return mockedAltsContext;
|
||||
}
|
||||
};
|
||||
private final InternalNettyTsiHandshaker serverHandshaker =
|
||||
new InternalNettyTsiHandshaker(mockHandshaker);
|
||||
|
||||
@Before
|
||||
public void setup() throws Exception {
|
||||
ChannelHandler userEventHandler =
|
||||
new ChannelDuplexHandler() {
|
||||
@Override
|
||||
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
|
||||
if (evt instanceof InternalTsiHandshakeHandler.TsiHandshakeCompletionEvent) {
|
||||
tsiEvent = (InternalTsiHandshakeHandler.TsiHandshakeCompletionEvent) evt;
|
||||
} else {
|
||||
super.userEventTriggered(ctx, evt);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
ChannelHandler uncaughtExceptionHandler =
|
||||
new ChannelDuplexHandler() {
|
||||
@Override
|
||||
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
|
||||
caughtException = cause;
|
||||
super.exceptionCaught(ctx, cause);
|
||||
}
|
||||
};
|
||||
|
||||
TsiHandshakerFactory handshakerFactory =
|
||||
new DelegatingTsiHandshakerFactory(FakeTsiHandshaker.clientHandshakerFactory()) {
|
||||
@Override
|
||||
public TsiHandshaker newHandshaker() {
|
||||
return new DelegatingTsiHandshaker(super.newHandshaker()) {
|
||||
@Override
|
||||
public TsiPeer extractPeer() throws GeneralSecurityException {
|
||||
return mockedTsiPeer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object extractPeerObject() throws GeneralSecurityException {
|
||||
return mockedAltsContext;
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
handler = AltsProtocolNegotiator.create(handshakerFactory).newHandler(grpcHandler);
|
||||
channel = new EmbeddedChannel(uncaughtExceptionHandler, handler, userEventHandler);
|
||||
}
|
||||
|
||||
@After
|
||||
public void teardown() throws Exception {
|
||||
if (channel != null) {
|
||||
@SuppressWarnings("unused") // go/futurereturn-lsc
|
||||
Future<?> possiblyIgnoredError = channel.close();
|
||||
}
|
||||
|
||||
for (ReferenceCounted reference : references) {
|
||||
ReferenceCountUtil.safeRelease(reference);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void handshakeShouldBeSuccessful() throws Exception {
|
||||
doHandshake();
|
||||
}
|
||||
|
||||
@Test
|
||||
@SuppressWarnings("unchecked") // List cast
|
||||
public void protectShouldRoundtrip() throws Exception {
|
||||
// Write the message 1 character at a time. The message should be buffered
|
||||
// and not interfere with the handshake.
|
||||
final AtomicInteger writeCount = new AtomicInteger();
|
||||
String message = "hello";
|
||||
for (int ix = 0; ix < message.length(); ++ix) {
|
||||
ByteBuf in = Unpooled.copiedBuffer(message, ix, 1, UTF_8);
|
||||
@SuppressWarnings("unused") // go/futurereturn-lsc
|
||||
Future<?> possiblyIgnoredError =
|
||||
channel
|
||||
.write(in)
|
||||
.addListener(
|
||||
new ChannelFutureListener() {
|
||||
@Override
|
||||
public void operationComplete(ChannelFuture future) throws Exception {
|
||||
if (future.isSuccess()) {
|
||||
writeCount.incrementAndGet();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
channel.flush();
|
||||
|
||||
// Now do the handshake. The buffered message will automatically be protected
|
||||
// and sent.
|
||||
doHandshake();
|
||||
|
||||
// Capture the protected data written to the wire.
|
||||
assertEquals(1, channel.outboundMessages().size());
|
||||
ByteBuf protectedData = channel.<ByteBuf>readOutbound();
|
||||
assertEquals(message.length(), writeCount.get());
|
||||
|
||||
// Read the protected message at the server and verify it matches the original message.
|
||||
TsiFrameProtector serverProtector = serverHandshaker.createFrameProtector(channel.alloc());
|
||||
List<ByteBuf> unprotected = new ArrayList<>();
|
||||
serverProtector.unprotect(protectedData, (List<Object>) (List<?>) unprotected, channel.alloc());
|
||||
// We try our best to remove the HTTP2 handler as soon as possible, but just by constructing it
|
||||
// a settings frame is written (and an HTTP2 preface). This is hard coded into Netty, so we
|
||||
// have to remove it here. See {@code Http2ConnectionHandler.PrefaceDecode.sendPreface}.
|
||||
int settingsFrameLength = 9;
|
||||
|
||||
CompositeByteBuf unprotectedAll =
|
||||
new CompositeByteBuf(channel.alloc(), false, unprotected.size() + 1, unprotected);
|
||||
ByteBuf unprotectedData = unprotectedAll.slice(settingsFrameLength, message.length());
|
||||
assertEquals(message, unprotectedData.toString(UTF_8));
|
||||
|
||||
// Protect the same message at the server.
|
||||
AtomicReference<ByteBuf> newlyProtectedData = new AtomicReference<>();
|
||||
serverProtector.protectFlush(
|
||||
Collections.singletonList(unprotectedData),
|
||||
b -> newlyProtectedData.set(b),
|
||||
channel.alloc());
|
||||
|
||||
// Read the protected message at the client and verify that it matches the original message.
|
||||
channel.writeInbound(newlyProtectedData.get());
|
||||
assertEquals(1, channel.inboundMessages().size());
|
||||
assertEquals(message, channel.<ByteBuf>readInbound().toString(UTF_8));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void unprotectLargeIncomingFrame() throws Exception {
|
||||
|
||||
// We use a server frameprotector with twice the standard frame size.
|
||||
int serverFrameSize = 4096 * 2;
|
||||
// This should fit into one frame.
|
||||
byte[] unprotectedBytes = new byte[serverFrameSize - 500];
|
||||
Arrays.fill(unprotectedBytes, (byte) 7);
|
||||
ByteBuf unprotectedData = Unpooled.wrappedBuffer(unprotectedBytes);
|
||||
unprotectedData.writerIndex(unprotectedBytes.length);
|
||||
|
||||
// Perform handshake.
|
||||
doHandshake();
|
||||
|
||||
// Protect the message on the server.
|
||||
TsiFrameProtector serverProtector =
|
||||
serverHandshaker.createFrameProtector(serverFrameSize, channel.alloc());
|
||||
serverProtector.protectFlush(
|
||||
Collections.singletonList(unprotectedData), b -> channel.writeInbound(b), channel.alloc());
|
||||
channel.flushInbound();
|
||||
|
||||
// Read the protected message at the client and verify that it matches the original message.
|
||||
assertEquals(1, channel.inboundMessages().size());
|
||||
|
||||
ByteBuf receivedData1 = channel.<ByteBuf>readInbound();
|
||||
int receivedLen1 = receivedData1.readableBytes();
|
||||
byte[] receivedBytes = new byte[receivedLen1];
|
||||
receivedData1.readBytes(receivedBytes, 0, receivedLen1);
|
||||
|
||||
assertThat(unprotectedBytes.length).isEqualTo(receivedBytes.length);
|
||||
assertThat(unprotectedBytes).isEqualTo(receivedBytes);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void flushShouldFailAllPromises() throws Exception {
|
||||
doHandshake();
|
||||
|
||||
channel
|
||||
.pipeline()
|
||||
.addFirst(
|
||||
new ChannelDuplexHandler() {
|
||||
@Override
|
||||
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
|
||||
throws Exception {
|
||||
throw new Exception("Fake exception");
|
||||
}
|
||||
});
|
||||
|
||||
// Write the message 1 character at a time.
|
||||
String message = "hello";
|
||||
final AtomicInteger failures = new AtomicInteger();
|
||||
for (int ix = 0; ix < message.length(); ++ix) {
|
||||
ByteBuf in = Unpooled.copiedBuffer(message, ix, 1, UTF_8);
|
||||
@SuppressWarnings("unused") // go/futurereturn-lsc
|
||||
Future<?> possiblyIgnoredError =
|
||||
channel
|
||||
.write(in)
|
||||
.addListener(
|
||||
new ChannelFutureListener() {
|
||||
@Override
|
||||
public void operationComplete(ChannelFuture future) throws Exception {
|
||||
if (!future.isSuccess()) {
|
||||
failures.incrementAndGet();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
channel.flush();
|
||||
|
||||
// Verify that the promises fail.
|
||||
assertEquals(message.length(), failures.get());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doNotFlushEmptyBuffer() throws Exception {
|
||||
doHandshake();
|
||||
assertEquals(1, protectors.size());
|
||||
InterceptingProtector protector = protectors.poll();
|
||||
|
||||
String message = "hello";
|
||||
ByteBuf in = Unpooled.copiedBuffer(message, UTF_8);
|
||||
|
||||
assertEquals(0, protector.flushes.get());
|
||||
Future<?> done = channel.write(in);
|
||||
channel.flush();
|
||||
done.get(5, TimeUnit.SECONDS);
|
||||
assertEquals(1, protector.flushes.get());
|
||||
|
||||
done = channel.write(Unpooled.EMPTY_BUFFER);
|
||||
channel.flush();
|
||||
done.get(5, TimeUnit.SECONDS);
|
||||
assertEquals(1, protector.flushes.get());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void peerPropagated() throws Exception {
|
||||
doHandshake();
|
||||
|
||||
assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.getTsiPeerAttributeKey()))
|
||||
.isEqualTo(mockedTsiPeer);
|
||||
assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.getAltsAuthContextAttributeKey()))
|
||||
.isEqualTo(mockedAltsContext);
|
||||
assertThat(grpcHandler.attrs.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR).toString())
|
||||
.isEqualTo("embedded");
|
||||
}
|
||||
|
||||
private void doHandshake() throws Exception {
|
||||
// Capture the client frame and add to the server.
|
||||
assertEquals(1, channel.outboundMessages().size());
|
||||
ByteBuf clientFrame = channel.<ByteBuf>readOutbound();
|
||||
assertTrue(serverHandshaker.processBytesFromPeer(clientFrame));
|
||||
|
||||
// Get the server response handshake frames.
|
||||
ByteBuf serverFrame = channel.alloc().buffer();
|
||||
serverHandshaker.getBytesToSendToPeer(serverFrame);
|
||||
channel.writeInbound(serverFrame);
|
||||
|
||||
// Capture the next client frame and add to the server.
|
||||
assertEquals(1, channel.outboundMessages().size());
|
||||
clientFrame = channel.<ByteBuf>readOutbound();
|
||||
assertTrue(serverHandshaker.processBytesFromPeer(clientFrame));
|
||||
|
||||
// Get the server response handshake frames.
|
||||
serverFrame = channel.alloc().buffer();
|
||||
serverHandshaker.getBytesToSendToPeer(serverFrame);
|
||||
channel.writeInbound(serverFrame);
|
||||
|
||||
// Ensure that both sides have confirmed that the handshake has completed.
|
||||
assertFalse(serverHandshaker.isInProgress());
|
||||
|
||||
if (caughtException != null) {
|
||||
throw new RuntimeException(caughtException);
|
||||
}
|
||||
assertNotNull(tsiEvent);
|
||||
}
|
||||
|
||||
private CapturingGrpcHttp2ConnectionHandler capturingGrpcHandler() {
|
||||
// Netty Boilerplate. We don't really need any of this, but there is a tight coupling
|
||||
// between a Http2ConnectionHandler and its dependencies.
|
||||
Http2Connection connection = new DefaultHttp2Connection(true);
|
||||
Http2FrameWriter frameWriter = new DefaultHttp2FrameWriter();
|
||||
Http2FrameReader frameReader = new DefaultHttp2FrameReader(false);
|
||||
DefaultHttp2ConnectionEncoder encoder =
|
||||
new DefaultHttp2ConnectionEncoder(connection, frameWriter);
|
||||
DefaultHttp2ConnectionDecoder decoder =
|
||||
new DefaultHttp2ConnectionDecoder(connection, encoder, frameReader);
|
||||
|
||||
return new CapturingGrpcHttp2ConnectionHandler(decoder, encoder, new Http2Settings());
|
||||
}
|
||||
|
||||
private final class CapturingGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler {
|
||||
private Attributes attrs;
|
||||
|
||||
private CapturingGrpcHttp2ConnectionHandler(
|
||||
Http2ConnectionDecoder decoder,
|
||||
Http2ConnectionEncoder encoder,
|
||||
Http2Settings initialSettings) {
|
||||
super(null, decoder, encoder, initialSettings);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleProtocolNegotiationCompleted(Attributes attrs) {
|
||||
// If we are added to the pipeline, we need to remove ourselves. The HTTP2 handler
|
||||
channel.pipeline().remove(this);
|
||||
this.attrs = attrs;
|
||||
}
|
||||
}
|
||||
|
||||
private static class DelegatingTsiHandshakerFactory implements TsiHandshakerFactory {
|
||||
|
||||
private TsiHandshakerFactory delegate;
|
||||
|
||||
DelegatingTsiHandshakerFactory(TsiHandshakerFactory delegate) {
|
||||
this.delegate = delegate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TsiHandshaker newHandshaker() {
|
||||
return delegate.newHandshaker();
|
||||
}
|
||||
}
|
||||
|
||||
private class DelegatingTsiHandshaker implements TsiHandshaker {
|
||||
|
||||
private final TsiHandshaker delegate;
|
||||
|
||||
DelegatingTsiHandshaker(TsiHandshaker delegate) {
|
||||
this.delegate = delegate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void getBytesToSendToPeer(ByteBuffer bytes) throws GeneralSecurityException {
|
||||
delegate.getBytesToSendToPeer(bytes);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean processBytesFromPeer(ByteBuffer bytes) throws GeneralSecurityException {
|
||||
return delegate.processBytesFromPeer(bytes);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isInProgress() {
|
||||
return delegate.isInProgress();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TsiPeer extractPeer() throws GeneralSecurityException {
|
||||
return delegate.extractPeer();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object extractPeerObject() throws GeneralSecurityException {
|
||||
return delegate.extractPeerObject();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TsiFrameProtector createFrameProtector(ByteBufAllocator alloc) {
|
||||
InterceptingProtector protector =
|
||||
new InterceptingProtector(delegate.createFrameProtector(alloc));
|
||||
protectors.add(protector);
|
||||
return protector;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TsiFrameProtector createFrameProtector(int maxFrameSize, ByteBufAllocator alloc) {
|
||||
InterceptingProtector protector =
|
||||
new InterceptingProtector(delegate.createFrameProtector(maxFrameSize, alloc));
|
||||
protectors.add(protector);
|
||||
return protector;
|
||||
}
|
||||
}
|
||||
|
||||
private static class InterceptingProtector implements TsiFrameProtector {
|
||||
private final TsiFrameProtector delegate;
|
||||
final AtomicInteger flushes = new AtomicInteger();
|
||||
|
||||
InterceptingProtector(TsiFrameProtector delegate) {
|
||||
this.delegate = delegate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void protectFlush(
|
||||
List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc)
|
||||
throws GeneralSecurityException {
|
||||
flushes.incrementAndGet();
|
||||
delegate.protectFlush(unprotectedBufs, ctxWrite, alloc);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)
|
||||
throws GeneralSecurityException {
|
||||
delegate.unprotect(in, out, alloc);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void destroy() {
|
||||
delegate.destroy();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
||||
@RunWith(JUnit4.class)
|
||||
public final class AltsServerBuilderTest {
|
||||
|
||||
@Test
|
||||
public void buildsNettyServer() throws Exception {
|
||||
AltsServerBuilder.forPort(1234).enableUntrustedAltsForTesting().build();
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,137 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
import com.google.common.truth.Truth;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufAllocator;
|
||||
import io.netty.buffer.CompositeByteBuf;
|
||||
import io.netty.buffer.UnpooledByteBufAllocator;
|
||||
import java.nio.ByteBuffer;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
||||
@RunWith(JUnit4.class)
|
||||
public class BufUnwrapperTest {
|
||||
|
||||
private final ByteBufAllocator alloc = UnpooledByteBufAllocator.DEFAULT;
|
||||
|
||||
@Test
|
||||
public void closeEmptiesBuffers() {
|
||||
BufUnwrapper unwrapper = new BufUnwrapper();
|
||||
ByteBuf buf = alloc.buffer();
|
||||
try {
|
||||
ByteBuffer[] readableBufs = unwrapper.readableNioBuffers(buf);
|
||||
ByteBuffer[] writableBufs = unwrapper.writableNioBuffers(buf);
|
||||
Truth.assertThat(readableBufs).hasLength(1);
|
||||
Truth.assertThat(readableBufs[0]).isNotNull();
|
||||
Truth.assertThat(writableBufs).hasLength(1);
|
||||
Truth.assertThat(writableBufs[0]).isNotNull();
|
||||
|
||||
unwrapper.close();
|
||||
|
||||
Truth.assertThat(readableBufs[0]).isNull();
|
||||
Truth.assertThat(writableBufs[0]).isNull();
|
||||
} finally {
|
||||
buf.release();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void readableNioBuffers_worksWithNormal() {
|
||||
ByteBuf buf = alloc.buffer(1).writeByte('a');
|
||||
try (BufUnwrapper unwrapper = new BufUnwrapper()) {
|
||||
ByteBuffer[] internalBufs = unwrapper.readableNioBuffers(buf);
|
||||
Truth.assertThat(internalBufs).hasLength(1);
|
||||
|
||||
assertEquals('a', internalBufs[0].get(0));
|
||||
} finally {
|
||||
buf.release();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void readableNioBuffers_worksWithComposite() {
|
||||
CompositeByteBuf buf = alloc.compositeBuffer();
|
||||
buf.addComponent(true, alloc.buffer(1).writeByte('a'));
|
||||
try (BufUnwrapper unwrapper = new BufUnwrapper()) {
|
||||
ByteBuffer[] internalBufs = unwrapper.readableNioBuffers(buf);
|
||||
Truth.assertThat(internalBufs).hasLength(1);
|
||||
|
||||
assertEquals('a', internalBufs[0].get(0));
|
||||
} finally {
|
||||
buf.release();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void writableNioBuffers_indexesPreserved() {
|
||||
ByteBuf buf = alloc.buffer(1);
|
||||
int ridx = buf.readerIndex();
|
||||
int widx = buf.writerIndex();
|
||||
int cap = buf.capacity();
|
||||
try (BufUnwrapper unwrapper = new BufUnwrapper()) {
|
||||
ByteBuffer[] internalBufs = unwrapper.writableNioBuffers(buf);
|
||||
Truth.assertThat(internalBufs).hasLength(1);
|
||||
|
||||
internalBufs[0].put((byte) 'a');
|
||||
|
||||
assertEquals(ridx, buf.readerIndex());
|
||||
assertEquals(widx, buf.writerIndex());
|
||||
assertEquals(cap, buf.capacity());
|
||||
} finally {
|
||||
buf.release();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void writableNioBuffers_worksWithNormal() {
|
||||
ByteBuf buf = alloc.buffer(1);
|
||||
try (BufUnwrapper unwrapper = new BufUnwrapper()) {
|
||||
ByteBuffer[] internalBufs = unwrapper.writableNioBuffers(buf);
|
||||
Truth.assertThat(internalBufs).hasLength(1);
|
||||
|
||||
internalBufs[0].put((byte) 'a');
|
||||
|
||||
buf.writerIndex(1);
|
||||
assertEquals('a', buf.readByte());
|
||||
} finally {
|
||||
buf.release();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void writableNioBuffers_worksWithComposite() {
|
||||
CompositeByteBuf buf = alloc.compositeBuffer();
|
||||
buf.addComponent(alloc.buffer(1));
|
||||
buf.capacity(1);
|
||||
try (BufUnwrapper unwrapper = new BufUnwrapper()) {
|
||||
ByteBuffer[] internalBufs = unwrapper.writableNioBuffers(buf);
|
||||
Truth.assertThat(internalBufs).hasLength(1);
|
||||
|
||||
internalBufs[0].put((byte) 'a');
|
||||
|
||||
buf.writerIndex(1);
|
||||
assertEquals('a', buf.readByte());
|
||||
} finally {
|
||||
buf.release();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts;
|
||||
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.StringReader;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
||||
@RunWith(JUnit4.class)
|
||||
public final class CheckGcpEnvironmentTest {
|
||||
|
||||
@Test
|
||||
public void checkGcpLinuxPlatformData() throws Exception {
|
||||
BufferedReader reader;
|
||||
reader = new BufferedReader(new StringReader("HP Z440 Workstation"));
|
||||
assertFalse(CheckGcpEnvironment.checkProductNameOnLinux(reader));
|
||||
reader = new BufferedReader(new StringReader("Google"));
|
||||
assertTrue(CheckGcpEnvironment.checkProductNameOnLinux(reader));
|
||||
reader = new BufferedReader(new StringReader("Google Compute Engine"));
|
||||
assertTrue(CheckGcpEnvironment.checkProductNameOnLinux(reader));
|
||||
reader = new BufferedReader(new StringReader("Google Compute Engine "));
|
||||
assertTrue(CheckGcpEnvironment.checkProductNameOnLinux(reader));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void checkGcpWindowsPlatformData() throws Exception {
|
||||
BufferedReader reader;
|
||||
reader = new BufferedReader(new StringReader("Product : Google"));
|
||||
assertFalse(CheckGcpEnvironment.checkBiosDataOnWindows(reader));
|
||||
reader = new BufferedReader(new StringReader("Manufacturer : LENOVO"));
|
||||
assertFalse(CheckGcpEnvironment.checkBiosDataOnWindows(reader));
|
||||
reader = new BufferedReader(new StringReader("Manufacturer : Google Compute Engine"));
|
||||
assertFalse(CheckGcpEnvironment.checkBiosDataOnWindows(reader));
|
||||
reader = new BufferedReader(new StringReader("Manufacturer : Google"));
|
||||
assertTrue(CheckGcpEnvironment.checkBiosDataOnWindows(reader));
|
||||
reader = new BufferedReader(new StringReader("Manufacturer:Google"));
|
||||
assertTrue(CheckGcpEnvironment.checkBiosDataOnWindows(reader));
|
||||
reader = new BufferedReader(new StringReader("Manufacturer : Google "));
|
||||
assertTrue(CheckGcpEnvironment.checkBiosDataOnWindows(reader));
|
||||
reader = new BufferedReader(new StringReader("BIOSVersion : 1.0\nManufacturer : Google\n"));
|
||||
assertTrue(CheckGcpEnvironment.checkBiosDataOnWindows(reader));
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,193 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
import io.grpc.alts.transportsecurity.FakeTsiHandshaker;
|
||||
import io.grpc.alts.transportsecurity.TsiHandshaker;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufAllocator;
|
||||
import io.netty.buffer.UnpooledByteBufAllocator;
|
||||
import io.netty.util.ReferenceCounted;
|
||||
import java.lang.reflect.Method;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.function.Function;
|
||||
import org.junit.After;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
||||
@RunWith(JUnit4.class)
|
||||
public class InternalNettyTsiHandshakerTest {
|
||||
private final UnpooledByteBufAllocator alloc = UnpooledByteBufAllocator.DEFAULT;
|
||||
private final List<ReferenceCounted> references = new ArrayList<>();
|
||||
|
||||
private final InternalNettyTsiHandshaker clientHandshaker =
|
||||
new InternalNettyTsiHandshaker(FakeTsiHandshaker.newFakeHandshakerClient());
|
||||
private final InternalNettyTsiHandshaker serverHandshaker =
|
||||
new InternalNettyTsiHandshaker(FakeTsiHandshaker.newFakeHandshakerServer());
|
||||
|
||||
@After
|
||||
public void teardown() {
|
||||
for (ReferenceCounted reference : references) {
|
||||
reference.release(reference.refCnt());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void failsOnNullHandshaker() {
|
||||
try {
|
||||
new InternalNettyTsiHandshaker(null);
|
||||
fail("Exception expected");
|
||||
} catch (NullPointerException ex) {
|
||||
// Do nothing.
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void processPeerHandshakeShouldAcceptPartialFrames() throws GeneralSecurityException {
|
||||
for (int i = 0; i < 1024; i++) {
|
||||
ByteBuf clientData = ref(alloc.buffer(1));
|
||||
clientHandshaker.getBytesToSendToPeer(clientData);
|
||||
if (clientData.isReadable()) {
|
||||
if (serverHandshaker.processBytesFromPeer(clientData)) {
|
||||
// Done.
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
fail("Failed to process the handshake frame.");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void handshakeShouldSucceed() throws GeneralSecurityException {
|
||||
doHandshake();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void isInProgress() throws GeneralSecurityException {
|
||||
assertTrue(clientHandshaker.isInProgress());
|
||||
assertTrue(serverHandshaker.isInProgress());
|
||||
|
||||
doHandshake();
|
||||
|
||||
assertFalse(clientHandshaker.isInProgress());
|
||||
assertFalse(serverHandshaker.isInProgress());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void extractPeer_notNull() throws GeneralSecurityException {
|
||||
doHandshake();
|
||||
|
||||
assertNotNull(serverHandshaker.extractPeer());
|
||||
assertNotNull(clientHandshaker.extractPeer());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void extractPeer_failsBeforeHandshake() throws GeneralSecurityException {
|
||||
try {
|
||||
clientHandshaker.extractPeer();
|
||||
fail("Exception expected");
|
||||
} catch (IllegalStateException ex) {
|
||||
// Do nothing.
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void extractPeerObject_notNull() throws GeneralSecurityException {
|
||||
doHandshake();
|
||||
|
||||
assertNotNull(serverHandshaker.extractPeerObject());
|
||||
assertNotNull(clientHandshaker.extractPeerObject());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void extractPeerObject_failsBeforeHandshake() throws GeneralSecurityException {
|
||||
try {
|
||||
clientHandshaker.extractPeerObject();
|
||||
fail("Exception expected");
|
||||
} catch (IllegalStateException ex) {
|
||||
// Do nothing.
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* InternalNettyTsiHandshaker just converts {@link ByteBuffer} to {@link ByteBuf}, so check that
|
||||
* the other methods are otherwise the same.
|
||||
*/
|
||||
@Test
|
||||
public void handshakerMethodsMatch() {
|
||||
List<String> expectedMethods = new ArrayList<>();
|
||||
for (Method m : TsiHandshaker.class.getDeclaredMethods()) {
|
||||
expectedMethods.add(m.getName());
|
||||
}
|
||||
|
||||
List<String> actualMethods = new ArrayList<>();
|
||||
for (Method m : InternalNettyTsiHandshaker.class.getDeclaredMethods()) {
|
||||
actualMethods.add(m.getName());
|
||||
}
|
||||
|
||||
assertThat(actualMethods).containsAllIn(expectedMethods);
|
||||
}
|
||||
|
||||
static void doHandshake(
|
||||
InternalNettyTsiHandshaker clientHandshaker,
|
||||
InternalNettyTsiHandshaker serverHandshaker,
|
||||
ByteBufAllocator alloc,
|
||||
Function<ByteBuf, ByteBuf> ref)
|
||||
throws GeneralSecurityException {
|
||||
// Get the server response handshake frames.
|
||||
for (int i = 0; i < 10; i++) {
|
||||
if (!(clientHandshaker.isInProgress() || serverHandshaker.isInProgress())) {
|
||||
return;
|
||||
}
|
||||
|
||||
ByteBuf clientData = ref.apply(alloc.buffer());
|
||||
clientHandshaker.getBytesToSendToPeer(clientData);
|
||||
if (clientData.isReadable()) {
|
||||
serverHandshaker.processBytesFromPeer(clientData);
|
||||
}
|
||||
|
||||
ByteBuf serverData = ref.apply(alloc.buffer());
|
||||
serverHandshaker.getBytesToSendToPeer(serverData);
|
||||
if (serverData.isReadable()) {
|
||||
clientHandshaker.processBytesFromPeer(serverData);
|
||||
}
|
||||
}
|
||||
|
||||
throw new AssertionError("Failed to complete the handshake.");
|
||||
}
|
||||
|
||||
private void doHandshake() throws GeneralSecurityException {
|
||||
doHandshake(clientHandshaker, serverHandshaker, alloc, buf -> ref(buf));
|
||||
}
|
||||
|
||||
private ByteBuf ref(ByteBuf buf) {
|
||||
if (buf != null) {
|
||||
references.add(buf);
|
||||
}
|
||||
return buf;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,248 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import io.grpc.alts.RpcProtocolVersionsUtil.RpcVersionsCheckResult;
|
||||
import io.grpc.alts.TransportSecurityCommon.RpcProtocolVersions;
|
||||
import io.grpc.alts.TransportSecurityCommon.RpcProtocolVersions.Version;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
||||
/** Unit tests for {@link RpcProtocolVersionsUtil}. */
|
||||
@RunWith(JUnit4.class)
|
||||
public final class RpcProtocolVersionsUtilTest {
|
||||
|
||||
@Test
|
||||
public void compareVersions() throws Exception {
|
||||
assertTrue(
|
||||
RpcProtocolVersionsUtil.isGreaterThanOrEqualTo(
|
||||
Version.newBuilder().setMajor(3).setMinor(2).build(),
|
||||
Version.newBuilder().setMajor(2).setMinor(1).build()));
|
||||
assertTrue(
|
||||
RpcProtocolVersionsUtil.isGreaterThanOrEqualTo(
|
||||
Version.newBuilder().setMajor(3).setMinor(2).build(),
|
||||
Version.newBuilder().setMajor(2).setMinor(1).build()));
|
||||
assertTrue(
|
||||
RpcProtocolVersionsUtil.isGreaterThanOrEqualTo(
|
||||
Version.newBuilder().setMajor(3).setMinor(2).build(),
|
||||
Version.newBuilder().setMajor(3).setMinor(2).build()));
|
||||
assertFalse(
|
||||
RpcProtocolVersionsUtil.isGreaterThanOrEqualTo(
|
||||
Version.newBuilder().setMajor(2).setMinor(3).build(),
|
||||
Version.newBuilder().setMajor(3).setMinor(2).build()));
|
||||
assertFalse(
|
||||
RpcProtocolVersionsUtil.isGreaterThanOrEqualTo(
|
||||
Version.newBuilder().setMajor(3).setMinor(1).build(),
|
||||
Version.newBuilder().setMajor(3).setMinor(2).build()));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void checkRpcVersions() throws Exception {
|
||||
// local.max > peer.max and local.min > peer.min
|
||||
RpcVersionsCheckResult checkResult =
|
||||
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(Version.newBuilder().setMajor(3).setMinor(2).build())
|
||||
.setMinRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.build(),
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.setMinRpcVersion(Version.newBuilder().setMajor(1).setMinor(2).build())
|
||||
.build());
|
||||
assertTrue(checkResult.getResult());
|
||||
assertEquals(
|
||||
Version.newBuilder().setMajor(2).setMinor(1).build(),
|
||||
checkResult.getHighestCommonVersion());
|
||||
// local.max > peer.max and local.min < peer.min
|
||||
checkResult =
|
||||
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(Version.newBuilder().setMajor(3).setMinor(2).build())
|
||||
.setMinRpcVersion(Version.newBuilder().setMajor(1).setMinor(2).build())
|
||||
.build(),
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.setMinRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.build());
|
||||
assertTrue(checkResult.getResult());
|
||||
assertEquals(
|
||||
Version.newBuilder().setMajor(2).setMinor(1).build(),
|
||||
checkResult.getHighestCommonVersion());
|
||||
// local.max > peer.max and local.min = peer.min
|
||||
checkResult =
|
||||
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(Version.newBuilder().setMajor(3).setMinor(2).build())
|
||||
.setMinRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.build(),
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.setMinRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.build());
|
||||
assertTrue(checkResult.getResult());
|
||||
assertEquals(
|
||||
Version.newBuilder().setMajor(2).setMinor(1).build(),
|
||||
checkResult.getHighestCommonVersion());
|
||||
// local.max < peer.max and local.min > peer.min
|
||||
checkResult =
|
||||
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.setMinRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.build(),
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(Version.newBuilder().setMajor(3).setMinor(2).build())
|
||||
.setMinRpcVersion(Version.newBuilder().setMajor(1).setMinor(2).build())
|
||||
.build());
|
||||
assertTrue(checkResult.getResult());
|
||||
assertEquals(
|
||||
Version.newBuilder().setMajor(2).setMinor(1).build(),
|
||||
checkResult.getHighestCommonVersion());
|
||||
// local.max = peer.max and local.min > peer.min
|
||||
checkResult =
|
||||
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.setMinRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.build(),
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.setMinRpcVersion(Version.newBuilder().setMajor(1).setMinor(2).build())
|
||||
.build());
|
||||
assertTrue(checkResult.getResult());
|
||||
assertEquals(
|
||||
Version.newBuilder().setMajor(2).setMinor(1).build(),
|
||||
checkResult.getHighestCommonVersion());
|
||||
// local.max < peer.max and local.min < peer.min
|
||||
checkResult =
|
||||
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.setMinRpcVersion(Version.newBuilder().setMajor(1).setMinor(2).build())
|
||||
.build(),
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(Version.newBuilder().setMajor(3).setMinor(2).build())
|
||||
.setMinRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.build());
|
||||
assertTrue(checkResult.getResult());
|
||||
assertEquals(
|
||||
Version.newBuilder().setMajor(2).setMinor(1).build(),
|
||||
checkResult.getHighestCommonVersion());
|
||||
// local.max < peer.max and local.min = peer.min
|
||||
checkResult =
|
||||
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.setMinRpcVersion(Version.newBuilder().setMajor(1).setMinor(2).build())
|
||||
.build(),
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(Version.newBuilder().setMajor(3).setMinor(2).build())
|
||||
.setMinRpcVersion(Version.newBuilder().setMajor(1).setMinor(2).build())
|
||||
.build());
|
||||
assertTrue(checkResult.getResult());
|
||||
assertEquals(
|
||||
Version.newBuilder().setMajor(2).setMinor(1).build(),
|
||||
checkResult.getHighestCommonVersion());
|
||||
// local.max = peer.max and local.min < peer.min
|
||||
checkResult =
|
||||
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.setMinRpcVersion(Version.newBuilder().setMajor(1).setMinor(2).build())
|
||||
.build(),
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.setMinRpcVersion(Version.newBuilder().setMajor(1).setMinor(2).build())
|
||||
.build());
|
||||
assertTrue(checkResult.getResult());
|
||||
assertEquals(
|
||||
Version.newBuilder().setMajor(2).setMinor(1).build(),
|
||||
checkResult.getHighestCommonVersion());
|
||||
// all equal
|
||||
checkResult =
|
||||
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.setMinRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.build(),
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.setMinRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.build());
|
||||
assertTrue(checkResult.getResult());
|
||||
assertEquals(
|
||||
Version.newBuilder().setMajor(2).setMinor(1).build(),
|
||||
checkResult.getHighestCommonVersion());
|
||||
// max is smaller than min
|
||||
checkResult =
|
||||
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(Version.newBuilder().setMajor(1).setMinor(2).build())
|
||||
.setMinRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.build(),
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(Version.newBuilder().setMajor(1).setMinor(2).build())
|
||||
.setMinRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.build());
|
||||
assertFalse(checkResult.getResult());
|
||||
assertEquals(null, checkResult.getHighestCommonVersion());
|
||||
// no overlap, local > peer
|
||||
checkResult =
|
||||
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(Version.newBuilder().setMajor(6).setMinor(5).build())
|
||||
.setMinRpcVersion(Version.newBuilder().setMajor(4).setMinor(3).build())
|
||||
.build(),
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.setMinRpcVersion(Version.newBuilder().setMajor(0).setMinor(0).build())
|
||||
.build());
|
||||
assertFalse(checkResult.getResult());
|
||||
assertEquals(null, checkResult.getHighestCommonVersion());
|
||||
// no overlap, local < peer
|
||||
checkResult =
|
||||
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.setMinRpcVersion(Version.newBuilder().setMajor(1).setMinor(0).build())
|
||||
.build(),
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(Version.newBuilder().setMajor(6).setMinor(5).build())
|
||||
.setMinRpcVersion(Version.newBuilder().setMajor(4).setMinor(3).build())
|
||||
.build());
|
||||
assertFalse(checkResult.getResult());
|
||||
assertEquals(null, checkResult.getHighestCommonVersion());
|
||||
// no overlap, max < min
|
||||
checkResult =
|
||||
RpcProtocolVersionsUtil.checkRpcProtocolVersions(
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(Version.newBuilder().setMajor(4).setMinor(3).build())
|
||||
.setMinRpcVersion(Version.newBuilder().setMajor(6).setMinor(5).build())
|
||||
.build(),
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(Version.newBuilder().setMajor(1).setMinor(0).build())
|
||||
.setMinRpcVersion(Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.build());
|
||||
assertFalse(checkResult.getResult());
|
||||
assertEquals(null, checkResult.getHighestCommonVersion());
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,494 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import static com.google.common.truth.Truth.assertWithMessage;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.util.Arrays;
|
||||
import javax.xml.bind.DatatypeConverter;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
||||
/** Unit tests for {@link AesGcmHkdfAeadCrypter}. */
|
||||
@RunWith(JUnit4.class)
|
||||
public final class AesGcmHkdfAeadCrypterTest {
|
||||
|
||||
private static class TestVector {
|
||||
final String comment;
|
||||
final byte[] key;
|
||||
final byte[] nonce;
|
||||
final byte[] aad;
|
||||
final byte[] plaintext;
|
||||
final byte[] ciphertext;
|
||||
|
||||
TestVector(TestVectorBuilder builder) {
|
||||
comment = builder.comment;
|
||||
key = builder.key;
|
||||
nonce = builder.nonce;
|
||||
aad = builder.aad;
|
||||
plaintext = builder.plaintext;
|
||||
ciphertext = builder.ciphertext;
|
||||
}
|
||||
|
||||
static TestVectorBuilder builder() {
|
||||
return new TestVectorBuilder();
|
||||
}
|
||||
}
|
||||
|
||||
private static class TestVectorBuilder {
|
||||
String comment;
|
||||
byte[] key;
|
||||
byte[] nonce;
|
||||
byte[] aad;
|
||||
byte[] plaintext;
|
||||
byte[] ciphertext;
|
||||
|
||||
TestVector build() {
|
||||
if (comment == null
|
||||
&& key == null
|
||||
&& nonce == null
|
||||
&& aad == null
|
||||
&& plaintext == null
|
||||
&& ciphertext == null) {
|
||||
throw new IllegalStateException("All fields must be set before calling build().");
|
||||
}
|
||||
return new TestVector(this);
|
||||
}
|
||||
|
||||
TestVectorBuilder withComment(String comment) {
|
||||
this.comment = comment;
|
||||
return this;
|
||||
}
|
||||
|
||||
TestVectorBuilder withKey(String key) {
|
||||
this.key = DatatypeConverter.parseHexBinary(key);
|
||||
return this;
|
||||
}
|
||||
|
||||
TestVectorBuilder withNonce(String nonce) {
|
||||
this.nonce = DatatypeConverter.parseHexBinary(nonce);
|
||||
return this;
|
||||
}
|
||||
|
||||
TestVectorBuilder withAad(String aad) {
|
||||
this.aad = DatatypeConverter.parseHexBinary(aad);
|
||||
return this;
|
||||
}
|
||||
|
||||
TestVectorBuilder withPlaintext(String plaintext) {
|
||||
this.plaintext = DatatypeConverter.parseHexBinary(plaintext);
|
||||
return this;
|
||||
}
|
||||
|
||||
TestVectorBuilder withCiphertext(String ciphertext) {
|
||||
this.ciphertext = DatatypeConverter.parseHexBinary(ciphertext);
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVectorEncrypt() throws GeneralSecurityException {
|
||||
int i = 0;
|
||||
for (TestVector testVector : testVectors) {
|
||||
int bufferSize = testVector.ciphertext.length;
|
||||
byte[] ciphertext = new byte[bufferSize];
|
||||
ByteBuffer ciphertextBuffer = ByteBuffer.wrap(ciphertext);
|
||||
|
||||
AesGcmHkdfAeadCrypter aeadCrypter = new AesGcmHkdfAeadCrypter(testVector.key);
|
||||
aeadCrypter.encrypt(
|
||||
ciphertextBuffer,
|
||||
ByteBuffer.wrap(testVector.plaintext),
|
||||
ByteBuffer.wrap(testVector.aad),
|
||||
testVector.nonce);
|
||||
String msg = "Failure for test vector " + i;
|
||||
assertWithMessage(msg)
|
||||
.that(ciphertextBuffer.remaining())
|
||||
.isEqualTo(bufferSize - testVector.ciphertext.length);
|
||||
byte[] exactCiphertext = Arrays.copyOf(ciphertext, testVector.ciphertext.length);
|
||||
assertWithMessage(msg).that(exactCiphertext).isEqualTo(testVector.ciphertext);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVectorDecrypt() throws GeneralSecurityException {
|
||||
int i = 0;
|
||||
for (TestVector testVector : testVectors) {
|
||||
// The plaintext buffer might require space for the tag to decrypt (e.g., for conscrypt).
|
||||
int bufferSize = testVector.ciphertext.length;
|
||||
byte[] plaintext = new byte[bufferSize];
|
||||
ByteBuffer plaintextBuffer = ByteBuffer.wrap(plaintext);
|
||||
|
||||
AesGcmHkdfAeadCrypter aeadCrypter = new AesGcmHkdfAeadCrypter(testVector.key);
|
||||
aeadCrypter.decrypt(
|
||||
plaintextBuffer,
|
||||
ByteBuffer.wrap(testVector.ciphertext),
|
||||
ByteBuffer.wrap(testVector.aad),
|
||||
testVector.nonce);
|
||||
String msg = "Failure for test vector " + i;
|
||||
assertWithMessage(msg)
|
||||
.that(plaintextBuffer.remaining())
|
||||
.isEqualTo(bufferSize - testVector.plaintext.length);
|
||||
byte[] exactPlaintext = Arrays.copyOf(plaintext, testVector.plaintext.length);
|
||||
assertWithMessage(msg).that(exactPlaintext).isEqualTo(testVector.plaintext);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* NIST vectors from:
|
||||
* http://csrc.nist.gov/groups/ST/toolkit/BCM/documents/proposedmodes/gcm/gcm-revised-spec.pdf
|
||||
*
|
||||
* IEEE vectors from:
|
||||
* http://www.ieee802.org/1/files/public/docs2011/bn-randall-test-vectors-0511-v1.pdf
|
||||
* Key expanded by setting
|
||||
* expandedKey = (key||(key ^ {0x01, .., 0x01})||key ^ {0x02,..,0x02}))[0:44].
|
||||
*/
|
||||
private static final TestVector[] testVectors =
|
||||
new TestVector[] {
|
||||
TestVector.builder()
|
||||
.withComment("Derived from NIST test vector 1")
|
||||
.withKey(
|
||||
"0000000000000000000000000000000001010101010101010101010101010101020202020202020202"
|
||||
+ "020202")
|
||||
.withNonce("000000000000000000000000")
|
||||
.withAad("")
|
||||
.withPlaintext("")
|
||||
.withCiphertext("85e873e002f6ebdc4060954eb8675508")
|
||||
.build(),
|
||||
TestVector.builder()
|
||||
.withComment("Derived from NIST test vector 2")
|
||||
.withKey(
|
||||
"0000000000000000000000000000000001010101010101010101010101010101020202020202020202"
|
||||
+ "020202")
|
||||
.withNonce("000000000000000000000000")
|
||||
.withAad("")
|
||||
.withPlaintext("00000000000000000000000000000000")
|
||||
.withCiphertext("51e9a8cb23ca2512c8256afff8e72d681aca19a1148ac115e83df4888cc00d11")
|
||||
.build(),
|
||||
TestVector.builder()
|
||||
.withComment("Derived from NIST test vector 3")
|
||||
.withKey(
|
||||
"feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f"
|
||||
+ "688d96")
|
||||
.withNonce("cafebabefacedbaddecaf888")
|
||||
.withAad("")
|
||||
.withPlaintext(
|
||||
"d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532f"
|
||||
+ "cf0e2449a6b525b16aedf5aa0de657ba637b391aafd255")
|
||||
.withCiphertext(
|
||||
"1018ed5a1402a86516d6576d70b2ffccca261b94df88b58f53b64dfba435d18b2f6e3b7869f9353d4a"
|
||||
+ "c8cf09afb1663daa7b4017e6fc2c177c0c087c0df1162129952213cee1bc6e9c8495dd705e1f"
|
||||
+ "3d")
|
||||
.build(),
|
||||
TestVector.builder()
|
||||
.withComment("Derived from NIST test vector 4")
|
||||
.withKey(
|
||||
"feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f"
|
||||
+ "688d96")
|
||||
.withNonce("cafebabefacedbaddecaf888")
|
||||
.withAad("feedfacedeadbeeffeedfacedeadbeefabaddad2")
|
||||
.withPlaintext(
|
||||
"d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532f"
|
||||
+ "cf0e2449a6b525b16aedf5aa0de657ba637b39")
|
||||
.withCiphertext(
|
||||
"1018ed5a1402a86516d6576d70b2ffccca261b94df88b58f53b64dfba435d18b2f6e3b7869f9353d4a"
|
||||
+ "c8cf09afb1663daa7b4017e6fc2c177c0c087c4764565d077e9124001ddb27fc0848c5")
|
||||
.build(),
|
||||
TestVector.builder()
|
||||
.withComment(
|
||||
"Derived from adapted NIST test vector 4"
|
||||
+ " for KDF counter boundary (flip nonce bit 15)")
|
||||
.withKey(
|
||||
"feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f"
|
||||
+ "688d96")
|
||||
.withNonce("ca7ebabefacedbaddecaf888")
|
||||
.withAad("feedfacedeadbeeffeedfacedeadbeefabaddad2")
|
||||
.withPlaintext(
|
||||
"d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532f"
|
||||
+ "cf0e2449a6b525b16aedf5aa0de657ba637b39")
|
||||
.withCiphertext(
|
||||
"e650d3c0fb879327f2d03287fa93cd07342b136215adbca00c3bd5099ec41832b1d18e0423ed26bb12"
|
||||
+ "c6cd09debb29230a94c0cee15903656f85edb6fc509b1b28216382172ecbcc31e1e9b1")
|
||||
.build(),
|
||||
TestVector.builder()
|
||||
.withComment(
|
||||
"Derived from adapted NIST test vector 4"
|
||||
+ " for KDF counter boundary (flip nonce bit 16)")
|
||||
.withKey(
|
||||
"feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f"
|
||||
+ "688d96")
|
||||
.withNonce("cafebbbefacedbaddecaf888")
|
||||
.withAad("feedfacedeadbeeffeedfacedeadbeefabaddad2")
|
||||
.withPlaintext(
|
||||
"d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532f"
|
||||
+ "cf0e2449a6b525b16aedf5aa0de657ba637b39")
|
||||
.withCiphertext(
|
||||
"c0121e6c954d0767f96630c33450999791b2da2ad05c4190169ccad9ac86ff1c721e3d82f2ad22ab46"
|
||||
+ "3bab4a0754b7dd68ca4de7ea2531b625eda01f89312b2ab957d5c7f8568dd95fcdcd1f")
|
||||
.build(),
|
||||
TestVector.builder()
|
||||
.withComment(
|
||||
"Derived from adapted NIST test vector 4"
|
||||
+ " for KDF counter boundary (flip nonce bit 63)")
|
||||
.withKey(
|
||||
"feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f"
|
||||
+ "688d96")
|
||||
.withNonce("cafebabefacedb2ddecaf888")
|
||||
.withAad("feedfacedeadbeeffeedfacedeadbeefabaddad2")
|
||||
.withPlaintext(
|
||||
"d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532f"
|
||||
+ "cf0e2449a6b525b16aedf5aa0de657ba637b39")
|
||||
.withCiphertext(
|
||||
"8af37ea5684a4d81d4fd817261fd9743099e7e6a025eaacf8e54b124fb5743149e05cb89f4a49467fe"
|
||||
+ "2e5e5965f29a19f99416b0016b54585d12553783ba59e9f782e82e097c336bf7989f08")
|
||||
.build(),
|
||||
TestVector.builder()
|
||||
.withComment(
|
||||
"Derived from adapted NIST test vector 4"
|
||||
+ " for KDF counter boundary (flip nonce bit 64)")
|
||||
.withKey(
|
||||
"feffe9928665731c6d6a8f9467308308fffee8938764721d6c6b8e9566318209fcfdeb908467711e6f"
|
||||
+ "688d96")
|
||||
.withNonce("cafebabefacedbaddfcaf888")
|
||||
.withAad("feedfacedeadbeeffeedfacedeadbeefabaddad2")
|
||||
.withPlaintext(
|
||||
"d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532f"
|
||||
+ "cf0e2449a6b525b16aedf5aa0de657ba637b39")
|
||||
.withCiphertext(
|
||||
"fbd528448d0346bfa878634864d407a35a039de9db2f1feb8e965b3ae9356ce6289441d77f8f0df294"
|
||||
+ "891f37ea438b223e3bf2bdc53d4c5a74fb680bb312a8dec6f7252cbcd7f5799750ad78")
|
||||
.build(),
|
||||
TestVector.builder()
|
||||
.withComment("Derived from IEEE 2.1.1 54-byte auth")
|
||||
.withKey(
|
||||
"ad7a2bd03eac835a6f620fdcb506b345ac7b2ad13fad825b6e630eddb407b244af7829d23cae81586d"
|
||||
+ "600dde")
|
||||
.withNonce("12153524c0895e81b2c28465")
|
||||
.withAad(
|
||||
"d609b1f056637a0d46df998d88e5222ab2c2846512153524c0895e8108000f10111213141516171819"
|
||||
+ "1a1b1c1d1e1f202122232425262728292a2b2c2d2e2f30313233340001")
|
||||
.withPlaintext("")
|
||||
.withCiphertext("3ea0b584f3c85e93f9320ea591699efb")
|
||||
.build(),
|
||||
TestVector.builder()
|
||||
.withComment("Derived from IEEE 2.1.2 54-byte auth")
|
||||
.withKey(
|
||||
"e3c08a8f06c6e3ad95a70557b23f75483ce33021a9c72b7025666204c69c0b72e1c2888d04c4e1af97"
|
||||
+ "a50755")
|
||||
.withNonce("12153524c0895e81b2c28465")
|
||||
.withAad(
|
||||
"d609b1f056637a0d46df998d88e5222ab2c2846512153524c0895e8108000f10111213141516171819"
|
||||
+ "1a1b1c1d1e1f202122232425262728292a2b2c2d2e2f30313233340001")
|
||||
.withPlaintext("")
|
||||
.withCiphertext("294e028bf1fe6f14c4e8f7305c933eb5")
|
||||
.build(),
|
||||
TestVector.builder()
|
||||
.withComment("Derived from IEEE 2.2.1 60-byte crypt")
|
||||
.withKey(
|
||||
"ad7a2bd03eac835a6f620fdcb506b345ac7b2ad13fad825b6e630eddb407b244af7829d23cae81586d"
|
||||
+ "600dde")
|
||||
.withNonce("12153524c0895e81b2c28465")
|
||||
.withAad("d609b1f056637a0d46df998d88e52e00b2c2846512153524c0895e81")
|
||||
.withPlaintext(
|
||||
"08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435"
|
||||
+ "363738393a0002")
|
||||
.withCiphertext(
|
||||
"db3d25719c6b0a3ca6145c159d5c6ed9aff9c6e0b79f17019ea923b8665ddf52137ad611f0d1bf417a"
|
||||
+ "7ca85e45afe106ff9c7569d335d086ae6c03f00987ccd6")
|
||||
.build(),
|
||||
TestVector.builder()
|
||||
.withComment("Derived from IEEE 2.2.2 60-byte crypt")
|
||||
.withKey(
|
||||
"e3c08a8f06c6e3ad95a70557b23f75483ce33021a9c72b7025666204c69c0b72e1c2888d04c4e1af97"
|
||||
+ "a50755")
|
||||
.withNonce("12153524c0895e81b2c28465")
|
||||
.withAad("d609b1f056637a0d46df998d88e52e00b2c2846512153524c0895e81")
|
||||
.withPlaintext(
|
||||
"08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435"
|
||||
+ "363738393a0002")
|
||||
.withCiphertext(
|
||||
"1641f28ec13afcc8f7903389787201051644914933e9202bb9d06aa020c2a67ef51dfe7bc00a856c55"
|
||||
+ "b8f8133e77f659132502bad63f5713d57d0c11e0f871ed")
|
||||
.build(),
|
||||
TestVector.builder()
|
||||
.withComment("Derived from IEEE 2.3.1 60-byte auth")
|
||||
.withKey(
|
||||
"071b113b0ca743fecccf3d051f737382061a103a0da642ffcdce3c041e727283051913390ea541fcce"
|
||||
+ "cd3f07")
|
||||
.withNonce("f0761e8dcd3d000176d457ed")
|
||||
.withAad(
|
||||
"e20106d7cd0df0761e8dcd3d88e5400076d457ed08000f101112131415161718191a1b1c1d1e1f2021"
|
||||
+ "22232425262728292a2b2c2d2e2f303132333435363738393a0003")
|
||||
.withPlaintext("")
|
||||
.withCiphertext("58837a10562b0f1f8edbe58ca55811d3")
|
||||
.build(),
|
||||
TestVector.builder()
|
||||
.withComment("Derived from IEEE 2.3.2 60-byte auth")
|
||||
.withKey(
|
||||
"691d3ee909d7f54167fd1ca0b5d769081f2bde1aee655fdbab80bd5295ae6be76b1f3ceb0bd5f74365"
|
||||
+ "ff1ea2")
|
||||
.withNonce("f0761e8dcd3d000176d457ed")
|
||||
.withAad(
|
||||
"e20106d7cd0df0761e8dcd3d88e5400076d457ed08000f101112131415161718191a1b1c1d1e1f2021"
|
||||
+ "22232425262728292a2b2c2d2e2f303132333435363738393a0003")
|
||||
.withPlaintext("")
|
||||
.withCiphertext("c2722ff6ca29a257718a529d1f0c6a3b")
|
||||
.build(),
|
||||
TestVector.builder()
|
||||
.withComment("Derived from IEEE 2.4.1 54-byte crypt")
|
||||
.withKey(
|
||||
"071b113b0ca743fecccf3d051f737382061a103a0da642ffcdce3c041e727283051913390ea541fcce"
|
||||
+ "cd3f07")
|
||||
.withNonce("f0761e8dcd3d000176d457ed")
|
||||
.withAad("e20106d7cd0df0761e8dcd3d88e54c2a76d457ed")
|
||||
.withPlaintext(
|
||||
"08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333400"
|
||||
+ "04")
|
||||
.withCiphertext(
|
||||
"fd96b715b93a13346af51e8acdf792cdc7b2686f8574c70e6b0cbf16291ded427ad73fec48cd298e05"
|
||||
+ "28a1f4c644a949fc31dc9279706ddba33f")
|
||||
.build(),
|
||||
TestVector.builder()
|
||||
.withComment("Derived from IEEE 2.4.2 54-byte crypt")
|
||||
.withKey(
|
||||
"691d3ee909d7f54167fd1ca0b5d769081f2bde1aee655fdbab80bd5295ae6be76b1f3ceb0bd5f74365"
|
||||
+ "ff1ea2")
|
||||
.withNonce("f0761e8dcd3d000176d457ed")
|
||||
.withAad("e20106d7cd0df0761e8dcd3d88e54c2a76d457ed")
|
||||
.withPlaintext(
|
||||
"08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333400"
|
||||
+ "04")
|
||||
.withCiphertext(
|
||||
"b68f6300c2e9ae833bdc070e24021a3477118e78ccf84e11a485d861476c300f175353d5cdf92008a4"
|
||||
+ "f878e6cc3577768085c50a0e98fda6cbb8")
|
||||
.build(),
|
||||
TestVector.builder()
|
||||
.withComment("Derived from IEEE 2.5.1 65-byte auth")
|
||||
.withKey(
|
||||
"013fe00b5f11be7f866d0cbbc55a7a90003ee10a5e10bf7e876c0dbac45b7b91033de2095d13bc7d84"
|
||||
+ "6f0eb9")
|
||||
.withNonce("7cfde9f9e33724c68932d612")
|
||||
.withAad(
|
||||
"84c5d513d2aaf6e5bbd2727788e523008932d6127cfde9f9e33724c608000f10111213141516171819"
|
||||
+ "1a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f"
|
||||
+ "0005")
|
||||
.withPlaintext("")
|
||||
.withCiphertext("cca20eecda6283f09bb3543dd99edb9b")
|
||||
.build(),
|
||||
TestVector.builder()
|
||||
.withComment("Derived from IEEE 2.5.2 65-byte auth")
|
||||
.withKey(
|
||||
"83c093b58de7ffe1c0da926ac43fb3609ac1c80fee1b624497ef942e2f79a82381c291b78fe5fde3c2"
|
||||
+ "d89068")
|
||||
.withNonce("7cfde9f9e33724c68932d612")
|
||||
.withAad(
|
||||
"84c5d513d2aaf6e5bbd2727788e523008932d6127cfde9f9e33724c608000f10111213141516171819"
|
||||
+ "1a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f"
|
||||
+ "0005")
|
||||
.withPlaintext("")
|
||||
.withCiphertext("b232cc1da5117bf15003734fa599d271")
|
||||
.build(),
|
||||
TestVector.builder()
|
||||
.withComment("Derived from IEEE 2.6.1 61-byte crypt")
|
||||
.withKey(
|
||||
"013fe00b5f11be7f866d0cbbc55a7a90003ee10a5e10bf7e876c0dbac45b7b91033de2095d13bc7d84"
|
||||
+ "6f0eb9")
|
||||
.withNonce("7cfde9f9e33724c68932d612")
|
||||
.withAad("84c5d513d2aaf6e5bbd2727788e52f008932d6127cfde9f9e33724c6")
|
||||
.withPlaintext(
|
||||
"08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435"
|
||||
+ "363738393a3b0006")
|
||||
.withCiphertext(
|
||||
"ff1910d35ad7e5657890c7c560146fd038707f204b66edbc3d161f8ace244b985921023c436e3a1c35"
|
||||
+ "32ecd5d09a056d70be583f0d10829d9387d07d33d872e490")
|
||||
.build(),
|
||||
TestVector.builder()
|
||||
.withComment("Derived from IEEE 2.6.2 61-byte crypt")
|
||||
.withKey(
|
||||
"83c093b58de7ffe1c0da926ac43fb3609ac1c80fee1b624497ef942e2f79a82381c291b78fe5fde3c2"
|
||||
+ "d89068")
|
||||
.withNonce("7cfde9f9e33724c68932d612")
|
||||
.withAad("84c5d513d2aaf6e5bbd2727788e52f008932d6127cfde9f9e33724c6")
|
||||
.withPlaintext(
|
||||
"08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435"
|
||||
+ "363738393a3b0006")
|
||||
.withCiphertext(
|
||||
"0db4cf956b5f97eca4eab82a6955307f9ae02a32dd7d93f83d66ad04e1cfdc5182ad12abdea5bbb619"
|
||||
+ "a1bd5fb9a573590fba908e9c7a46c1f7ba0905d1b55ffda4")
|
||||
.build(),
|
||||
TestVector.builder()
|
||||
.withComment("Derived from IEEE 2.7.1 79-byte crypt")
|
||||
.withKey(
|
||||
"88ee087fd95da9fbf6725aa9d757b0cd89ef097ed85ca8faf7735ba8d656b1cc8aec0a7ddb5fabf9f4"
|
||||
+ "7058ab")
|
||||
.withNonce("7ae8e2ca4ec500012e58495c")
|
||||
.withAad(
|
||||
"68f2e77696ce7ae8e2ca4ec588e541002e58495c08000f101112131415161718191a1b1c1d1e1f2021"
|
||||
+ "22232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f4041424344454647"
|
||||
+ "48494a4b4c4d0007")
|
||||
.withPlaintext("")
|
||||
.withCiphertext("813f0e630f96fb2d030f58d83f5cdfd0")
|
||||
.build(),
|
||||
TestVector.builder()
|
||||
.withComment("Derived from IEEE 2.7.2 79-byte crypt")
|
||||
.withKey(
|
||||
"4c973dbc7364621674f8b5b89e5c15511fced9216490fb1c1a2caa0ffe0407e54e953fbe7166601476"
|
||||
+ "fab7ba")
|
||||
.withNonce("7ae8e2ca4ec500012e58495c")
|
||||
.withAad(
|
||||
"68f2e77696ce7ae8e2ca4ec588e541002e58495c08000f101112131415161718191a1b1c1d1e1f2021"
|
||||
+ "22232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f4041424344454647"
|
||||
+ "48494a4b4c4d0007")
|
||||
.withPlaintext("")
|
||||
.withCiphertext("77e5a44c21eb07188aacbd74d1980e97")
|
||||
.build(),
|
||||
TestVector.builder()
|
||||
.withComment("Derived from IEEE 2.8.1 61-byte crypt")
|
||||
.withKey(
|
||||
"88ee087fd95da9fbf6725aa9d757b0cd89ef097ed85ca8faf7735ba8d656b1cc8aec0a7ddb5fabf9f4"
|
||||
+ "7058ab")
|
||||
.withNonce("7ae8e2ca4ec500012e58495c")
|
||||
.withAad("68f2e77696ce7ae8e2ca4ec588e54d002e58495c")
|
||||
.withPlaintext(
|
||||
"08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435"
|
||||
+ "363738393a3b3c3d3e3f404142434445464748490008")
|
||||
.withCiphertext(
|
||||
"958ec3f6d60afeda99efd888f175e5fcd4c87b9bcc5c2f5426253a8b506296c8c43309ab2adb593946"
|
||||
+ "2541d95e80811e04e706b1498f2c407c7fb234f8cc01a647550ee6b557b35a7e3945381821"
|
||||
+ "f4")
|
||||
.build(),
|
||||
TestVector.builder()
|
||||
.withComment("Derived from IEEE 2.8.2 61-byte crypt")
|
||||
.withKey(
|
||||
"4c973dbc7364621674f8b5b89e5c15511fced9216490fb1c1a2caa0ffe0407e54e953fbe7166601476"
|
||||
+ "fab7ba")
|
||||
.withNonce("7ae8e2ca4ec500012e58495c")
|
||||
.withAad("68f2e77696ce7ae8e2ca4ec588e54d002e58495c")
|
||||
.withPlaintext(
|
||||
"08000f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435"
|
||||
+ "363738393a3b3c3d3e3f404142434445464748490008")
|
||||
.withCiphertext(
|
||||
"b44d072011cd36d272a9b7a98db9aa90cbc5c67b93ddce67c854503214e2e896ec7e9db649ed4bcf6f"
|
||||
+ "850aac0223d0cf92c83db80795c3a17ecc1248bb00591712b1ae71e268164196252162810b"
|
||||
+ "00")
|
||||
.build()
|
||||
};
|
||||
}
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
import io.grpc.alts.Handshaker.HandshakerResult;
|
||||
import io.grpc.alts.Handshaker.Identity;
|
||||
import io.grpc.alts.TransportSecurityCommon.RpcProtocolVersions;
|
||||
import io.grpc.alts.TransportSecurityCommon.SecurityLevel;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
||||
/** Unit tests for {@link AltsAuthContext}. */
|
||||
@RunWith(JUnit4.class)
|
||||
public final class AltsAuthContextTest {
|
||||
private static final int TEST_MAX_RPC_VERSION_MAJOR = 3;
|
||||
private static final int TEST_MAX_RPC_VERSION_MINOR = 5;
|
||||
private static final int TEST_MIN_RPC_VERSION_MAJOR = 2;
|
||||
private static final int TEST_MIN_RPC_VERSION_MINOR = 1;
|
||||
private static final SecurityLevel TEST_SECURITY_LEVEL = SecurityLevel.INTEGRITY_AND_PRIVACY;
|
||||
private static final String TEST_APPLICATION_PROTOCOL = "grpc";
|
||||
private static final String TEST_LOCAL_SERVICE_ACCOUNT = "local@gserviceaccount.com";
|
||||
private static final String TEST_PEER_SERVICE_ACCOUNT = "peer@gserviceaccount.com";
|
||||
private static final String TEST_RECORD_PROTOCOL = "ALTSRP_GCM_AES128";
|
||||
|
||||
private HandshakerResult handshakerResult;
|
||||
private RpcProtocolVersions rpcVersions;
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
rpcVersions =
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(
|
||||
RpcProtocolVersions.Version.newBuilder()
|
||||
.setMajor(TEST_MAX_RPC_VERSION_MAJOR)
|
||||
.setMinor(TEST_MAX_RPC_VERSION_MINOR)
|
||||
.build())
|
||||
.setMinRpcVersion(
|
||||
RpcProtocolVersions.Version.newBuilder()
|
||||
.setMajor(TEST_MIN_RPC_VERSION_MAJOR)
|
||||
.setMinor(TEST_MIN_RPC_VERSION_MINOR)
|
||||
.build())
|
||||
.build();
|
||||
handshakerResult =
|
||||
HandshakerResult.newBuilder()
|
||||
.setApplicationProtocol(TEST_APPLICATION_PROTOCOL)
|
||||
.setRecordProtocol(TEST_RECORD_PROTOCOL)
|
||||
.setPeerIdentity(Identity.newBuilder().setServiceAccount(TEST_PEER_SERVICE_ACCOUNT))
|
||||
.setLocalIdentity(Identity.newBuilder().setServiceAccount(TEST_LOCAL_SERVICE_ACCOUNT))
|
||||
.setPeerRpcVersions(rpcVersions)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAltsAuthContext() {
|
||||
AltsAuthContext authContext = new AltsAuthContext(handshakerResult);
|
||||
assertEquals(TEST_APPLICATION_PROTOCOL, authContext.getApplicationProtocol());
|
||||
assertEquals(TEST_RECORD_PROTOCOL, authContext.getRecordProtocol());
|
||||
assertEquals(TEST_SECURITY_LEVEL, authContext.getSecurityLevel());
|
||||
assertEquals(TEST_PEER_SERVICE_ACCOUNT, authContext.getPeerServiceAccount());
|
||||
assertEquals(TEST_LOCAL_SERVICE_ACCOUNT, authContext.getLocalServiceAccount());
|
||||
assertEquals(rpcVersions, authContext.getPeerRpcVersions());
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,150 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static io.grpc.alts.transportsecurity.AltsChannelCrypter.incrementCounter;
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
import com.google.common.testing.GcFinalization;
|
||||
import io.netty.util.ReferenceCounted;
|
||||
import io.netty.util.ResourceLeakDetector;
|
||||
import io.netty.util.ResourceLeakDetector.Level;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.util.Arrays;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
||||
/** Unit tests for {@link AltsChannelCrypter}. */
|
||||
@RunWith(JUnit4.class)
|
||||
public final class AltsChannelCrypterTest extends ChannelCrypterNettyTestBase {
|
||||
|
||||
@Before
|
||||
public void setUp() throws GeneralSecurityException {
|
||||
ResourceLeakDetector.setLevel(Level.PARANOID);
|
||||
client = new AltsChannelCrypter(new byte[AltsChannelCrypter.getKeyLength()], true);
|
||||
server = new AltsChannelCrypter(new byte[AltsChannelCrypter.getKeyLength()], false);
|
||||
}
|
||||
|
||||
@After
|
||||
public void tearDown() throws GeneralSecurityException {
|
||||
for (ReferenceCounted reference : references) {
|
||||
reference.release();
|
||||
}
|
||||
references.clear();
|
||||
client.destroy();
|
||||
server.destroy();
|
||||
// Increase our chances to detect ByteBuf leaks.
|
||||
GcFinalization.awaitFullGc();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void encryptDecryptKdfCounterIncr() throws GeneralSecurityException {
|
||||
AltsChannelCrypter client =
|
||||
new AltsChannelCrypter(new byte[AltsChannelCrypter.getKeyLength()], true);
|
||||
AltsChannelCrypter server =
|
||||
new AltsChannelCrypter(new byte[AltsChannelCrypter.getKeyLength()], false);
|
||||
|
||||
String message = "Hello world";
|
||||
FrameEncrypt frameEncrypt1 = createFrameEncrypt(message);
|
||||
|
||||
client.encrypt(frameEncrypt1.out, frameEncrypt1.plain);
|
||||
FrameDecrypt frameDecrypt1 = frameDecryptOfEncrypt(frameEncrypt1);
|
||||
|
||||
server.decrypt(frameDecrypt1.out, frameDecrypt1.tag, frameDecrypt1.ciphertext);
|
||||
assertThat(frameEncrypt1.plain.get(0).slice(0, frameDecrypt1.out.readableBytes()))
|
||||
.isEqualTo(frameDecrypt1.out);
|
||||
|
||||
// Increase counters to get a new KDF counter value (first two bytes are skipped).
|
||||
client.incrementOutCounterForTesting(1 << 17);
|
||||
server.incrementInCounterForTesting(1 << 17);
|
||||
|
||||
FrameEncrypt frameEncrypt2 = createFrameEncrypt(message);
|
||||
|
||||
client.encrypt(frameEncrypt2.out, frameEncrypt2.plain);
|
||||
FrameDecrypt frameDecrypt2 = frameDecryptOfEncrypt(frameEncrypt2);
|
||||
|
||||
server.decrypt(frameDecrypt2.out, frameDecrypt2.tag, frameDecrypt2.ciphertext);
|
||||
assertThat(frameEncrypt2.plain.get(0).slice(0, frameDecrypt2.out.readableBytes()))
|
||||
.isEqualTo(frameDecrypt2.out);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void overflowsClient() throws GeneralSecurityException {
|
||||
byte[] maxFirst =
|
||||
new byte[] {
|
||||
(byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF,
|
||||
(byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF,
|
||||
(byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00
|
||||
};
|
||||
|
||||
byte[] maxFirstPred = Arrays.copyOf(maxFirst, maxFirst.length);
|
||||
maxFirstPred[0]--;
|
||||
|
||||
byte[] oldCounter = new byte[AltsChannelCrypter.getCounterLength()];
|
||||
byte[] counter = Arrays.copyOf(maxFirstPred, maxFirstPred.length);
|
||||
|
||||
incrementCounter(counter, oldCounter);
|
||||
|
||||
assertThat(oldCounter).isEqualTo(maxFirstPred);
|
||||
assertThat(counter).isEqualTo(maxFirst);
|
||||
|
||||
try {
|
||||
incrementCounter(counter, oldCounter);
|
||||
fail("Exception expected");
|
||||
} catch (GeneralSecurityException ex) {
|
||||
assertThat(ex).hasMessageThat().contains("Counter has overflowed");
|
||||
}
|
||||
|
||||
assertThat(oldCounter).isEqualTo(maxFirst);
|
||||
assertThat(counter).isEqualTo(maxFirst);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void overflowsServer() throws GeneralSecurityException {
|
||||
byte[] maxSecond =
|
||||
new byte[] {
|
||||
(byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF,
|
||||
(byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF,
|
||||
(byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x80
|
||||
};
|
||||
|
||||
byte[] maxSecondPred = Arrays.copyOf(maxSecond, maxSecond.length);
|
||||
maxSecondPred[0]--;
|
||||
|
||||
byte[] oldCounter = new byte[AltsChannelCrypter.getCounterLength()];
|
||||
byte[] counter = Arrays.copyOf(maxSecondPred, maxSecondPred.length);
|
||||
|
||||
incrementCounter(counter, oldCounter);
|
||||
|
||||
assertThat(oldCounter).isEqualTo(maxSecondPred);
|
||||
assertThat(counter).isEqualTo(maxSecond);
|
||||
|
||||
try {
|
||||
incrementCounter(counter, oldCounter);
|
||||
fail("Exception expected");
|
||||
} catch (GeneralSecurityException ex) {
|
||||
assertThat(ex).hasMessageThat().contains("Counter has overflowed");
|
||||
}
|
||||
|
||||
assertThat(oldCounter).isEqualTo(maxSecond);
|
||||
assertThat(counter).isEqualTo(maxSecond);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
|
||||
import io.grpc.alts.TransportSecurityCommon.RpcProtocolVersions;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
||||
/** Unit tests for {@link AltsClientOptions}. */
|
||||
@RunWith(JUnit4.class)
|
||||
public final class AltsClientOptionsTest {
|
||||
|
||||
@Test
|
||||
public void setAndGet() throws Exception {
|
||||
String targetName = "foo";
|
||||
String serviceAccount1 = "bar1";
|
||||
String serviceAccount2 = "bar2";
|
||||
RpcProtocolVersions rpcVersions =
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(
|
||||
RpcProtocolVersions.Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.setMinRpcVersion(
|
||||
RpcProtocolVersions.Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.build();
|
||||
|
||||
AltsClientOptions options =
|
||||
new AltsClientOptions.Builder()
|
||||
.setTargetName(targetName)
|
||||
.addTargetServiceAccount(serviceAccount1)
|
||||
.addTargetServiceAccount(serviceAccount2)
|
||||
.setRpcProtocolVersions(rpcVersions)
|
||||
.build();
|
||||
|
||||
assertThat(options.getTargetName()).isEqualTo(targetName);
|
||||
assertThat(options.getTargetServiceAccounts()).containsAllOf(serviceAccount1, serviceAccount2);
|
||||
assertThat(options.getRpcProtocolVersions()).isEqualTo(rpcVersions);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,126 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.security.GeneralSecurityException;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
||||
/** Unit tests for {@link AltsFraming}. */
|
||||
@RunWith(JUnit4.class)
|
||||
public class AltsFramingTest {
|
||||
|
||||
@Test
|
||||
public void parserFrameLengthNegativeFails() throws GeneralSecurityException {
|
||||
AltsFraming.Parser parser = new AltsFraming.Parser();
|
||||
// frame length + one remaining byte (required)
|
||||
ByteBuffer buffer = ByteBuffer.allocate(AltsFraming.getFrameLengthHeaderSize() + 1);
|
||||
buffer.order(ByteOrder.LITTLE_ENDIAN);
|
||||
buffer.putInt(-1); // write invalid length
|
||||
buffer.put((byte) 0); // write some byte
|
||||
buffer.flip();
|
||||
|
||||
try {
|
||||
parser.readBytes(buffer);
|
||||
fail("Exception expected");
|
||||
} catch (IllegalArgumentException ex) {
|
||||
assertThat(ex).hasMessageThat().contains("Invalid frame length");
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void parserFrameLengthSmallerMessageTypeFails() throws GeneralSecurityException {
|
||||
AltsFraming.Parser parser = new AltsFraming.Parser();
|
||||
// frame length + one remaining byte (required)
|
||||
ByteBuffer buffer = ByteBuffer.allocate(AltsFraming.getFrameLengthHeaderSize() + 1);
|
||||
buffer.order(ByteOrder.LITTLE_ENDIAN);
|
||||
buffer.putInt(AltsFraming.getFrameMessageTypeHeaderSize() - 1); // write invalid length
|
||||
buffer.put((byte) 0); // write some byte
|
||||
buffer.flip();
|
||||
|
||||
try {
|
||||
parser.readBytes(buffer);
|
||||
fail("Exception expected");
|
||||
} catch (IllegalArgumentException ex) {
|
||||
assertThat(ex).hasMessageThat().contains("Invalid frame length");
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void parserFrameLengthTooLargeFails() throws GeneralSecurityException {
|
||||
AltsFraming.Parser parser = new AltsFraming.Parser();
|
||||
// frame length + one remaining byte (required)
|
||||
ByteBuffer buffer = ByteBuffer.allocate(AltsFraming.getFrameLengthHeaderSize() + 1);
|
||||
buffer.order(ByteOrder.LITTLE_ENDIAN);
|
||||
buffer.putInt(AltsFraming.getMaxDataLength() + 1); // write invalid length
|
||||
buffer.put((byte) 0); // write some byte
|
||||
buffer.flip();
|
||||
|
||||
try {
|
||||
parser.readBytes(buffer);
|
||||
fail("Exception expected");
|
||||
} catch (IllegalArgumentException ex) {
|
||||
assertThat(ex).hasMessageThat().contains("Invalid frame length");
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void parserFrameLengthMaxOk() throws GeneralSecurityException {
|
||||
AltsFraming.Parser parser = new AltsFraming.Parser();
|
||||
// length of type header + data
|
||||
int dataLength = AltsFraming.getMaxDataLength();
|
||||
// complete frame + 1 byte
|
||||
ByteBuffer buffer =
|
||||
ByteBuffer.allocate(AltsFraming.getFrameLengthHeaderSize() + dataLength + 1);
|
||||
buffer.order(ByteOrder.LITTLE_ENDIAN);
|
||||
buffer.putInt(dataLength); // write invalid length
|
||||
buffer.putInt(6); // default message type
|
||||
buffer.put(new byte[dataLength - AltsFraming.getFrameMessageTypeHeaderSize()]); // write data
|
||||
buffer.put((byte) 0);
|
||||
buffer.flip();
|
||||
|
||||
parser.readBytes(buffer);
|
||||
|
||||
assertThat(parser.isComplete()).isTrue();
|
||||
assertThat(buffer.remaining()).isEqualTo(1);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void parserFrameLengthZeroOk() throws GeneralSecurityException {
|
||||
AltsFraming.Parser parser = new AltsFraming.Parser();
|
||||
int dataLength = AltsFraming.getFrameMessageTypeHeaderSize();
|
||||
// complete frame + 1 byte
|
||||
ByteBuffer buffer =
|
||||
ByteBuffer.allocate(AltsFraming.getFrameLengthHeaderSize() + dataLength + 1);
|
||||
buffer.order(ByteOrder.LITTLE_ENDIAN);
|
||||
buffer.putInt(dataLength); // write invalid length
|
||||
buffer.putInt(6); // default message type
|
||||
buffer.put((byte) 0);
|
||||
buffer.flip();
|
||||
|
||||
parser.readBytes(buffer);
|
||||
|
||||
assertThat(parser.isComplete()).isTrue();
|
||||
assertThat(buffer.remaining()).isEqualTo(1);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,263 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertNull;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.fail;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.grpc.alts.Handshaker.HandshakeProtocol;
|
||||
import io.grpc.alts.Handshaker.HandshakerReq;
|
||||
import io.grpc.alts.Handshaker.Identity;
|
||||
import io.grpc.alts.Handshaker.StartClientHandshakeReq;
|
||||
import io.grpc.alts.TransportSecurityCommon.RpcProtocolVersions;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.security.GeneralSecurityException;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Matchers;
|
||||
|
||||
/** Unit tests for {@link AltsHandshakerClient}. */
|
||||
@RunWith(JUnit4.class)
|
||||
public class AltsHandshakerClientTest {
|
||||
private static final int IN_BYTES_SIZE = 100;
|
||||
private static final int BYTES_CONSUMED = 30;
|
||||
private static final int PREFIX_POSITION = 20;
|
||||
private static final String TEST_TARGET_NAME = "target name";
|
||||
private static final String TEST_TARGET_SERVICE_ACCOUNT = "peer service account";
|
||||
|
||||
private AltsHandshakerStub mockStub;
|
||||
private AltsHandshakerClient handshaker;
|
||||
private AltsClientOptions clientOptions;
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
mockStub = mock(AltsHandshakerStub.class);
|
||||
clientOptions =
|
||||
new AltsClientOptions.Builder()
|
||||
.setTargetName(TEST_TARGET_NAME)
|
||||
.addTargetServiceAccount(TEST_TARGET_SERVICE_ACCOUNT)
|
||||
.build();
|
||||
handshaker = new AltsHandshakerClient(mockStub, clientOptions);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void startClientHandshakeFailure() throws Exception {
|
||||
when(mockStub.send(Matchers.<HandshakerReq>any()))
|
||||
.thenReturn(MockAltsHandshakerResp.getErrorResponse());
|
||||
|
||||
try {
|
||||
handshaker.startClientHandshake();
|
||||
fail("Exception expected");
|
||||
} catch (GeneralSecurityException ex) {
|
||||
assertThat(ex).hasMessageThat().contains(MockAltsHandshakerResp.getTestErrorDetails());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void startClientHandshakeSuccess() throws Exception {
|
||||
when(mockStub.send(Matchers.<HandshakerReq>any()))
|
||||
.thenReturn(MockAltsHandshakerResp.getOkResponse(0));
|
||||
|
||||
ByteBuffer outFrame = handshaker.startClientHandshake();
|
||||
|
||||
assertEquals(ByteString.copyFrom(outFrame), MockAltsHandshakerResp.getOutFrame());
|
||||
assertFalse(handshaker.isFinished());
|
||||
assertNull(handshaker.getResult());
|
||||
assertNull(handshaker.getKey());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void startClientHandshakeWithOptions() throws Exception {
|
||||
when(mockStub.send(Matchers.<HandshakerReq>any()))
|
||||
.thenReturn(MockAltsHandshakerResp.getOkResponse(0));
|
||||
|
||||
ByteBuffer outFrame = handshaker.startClientHandshake();
|
||||
assertEquals(ByteString.copyFrom(outFrame), MockAltsHandshakerResp.getOutFrame());
|
||||
|
||||
HandshakerReq req =
|
||||
HandshakerReq.newBuilder()
|
||||
.setClientStart(
|
||||
StartClientHandshakeReq.newBuilder()
|
||||
.setHandshakeSecurityProtocol(HandshakeProtocol.ALTS)
|
||||
.addApplicationProtocols(AltsHandshakerClient.getApplicationProtocol())
|
||||
.addRecordProtocols(AltsHandshakerClient.getRecordProtocol())
|
||||
.setTargetName(TEST_TARGET_NAME)
|
||||
.addTargetIdentities(
|
||||
Identity.newBuilder().setServiceAccount(TEST_TARGET_SERVICE_ACCOUNT))
|
||||
.build())
|
||||
.build();
|
||||
verify(mockStub).send(req);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void startServerHandshakeFailure() throws Exception {
|
||||
when(mockStub.send(Matchers.<HandshakerReq>any()))
|
||||
.thenReturn(MockAltsHandshakerResp.getErrorResponse());
|
||||
|
||||
try {
|
||||
ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE);
|
||||
handshaker.startServerHandshake(inBytes);
|
||||
fail("Exception expected");
|
||||
} catch (GeneralSecurityException ex) {
|
||||
assertThat(ex).hasMessageThat().contains(MockAltsHandshakerResp.getTestErrorDetails());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void startServerHandshakeSuccess() throws Exception {
|
||||
when(mockStub.send(Matchers.<HandshakerReq>any()))
|
||||
.thenReturn(MockAltsHandshakerResp.getOkResponse(BYTES_CONSUMED));
|
||||
|
||||
ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE);
|
||||
ByteBuffer outFrame = handshaker.startServerHandshake(inBytes);
|
||||
|
||||
assertEquals(ByteString.copyFrom(outFrame), MockAltsHandshakerResp.getOutFrame());
|
||||
assertFalse(handshaker.isFinished());
|
||||
assertNull(handshaker.getResult());
|
||||
assertNull(handshaker.getKey());
|
||||
assertEquals(IN_BYTES_SIZE - BYTES_CONSUMED, inBytes.remaining());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void startServerHandshakeEmptyOutFrame() throws Exception {
|
||||
when(mockStub.send(Matchers.<HandshakerReq>any()))
|
||||
.thenReturn(MockAltsHandshakerResp.getEmptyOutFrameResponse(BYTES_CONSUMED));
|
||||
|
||||
ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE);
|
||||
ByteBuffer outFrame = handshaker.startServerHandshake(inBytes);
|
||||
|
||||
assertEquals(0, outFrame.remaining());
|
||||
assertFalse(handshaker.isFinished());
|
||||
assertNull(handshaker.getResult());
|
||||
assertNull(handshaker.getKey());
|
||||
assertEquals(IN_BYTES_SIZE - BYTES_CONSUMED, inBytes.remaining());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void startServerHandshakeWithPrefixBuffer() throws Exception {
|
||||
when(mockStub.send(Matchers.<HandshakerReq>any()))
|
||||
.thenReturn(MockAltsHandshakerResp.getOkResponse(BYTES_CONSUMED));
|
||||
|
||||
ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE);
|
||||
inBytes.position(PREFIX_POSITION);
|
||||
ByteBuffer outFrame = handshaker.startServerHandshake(inBytes);
|
||||
|
||||
assertEquals(ByteString.copyFrom(outFrame), MockAltsHandshakerResp.getOutFrame());
|
||||
assertFalse(handshaker.isFinished());
|
||||
assertNull(handshaker.getResult());
|
||||
assertNull(handshaker.getKey());
|
||||
assertEquals(PREFIX_POSITION + BYTES_CONSUMED, inBytes.position());
|
||||
assertEquals(IN_BYTES_SIZE - BYTES_CONSUMED - PREFIX_POSITION, inBytes.remaining());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void nextFailure() throws Exception {
|
||||
when(mockStub.send(Matchers.<HandshakerReq>any()))
|
||||
.thenReturn(MockAltsHandshakerResp.getErrorResponse());
|
||||
|
||||
try {
|
||||
ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE);
|
||||
handshaker.next(inBytes);
|
||||
fail("Exception expected");
|
||||
} catch (GeneralSecurityException ex) {
|
||||
assertThat(ex).hasMessageThat().contains(MockAltsHandshakerResp.getTestErrorDetails());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void nextSuccess() throws Exception {
|
||||
when(mockStub.send(Matchers.<HandshakerReq>any()))
|
||||
.thenReturn(MockAltsHandshakerResp.getOkResponse(BYTES_CONSUMED));
|
||||
|
||||
ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE);
|
||||
ByteBuffer outFrame = handshaker.next(inBytes);
|
||||
|
||||
assertEquals(ByteString.copyFrom(outFrame), MockAltsHandshakerResp.getOutFrame());
|
||||
assertFalse(handshaker.isFinished());
|
||||
assertNull(handshaker.getResult());
|
||||
assertNull(handshaker.getKey());
|
||||
assertEquals(IN_BYTES_SIZE - BYTES_CONSUMED, inBytes.remaining());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void nextEmptyOutFrame() throws Exception {
|
||||
when(mockStub.send(Matchers.<HandshakerReq>any()))
|
||||
.thenReturn(MockAltsHandshakerResp.getEmptyOutFrameResponse(BYTES_CONSUMED));
|
||||
|
||||
ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE);
|
||||
ByteBuffer outFrame = handshaker.next(inBytes);
|
||||
|
||||
assertEquals(0, outFrame.remaining());
|
||||
assertFalse(handshaker.isFinished());
|
||||
assertNull(handshaker.getResult());
|
||||
assertNull(handshaker.getKey());
|
||||
assertEquals(IN_BYTES_SIZE - BYTES_CONSUMED, inBytes.remaining());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void nextFinished() throws Exception {
|
||||
when(mockStub.send(Matchers.<HandshakerReq>any()))
|
||||
.thenReturn(MockAltsHandshakerResp.getFinishedResponse(BYTES_CONSUMED));
|
||||
|
||||
ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE);
|
||||
ByteBuffer outFrame = handshaker.next(inBytes);
|
||||
|
||||
assertEquals(ByteString.copyFrom(outFrame), MockAltsHandshakerResp.getOutFrame());
|
||||
assertTrue(handshaker.isFinished());
|
||||
assertArrayEquals(handshaker.getKey(), MockAltsHandshakerResp.getTestKeyData());
|
||||
assertEquals(IN_BYTES_SIZE - BYTES_CONSUMED, inBytes.remaining());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void setRpcVersions() throws Exception {
|
||||
when(mockStub.send(Matchers.<HandshakerReq>any()))
|
||||
.thenReturn(MockAltsHandshakerResp.getOkResponse(0));
|
||||
|
||||
RpcProtocolVersions rpcVersions =
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMinRpcVersion(
|
||||
RpcProtocolVersions.Version.newBuilder().setMajor(3).setMinor(4).build())
|
||||
.setMaxRpcVersion(
|
||||
RpcProtocolVersions.Version.newBuilder().setMajor(5).setMinor(6).build())
|
||||
.build();
|
||||
clientOptions =
|
||||
new AltsClientOptions.Builder()
|
||||
.setTargetName(TEST_TARGET_NAME)
|
||||
.addTargetServiceAccount(TEST_TARGET_SERVICE_ACCOUNT)
|
||||
.setRpcProtocolVersions(rpcVersions)
|
||||
.build();
|
||||
handshaker = new AltsHandshakerClient(mockStub, clientOptions);
|
||||
|
||||
handshaker.startClientHandshake();
|
||||
|
||||
ArgumentCaptor<HandshakerReq> reqCaptor = ArgumentCaptor.forClass(HandshakerReq.class);
|
||||
verify(mockStub).send(reqCaptor.capture());
|
||||
assertEquals(rpcVersions, reqCaptor.getValue().getClientStart().getRpcVersions());
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
|
||||
import io.grpc.alts.TransportSecurityCommon.RpcProtocolVersions;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
||||
/** Unit tests for {@link AltsHandshakerOptions}. */
|
||||
@RunWith(JUnit4.class)
|
||||
public final class AltsHandshakerOptionsTest {
|
||||
|
||||
@Test
|
||||
public void setAndGet() throws Exception {
|
||||
RpcProtocolVersions rpcVersions =
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(
|
||||
RpcProtocolVersions.Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.setMinRpcVersion(
|
||||
RpcProtocolVersions.Version.newBuilder().setMajor(2).setMinor(1).build())
|
||||
.build();
|
||||
|
||||
AltsHandshakerOptions options = new AltsHandshakerOptions(rpcVersions);
|
||||
assertThat(options.getRpcProtocolVersions()).isEqualTo(rpcVersions);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,199 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.grpc.alts.Handshaker.HandshakerReq;
|
||||
import io.grpc.alts.Handshaker.HandshakerResp;
|
||||
import io.grpc.alts.Handshaker.NextHandshakeMessageReq;
|
||||
import io.grpc.stub.StreamObserver;
|
||||
import java.io.IOException;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
||||
/** Unit tests for {@link AltsHandshakerStub}. */
|
||||
@RunWith(JUnit4.class)
|
||||
public class AltsHandshakerStubTest {
|
||||
/** Mock status of handshaker service. */
|
||||
private static enum Status {
|
||||
OK,
|
||||
ERROR,
|
||||
COMPLETE
|
||||
}
|
||||
|
||||
private AltsHandshakerStub stub;
|
||||
private MockWriter writer;
|
||||
private ExecutorService executor;
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
executor = Executors.newSingleThreadExecutor();
|
||||
writer = new MockWriter();
|
||||
stub = new AltsHandshakerStub(writer);
|
||||
writer.setReader(stub.getReaderForTest());
|
||||
}
|
||||
|
||||
@After
|
||||
public void tearDown() {
|
||||
executor.shutdown();
|
||||
}
|
||||
|
||||
/** Send a message as in_bytes and expect same message as out_frames echo back. */
|
||||
private void sendSuccessfulMessage() throws Exception {
|
||||
String message = "hello world";
|
||||
HandshakerReq.Builder req =
|
||||
HandshakerReq.newBuilder()
|
||||
.setNext(
|
||||
NextHandshakeMessageReq.newBuilder()
|
||||
.setInBytes(ByteString.copyFromUtf8(message))
|
||||
.build());
|
||||
HandshakerResp resp = stub.send(req.build());
|
||||
assertEquals(resp.getOutFrames().toStringUtf8(), message);
|
||||
}
|
||||
|
||||
/** Send a message and expect an IOException on error. */
|
||||
private void sendAndExpectError() throws InterruptedException {
|
||||
try {
|
||||
stub.send(HandshakerReq.newBuilder().build());
|
||||
fail("Exception expected");
|
||||
} catch (IOException ex) {
|
||||
assertThat(ex).hasMessageThat().contains("Received a terminating error");
|
||||
}
|
||||
}
|
||||
|
||||
/** Send a message and expect an IOException on closing. */
|
||||
private void sendAndExpectComplete() throws InterruptedException {
|
||||
try {
|
||||
stub.send(HandshakerReq.newBuilder().build());
|
||||
fail("Exception expected");
|
||||
} catch (IOException ex) {
|
||||
assertThat(ex).hasMessageThat().contains("Response stream closed");
|
||||
}
|
||||
}
|
||||
|
||||
/** Send a message and expect an IOException on unexpected message. */
|
||||
private void sendAndExpectUnexpectedMessage() throws InterruptedException {
|
||||
try {
|
||||
stub.send(HandshakerReq.newBuilder().build());
|
||||
fail("Exception expected");
|
||||
} catch (IOException ex) {
|
||||
assertThat(ex).hasMessageThat().contains("Received an unexpected response");
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void sendSuccessfulMessageTest() throws Exception {
|
||||
writer.setServiceStatus(Status.OK);
|
||||
sendSuccessfulMessage();
|
||||
stub.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getServiceErrorTest() throws InterruptedException {
|
||||
writer.setServiceStatus(Status.ERROR);
|
||||
sendAndExpectError();
|
||||
stub.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getServiceCompleteTest() throws Exception {
|
||||
writer.setServiceStatus(Status.COMPLETE);
|
||||
sendAndExpectComplete();
|
||||
stub.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getUnexpectedMessageTest() throws Exception {
|
||||
writer.setServiceStatus(Status.OK);
|
||||
writer.sendUnexpectedResponse();
|
||||
sendAndExpectUnexpectedMessage();
|
||||
stub.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void closeEarlyTest() throws InterruptedException {
|
||||
stub.close();
|
||||
sendAndExpectComplete();
|
||||
}
|
||||
|
||||
private class MockWriter implements StreamObserver<HandshakerReq> {
|
||||
private StreamObserver<HandshakerResp> reader;
|
||||
private Status status = Status.OK;
|
||||
|
||||
private void setReader(StreamObserver<HandshakerResp> reader) {
|
||||
this.reader = reader;
|
||||
}
|
||||
|
||||
private void setServiceStatus(Status status) {
|
||||
this.status = status;
|
||||
}
|
||||
|
||||
/** Send a handshaker response to reader. */
|
||||
private void sendUnexpectedResponse() {
|
||||
reader.onNext(HandshakerResp.newBuilder().build());
|
||||
}
|
||||
|
||||
/** Mock writer onNext. Will respond based on the server status. */
|
||||
@Override
|
||||
public void onNext(final HandshakerReq req) {
|
||||
executor.execute(
|
||||
new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
switch (status) {
|
||||
case OK:
|
||||
HandshakerResp.Builder resp = HandshakerResp.newBuilder();
|
||||
reader.onNext(resp.setOutFrames(req.getNext().getInBytes()).build());
|
||||
break;
|
||||
case ERROR:
|
||||
reader.onError(new RuntimeException());
|
||||
break;
|
||||
case COMPLETE:
|
||||
reader.onCompleted();
|
||||
break;
|
||||
default:
|
||||
return;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable t) {}
|
||||
|
||||
/** Mock writer onComplete. */
|
||||
@Override
|
||||
public void onCompleted() {
|
||||
executor.execute(
|
||||
new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
reader.onCompleted();
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,480 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static io.grpc.alts.transportsecurity.ByteBufTestUtils.getDirectBuffer;
|
||||
import static io.grpc.alts.transportsecurity.ByteBufTestUtils.getRandom;
|
||||
import static io.grpc.alts.transportsecurity.ByteBufTestUtils.writeSlice;
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
import com.google.common.testing.GcFinalization;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufAllocator;
|
||||
import io.netty.util.ReferenceCounted;
|
||||
import io.netty.util.ResourceLeakDetector;
|
||||
import io.netty.util.ResourceLeakDetector.Level;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
||||
/** Unit tests for {@link AltsTsiFrameProtector}. */
|
||||
@RunWith(JUnit4.class)
|
||||
public class AltsTsiFrameProtectorTest {
|
||||
private static final int FRAME_MIN_SIZE =
|
||||
AltsTsiFrameProtector.getHeaderTypeFieldBytes() + FakeChannelCrypter.getTagBytes();
|
||||
|
||||
private final List<ReferenceCounted> references = new ArrayList<ReferenceCounted>();
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
ResourceLeakDetector.setLevel(Level.PARANOID);
|
||||
}
|
||||
|
||||
@After
|
||||
public void teardown() {
|
||||
for (ReferenceCounted reference : references) {
|
||||
reference.release();
|
||||
}
|
||||
references.clear();
|
||||
// Increase our chances to detect ByteBuf leaks.
|
||||
GcFinalization.awaitFullGc();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void parserHeader_frameLengthNegativeFails() throws GeneralSecurityException {
|
||||
ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
|
||||
List<Object> out = new ArrayList<>();
|
||||
FakeChannelCrypter crypter = new FakeChannelCrypter();
|
||||
AltsTsiFrameProtector.Unprotector unprotector =
|
||||
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
||||
ByteBuf in = getDirectBuffer(AltsTsiFrameProtector.getHeaderBytes(), this::ref);
|
||||
in.writeIntLE(-1);
|
||||
in.writeIntLE(6);
|
||||
try {
|
||||
unprotector.unprotect(in, out, alloc);
|
||||
fail("Exception expected");
|
||||
} catch (IllegalArgumentException ex) {
|
||||
assertThat(ex).hasMessageThat().contains("Invalid header field: frame size too small");
|
||||
}
|
||||
|
||||
unprotector.destroy();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void parserHeader_frameTooSmall() throws GeneralSecurityException {
|
||||
ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
|
||||
List<Object> out = new ArrayList<>();
|
||||
FakeChannelCrypter crypter = new FakeChannelCrypter();
|
||||
AltsTsiFrameProtector.Unprotector unprotector =
|
||||
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
||||
ByteBuf in =
|
||||
getDirectBuffer(
|
||||
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
|
||||
in.writeIntLE(FRAME_MIN_SIZE - 1);
|
||||
in.writeIntLE(6);
|
||||
try {
|
||||
unprotector.unprotect(in, out, alloc);
|
||||
fail("Exception expected");
|
||||
} catch (IllegalArgumentException ex) {
|
||||
assertThat(ex).hasMessageThat().contains("Invalid header field: frame size too small");
|
||||
}
|
||||
|
||||
unprotector.destroy();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void parserHeader_frameTooLarge() throws GeneralSecurityException {
|
||||
ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
|
||||
List<Object> out = new ArrayList<>();
|
||||
FakeChannelCrypter crypter = new FakeChannelCrypter();
|
||||
AltsTsiFrameProtector.Unprotector unprotector =
|
||||
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
||||
ByteBuf in =
|
||||
getDirectBuffer(
|
||||
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
|
||||
in.writeIntLE(
|
||||
AltsTsiFrameProtector.getLimitMaxAllowedFrameBytes()
|
||||
- AltsTsiFrameProtector.getHeaderLenFieldBytes()
|
||||
+ 1);
|
||||
in.writeIntLE(6);
|
||||
try {
|
||||
unprotector.unprotect(in, out, alloc);
|
||||
fail("Exception expected");
|
||||
} catch (IllegalArgumentException ex) {
|
||||
assertThat(ex).hasMessageThat().contains("Invalid header field: frame size too large");
|
||||
}
|
||||
|
||||
unprotector.destroy();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void parserHeader_frameTypeInvalid() throws GeneralSecurityException {
|
||||
ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
|
||||
List<Object> out = new ArrayList<>();
|
||||
FakeChannelCrypter crypter = new FakeChannelCrypter();
|
||||
AltsTsiFrameProtector.Unprotector unprotector =
|
||||
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
||||
ByteBuf in =
|
||||
getDirectBuffer(
|
||||
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
|
||||
in.writeIntLE(FRAME_MIN_SIZE);
|
||||
in.writeIntLE(5);
|
||||
try {
|
||||
unprotector.unprotect(in, out, alloc);
|
||||
fail("Exception expected");
|
||||
} catch (IllegalArgumentException ex) {
|
||||
assertThat(ex).hasMessageThat().contains("Invalid header field: frame type");
|
||||
}
|
||||
|
||||
unprotector.destroy();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void parserHeader_frameZeroOk() throws GeneralSecurityException {
|
||||
ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
|
||||
List<Object> out = new ArrayList<>();
|
||||
FakeChannelCrypter crypter = new FakeChannelCrypter();
|
||||
AltsTsiFrameProtector.Unprotector unprotector =
|
||||
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
||||
ByteBuf in =
|
||||
getDirectBuffer(
|
||||
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
|
||||
in.writeIntLE(FRAME_MIN_SIZE);
|
||||
in.writeIntLE(6);
|
||||
|
||||
unprotector.unprotect(in, out, alloc);
|
||||
assertThat(in.readableBytes()).isEqualTo(0);
|
||||
|
||||
unprotector.destroy();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void parserHeader_EmptyUnprotectNoRetain() throws GeneralSecurityException {
|
||||
ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
|
||||
List<Object> out = new ArrayList<>();
|
||||
FakeChannelCrypter crypter = new FakeChannelCrypter();
|
||||
AltsTsiFrameProtector.Unprotector unprotector =
|
||||
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
||||
ByteBuf emptyBuf = getDirectBuffer(0, this::ref);
|
||||
unprotector.unprotect(emptyBuf, out, alloc);
|
||||
|
||||
assertThat(emptyBuf.refCnt()).isEqualTo(1);
|
||||
|
||||
unprotector.destroy();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void parserHeader_frameMaxOk() throws GeneralSecurityException {
|
||||
ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
|
||||
List<Object> out = new ArrayList<>();
|
||||
FakeChannelCrypter crypter = new FakeChannelCrypter();
|
||||
AltsTsiFrameProtector.Unprotector unprotector =
|
||||
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
||||
ByteBuf in =
|
||||
getDirectBuffer(
|
||||
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
|
||||
in.writeIntLE(
|
||||
AltsTsiFrameProtector.getLimitMaxAllowedFrameBytes()
|
||||
- AltsTsiFrameProtector.getHeaderLenFieldBytes());
|
||||
in.writeIntLE(6);
|
||||
|
||||
unprotector.unprotect(in, out, alloc);
|
||||
assertThat(in.readableBytes()).isEqualTo(0);
|
||||
|
||||
unprotector.destroy();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void parserHeader_frameOkFragment() throws GeneralSecurityException {
|
||||
ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
|
||||
List<Object> out = new ArrayList<>();
|
||||
FakeChannelCrypter crypter = new FakeChannelCrypter();
|
||||
AltsTsiFrameProtector.Unprotector unprotector =
|
||||
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
||||
ByteBuf in =
|
||||
getDirectBuffer(
|
||||
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
|
||||
in.writeIntLE(FRAME_MIN_SIZE);
|
||||
in.writeIntLE(6);
|
||||
ByteBuf in1 = in.readSlice(AltsTsiFrameProtector.getHeaderBytes() - 1);
|
||||
ByteBuf in2 = in.readSlice(1);
|
||||
|
||||
unprotector.unprotect(in1, out, alloc);
|
||||
assertThat(in1.readableBytes()).isEqualTo(0);
|
||||
|
||||
unprotector.unprotect(in2, out, alloc);
|
||||
assertThat(in2.readableBytes()).isEqualTo(0);
|
||||
|
||||
unprotector.destroy();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void parseHeader_frameFailFragment() throws GeneralSecurityException {
|
||||
ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
|
||||
List<Object> out = new ArrayList<>();
|
||||
FakeChannelCrypter crypter = new FakeChannelCrypter();
|
||||
AltsTsiFrameProtector.Unprotector unprotector =
|
||||
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
||||
ByteBuf in =
|
||||
getDirectBuffer(
|
||||
AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes(), this::ref);
|
||||
in.writeIntLE(FRAME_MIN_SIZE - 1);
|
||||
in.writeIntLE(6);
|
||||
ByteBuf in1 = in.readSlice(AltsTsiFrameProtector.getHeaderBytes() - 1);
|
||||
ByteBuf in2 = in.readSlice(1);
|
||||
|
||||
unprotector.unprotect(in1, out, alloc);
|
||||
assertThat(in1.readableBytes()).isEqualTo(0);
|
||||
|
||||
try {
|
||||
unprotector.unprotect(in2, out, alloc);
|
||||
fail("Exception expected");
|
||||
} catch (IllegalArgumentException ex) {
|
||||
assertThat(ex).hasMessageThat().contains("Invalid header field: frame size too small");
|
||||
}
|
||||
|
||||
assertThat(in2.readableBytes()).isEqualTo(0);
|
||||
|
||||
unprotector.destroy();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void parseFrame_oneFrameNoFragment() throws GeneralSecurityException {
|
||||
int payloadBytes = 1024;
|
||||
ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
|
||||
List<Object> out = new ArrayList<>();
|
||||
FakeChannelCrypter crypter = new FakeChannelCrypter();
|
||||
AltsTsiFrameProtector.Unprotector unprotector =
|
||||
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
||||
ByteBuf plain = getRandom(payloadBytes, this::ref);
|
||||
ByteBuf outFrame =
|
||||
getDirectBuffer(
|
||||
AltsTsiFrameProtector.getHeaderBytes()
|
||||
+ payloadBytes
|
||||
+ FakeChannelCrypter.getTagBytes(),
|
||||
this::ref);
|
||||
|
||||
outFrame.writeIntLE(
|
||||
AltsTsiFrameProtector.getHeaderTypeFieldBytes()
|
||||
+ payloadBytes
|
||||
+ FakeChannelCrypter.getTagBytes());
|
||||
outFrame.writeIntLE(6);
|
||||
List<ByteBuf> framePlain = Collections.singletonList(plain);
|
||||
ByteBuf frameOut = writeSlice(outFrame, payloadBytes + FakeChannelCrypter.getTagBytes());
|
||||
crypter.encrypt(frameOut, framePlain);
|
||||
plain.readerIndex(0);
|
||||
|
||||
unprotector.unprotect(outFrame, out, alloc);
|
||||
assertThat(outFrame.readableBytes()).isEqualTo(0);
|
||||
assertThat(out.size()).isEqualTo(1);
|
||||
ByteBuf out1 = ref((ByteBuf) out.get(0));
|
||||
assertThat(out1).isEqualTo(plain);
|
||||
|
||||
unprotector.destroy();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void parseFrame_twoFramesNoFragment() throws GeneralSecurityException {
|
||||
int payloadBytes = 1536;
|
||||
int payloadBytes1 = 1024;
|
||||
int payloadBytes2 = payloadBytes - payloadBytes1;
|
||||
ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
|
||||
List<Object> out = new ArrayList<>();
|
||||
FakeChannelCrypter crypter = new FakeChannelCrypter();
|
||||
AltsTsiFrameProtector.Unprotector unprotector =
|
||||
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
||||
|
||||
ByteBuf plain = getRandom(payloadBytes, this::ref);
|
||||
ByteBuf outFrame =
|
||||
getDirectBuffer(
|
||||
2 * (AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes())
|
||||
+ payloadBytes,
|
||||
this::ref);
|
||||
|
||||
outFrame.writeIntLE(
|
||||
AltsTsiFrameProtector.getHeaderTypeFieldBytes()
|
||||
+ payloadBytes1
|
||||
+ FakeChannelCrypter.getTagBytes());
|
||||
outFrame.writeIntLE(6);
|
||||
List<ByteBuf> framePlain1 = Collections.singletonList(plain.readSlice(payloadBytes1));
|
||||
ByteBuf frameOut1 = writeSlice(outFrame, payloadBytes1 + FakeChannelCrypter.getTagBytes());
|
||||
|
||||
outFrame.writeIntLE(
|
||||
AltsTsiFrameProtector.getHeaderTypeFieldBytes()
|
||||
+ payloadBytes2
|
||||
+ FakeChannelCrypter.getTagBytes());
|
||||
outFrame.writeIntLE(6);
|
||||
List<ByteBuf> framePlain2 = Collections.singletonList(plain);
|
||||
ByteBuf frameOut2 = writeSlice(outFrame, payloadBytes2 + FakeChannelCrypter.getTagBytes());
|
||||
|
||||
crypter.encrypt(frameOut1, framePlain1);
|
||||
crypter.encrypt(frameOut2, framePlain2);
|
||||
plain.readerIndex(0);
|
||||
|
||||
unprotector.unprotect(outFrame, out, alloc);
|
||||
assertThat(out.size()).isEqualTo(1);
|
||||
ByteBuf out1 = ref((ByteBuf) out.get(0));
|
||||
assertThat(out1).isEqualTo(plain);
|
||||
assertThat(outFrame.refCnt()).isEqualTo(1);
|
||||
assertThat(outFrame.readableBytes()).isEqualTo(0);
|
||||
|
||||
unprotector.destroy();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void parseFrame_twoFramesNoFragment_Leftover() throws GeneralSecurityException {
|
||||
int payloadBytes = 1536;
|
||||
int payloadBytes1 = 1024;
|
||||
int payloadBytes2 = payloadBytes - payloadBytes1;
|
||||
ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
|
||||
List<Object> out = new ArrayList<>();
|
||||
FakeChannelCrypter crypter = new FakeChannelCrypter();
|
||||
AltsTsiFrameProtector.Unprotector unprotector =
|
||||
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
||||
|
||||
ByteBuf plain = getRandom(payloadBytes, this::ref);
|
||||
ByteBuf protectedBuf =
|
||||
getDirectBuffer(
|
||||
2 * (AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes())
|
||||
+ payloadBytes
|
||||
+ AltsTsiFrameProtector.getHeaderBytes(),
|
||||
this::ref);
|
||||
|
||||
protectedBuf.writeIntLE(
|
||||
AltsTsiFrameProtector.getHeaderTypeFieldBytes()
|
||||
+ payloadBytes1
|
||||
+ FakeChannelCrypter.getTagBytes());
|
||||
protectedBuf.writeIntLE(6);
|
||||
List<ByteBuf> framePlain1 = Collections.singletonList(plain.readSlice(payloadBytes1));
|
||||
ByteBuf frameOut1 = writeSlice(protectedBuf, payloadBytes1 + FakeChannelCrypter.getTagBytes());
|
||||
|
||||
protectedBuf.writeIntLE(
|
||||
AltsTsiFrameProtector.getHeaderTypeFieldBytes()
|
||||
+ payloadBytes2
|
||||
+ FakeChannelCrypter.getTagBytes());
|
||||
protectedBuf.writeIntLE(6);
|
||||
List<ByteBuf> framePlain2 = Collections.singletonList(plain);
|
||||
ByteBuf frameOut2 = writeSlice(protectedBuf, payloadBytes2 + FakeChannelCrypter.getTagBytes());
|
||||
// This is an invalid header length field, make sure it triggers an error
|
||||
// when the remainder of the header is given.
|
||||
protectedBuf.writeIntLE((byte) -1);
|
||||
|
||||
crypter.encrypt(frameOut1, framePlain1);
|
||||
crypter.encrypt(frameOut2, framePlain2);
|
||||
plain.readerIndex(0);
|
||||
|
||||
unprotector.unprotect(protectedBuf, out, alloc);
|
||||
assertThat(out.size()).isEqualTo(1);
|
||||
ByteBuf out1 = ref((ByteBuf) out.get(0));
|
||||
assertThat(out1).isEqualTo(plain);
|
||||
|
||||
// The protectedBuf is buffered inside the unprotector.
|
||||
assertThat(protectedBuf.readableBytes()).isEqualTo(0);
|
||||
assertThat(protectedBuf.refCnt()).isEqualTo(2);
|
||||
|
||||
protectedBuf.writeIntLE(6);
|
||||
try {
|
||||
unprotector.unprotect(protectedBuf, out, alloc);
|
||||
fail("Exception expected");
|
||||
} catch (IllegalArgumentException ex) {
|
||||
assertThat(ex).hasMessageThat().contains("Invalid header field: frame size too small");
|
||||
}
|
||||
|
||||
unprotector.destroy();
|
||||
|
||||
// Make sure that unprotector does not hold onto buffered ByteBuf instance after destroy.
|
||||
assertThat(protectedBuf.refCnt()).isEqualTo(1);
|
||||
|
||||
// Make sure that destroying twice does not throw.
|
||||
unprotector.destroy();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void parseFrame_twoFramesFragmentSecond() throws GeneralSecurityException {
|
||||
int payloadBytes = 1536;
|
||||
int payloadBytes1 = 1024;
|
||||
int payloadBytes2 = payloadBytes - payloadBytes1;
|
||||
ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
|
||||
List<Object> out = new ArrayList<>();
|
||||
FakeChannelCrypter crypter = new FakeChannelCrypter();
|
||||
AltsTsiFrameProtector.Unprotector unprotector =
|
||||
new AltsTsiFrameProtector.Unprotector(crypter, alloc);
|
||||
|
||||
ByteBuf plain = getRandom(payloadBytes, this::ref);
|
||||
ByteBuf protectedBuf =
|
||||
getDirectBuffer(
|
||||
2 * (AltsTsiFrameProtector.getHeaderBytes() + FakeChannelCrypter.getTagBytes())
|
||||
+ payloadBytes
|
||||
+ AltsTsiFrameProtector.getHeaderBytes(),
|
||||
this::ref);
|
||||
|
||||
protectedBuf.writeIntLE(
|
||||
AltsTsiFrameProtector.getHeaderTypeFieldBytes()
|
||||
+ payloadBytes1
|
||||
+ FakeChannelCrypter.getTagBytes());
|
||||
protectedBuf.writeIntLE(6);
|
||||
List<ByteBuf> framePlain1 = Collections.singletonList(plain.readSlice(payloadBytes1));
|
||||
ByteBuf frameOut1 = writeSlice(protectedBuf, payloadBytes1 + FakeChannelCrypter.getTagBytes());
|
||||
|
||||
protectedBuf.writeIntLE(
|
||||
AltsTsiFrameProtector.getHeaderTypeFieldBytes()
|
||||
+ payloadBytes2
|
||||
+ FakeChannelCrypter.getTagBytes());
|
||||
protectedBuf.writeIntLE(6);
|
||||
List<ByteBuf> framePlain2 = Collections.singletonList(plain);
|
||||
ByteBuf frameOut2 = writeSlice(protectedBuf, payloadBytes2 + FakeChannelCrypter.getTagBytes());
|
||||
|
||||
crypter.encrypt(frameOut1, framePlain1);
|
||||
crypter.encrypt(frameOut2, framePlain2);
|
||||
plain.readerIndex(0);
|
||||
|
||||
unprotector.unprotect(
|
||||
protectedBuf.readSlice(
|
||||
payloadBytes
|
||||
+ AltsTsiFrameProtector.getHeaderBytes()
|
||||
+ FakeChannelCrypter.getTagBytes()
|
||||
+ AltsTsiFrameProtector.getHeaderBytes()),
|
||||
out,
|
||||
alloc);
|
||||
assertThat(out.size()).isEqualTo(1);
|
||||
ByteBuf out1 = ref((ByteBuf) out.get(0));
|
||||
assertThat(out1).isEqualTo(plain.readSlice(payloadBytes1));
|
||||
assertThat(protectedBuf.refCnt()).isEqualTo(2);
|
||||
|
||||
unprotector.unprotect(protectedBuf, out, alloc);
|
||||
assertThat(out.size()).isEqualTo(2);
|
||||
ByteBuf out2 = ref((ByteBuf) out.get(1));
|
||||
assertThat(out2).isEqualTo(plain);
|
||||
assertThat(protectedBuf.refCnt()).isEqualTo(1);
|
||||
|
||||
unprotector.destroy();
|
||||
}
|
||||
|
||||
private ByteBuf ref(ByteBuf buf) {
|
||||
if (buf != null) {
|
||||
references.add(buf);
|
||||
}
|
||||
return buf;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,269 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.fail;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.grpc.alts.Handshaker.HandshakerResult;
|
||||
import io.grpc.alts.Handshaker.Identity;
|
||||
import io.grpc.alts.TransportSecurityCommon.RpcProtocolVersions;
|
||||
import java.nio.ByteBuffer;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
import org.mockito.Matchers;
|
||||
|
||||
/** Unit tests for {@link AltsTsiHandshaker}. */
|
||||
@RunWith(JUnit4.class)
|
||||
public class AltsTsiHandshakerTest {
|
||||
private static final String TEST_KEY_DATA = "super secret 123";
|
||||
private static final String TEST_APPLICATION_PROTOCOL = "grpc";
|
||||
private static final String TEST_RECORD_PROTOCOL = "ALTSRP_GCM_AES128";
|
||||
private static final String TEST_CLIENT_SERVICE_ACCOUNT = "client@developer.gserviceaccount.com";
|
||||
private static final String TEST_SERVER_SERVICE_ACCOUNT = "server@developer.gserviceaccount.com";
|
||||
private static final int OUT_FRAME_SIZE = 100;
|
||||
private static final int TRANSPORT_BUFFER_SIZE = 200;
|
||||
private static final int TEST_MAX_RPC_VERSION_MAJOR = 3;
|
||||
private static final int TEST_MAX_RPC_VERSION_MINOR = 2;
|
||||
private static final int TEST_MIN_RPC_VERSION_MAJOR = 2;
|
||||
private static final int TEST_MIN_RPC_VERSION_MINOR = 1;
|
||||
private static final RpcProtocolVersions TEST_RPC_PROTOCOL_VERSIONS =
|
||||
RpcProtocolVersions.newBuilder()
|
||||
.setMaxRpcVersion(
|
||||
RpcProtocolVersions.Version.newBuilder()
|
||||
.setMajor(TEST_MAX_RPC_VERSION_MAJOR)
|
||||
.setMinor(TEST_MAX_RPC_VERSION_MINOR)
|
||||
.build())
|
||||
.setMinRpcVersion(
|
||||
RpcProtocolVersions.Version.newBuilder()
|
||||
.setMajor(TEST_MIN_RPC_VERSION_MAJOR)
|
||||
.setMinor(TEST_MIN_RPC_VERSION_MINOR)
|
||||
.build())
|
||||
.build();
|
||||
|
||||
private AltsHandshakerClient mockClient;
|
||||
private AltsHandshakerClient mockServer;
|
||||
private AltsTsiHandshaker handshakerClient;
|
||||
private AltsTsiHandshaker handshakerServer;
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
mockClient = mock(AltsHandshakerClient.class);
|
||||
mockServer = mock(AltsHandshakerClient.class);
|
||||
handshakerClient = new AltsTsiHandshaker(true, mockClient);
|
||||
handshakerServer = new AltsTsiHandshaker(false, mockServer);
|
||||
}
|
||||
|
||||
private HandshakerResult getHandshakerResult(boolean isClient) {
|
||||
HandshakerResult.Builder builder =
|
||||
HandshakerResult.newBuilder()
|
||||
.setApplicationProtocol(TEST_APPLICATION_PROTOCOL)
|
||||
.setRecordProtocol(TEST_RECORD_PROTOCOL)
|
||||
.setKeyData(ByteString.copyFromUtf8(TEST_KEY_DATA))
|
||||
.setPeerRpcVersions(TEST_RPC_PROTOCOL_VERSIONS);
|
||||
if (isClient) {
|
||||
builder.setPeerIdentity(
|
||||
Identity.newBuilder().setServiceAccount(TEST_SERVER_SERVICE_ACCOUNT).build());
|
||||
builder.setLocalIdentity(
|
||||
Identity.newBuilder().setServiceAccount(TEST_CLIENT_SERVICE_ACCOUNT).build());
|
||||
} else {
|
||||
builder.setPeerIdentity(
|
||||
Identity.newBuilder().setServiceAccount(TEST_CLIENT_SERVICE_ACCOUNT).build());
|
||||
builder.setLocalIdentity(
|
||||
Identity.newBuilder().setServiceAccount(TEST_SERVER_SERVICE_ACCOUNT).build());
|
||||
}
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void processBytesFromPeerFalseStart() throws Exception {
|
||||
verify(mockClient, never()).startClientHandshake();
|
||||
verify(mockClient, never()).startServerHandshake(Matchers.<ByteBuffer>any());
|
||||
verify(mockClient, never()).next(Matchers.<ByteBuffer>any());
|
||||
|
||||
ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
|
||||
assertTrue(handshakerClient.processBytesFromPeer(transportBuffer));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void processBytesFromPeerStartServer() throws Exception {
|
||||
ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
|
||||
ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE);
|
||||
verify(mockServer, never()).startClientHandshake();
|
||||
verify(mockServer, never()).next(Matchers.<ByteBuffer>any());
|
||||
// Mock transport buffer all consumed by processBytesFromPeer and there is an output frame.
|
||||
transportBuffer.position(transportBuffer.limit());
|
||||
when(mockServer.startServerHandshake(transportBuffer)).thenReturn(outputFrame);
|
||||
when(mockServer.isFinished()).thenReturn(false);
|
||||
|
||||
assertTrue(handshakerServer.processBytesFromPeer(transportBuffer));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void processBytesFromPeerStartServerEmptyOutput() throws Exception {
|
||||
ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
|
||||
ByteBuffer emptyOutputFrame = ByteBuffer.allocate(0);
|
||||
verify(mockServer, never()).startClientHandshake();
|
||||
verify(mockServer, never()).next(Matchers.<ByteBuffer>any());
|
||||
// Mock transport buffer all consumed by processBytesFromPeer and output frame is empty.
|
||||
// Expect processBytesFromPeer return False, because more data are needed from the peer.
|
||||
transportBuffer.position(transportBuffer.limit());
|
||||
when(mockServer.startServerHandshake(transportBuffer)).thenReturn(emptyOutputFrame);
|
||||
when(mockServer.isFinished()).thenReturn(false);
|
||||
|
||||
assertFalse(handshakerServer.processBytesFromPeer(transportBuffer));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void processBytesFromPeerStartServerFinished() throws Exception {
|
||||
ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
|
||||
ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE);
|
||||
verify(mockServer, never()).startClientHandshake();
|
||||
verify(mockServer, never()).next(Matchers.<ByteBuffer>any());
|
||||
// Mock handshake complete after processBytesFromPeer.
|
||||
when(mockServer.startServerHandshake(transportBuffer)).thenReturn(outputFrame);
|
||||
when(mockServer.isFinished()).thenReturn(true);
|
||||
|
||||
assertTrue(handshakerServer.processBytesFromPeer(transportBuffer));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void processBytesFromPeerNoBytesConsumed() throws Exception {
|
||||
ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
|
||||
ByteBuffer emptyOutputFrame = ByteBuffer.allocate(0);
|
||||
verify(mockServer, never()).startClientHandshake();
|
||||
verify(mockServer, never()).next(Matchers.<ByteBuffer>any());
|
||||
when(mockServer.startServerHandshake(transportBuffer)).thenReturn(emptyOutputFrame);
|
||||
when(mockServer.isFinished()).thenReturn(false);
|
||||
|
||||
try {
|
||||
assertTrue(handshakerServer.processBytesFromPeer(transportBuffer));
|
||||
fail("Expected IllegalStateException");
|
||||
} catch (IllegalStateException expected) {
|
||||
assertEquals("Handshaker did not consume any bytes.", expected.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void processBytesFromPeerClientNext() throws Exception {
|
||||
ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
|
||||
ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE);
|
||||
verify(mockClient, never()).startServerHandshake(Matchers.<ByteBuffer>any());
|
||||
when(mockClient.startClientHandshake()).thenReturn(outputFrame);
|
||||
when(mockClient.next(transportBuffer)).thenReturn(outputFrame);
|
||||
when(mockClient.isFinished()).thenReturn(false);
|
||||
|
||||
handshakerClient.getBytesToSendToPeer(transportBuffer);
|
||||
transportBuffer.position(transportBuffer.limit());
|
||||
assertFalse(handshakerClient.processBytesFromPeer(transportBuffer));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void processBytesFromPeerClientNextFinished() throws Exception {
|
||||
ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
|
||||
ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE);
|
||||
verify(mockClient, never()).startServerHandshake(Matchers.<ByteBuffer>any());
|
||||
when(mockClient.startClientHandshake()).thenReturn(outputFrame);
|
||||
when(mockClient.next(transportBuffer)).thenReturn(outputFrame);
|
||||
when(mockClient.isFinished()).thenReturn(true);
|
||||
|
||||
handshakerClient.getBytesToSendToPeer(transportBuffer);
|
||||
assertTrue(handshakerClient.processBytesFromPeer(transportBuffer));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void extractPeerFailure() throws Exception {
|
||||
when(mockClient.isFinished()).thenReturn(false);
|
||||
|
||||
try {
|
||||
handshakerClient.extractPeer();
|
||||
fail("Expected IllegalStateException");
|
||||
} catch (IllegalStateException expected) {
|
||||
assertEquals("Handshake is not complete.", expected.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void extractPeerObjectFailure() throws Exception {
|
||||
when(mockClient.isFinished()).thenReturn(false);
|
||||
|
||||
try {
|
||||
handshakerClient.extractPeerObject();
|
||||
fail("Expected IllegalStateException");
|
||||
} catch (IllegalStateException expected) {
|
||||
assertEquals("Handshake is not complete.", expected.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void extractClientPeerSuccess() throws Exception {
|
||||
ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE);
|
||||
ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
|
||||
when(mockClient.startClientHandshake()).thenReturn(outputFrame);
|
||||
when(mockClient.isFinished()).thenReturn(true);
|
||||
when(mockClient.getResult()).thenReturn(getHandshakerResult(/* isClient = */ true));
|
||||
|
||||
handshakerClient.getBytesToSendToPeer(transportBuffer);
|
||||
TsiPeer clientPeer = handshakerClient.extractPeer();
|
||||
|
||||
assertEquals(1, clientPeer.getProperties().size());
|
||||
assertEquals(
|
||||
TEST_SERVER_SERVICE_ACCOUNT,
|
||||
clientPeer.getProperty(AltsTsiHandshaker.TSI_SERVICE_ACCOUNT_PEER_PROPERTY).getValue());
|
||||
|
||||
AltsAuthContext clientContext = (AltsAuthContext) handshakerClient.extractPeerObject();
|
||||
assertEquals(TEST_APPLICATION_PROTOCOL, clientContext.getApplicationProtocol());
|
||||
assertEquals(TEST_RECORD_PROTOCOL, clientContext.getRecordProtocol());
|
||||
assertEquals(TEST_SERVER_SERVICE_ACCOUNT, clientContext.getPeerServiceAccount());
|
||||
assertEquals(TEST_CLIENT_SERVICE_ACCOUNT, clientContext.getLocalServiceAccount());
|
||||
assertEquals(TEST_RPC_PROTOCOL_VERSIONS, clientContext.getPeerRpcVersions());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void extractServerPeerSuccess() throws Exception {
|
||||
ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE);
|
||||
ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
|
||||
when(mockServer.startServerHandshake(Matchers.<ByteBuffer>any())).thenReturn(outputFrame);
|
||||
when(mockServer.isFinished()).thenReturn(true);
|
||||
when(mockServer.getResult()).thenReturn(getHandshakerResult(/* isClient = */ false));
|
||||
|
||||
handshakerServer.processBytesFromPeer(transportBuffer);
|
||||
handshakerServer.getBytesToSendToPeer(transportBuffer);
|
||||
TsiPeer serverPeer = handshakerServer.extractPeer();
|
||||
|
||||
assertEquals(1, serverPeer.getProperties().size());
|
||||
assertEquals(
|
||||
TEST_CLIENT_SERVICE_ACCOUNT,
|
||||
serverPeer.getProperty(AltsTsiHandshaker.TSI_SERVICE_ACCOUNT_PEER_PROPERTY).getValue());
|
||||
|
||||
AltsAuthContext serverContext = (AltsAuthContext) handshakerServer.extractPeerObject();
|
||||
assertEquals(TEST_APPLICATION_PROTOCOL, serverContext.getApplicationProtocol());
|
||||
assertEquals(TEST_RECORD_PROTOCOL, serverContext.getRecordProtocol());
|
||||
assertEquals(TEST_CLIENT_SERVICE_ACCOUNT, serverContext.getPeerServiceAccount());
|
||||
assertEquals(TEST_SERVER_SERVICE_ACCOUNT, serverContext.getLocalServiceAccount());
|
||||
assertEquals(TEST_RPC_PROTOCOL_VERSIONS, serverContext.getPeerRpcVersions());
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,194 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
import com.google.common.testing.GcFinalization;
|
||||
import io.grpc.alts.Handshaker.HandshakeProtocol;
|
||||
import io.grpc.alts.Handshaker.HandshakerReq;
|
||||
import io.grpc.alts.Handshaker.HandshakerResp;
|
||||
import io.grpc.alts.transportsecurity.TsiTest.Handshakers;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.util.ReferenceCounted;
|
||||
import io.netty.util.ResourceLeakDetector;
|
||||
import io.netty.util.ResourceLeakDetector.Level;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
||||
/** Unit tests for {@link AltsTsiHandshaker}. */
|
||||
@RunWith(JUnit4.class)
|
||||
public class AltsTsiTest {
|
||||
private static final int OVERHEAD =
|
||||
FakeChannelCrypter.getTagBytes() + AltsTsiFrameProtector.getHeaderBytes();
|
||||
|
||||
private final List<ReferenceCounted> references = new ArrayList<>();
|
||||
private AltsHandshakerClient client;
|
||||
private AltsHandshakerClient server;
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
ResourceLeakDetector.setLevel(Level.PARANOID);
|
||||
// Use MockAltsHandshakerStub for all the tests.
|
||||
AltsHandshakerOptions handshakerOptions = new AltsHandshakerOptions(null);
|
||||
MockAltsHandshakerStub clientStub = new MockAltsHandshakerStub();
|
||||
MockAltsHandshakerStub serverStub = new MockAltsHandshakerStub();
|
||||
client = new AltsHandshakerClient(clientStub, handshakerOptions);
|
||||
server = new AltsHandshakerClient(serverStub, handshakerOptions);
|
||||
}
|
||||
|
||||
@After
|
||||
public void tearDown() {
|
||||
for (ReferenceCounted reference : references) {
|
||||
reference.release();
|
||||
}
|
||||
references.clear();
|
||||
// Increase our chances to detect ByteBuf leaks.
|
||||
GcFinalization.awaitFullGc();
|
||||
}
|
||||
|
||||
private Handshakers newHandshakers() {
|
||||
TsiHandshaker clientHandshaker = new AltsTsiHandshaker(true, client);
|
||||
TsiHandshaker serverHandshaker = new AltsTsiHandshaker(false, server);
|
||||
return new Handshakers(clientHandshaker, serverHandshaker);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void verifyHandshakePeer() throws Exception {
|
||||
Handshakers handshakers = newHandshakers();
|
||||
TsiTest.performHandshake(TsiTest.getDefaultTransportBufferSize(), handshakers);
|
||||
TsiPeer clientPeer = handshakers.getClient().extractPeer();
|
||||
assertEquals(1, clientPeer.getProperties().size());
|
||||
assertEquals(
|
||||
MockAltsHandshakerResp.getTestPeerAccount(),
|
||||
clientPeer.getProperty("service_account").getValue());
|
||||
TsiPeer serverPeer = handshakers.getServer().extractPeer();
|
||||
assertEquals(1, serverPeer.getProperties().size());
|
||||
assertEquals(
|
||||
MockAltsHandshakerResp.getTestPeerAccount(),
|
||||
serverPeer.getProperty("service_account").getValue());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void handshake() throws GeneralSecurityException {
|
||||
TsiTest.handshakeTest(newHandshakers());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void handshakeSmallBuffer() throws GeneralSecurityException {
|
||||
TsiTest.handshakeSmallBufferTest(newHandshakers());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void pingPong() throws GeneralSecurityException {
|
||||
TsiTest.pingPongTest(newHandshakers(), this::ref);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void pingPongExactFrameSize() throws GeneralSecurityException {
|
||||
TsiTest.pingPongExactFrameSizeTest(newHandshakers(), this::ref);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void pingPongSmallBuffer() throws GeneralSecurityException {
|
||||
TsiTest.pingPongSmallBufferTest(newHandshakers(), this::ref);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void pingPongSmallFrame() throws GeneralSecurityException {
|
||||
TsiTest.pingPongSmallFrameTest(OVERHEAD, newHandshakers(), this::ref);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void pingPongSmallFrameSmallBuffer() throws GeneralSecurityException {
|
||||
TsiTest.pingPongSmallFrameSmallBufferTest(OVERHEAD, newHandshakers(), this::ref);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void corruptedCounter() throws GeneralSecurityException {
|
||||
TsiTest.corruptedCounterTest(newHandshakers(), this::ref);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void corruptedCiphertext() throws GeneralSecurityException {
|
||||
TsiTest.corruptedCiphertextTest(newHandshakers(), this::ref);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void corruptedTag() throws GeneralSecurityException {
|
||||
TsiTest.corruptedTagTest(newHandshakers(), this::ref);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void reflectedCiphertext() throws GeneralSecurityException {
|
||||
TsiTest.reflectedCiphertextTest(newHandshakers(), this::ref);
|
||||
}
|
||||
|
||||
private static class MockAltsHandshakerStub extends AltsHandshakerStub {
|
||||
private boolean started = false;
|
||||
|
||||
@Override
|
||||
public HandshakerResp send(HandshakerReq req) {
|
||||
if (started) {
|
||||
// Expect handshake next message.
|
||||
if (req.getReqOneofCase().getNumber() != 3) {
|
||||
return MockAltsHandshakerResp.getErrorResponse();
|
||||
}
|
||||
return MockAltsHandshakerResp.getFinishedResponse(req.getNext().getInBytes().size());
|
||||
} else {
|
||||
List<String> recordProtocols;
|
||||
int bytesConsumed = 0;
|
||||
switch (req.getReqOneofCase().getNumber()) {
|
||||
case 1:
|
||||
recordProtocols = req.getClientStart().getRecordProtocolsList();
|
||||
break;
|
||||
case 2:
|
||||
recordProtocols =
|
||||
req.getServerStart()
|
||||
.getHandshakeParametersMap()
|
||||
.get(HandshakeProtocol.ALTS.getNumber())
|
||||
.getRecordProtocolsList();
|
||||
bytesConsumed = req.getServerStart().getInBytes().size();
|
||||
break;
|
||||
default:
|
||||
return MockAltsHandshakerResp.getErrorResponse();
|
||||
}
|
||||
if (recordProtocols.isEmpty()) {
|
||||
return MockAltsHandshakerResp.getErrorResponse();
|
||||
}
|
||||
started = true;
|
||||
return MockAltsHandshakerResp.getOkResponse(bytesConsumed);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {}
|
||||
}
|
||||
|
||||
private ByteBuf ref(ByteBuf buf) {
|
||||
if (buf != null) {
|
||||
references.add(buf);
|
||||
}
|
||||
return buf;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import java.security.SecureRandom;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
public final class ByteBufTestUtils {
|
||||
public interface RegisterRef {
|
||||
ByteBuf register(ByteBuf buf);
|
||||
}
|
||||
|
||||
private static final Random random = new SecureRandom();
|
||||
|
||||
// The {@code ref} argument can be used to register the buffer for {@code release}.
|
||||
// TODO: allow the allocator to be passed in.
|
||||
public static ByteBuf getDirectBuffer(int len, RegisterRef ref) {
|
||||
return ref.register(Unpooled.directBuffer(len));
|
||||
}
|
||||
|
||||
/** Get random bytes. */
|
||||
public static ByteBuf getRandom(int len, RegisterRef ref) {
|
||||
ByteBuf buf = getDirectBuffer(len, ref);
|
||||
byte[] bytes = new byte[len];
|
||||
random.nextBytes(bytes);
|
||||
buf.writeBytes(bytes);
|
||||
return buf;
|
||||
}
|
||||
|
||||
/** Fragment byte buffer into multiple pieces. */
|
||||
public static List<ByteBuf> fragmentByteBuf(ByteBuf in, int num, RegisterRef ref) {
|
||||
ByteBuf buf = in.slice();
|
||||
Preconditions.checkArgument(num > 0);
|
||||
List<ByteBuf> fragmentedBufs = new ArrayList<>(num);
|
||||
int fragmentSize = buf.readableBytes() / num;
|
||||
while (buf.isReadable()) {
|
||||
int readBytes = num == 0 ? buf.readableBytes() : fragmentSize;
|
||||
ByteBuf tmpBuf = getDirectBuffer(readBytes, ref);
|
||||
tmpBuf.writeBytes(buf, readBytes);
|
||||
fragmentedBufs.add(tmpBuf);
|
||||
num--;
|
||||
}
|
||||
return fragmentedBufs;
|
||||
}
|
||||
|
||||
static ByteBuf writeSlice(ByteBuf in, int len) {
|
||||
Preconditions.checkArgument(len <= in.writableBytes());
|
||||
ByteBuf out = in.slice(in.writerIndex(), len);
|
||||
in.writerIndex(in.writerIndex() + len);
|
||||
return out.writerIndex(0);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,222 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static io.grpc.alts.transportsecurity.ByteBufTestUtils.getDirectBuffer;
|
||||
import static io.grpc.alts.transportsecurity.ByteBufTestUtils.getRandom;
|
||||
import static java.nio.charset.StandardCharsets.UTF_8;
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.util.ReferenceCounted;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import javax.crypto.AEADBadTagException;
|
||||
import org.junit.Test;
|
||||
|
||||
/** Abstract class for unit tests of {@link ChannelCrypterNetty}. */
|
||||
public abstract class ChannelCrypterNettyTestBase {
|
||||
private static final String DECRYPTION_FAILURE_MESSAGE = "Tag mismatch";
|
||||
|
||||
protected final List<ReferenceCounted> references = new ArrayList<>();
|
||||
public ChannelCrypterNetty client;
|
||||
public ChannelCrypterNetty server;
|
||||
|
||||
static final class FrameEncrypt {
|
||||
List<ByteBuf> plain;
|
||||
ByteBuf out;
|
||||
}
|
||||
|
||||
static final class FrameDecrypt {
|
||||
List<ByteBuf> ciphertext;
|
||||
ByteBuf out;
|
||||
ByteBuf tag;
|
||||
}
|
||||
|
||||
FrameEncrypt createFrameEncrypt(String message) {
|
||||
byte[] messageBytes = message.getBytes(UTF_8);
|
||||
FrameEncrypt frame = new FrameEncrypt();
|
||||
ByteBuf plain = getDirectBuffer(messageBytes.length, this::ref);
|
||||
plain.writeBytes(messageBytes);
|
||||
frame.plain = Collections.singletonList(plain);
|
||||
frame.out = getDirectBuffer(messageBytes.length + client.getSuffixLength(), this::ref);
|
||||
return frame;
|
||||
}
|
||||
|
||||
FrameDecrypt frameDecryptOfEncrypt(FrameEncrypt frameEncrypt) {
|
||||
int tagLen = client.getSuffixLength();
|
||||
FrameDecrypt frameDecrypt = new FrameDecrypt();
|
||||
ByteBuf out = frameEncrypt.out;
|
||||
frameDecrypt.ciphertext =
|
||||
Collections.singletonList(out.slice(out.readerIndex(), out.readableBytes() - tagLen));
|
||||
frameDecrypt.tag = out.slice(out.readerIndex() + out.readableBytes() - tagLen, tagLen);
|
||||
frameDecrypt.out = getDirectBuffer(out.readableBytes(), this::ref);
|
||||
return frameDecrypt;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void encryptDecrypt() throws GeneralSecurityException {
|
||||
String message = "Hello world";
|
||||
FrameEncrypt frameEncrypt = createFrameEncrypt(message);
|
||||
client.encrypt(frameEncrypt.out, frameEncrypt.plain);
|
||||
FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt);
|
||||
|
||||
server.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext);
|
||||
assertThat(frameEncrypt.plain.get(0).slice(0, frameDecrypt.out.readableBytes()))
|
||||
.isEqualTo(frameDecrypt.out);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void encryptDecryptLarge() throws GeneralSecurityException {
|
||||
FrameEncrypt frameEncrypt = new FrameEncrypt();
|
||||
ByteBuf plain = getRandom(17 * 1024, this::ref);
|
||||
frameEncrypt.plain = Collections.singletonList(plain);
|
||||
frameEncrypt.out = getDirectBuffer(plain.readableBytes() + client.getSuffixLength(), this::ref);
|
||||
|
||||
client.encrypt(frameEncrypt.out, frameEncrypt.plain);
|
||||
FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt);
|
||||
|
||||
// Call decrypt overload that takes ciphertext and tag.
|
||||
server.decrypt(frameDecrypt.out, frameEncrypt.out);
|
||||
assertThat(frameEncrypt.plain.get(0).slice(0, frameDecrypt.out.readableBytes()))
|
||||
.isEqualTo(frameDecrypt.out);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void encryptDecryptMultiple() throws GeneralSecurityException {
|
||||
String message = "Hello world";
|
||||
for (int i = 0; i < 512; ++i) {
|
||||
FrameEncrypt frameEncrypt = createFrameEncrypt(message);
|
||||
client.encrypt(frameEncrypt.out, frameEncrypt.plain);
|
||||
FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt);
|
||||
|
||||
server.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext);
|
||||
assertThat(frameEncrypt.plain.get(0).slice(0, frameDecrypt.out.readableBytes()))
|
||||
.isEqualTo(frameDecrypt.out);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void encryptDecryptComposite() throws GeneralSecurityException {
|
||||
String message = "Hello world";
|
||||
int lastLen = 2;
|
||||
byte[] messageBytes = message.getBytes(UTF_8);
|
||||
FrameEncrypt frameEncrypt = new FrameEncrypt();
|
||||
ByteBuf plain1 = getDirectBuffer(messageBytes.length - lastLen, this::ref);
|
||||
ByteBuf plain2 = getDirectBuffer(lastLen, this::ref);
|
||||
plain1.writeBytes(messageBytes, 0, messageBytes.length - lastLen);
|
||||
plain2.writeBytes(messageBytes, messageBytes.length - lastLen, lastLen);
|
||||
ByteBuf plain = Unpooled.wrappedBuffer(plain1, plain2);
|
||||
frameEncrypt.plain = Collections.singletonList(plain);
|
||||
frameEncrypt.out = getDirectBuffer(messageBytes.length + client.getSuffixLength(), this::ref);
|
||||
|
||||
client.encrypt(frameEncrypt.out, frameEncrypt.plain);
|
||||
|
||||
int tagLen = client.getSuffixLength();
|
||||
FrameDecrypt frameDecrypt = new FrameDecrypt();
|
||||
ByteBuf out = frameEncrypt.out;
|
||||
int outLen = out.readableBytes();
|
||||
ByteBuf cipher1 = getDirectBuffer(outLen - lastLen - tagLen, this::ref);
|
||||
ByteBuf cipher2 = getDirectBuffer(lastLen, this::ref);
|
||||
cipher1.writeBytes(out, 0, outLen - lastLen - tagLen);
|
||||
cipher2.writeBytes(out, outLen - tagLen - lastLen, lastLen);
|
||||
ByteBuf cipher = Unpooled.wrappedBuffer(cipher1, cipher2);
|
||||
frameDecrypt.ciphertext = Collections.singletonList(cipher);
|
||||
frameDecrypt.tag = out.slice(out.readerIndex() + out.readableBytes() - tagLen, tagLen);
|
||||
frameDecrypt.out = getDirectBuffer(out.readableBytes(), this::ref);
|
||||
|
||||
server.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext);
|
||||
assertThat(frameEncrypt.plain.get(0).slice(0, frameDecrypt.out.readableBytes()))
|
||||
.isEqualTo(frameDecrypt.out);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void reflection() throws GeneralSecurityException {
|
||||
String message = "Hello world";
|
||||
FrameEncrypt frameEncrypt = createFrameEncrypt(message);
|
||||
client.encrypt(frameEncrypt.out, frameEncrypt.plain);
|
||||
FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt);
|
||||
try {
|
||||
client.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext);
|
||||
fail("Exception expected");
|
||||
} catch (AEADBadTagException ex) {
|
||||
assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_MESSAGE);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void skipMessage() throws GeneralSecurityException {
|
||||
String message = "Hello world";
|
||||
FrameEncrypt frameEncrypt1 = createFrameEncrypt(message);
|
||||
client.encrypt(frameEncrypt1.out, frameEncrypt1.plain);
|
||||
FrameEncrypt frameEncrypt2 = createFrameEncrypt(message);
|
||||
client.encrypt(frameEncrypt2.out, frameEncrypt2.plain);
|
||||
FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt2);
|
||||
|
||||
try {
|
||||
client.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext);
|
||||
fail("Exception expected");
|
||||
} catch (AEADBadTagException ex) {
|
||||
assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_MESSAGE);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void corruptMessage() throws GeneralSecurityException {
|
||||
String message = "Hello world";
|
||||
FrameEncrypt frameEncrypt = createFrameEncrypt(message);
|
||||
client.encrypt(frameEncrypt.out, frameEncrypt.plain);
|
||||
FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt);
|
||||
frameEncrypt.out.setByte(3, frameEncrypt.out.getByte(3) + 1);
|
||||
|
||||
try {
|
||||
client.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext);
|
||||
fail("Exception expected");
|
||||
} catch (AEADBadTagException ex) {
|
||||
assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_MESSAGE);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void replayMessage() throws GeneralSecurityException {
|
||||
String message = "Hello world";
|
||||
FrameEncrypt frameEncrypt = createFrameEncrypt(message);
|
||||
client.encrypt(frameEncrypt.out, frameEncrypt.plain);
|
||||
FrameDecrypt frameDecrypt1 = frameDecryptOfEncrypt(frameEncrypt);
|
||||
FrameDecrypt frameDecrypt2 = frameDecryptOfEncrypt(frameEncrypt);
|
||||
|
||||
server.decrypt(frameDecrypt1.out, frameDecrypt1.tag, frameDecrypt1.ciphertext);
|
||||
|
||||
try {
|
||||
server.decrypt(frameDecrypt2.out, frameDecrypt2.tag, frameDecrypt2.ciphertext);
|
||||
fail("Exception expected");
|
||||
} catch (AEADBadTagException ex) {
|
||||
assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_MESSAGE);
|
||||
}
|
||||
}
|
||||
|
||||
private ByteBuf ref(ByteBuf buf) {
|
||||
if (buf != null) {
|
||||
references.add(buf);
|
||||
}
|
||||
return buf;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import static com.google.common.base.Preconditions.checkState;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import javax.crypto.AEADBadTagException;
|
||||
|
||||
public final class FakeChannelCrypter implements ChannelCrypterNetty {
|
||||
private static final int TAG_BYTES = 16;
|
||||
private static final byte TAG_BYTE = (byte) 0xa1;
|
||||
|
||||
private boolean destroyCalled = false;
|
||||
|
||||
public static int getTagBytes() {
|
||||
return TAG_BYTES;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void encrypt(ByteBuf out, List<ByteBuf> plain) throws GeneralSecurityException {
|
||||
checkState(!destroyCalled);
|
||||
for (ByteBuf buf : plain) {
|
||||
out.writeBytes(buf);
|
||||
for (int i = 0; i < TAG_BYTES; ++i) {
|
||||
out.writeByte(TAG_BYTE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void decrypt(ByteBuf out, ByteBuf tag, List<ByteBuf> ciphertext)
|
||||
throws GeneralSecurityException {
|
||||
checkState(!destroyCalled);
|
||||
for (ByteBuf buf : ciphertext) {
|
||||
out.writeBytes(buf);
|
||||
}
|
||||
boolean tagValid = tag.forEachByte((byte value) -> value == TAG_BYTE) == -1;
|
||||
if (!tagValid) {
|
||||
throw new AEADBadTagException("Tag mismatch!");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void decrypt(ByteBuf out, ByteBuf ciphertextAndTag) throws GeneralSecurityException {
|
||||
checkState(!destroyCalled);
|
||||
ByteBuf ciphertext = ciphertextAndTag.readSlice(ciphertextAndTag.readableBytes() - TAG_BYTES);
|
||||
decrypt(out, /*tag=*/ ciphertextAndTag, Collections.singletonList(ciphertext));
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getSuffixLength() {
|
||||
return TAG_BYTES;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void destroy() {
|
||||
destroyCalled = true;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,227 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import static java.nio.charset.StandardCharsets.UTF_8;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.netty.buffer.ByteBufAllocator;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.util.Collections;
|
||||
import java.util.logging.Level;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* A fake handshaker compatible with security/transport_security/fake_transport_security.h See
|
||||
* {@link TsiHandshaker} for documentation.
|
||||
*/
|
||||
public class FakeTsiHandshaker implements TsiHandshaker {
|
||||
private static final Logger logger = Logger.getLogger(FakeTsiHandshaker.class.getName());
|
||||
|
||||
private static final TsiHandshakerFactory clientHandshakerFactory =
|
||||
new TsiHandshakerFactory() {
|
||||
@Override
|
||||
public TsiHandshaker newHandshaker() {
|
||||
return new FakeTsiHandshaker(true);
|
||||
}
|
||||
};
|
||||
|
||||
private static final TsiHandshakerFactory serverHandshakerFactory =
|
||||
new TsiHandshakerFactory() {
|
||||
@Override
|
||||
public TsiHandshaker newHandshaker() {
|
||||
return new FakeTsiHandshaker(false);
|
||||
}
|
||||
};
|
||||
|
||||
private boolean isClient;
|
||||
private ByteBuffer sendBuffer = null;
|
||||
private AltsFraming.Parser frameParser = new AltsFraming.Parser();
|
||||
|
||||
private State sendState;
|
||||
private State receiveState;
|
||||
|
||||
enum State {
|
||||
CLIENT_NONE,
|
||||
SERVER_NONE,
|
||||
CLIENT_INIT,
|
||||
SERVER_INIT,
|
||||
CLIENT_FINISHED,
|
||||
SERVER_FINISHED;
|
||||
|
||||
// Returns the next State. In order to advance to sendState=N, receiveState must be N-1.
|
||||
public State next() {
|
||||
if (ordinal() + 1 < values().length) {
|
||||
return values()[ordinal() + 1];
|
||||
}
|
||||
throw new UnsupportedOperationException("Can't call next() on last element: " + this);
|
||||
}
|
||||
}
|
||||
|
||||
public static TsiHandshakerFactory clientHandshakerFactory() {
|
||||
return clientHandshakerFactory;
|
||||
}
|
||||
|
||||
public static TsiHandshakerFactory serverHandshakerFactory() {
|
||||
return serverHandshakerFactory;
|
||||
}
|
||||
|
||||
public static TsiHandshaker newFakeHandshakerClient() {
|
||||
return clientHandshakerFactory.newHandshaker();
|
||||
}
|
||||
|
||||
public static TsiHandshaker newFakeHandshakerServer() {
|
||||
return serverHandshakerFactory.newHandshaker();
|
||||
}
|
||||
|
||||
protected FakeTsiHandshaker(boolean isClient) {
|
||||
this.isClient = isClient;
|
||||
if (isClient) {
|
||||
sendState = State.CLIENT_NONE;
|
||||
receiveState = State.SERVER_NONE;
|
||||
} else {
|
||||
sendState = State.SERVER_NONE;
|
||||
receiveState = State.CLIENT_NONE;
|
||||
}
|
||||
}
|
||||
|
||||
private State getNextState(State state) {
|
||||
switch (state) {
|
||||
case CLIENT_NONE:
|
||||
return State.CLIENT_INIT;
|
||||
case SERVER_NONE:
|
||||
return State.SERVER_INIT;
|
||||
case CLIENT_INIT:
|
||||
return State.CLIENT_FINISHED;
|
||||
case SERVER_INIT:
|
||||
return State.SERVER_FINISHED;
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private String getNextMessage() {
|
||||
State result = getNextState(sendState);
|
||||
return result == null ? "BAD STATE" : result.toString();
|
||||
}
|
||||
|
||||
private String getExpectedMessage() {
|
||||
State result = getNextState(receiveState);
|
||||
return result == null ? "BAD STATE" : result.toString();
|
||||
}
|
||||
|
||||
private void incrementSendState() {
|
||||
sendState = getNextState(sendState);
|
||||
}
|
||||
|
||||
private void incrementReceiveState() {
|
||||
receiveState = getNextState(receiveState);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void getBytesToSendToPeer(ByteBuffer bytes) throws GeneralSecurityException {
|
||||
Preconditions.checkNotNull(bytes);
|
||||
|
||||
// If we're done, return nothing.
|
||||
if (sendState == State.CLIENT_FINISHED || sendState == State.SERVER_FINISHED) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Prepare the next message, if neeeded.
|
||||
if (sendBuffer == null) {
|
||||
if (sendState.next() != receiveState) {
|
||||
// We're still waiting for bytes from the peer, so bail.
|
||||
return;
|
||||
}
|
||||
ByteBuffer payload = ByteBuffer.wrap(getNextMessage().getBytes(UTF_8));
|
||||
sendBuffer = AltsFraming.toFrame(payload, payload.remaining());
|
||||
logger.log(Level.FINE, "Buffered message: {0}", getNextMessage());
|
||||
}
|
||||
while (bytes.hasRemaining() && sendBuffer.hasRemaining()) {
|
||||
bytes.put(sendBuffer.get());
|
||||
}
|
||||
if (!sendBuffer.hasRemaining()) {
|
||||
// Get ready to send the next message.
|
||||
sendBuffer = null;
|
||||
incrementSendState();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean processBytesFromPeer(ByteBuffer bytes) throws GeneralSecurityException {
|
||||
Preconditions.checkNotNull(bytes);
|
||||
|
||||
frameParser.readBytes(bytes);
|
||||
if (frameParser.isComplete()) {
|
||||
ByteBuffer messageBytes = frameParser.getRawFrame();
|
||||
int offset = AltsFraming.getFramingOverhead();
|
||||
int length = messageBytes.limit() - offset;
|
||||
String message = new String(messageBytes.array(), offset, length, UTF_8);
|
||||
logger.log(Level.FINE, "Read message: {0}", message);
|
||||
|
||||
if (!message.equals(getExpectedMessage())) {
|
||||
throw new IllegalArgumentException(
|
||||
"Bad handshake message. Got "
|
||||
+ message
|
||||
+ " (length = "
|
||||
+ message.length()
|
||||
+ ") expected "
|
||||
+ getExpectedMessage()
|
||||
+ " (length = "
|
||||
+ getExpectedMessage().length()
|
||||
+ ")");
|
||||
}
|
||||
incrementReceiveState();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isInProgress() {
|
||||
boolean finishedReceiving =
|
||||
receiveState == State.CLIENT_FINISHED || receiveState == State.SERVER_FINISHED;
|
||||
boolean finishedSending =
|
||||
sendState == State.CLIENT_FINISHED || sendState == State.SERVER_FINISHED;
|
||||
return !finishedSending || !finishedReceiving;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TsiPeer extractPeer() {
|
||||
return new TsiPeer(Collections.emptyList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object extractPeerObject() {
|
||||
return AltsAuthContext.getDefaultInstance();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TsiFrameProtector createFrameProtector(int maxFrameSize, ByteBufAllocator alloc) {
|
||||
Preconditions.checkState(!isInProgress(), "Handshake is not complete.");
|
||||
|
||||
// We use an all-zero key, since this is the fake handshaker.
|
||||
byte[] key = new byte[AltsChannelCrypter.getKeyLength()];
|
||||
return new AltsTsiFrameProtector(maxFrameSize, new AltsChannelCrypter(key, isClient), alloc);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TsiFrameProtector createFrameProtector(ByteBufAllocator alloc) {
|
||||
return createFrameProtector(AltsTsiFrameProtector.getMaxAllowedFrameBytes(), alloc);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,209 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import static java.nio.charset.StandardCharsets.UTF_8;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
|
||||
import com.google.common.testing.GcFinalization;
|
||||
import io.grpc.alts.transportsecurity.TsiTest.Handshakers;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.util.ReferenceCounted;
|
||||
import io.netty.util.ResourceLeakDetector;
|
||||
import io.netty.util.ResourceLeakDetector.Level;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
||||
/** Unit tests for {@link TsiHandshaker}. */
|
||||
@RunWith(JUnit4.class)
|
||||
public class FakeTsiTest {
|
||||
|
||||
private static final int OVERHEAD =
|
||||
FakeChannelCrypter.getTagBytes() + AltsTsiFrameProtector.getHeaderBytes();
|
||||
|
||||
private final List<ReferenceCounted> references = new ArrayList<>();
|
||||
|
||||
private static Handshakers newHandshakers() {
|
||||
TsiHandshaker clientHandshaker = FakeTsiHandshaker.newFakeHandshakerClient();
|
||||
TsiHandshaker serverHandshaker = FakeTsiHandshaker.newFakeHandshakerServer();
|
||||
return new Handshakers(clientHandshaker, serverHandshaker);
|
||||
}
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
ResourceLeakDetector.setLevel(Level.PARANOID);
|
||||
}
|
||||
|
||||
@After
|
||||
public void tearDown() {
|
||||
for (ReferenceCounted reference : references) {
|
||||
reference.release();
|
||||
}
|
||||
references.clear();
|
||||
// Increase our chances to detect ByteBuf leaks.
|
||||
GcFinalization.awaitFullGc();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void handshakeStateOrderTest() {
|
||||
try {
|
||||
Handshakers handshakers = newHandshakers();
|
||||
TsiHandshaker clientHandshaker = handshakers.getClient();
|
||||
TsiHandshaker serverHandshaker = handshakers.getServer();
|
||||
|
||||
byte[] transportBufferBytes = new byte[TsiTest.getDefaultTransportBufferSize()];
|
||||
ByteBuffer transportBuffer = ByteBuffer.wrap(transportBufferBytes);
|
||||
transportBuffer.limit(0); // Start off with an empty buffer
|
||||
|
||||
transportBuffer.clear();
|
||||
clientHandshaker.getBytesToSendToPeer(transportBuffer);
|
||||
transportBuffer.flip();
|
||||
assertEquals(
|
||||
FakeTsiHandshaker.State.CLIENT_INIT.toString().trim(),
|
||||
new String(transportBufferBytes, 4, transportBuffer.remaining(), UTF_8).trim());
|
||||
|
||||
serverHandshaker.processBytesFromPeer(transportBuffer);
|
||||
assertFalse(transportBuffer.hasRemaining());
|
||||
|
||||
// client shouldn't offer any more bytes
|
||||
transportBuffer.clear();
|
||||
clientHandshaker.getBytesToSendToPeer(transportBuffer);
|
||||
transportBuffer.flip();
|
||||
assertFalse(transportBuffer.hasRemaining());
|
||||
|
||||
transportBuffer.clear();
|
||||
serverHandshaker.getBytesToSendToPeer(transportBuffer);
|
||||
transportBuffer.flip();
|
||||
assertEquals(
|
||||
FakeTsiHandshaker.State.SERVER_INIT.toString().trim(),
|
||||
new String(transportBufferBytes, 4, transportBuffer.remaining(), UTF_8).trim());
|
||||
|
||||
clientHandshaker.processBytesFromPeer(transportBuffer);
|
||||
assertFalse(transportBuffer.hasRemaining());
|
||||
|
||||
// server shouldn't offer any more bytes
|
||||
transportBuffer.clear();
|
||||
serverHandshaker.getBytesToSendToPeer(transportBuffer);
|
||||
transportBuffer.flip();
|
||||
assertFalse(transportBuffer.hasRemaining());
|
||||
|
||||
transportBuffer.clear();
|
||||
clientHandshaker.getBytesToSendToPeer(transportBuffer);
|
||||
transportBuffer.flip();
|
||||
assertEquals(
|
||||
FakeTsiHandshaker.State.CLIENT_FINISHED.toString().trim(),
|
||||
new String(transportBufferBytes, 4, transportBuffer.remaining(), UTF_8).trim());
|
||||
|
||||
serverHandshaker.processBytesFromPeer(transportBuffer);
|
||||
assertFalse(transportBuffer.hasRemaining());
|
||||
|
||||
// client shouldn't offer any more bytes
|
||||
transportBuffer.clear();
|
||||
clientHandshaker.getBytesToSendToPeer(transportBuffer);
|
||||
transportBuffer.flip();
|
||||
assertFalse(transportBuffer.hasRemaining());
|
||||
|
||||
transportBuffer.clear();
|
||||
serverHandshaker.getBytesToSendToPeer(transportBuffer);
|
||||
transportBuffer.flip();
|
||||
assertEquals(
|
||||
FakeTsiHandshaker.State.SERVER_FINISHED.toString().trim(),
|
||||
new String(transportBufferBytes, 4, transportBuffer.remaining(), UTF_8).trim());
|
||||
|
||||
clientHandshaker.processBytesFromPeer(transportBuffer);
|
||||
assertFalse(transportBuffer.hasRemaining());
|
||||
|
||||
// server shouldn't offer any more bytes
|
||||
transportBuffer.clear();
|
||||
serverHandshaker.getBytesToSendToPeer(transportBuffer);
|
||||
transportBuffer.flip();
|
||||
assertFalse(transportBuffer.hasRemaining());
|
||||
} catch (GeneralSecurityException e) {
|
||||
throw new AssertionError(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void handshake() throws GeneralSecurityException {
|
||||
TsiTest.handshakeTest(newHandshakers());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void handshakeSmallBuffer() throws GeneralSecurityException {
|
||||
TsiTest.handshakeSmallBufferTest(newHandshakers());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void pingPong() throws GeneralSecurityException {
|
||||
TsiTest.pingPongTest(newHandshakers(), this::ref);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void pingPongExactFrameSize() throws GeneralSecurityException {
|
||||
TsiTest.pingPongExactFrameSizeTest(newHandshakers(), this::ref);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void pingPongSmallBuffer() throws GeneralSecurityException {
|
||||
TsiTest.pingPongSmallBufferTest(newHandshakers(), this::ref);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void pingPongSmallFrame() throws GeneralSecurityException {
|
||||
TsiTest.pingPongSmallFrameTest(OVERHEAD, newHandshakers(), this::ref);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void pingPongSmallFrameSmallBuffer() throws GeneralSecurityException {
|
||||
TsiTest.pingPongSmallFrameSmallBufferTest(OVERHEAD, newHandshakers(), this::ref);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void corruptedCounter() throws GeneralSecurityException {
|
||||
TsiTest.corruptedCounterTest(newHandshakers(), this::ref);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void corruptedCiphertext() throws GeneralSecurityException {
|
||||
TsiTest.corruptedCiphertextTest(newHandshakers(), this::ref);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void corruptedTag() throws GeneralSecurityException {
|
||||
TsiTest.corruptedTagTest(newHandshakers(), this::ref);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void reflectedCiphertext() throws GeneralSecurityException {
|
||||
TsiTest.reflectedCiphertextTest(newHandshakers(), this::ref);
|
||||
}
|
||||
|
||||
private ByteBuf ref(ByteBuf buf) {
|
||||
if (buf != null) {
|
||||
references.add(buf);
|
||||
}
|
||||
return buf;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,117 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import static java.nio.charset.StandardCharsets.UTF_8;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import io.grpc.Status;
|
||||
import io.grpc.alts.Handshaker.HandshakerResp;
|
||||
import io.grpc.alts.Handshaker.HandshakerResult;
|
||||
import io.grpc.alts.Handshaker.HandshakerStatus;
|
||||
import io.grpc.alts.Handshaker.Identity;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.security.SecureRandom;
|
||||
import java.util.Random;
|
||||
|
||||
/** A class for mocking ALTS Handshaker Responses. */
|
||||
class MockAltsHandshakerResp {
|
||||
private static final String TEST_ERROR_DETAILS = "handshake error";
|
||||
private static final String TEST_APPLICATION_PROTOCOL = "grpc";
|
||||
private static final String TEST_RECORD_PROTOCOL = "ALTSRP_GCM_AES128";
|
||||
private static final String TEST_OUT_FRAME = "output frame";
|
||||
private static final String TEST_LOCAL_ACCOUNT = "local@developer.gserviceaccount.com";
|
||||
private static final String TEST_PEER_ACCOUNT = "peer@developer.gserviceaccount.com";
|
||||
private static final byte[] TEST_KEY_DATA = initializeTestKeyData();
|
||||
private static final int FRAME_HEADER_SIZE = 4;
|
||||
|
||||
static String getTestErrorDetails() {
|
||||
return TEST_ERROR_DETAILS;
|
||||
}
|
||||
|
||||
static String getTestPeerAccount() {
|
||||
return TEST_PEER_ACCOUNT;
|
||||
}
|
||||
|
||||
private static byte[] initializeTestKeyData() {
|
||||
Random random = new SecureRandom();
|
||||
byte[] randombytes = new byte[AltsChannelCrypter.getKeyLength()];
|
||||
random.nextBytes(randombytes);
|
||||
return randombytes;
|
||||
}
|
||||
|
||||
static byte[] getTestKeyData() {
|
||||
return TEST_KEY_DATA;
|
||||
}
|
||||
|
||||
/** Returns a mock output frame. */
|
||||
static ByteString getOutFrame() {
|
||||
int frameSize = TEST_OUT_FRAME.length();
|
||||
ByteBuffer buffer = ByteBuffer.allocate(FRAME_HEADER_SIZE + frameSize);
|
||||
buffer.order(ByteOrder.LITTLE_ENDIAN);
|
||||
buffer.putInt(frameSize);
|
||||
buffer.put(TEST_OUT_FRAME.getBytes(UTF_8));
|
||||
buffer.flip();
|
||||
return ByteString.copyFrom(buffer);
|
||||
}
|
||||
|
||||
/** Returns a mock error handshaker response. */
|
||||
static HandshakerResp getErrorResponse() {
|
||||
HandshakerResp.Builder resp = HandshakerResp.newBuilder();
|
||||
resp.setStatus(
|
||||
HandshakerStatus.newBuilder()
|
||||
.setCode(Status.Code.UNKNOWN.value())
|
||||
.setDetails(TEST_ERROR_DETAILS)
|
||||
.build());
|
||||
return resp.build();
|
||||
}
|
||||
|
||||
/** Returns a mock normal handshaker response. */
|
||||
static HandshakerResp getOkResponse(int bytesConsumed) {
|
||||
HandshakerResp.Builder resp = HandshakerResp.newBuilder();
|
||||
resp.setOutFrames(getOutFrame());
|
||||
resp.setBytesConsumed(bytesConsumed);
|
||||
resp.setStatus(HandshakerStatus.newBuilder().setCode(Status.Code.OK.value()).build());
|
||||
return resp.build();
|
||||
}
|
||||
|
||||
/** Returns a mock normal handshaker response. */
|
||||
static HandshakerResp getEmptyOutFrameResponse(int bytesConsumed) {
|
||||
HandshakerResp.Builder resp = HandshakerResp.newBuilder();
|
||||
resp.setBytesConsumed(bytesConsumed);
|
||||
resp.setStatus(HandshakerStatus.newBuilder().setCode(Status.Code.OK.value()).build());
|
||||
return resp.build();
|
||||
}
|
||||
|
||||
/** Returns a mock final handshaker response with handshake result. */
|
||||
static HandshakerResp getFinishedResponse(int bytesConsumed) {
|
||||
HandshakerResp.Builder resp = HandshakerResp.newBuilder();
|
||||
HandshakerResult.Builder result =
|
||||
HandshakerResult.newBuilder()
|
||||
.setApplicationProtocol(TEST_APPLICATION_PROTOCOL)
|
||||
.setRecordProtocol(TEST_RECORD_PROTOCOL)
|
||||
.setPeerIdentity(Identity.newBuilder().setServiceAccount(TEST_PEER_ACCOUNT).build())
|
||||
.setLocalIdentity(Identity.newBuilder().setServiceAccount(TEST_LOCAL_ACCOUNT).build())
|
||||
.setKeyData(ByteString.copyFrom(TEST_KEY_DATA));
|
||||
resp.setOutFrames(getOutFrame());
|
||||
resp.setBytesConsumed(bytesConsumed);
|
||||
resp.setStatus(HandshakerStatus.newBuilder().setCode(Status.Code.OK.value()).build());
|
||||
resp.setResult(result.build());
|
||||
return resp.build();
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,364 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.alts.transportsecurity;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static io.grpc.alts.transportsecurity.ByteBufTestUtils.getDirectBuffer;
|
||||
import static java.nio.charset.StandardCharsets.UTF_8;
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
import io.grpc.alts.transportsecurity.ByteBufTestUtils.RegisterRef;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.buffer.UnpooledByteBufAllocator;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.security.GeneralSecurityException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import javax.crypto.AEADBadTagException;
|
||||
|
||||
/** Utility class that provides tests for implementations of @{link TsiHandshaker}. */
|
||||
public final class TsiTest {
|
||||
private static final String DECRYPTION_FAILURE_RE = "Tag mismatch!";
|
||||
|
||||
private TsiTest() {}
|
||||
|
||||
/** A @{code TsiHandshaker} pair for running tests. */
|
||||
public static class Handshakers {
|
||||
private final TsiHandshaker client;
|
||||
private final TsiHandshaker server;
|
||||
|
||||
public Handshakers(TsiHandshaker client, TsiHandshaker server) {
|
||||
this.client = client;
|
||||
this.server = server;
|
||||
}
|
||||
|
||||
public TsiHandshaker getClient() {
|
||||
return client;
|
||||
}
|
||||
|
||||
public TsiHandshaker getServer() {
|
||||
return server;
|
||||
}
|
||||
}
|
||||
|
||||
private static final int DEFAULT_TRANSPORT_BUFFER_SIZE = 2048;
|
||||
|
||||
private static final UnpooledByteBufAllocator alloc = UnpooledByteBufAllocator.DEFAULT;
|
||||
|
||||
private static final String EXAMPLE_MESSAGE1 = "hello world";
|
||||
private static final String EXAMPLE_MESSAGE2 = "oysteroystersoysterseateateat";
|
||||
|
||||
private static final int EXAMPLE_MESSAGE1_LEN = EXAMPLE_MESSAGE1.getBytes(UTF_8).length;
|
||||
private static final int EXAMPLE_MESSAGE2_LEN = EXAMPLE_MESSAGE2.getBytes(UTF_8).length;
|
||||
|
||||
static int getDefaultTransportBufferSize() {
|
||||
return DEFAULT_TRANSPORT_BUFFER_SIZE;
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs a handshake between the client handshaker and server handshaker using a transport of
|
||||
* length transportBufferSize.
|
||||
*/
|
||||
static void performHandshake(int transportBufferSize, Handshakers handshakers)
|
||||
throws GeneralSecurityException {
|
||||
TsiHandshaker clientHandshaker = handshakers.getClient();
|
||||
TsiHandshaker serverHandshaker = handshakers.getServer();
|
||||
|
||||
byte[] transportBufferBytes = new byte[transportBufferSize];
|
||||
ByteBuffer transportBuffer = ByteBuffer.wrap(transportBufferBytes);
|
||||
transportBuffer.limit(0); // Start off with an empty buffer
|
||||
|
||||
while (clientHandshaker.isInProgress() || serverHandshaker.isInProgress()) {
|
||||
for (TsiHandshaker handshaker : new TsiHandshaker[] {clientHandshaker, serverHandshaker}) {
|
||||
if (handshaker.isInProgress()) {
|
||||
// Process any bytes on the wire.
|
||||
if (transportBuffer.hasRemaining()) {
|
||||
handshaker.processBytesFromPeer(transportBuffer);
|
||||
}
|
||||
// Put new bytes on the wire, if needed.
|
||||
if (handshaker.isInProgress()) {
|
||||
transportBuffer.clear();
|
||||
handshaker.getBytesToSendToPeer(transportBuffer);
|
||||
transportBuffer.flip();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
clientHandshaker.extractPeer();
|
||||
serverHandshaker.extractPeer();
|
||||
}
|
||||
|
||||
public static void handshakeTest(Handshakers handshakers) throws GeneralSecurityException {
|
||||
performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
|
||||
}
|
||||
|
||||
public static void handshakeSmallBufferTest(Handshakers handshakers)
|
||||
throws GeneralSecurityException {
|
||||
performHandshake(9, handshakers);
|
||||
}
|
||||
|
||||
/** Sends a message between the sender and receiver. */
|
||||
private static void sendMessage(
|
||||
TsiFrameProtector sender,
|
||||
TsiFrameProtector receiver,
|
||||
int recvFragmentSize,
|
||||
String message,
|
||||
RegisterRef ref)
|
||||
throws GeneralSecurityException {
|
||||
|
||||
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
|
||||
List<ByteBuf> protectOut = new ArrayList<>();
|
||||
List<Object> unprotectOut = new ArrayList<>();
|
||||
|
||||
sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc);
|
||||
assertThat(protectOut.size()).isEqualTo(1);
|
||||
|
||||
ByteBuf protect = ref.register(protectOut.get(0));
|
||||
while (protect.isReadable()) {
|
||||
ByteBuf buf = protect;
|
||||
if (recvFragmentSize > 0) {
|
||||
int size = Math.min(protect.readableBytes(), recvFragmentSize);
|
||||
buf = protect.readSlice(size);
|
||||
}
|
||||
receiver.unprotect(buf, unprotectOut, alloc);
|
||||
}
|
||||
ByteBuf plaintextRecvd = getDirectBuffer(message.getBytes(UTF_8).length, ref);
|
||||
for (Object unprotect : unprotectOut) {
|
||||
ByteBuf unprotectBuf = ref.register((ByteBuf) unprotect);
|
||||
plaintextRecvd.writeBytes(unprotectBuf);
|
||||
}
|
||||
assertThat(plaintextRecvd).isEqualTo(Unpooled.wrappedBuffer(message.getBytes(UTF_8)));
|
||||
}
|
||||
|
||||
/** Ping pong test. */
|
||||
public static void pingPongTest(Handshakers handshakers, RegisterRef ref)
|
||||
throws GeneralSecurityException {
|
||||
performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
|
||||
|
||||
TsiFrameProtector clientProtector = handshakers.getClient().createFrameProtector(alloc);
|
||||
TsiFrameProtector serverProtector = handshakers.getServer().createFrameProtector(alloc);
|
||||
|
||||
sendMessage(clientProtector, serverProtector, -1, EXAMPLE_MESSAGE1, ref);
|
||||
sendMessage(serverProtector, clientProtector, -1, EXAMPLE_MESSAGE2, ref);
|
||||
|
||||
clientProtector.destroy();
|
||||
serverProtector.destroy();
|
||||
}
|
||||
|
||||
/** Ping pong test with exact frame size. */
|
||||
public static void pingPongExactFrameSizeTest(Handshakers handshakers, RegisterRef ref)
|
||||
throws GeneralSecurityException {
|
||||
performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
|
||||
|
||||
int frameSize =
|
||||
EXAMPLE_MESSAGE1.getBytes(UTF_8).length
|
||||
+ AltsTsiFrameProtector.getHeaderBytes()
|
||||
+ FakeChannelCrypter.getTagBytes();
|
||||
|
||||
TsiFrameProtector clientProtector =
|
||||
handshakers.getClient().createFrameProtector(frameSize, alloc);
|
||||
TsiFrameProtector serverProtector =
|
||||
handshakers.getServer().createFrameProtector(frameSize, alloc);
|
||||
|
||||
sendMessage(clientProtector, serverProtector, -1, EXAMPLE_MESSAGE1, ref);
|
||||
sendMessage(serverProtector, clientProtector, -1, EXAMPLE_MESSAGE1, ref);
|
||||
|
||||
clientProtector.destroy();
|
||||
serverProtector.destroy();
|
||||
}
|
||||
|
||||
/** Ping pong test with small buffer size. */
|
||||
public static void pingPongSmallBufferTest(Handshakers handshakers, RegisterRef ref)
|
||||
throws GeneralSecurityException {
|
||||
performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
|
||||
|
||||
TsiFrameProtector clientProtector = handshakers.getClient().createFrameProtector(alloc);
|
||||
TsiFrameProtector serverProtector = handshakers.getServer().createFrameProtector(alloc);
|
||||
|
||||
sendMessage(clientProtector, serverProtector, 1, EXAMPLE_MESSAGE1, ref);
|
||||
sendMessage(serverProtector, clientProtector, 1, EXAMPLE_MESSAGE2, ref);
|
||||
|
||||
clientProtector.destroy();
|
||||
serverProtector.destroy();
|
||||
}
|
||||
|
||||
/** Ping pong test with small frame size. */
|
||||
public static void pingPongSmallFrameTest(
|
||||
int frameProtectorOverhead, Handshakers handshakers, RegisterRef ref)
|
||||
throws GeneralSecurityException {
|
||||
performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
|
||||
|
||||
// We send messages using small non-aligned buffers. We use 3 and 5, small primes.
|
||||
TsiFrameProtector clientProtector =
|
||||
handshakers.getClient().createFrameProtector(frameProtectorOverhead + 3, alloc);
|
||||
TsiFrameProtector serverProtector =
|
||||
handshakers.getServer().createFrameProtector(frameProtectorOverhead + 5, alloc);
|
||||
|
||||
sendMessage(clientProtector, serverProtector, EXAMPLE_MESSAGE1_LEN, EXAMPLE_MESSAGE1, ref);
|
||||
sendMessage(serverProtector, clientProtector, EXAMPLE_MESSAGE2_LEN, EXAMPLE_MESSAGE2, ref);
|
||||
|
||||
clientProtector.destroy();
|
||||
serverProtector.destroy();
|
||||
}
|
||||
|
||||
/** Ping pong test with small frame and small buffer. */
|
||||
public static void pingPongSmallFrameSmallBufferTest(
|
||||
int frameProtectorOverhead, Handshakers handshakers, RegisterRef ref)
|
||||
throws GeneralSecurityException {
|
||||
performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
|
||||
|
||||
// We send messages using small non-aligned buffers. We use 3 and 5, small primes.
|
||||
TsiFrameProtector clientProtector =
|
||||
handshakers.getClient().createFrameProtector(frameProtectorOverhead + 3, alloc);
|
||||
TsiFrameProtector serverProtector =
|
||||
handshakers.getServer().createFrameProtector(frameProtectorOverhead + 5, alloc);
|
||||
|
||||
sendMessage(clientProtector, serverProtector, EXAMPLE_MESSAGE1_LEN, EXAMPLE_MESSAGE1, ref);
|
||||
sendMessage(serverProtector, clientProtector, EXAMPLE_MESSAGE2_LEN, EXAMPLE_MESSAGE2, ref);
|
||||
|
||||
sendMessage(clientProtector, serverProtector, EXAMPLE_MESSAGE1_LEN, EXAMPLE_MESSAGE1, ref);
|
||||
sendMessage(serverProtector, clientProtector, EXAMPLE_MESSAGE2_LEN, EXAMPLE_MESSAGE2, ref);
|
||||
|
||||
clientProtector.destroy();
|
||||
serverProtector.destroy();
|
||||
}
|
||||
|
||||
/** Test corrupted counter. */
|
||||
public static void corruptedCounterTest(Handshakers handshakers, RegisterRef ref)
|
||||
throws GeneralSecurityException {
|
||||
performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
|
||||
|
||||
TsiFrameProtector sender = handshakers.getClient().createFrameProtector(alloc);
|
||||
TsiFrameProtector receiver = handshakers.getServer().createFrameProtector(alloc);
|
||||
|
||||
String message = "hello world";
|
||||
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
|
||||
List<ByteBuf> protectOut = new ArrayList<>();
|
||||
List<Object> unprotectOut = new ArrayList<>();
|
||||
|
||||
sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc);
|
||||
assertThat(protectOut.size()).isEqualTo(1);
|
||||
|
||||
ByteBuf protect = ref.register(protectOut.get(0));
|
||||
// Unprotect once to increase receiver counter.
|
||||
receiver.unprotect(protect.slice(), unprotectOut, alloc);
|
||||
assertThat(unprotectOut.size()).isEqualTo(1);
|
||||
ref.register((ByteBuf) unprotectOut.get(0));
|
||||
|
||||
try {
|
||||
receiver.unprotect(protect, unprotectOut, alloc);
|
||||
fail("Exception expected");
|
||||
} catch (AEADBadTagException ex) {
|
||||
assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_RE);
|
||||
}
|
||||
|
||||
sender.destroy();
|
||||
receiver.destroy();
|
||||
}
|
||||
|
||||
/** Test corrupted ciphertext. */
|
||||
public static void corruptedCiphertextTest(Handshakers handshakers, RegisterRef ref)
|
||||
throws GeneralSecurityException {
|
||||
performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
|
||||
|
||||
TsiFrameProtector sender = handshakers.getClient().createFrameProtector(alloc);
|
||||
TsiFrameProtector receiver = handshakers.getServer().createFrameProtector(alloc);
|
||||
|
||||
String message = "hello world";
|
||||
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
|
||||
List<ByteBuf> protectOut = new ArrayList<>();
|
||||
List<Object> unprotectOut = new ArrayList<>();
|
||||
|
||||
sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc);
|
||||
assertThat(protectOut.size()).isEqualTo(1);
|
||||
|
||||
ByteBuf protect = ref.register(protectOut.get(0));
|
||||
int ciphertextIdx = protect.writerIndex() - FakeChannelCrypter.getTagBytes() - 2;
|
||||
protect.setByte(ciphertextIdx, protect.getByte(ciphertextIdx) + 1);
|
||||
|
||||
try {
|
||||
receiver.unprotect(protect, unprotectOut, alloc);
|
||||
fail("Exception expected");
|
||||
} catch (AEADBadTagException ex) {
|
||||
assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_RE);
|
||||
}
|
||||
|
||||
sender.destroy();
|
||||
receiver.destroy();
|
||||
}
|
||||
|
||||
/** Test corrupted tag. */
|
||||
public static void corruptedTagTest(Handshakers handshakers, RegisterRef ref)
|
||||
throws GeneralSecurityException {
|
||||
performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
|
||||
|
||||
TsiFrameProtector sender = handshakers.getClient().createFrameProtector(alloc);
|
||||
TsiFrameProtector receiver = handshakers.getServer().createFrameProtector(alloc);
|
||||
|
||||
String message = "hello world";
|
||||
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
|
||||
List<ByteBuf> protectOut = new ArrayList<>();
|
||||
List<Object> unprotectOut = new ArrayList<>();
|
||||
|
||||
sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc);
|
||||
assertThat(protectOut.size()).isEqualTo(1);
|
||||
|
||||
ByteBuf protect = ref.register(protectOut.get(0));
|
||||
int tagIdx = protect.writerIndex() - 1;
|
||||
protect.setByte(tagIdx, protect.getByte(tagIdx) + 1);
|
||||
|
||||
try {
|
||||
receiver.unprotect(protect, unprotectOut, alloc);
|
||||
fail("Exception expected");
|
||||
} catch (AEADBadTagException ex) {
|
||||
assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_RE);
|
||||
}
|
||||
|
||||
sender.destroy();
|
||||
receiver.destroy();
|
||||
}
|
||||
|
||||
/** Test reflected ciphertext. */
|
||||
public static void reflectedCiphertextTest(Handshakers handshakers, RegisterRef ref)
|
||||
throws GeneralSecurityException {
|
||||
performHandshake(DEFAULT_TRANSPORT_BUFFER_SIZE, handshakers);
|
||||
|
||||
TsiFrameProtector sender = handshakers.getClient().createFrameProtector(alloc);
|
||||
TsiFrameProtector receiver = handshakers.getServer().createFrameProtector(alloc);
|
||||
|
||||
String message = "hello world";
|
||||
ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(message.getBytes(UTF_8));
|
||||
List<ByteBuf> protectOut = new ArrayList<>();
|
||||
List<Object> unprotectOut = new ArrayList<>();
|
||||
|
||||
sender.protectFlush(Collections.singletonList(plaintextBuffer), protectOut::add, alloc);
|
||||
assertThat(protectOut.size()).isEqualTo(1);
|
||||
|
||||
ByteBuf protect = ref.register(protectOut.get(0));
|
||||
try {
|
||||
sender.unprotect(protect.slice(), unprotectOut, alloc);
|
||||
fail("Exception expected");
|
||||
} catch (AEADBadTagException ex) {
|
||||
assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_RE);
|
||||
}
|
||||
|
||||
sender.destroy();
|
||||
receiver.destroy();
|
||||
}
|
||||
}
|
||||
|
|
@ -208,6 +208,7 @@ subprojects {
|
|||
protobuf_nano: "com.google.protobuf.nano:protobuf-javanano:${protobufNanoVersion}",
|
||||
protobuf_plugin: 'com.google.protobuf:protobuf-gradle-plugin:0.8.3',
|
||||
protobuf_util: "com.google.protobuf:protobuf-java-util:${protobufVersion}",
|
||||
lang: "org.apache.commons:commons-lang3:3.5",
|
||||
|
||||
netty: "io.netty:netty-codec-http2:[${nettyVersion}]",
|
||||
netty_epoll: "io.netty:netty-transport-native-epoll:${nettyVersion}" + epoll_suffix,
|
||||
|
|
@ -218,6 +219,7 @@ subprojects {
|
|||
junit: 'junit:junit:4.12',
|
||||
mockito: 'org.mockito:mockito-core:1.9.5',
|
||||
truth: 'com.google.truth:truth:0.36',
|
||||
guava_testlib: 'com.google.guava:guava-testlib:19.0',
|
||||
|
||||
// Benchmark dependencies
|
||||
hdrhistogram: 'org.hdrhistogram:HdrHistogram:2.1.10',
|
||||
|
|
@ -391,6 +393,8 @@ subprojects {
|
|||
// Run with: ./gradlew japicmp --continue
|
||||
def baselineGrpcVersion = '1.6.1'
|
||||
def publicApiSubprojects = [
|
||||
// TODO: uncomment after grpc-alts artifact is published.
|
||||
// ':grpc-alts',
|
||||
':grpc-auth',
|
||||
':grpc-context',
|
||||
':grpc-core',
|
||||
|
|
|
|||
|
|
@ -71,6 +71,7 @@ java_library(
|
|||
"@com_google_guava_guava//jar",
|
||||
"@com_google_protobuf//:protobuf_java",
|
||||
"@com_google_protobuf//:protobuf_java_util",
|
||||
"@grpc_java//alts",
|
||||
"@grpc_java//core",
|
||||
"@grpc_java//netty",
|
||||
"@grpc_java//protobuf",
|
||||
|
|
@ -97,6 +98,29 @@ java_binary(
|
|||
],
|
||||
)
|
||||
|
||||
java_binary(
|
||||
name = "hello-world-alts-client",
|
||||
testonly = 1,
|
||||
main_class = "io.grpc.examples.alts.HelloWorldAltsClient",
|
||||
runtime_deps = [
|
||||
":examples",
|
||||
"@grpc_java//alts",
|
||||
"@grpc_java//netty",
|
||||
],
|
||||
)
|
||||
|
||||
java_binary(
|
||||
name = "hello-world-alts-server",
|
||||
testonly = 1,
|
||||
main_class = "io.grpc.examples.alts.HelloWorldAltsServer",
|
||||
runtime_deps = [
|
||||
":examples",
|
||||
"@grpc_java//alts",
|
||||
"@grpc_java//netty",
|
||||
],
|
||||
|
||||
)
|
||||
|
||||
java_binary(
|
||||
name = "route-guide-client",
|
||||
testonly = 1,
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ def nettyTcNativeVersion = '2.0.7.Final'
|
|||
|
||||
dependencies {
|
||||
compile "com.google.api.grpc:proto-google-common-protos:1.0.0"
|
||||
compile "io.grpc:grpc-alts:${grpcVersion}"
|
||||
compile "io.grpc:grpc-netty:${grpcVersion}"
|
||||
compile "io.grpc:grpc-protobuf:${grpcVersion}"
|
||||
compile "io.grpc:grpc-stub:${grpcVersion}"
|
||||
|
|
@ -101,6 +102,20 @@ task helloWorldClient(type: CreateStartScripts) {
|
|||
classpath = jar.outputs.files + project.configurations.runtime
|
||||
}
|
||||
|
||||
task helloWorldAltsServer(type: CreateStartScripts) {
|
||||
mainClassName = 'io.grpc.examples.alts.HelloWorldAltsServer'
|
||||
applicationName = 'hello-world-alts-server'
|
||||
outputDir = new File(project.buildDir, 'tmp')
|
||||
classpath = jar.outputs.files + project.configurations.runtime
|
||||
}
|
||||
|
||||
task helloWorldAltsClient(type: CreateStartScripts) {
|
||||
mainClassName = 'io.grpc.examples.alts.HelloWorldAltsClient'
|
||||
applicationName = 'hello-world-alts-client'
|
||||
outputDir = new File(project.buildDir, 'tmp')
|
||||
classpath = jar.outputs.files + project.configurations.runtime
|
||||
}
|
||||
|
||||
task helloWorldTlsServer(type: CreateStartScripts) {
|
||||
mainClassName = 'io.grpc.examples.helloworldtls.HelloWorldServerTls'
|
||||
applicationName = 'hello-world-tls-server'
|
||||
|
|
@ -127,6 +142,8 @@ applicationDistribution.into('bin') {
|
|||
from(routeGuideClient)
|
||||
from(helloWorldServer)
|
||||
from(helloWorldClient)
|
||||
from(helloWorldAltsServer)
|
||||
from(helloWorldAltsClient)
|
||||
from(helloWorldTlsServer)
|
||||
from(helloWorldTlsClient)
|
||||
from(compressingHelloWorldClient)
|
||||
|
|
|
|||
|
|
@ -30,6 +30,11 @@
|
|||
<artifactId>grpc-stub</artifactId>
|
||||
<version>${grpc.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>io.grpc</groupId>
|
||||
<artifactId>grpc-alts</artifactId>
|
||||
<version>${grpc.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>io.grpc</groupId>
|
||||
<artifactId>grpc-testing</artifactId>
|
||||
|
|
|
|||
|
|
@ -0,0 +1,98 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.examples.alts;
|
||||
|
||||
import io.grpc.alts.AltsChannelBuilder;
|
||||
import io.grpc.ManagedChannel;
|
||||
import io.grpc.examples.helloworld.GreeterGrpc;
|
||||
import io.grpc.examples.helloworld.HelloReply;
|
||||
import io.grpc.examples.helloworld.HelloRequest;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.logging.Level;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* An example gRPC client that uses ALTS. Shows how to do a Unary RPC. This example can only be run
|
||||
* on Google Cloud Platform.
|
||||
*/
|
||||
public final class HelloWorldAltsClient {
|
||||
private static final Logger logger = Logger.getLogger(HelloWorldAltsClient.class.getName());
|
||||
private String serverAddress = "localhost:10001";
|
||||
|
||||
public static void main(String[] args) throws InterruptedException {
|
||||
new HelloWorldAltsClient().run(args);
|
||||
}
|
||||
|
||||
private void parseArgs(String[] args) {
|
||||
boolean usage = false;
|
||||
for (String arg : args) {
|
||||
if (!arg.startsWith("--")) {
|
||||
System.err.println("All arguments must start with '--': " + arg);
|
||||
usage = true;
|
||||
break;
|
||||
}
|
||||
String[] parts = arg.substring(2).split("=", 2);
|
||||
String key = parts[0];
|
||||
if ("help".equals(key)) {
|
||||
usage = true;
|
||||
break;
|
||||
}
|
||||
if (parts.length != 2) {
|
||||
System.err.println("All arguments must be of the form --arg=value");
|
||||
usage = true;
|
||||
break;
|
||||
}
|
||||
String value = parts[1];
|
||||
if ("server".equals(key)) {
|
||||
serverAddress = value;
|
||||
} else {
|
||||
System.err.println("Unknown argument: " + key);
|
||||
usage = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (usage) {
|
||||
HelloWorldAltsClient c = new HelloWorldAltsClient();
|
||||
System.out.println(
|
||||
"Usage: [ARGS...]"
|
||||
+ "\n"
|
||||
+ "\n --server=SERVER_ADDRESS Server address to connect to. Default "
|
||||
+ c.serverAddress);
|
||||
System.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
private void run(String[] args) throws InterruptedException {
|
||||
parseArgs(args);
|
||||
ExecutorService executor = Executors.newFixedThreadPool(1);
|
||||
ManagedChannel channel = AltsChannelBuilder.forTarget(serverAddress).executor(executor).build();
|
||||
try {
|
||||
GreeterGrpc.GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(channel);
|
||||
HelloReply resp = stub.sayHello(HelloRequest.newBuilder().setName("Waldo").build());
|
||||
|
||||
logger.log(Level.INFO, "Got {0}", resp);
|
||||
} finally {
|
||||
channel.shutdown();
|
||||
channel.awaitTermination(1, TimeUnit.SECONDS);
|
||||
// Wait until the channel has terminated, since tasks can be queued after the channel is
|
||||
// shutdown.
|
||||
executor.shutdown();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
/*
|
||||
* Copyright 2018, gRPC Authors All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package io.grpc.examples.alts;
|
||||
|
||||
import io.grpc.alts.AltsServerBuilder;
|
||||
import io.grpc.Server;
|
||||
import io.grpc.examples.helloworld.GreeterGrpc.GreeterImplBase;
|
||||
import io.grpc.examples.helloworld.HelloReply;
|
||||
import io.grpc.examples.helloworld.HelloRequest;
|
||||
import io.grpc.stub.StreamObserver;
|
||||
import java.io.IOException;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.logging.Level;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* An example gRPC server that uses ALTS. Shows how to do a Unary RPC. This example can only be run
|
||||
* on Google Cloud Platform.
|
||||
*/
|
||||
public final class HelloWorldAltsServer extends GreeterImplBase {
|
||||
private static final Logger logger = Logger.getLogger(HelloWorldAltsServer.class.getName());
|
||||
private Server server;
|
||||
private int port = 10001;
|
||||
|
||||
public static void main(String[] args) throws IOException, InterruptedException {
|
||||
new HelloWorldAltsServer().start(args);
|
||||
}
|
||||
|
||||
private void parseArgs(String[] args) {
|
||||
boolean usage = false;
|
||||
for (String arg : args) {
|
||||
if (!arg.startsWith("--")) {
|
||||
System.err.println("All arguments must start with '--': " + arg);
|
||||
usage = true;
|
||||
break;
|
||||
}
|
||||
String[] parts = arg.substring(2).split("=", 2);
|
||||
String key = parts[0];
|
||||
if ("help".equals(key)) {
|
||||
usage = true;
|
||||
break;
|
||||
}
|
||||
if (parts.length != 2) {
|
||||
System.err.println("All arguments must be of the form --arg=value");
|
||||
usage = true;
|
||||
break;
|
||||
}
|
||||
String value = parts[1];
|
||||
if ("port".equals(key)) {
|
||||
port = Integer.parseInt(value);
|
||||
} else {
|
||||
System.err.println("Unknown argument: " + key);
|
||||
usage = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (usage) {
|
||||
HelloWorldAltsServer s = new HelloWorldAltsServer();
|
||||
System.out.println(
|
||||
"Usage: [ARGS...]"
|
||||
+ "\n"
|
||||
+ "\n --port=PORT Server port to bind to. Default "
|
||||
+ s.port);
|
||||
System.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
private void start(String[] args) throws IOException, InterruptedException {
|
||||
parseArgs(args);
|
||||
server =
|
||||
AltsServerBuilder.forPort(port)
|
||||
.addService(this)
|
||||
.executor(Executors.newFixedThreadPool(1))
|
||||
.build();
|
||||
server.start();
|
||||
logger.log(Level.INFO, "Started on {0}", port);
|
||||
server.awaitTermination();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void sayHello(HelloRequest request, StreamObserver<HelloReply> observer) {
|
||||
observer.onNext(HelloReply.newBuilder().setMessage("Hello, " + request.getName()).build());
|
||||
observer.onCompleted();
|
||||
}
|
||||
}
|
||||
|
|
@ -26,7 +26,8 @@ def grpc_java_repositories(
|
|||
omit_io_netty_tcnative_boringssl_static=False,
|
||||
omit_io_opencensus_api=False,
|
||||
omit_io_opencensus_grpc_metrics=False,
|
||||
omit_junit_junit=False):
|
||||
omit_junit_junit=False,
|
||||
omit_org_apache_commons_lang3=False):
|
||||
"""Imports dependencies for grpc-java."""
|
||||
if not omit_com_google_api_grpc_google_common_protos:
|
||||
com_google_api_grpc_google_common_protos()
|
||||
|
|
@ -80,6 +81,8 @@ def grpc_java_repositories(
|
|||
io_opencensus_grpc_metrics()
|
||||
if not omit_junit_junit:
|
||||
junit_junit()
|
||||
if not omit_org_apache_commons_lang3:
|
||||
org_apache_commons_lang3()
|
||||
|
||||
native.bind(
|
||||
name = "guava",
|
||||
|
|
@ -268,3 +271,10 @@ def junit_junit():
|
|||
artifact = "junit:junit:4.12",
|
||||
sha1 = "2973d150c0dc1fefe998f834810d68f278ea58ec",
|
||||
)
|
||||
|
||||
def org_apache_commons_lang3():
|
||||
native.maven_jar(
|
||||
name = "org_apache_commons_commons_lang3",
|
||||
artifact = "org.apache.commons:commons-lang3:3.5",
|
||||
sha1 = "6c6c702c89bfff3cd9e80b04d668c5e190d588c6"
|
||||
)
|
||||
|
|
@ -16,6 +16,7 @@ include ":grpc-interop-testing"
|
|||
include ":grpc-gae-interop-testing-jdk7"
|
||||
include ":grpc-gae-interop-testing-jdk8"
|
||||
include ":grpc-all"
|
||||
include ":grpc-alts"
|
||||
include ":grpc-benchmarks"
|
||||
include ":grpc-services"
|
||||
|
||||
|
|
@ -36,6 +37,7 @@ project(':grpc-interop-testing').projectDir = "$rootDir/interop-testing" as File
|
|||
project(':grpc-gae-interop-testing-jdk7').projectDir = "$rootDir/gae-interop-testing/gae-jdk7" as File
|
||||
project(':grpc-gae-interop-testing-jdk8').projectDir = "$rootDir/gae-interop-testing/gae-jdk8" as File
|
||||
project(':grpc-all').projectDir = "$rootDir/all" as File
|
||||
project(':grpc-alts').projectDir = "$rootDir/alts" as File
|
||||
project(':grpc-benchmarks').projectDir = "$rootDir/benchmarks" as File
|
||||
project(':grpc-services').projectDir = "$rootDir/services" as File
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue