alts: add channel logs in handshake

The logs are to help with debugging issues for an internal customer.
This commit is contained in:
Zhouyihai Ding 2021-09-22 21:40:41 -07:00 committed by GitHub
parent e76efbb5da
commit cf41181c48
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 81 additions and 42 deletions

View File

@ -20,19 +20,17 @@ import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.protobuf.ByteString;
import io.grpc.ChannelLogger;
import io.grpc.ChannelLogger.ChannelLogLevel;
import io.grpc.Status;
import io.grpc.alts.internal.HandshakerServiceGrpc.HandshakerServiceStub;
import java.io.IOException;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.logging.Level;
import java.util.logging.Logger;
/** An API for conducting handshakes via ALTS handshaker service. */
class AltsHandshakerClient {
private static final Logger logger = Logger.getLogger(AltsHandshakerClient.class.getName());
private static final String APPLICATION_PROTOCOL = "grpc";
private static final String RECORD_PROTOCOL = "ALTSRP_GCM_AES128_REKEY";
private static final int KEY_LENGTH = AltsChannelCrypter.getKeyLength();
@ -41,17 +39,22 @@ class AltsHandshakerClient {
private final AltsHandshakerOptions handshakerOptions;
private HandshakerResult result;
private HandshakerStatus status;
private final ChannelLogger logger;
/** Starts a new handshake interacting with the handshaker service. */
AltsHandshakerClient(HandshakerServiceStub stub, AltsHandshakerOptions options) {
AltsHandshakerClient(
HandshakerServiceStub stub, AltsHandshakerOptions options, ChannelLogger logger) {
handshakerStub = new AltsHandshakerStub(stub);
handshakerOptions = options;
this.logger = logger;
}
@VisibleForTesting
AltsHandshakerClient(AltsHandshakerStub handshakerStub, AltsHandshakerOptions options) {
AltsHandshakerClient(
AltsHandshakerStub handshakerStub, AltsHandshakerOptions options, ChannelLogger logger) {
this.handshakerStub = handshakerStub;
handshakerOptions = options;
this.logger = logger;
}
static String getApplicationProtocol() {
@ -154,7 +157,7 @@ class AltsHandshakerClient {
}
if (status.getCode() != Status.Code.OK.value()) {
String error = "Handshaker service error: " + status.getDetails();
logger.log(Level.INFO, error);
logger.log(ChannelLogLevel.DEBUG, error);
close();
throw new GeneralSecurityException(error);
}
@ -173,7 +176,9 @@ class AltsHandshakerClient {
setStartClientFields(req);
HandshakerResp resp;
try {
logger.log(ChannelLogLevel.DEBUG, "Send ALTS handshake request to upstream");
resp = handshakerStub.send(req.build());
logger.log(ChannelLogLevel.DEBUG, "Receive ALTS handshake response from upstream");
} catch (IOException | InterruptedException e) {
throw new GeneralSecurityException(e);
}
@ -223,7 +228,9 @@ class AltsHandshakerClient {
.build());
HandshakerResp resp;
try {
logger.log(ChannelLogLevel.DEBUG, "Send ALTS handshake request to upstream");
resp = handshakerStub.send(req.build());
logger.log(ChannelLogLevel.DEBUG, "Receive ALTS handshake response from upstream");
} catch (IOException | InterruptedException e) {
throw new GeneralSecurityException(e);
}

View File

@ -115,8 +115,9 @@ public final class AltsProtocolNegotiator {
@Override
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
TsiHandshaker handshaker = handshakerFactory.newHandshaker(grpcHandler.getAuthority());
ChannelLogger negotiationLogger = grpcHandler.getNegotiationLogger();
TsiHandshaker handshaker =
handshakerFactory.newHandshaker(grpcHandler.getAuthority(), negotiationLogger);
NettyTsiHandshaker nettyHandshaker = new NettyTsiHandshaker(handshaker);
ChannelHandler gnh = InternalProtocolNegotiators.grpcNegotiationHandler(grpcHandler);
ChannelHandler thh = new TsiHandshakeHandler(
@ -142,11 +143,13 @@ public final class AltsProtocolNegotiator {
final class ServerTsiHandshakerFactory implements TsiHandshakerFactory {
@Override
public TsiHandshaker newHandshaker(@Nullable String authority) {
public TsiHandshaker newHandshaker(
@Nullable String authority, ChannelLogger negotiationLogger) {
assert authority == null;
return AltsTsiHandshaker.newServer(
HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()),
new AltsHandshakerOptions(RpcProtocolVersionsUtil.getRpcProtocolVersions()));
new AltsHandshakerOptions(RpcProtocolVersionsUtil.getRpcProtocolVersions()),
negotiationLogger);
}
}
@ -174,7 +177,8 @@ public final class AltsProtocolNegotiator {
@Override
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
ChannelLogger negotiationLogger = grpcHandler.getNegotiationLogger();
TsiHandshaker handshaker = handshakerFactory.newHandshaker(/* authority= */ null);
TsiHandshaker handshaker =
handshakerFactory.newHandshaker(/* authority= */ null, negotiationLogger);
NettyTsiHandshaker nettyHandshaker = new NettyTsiHandshaker(handshaker);
ChannelHandler gnh = InternalProtocolNegotiators.grpcNegotiationHandler(grpcHandler);
ChannelHandler thh = new TsiHandshakeHandler(
@ -292,7 +296,8 @@ public final class AltsProtocolNegotiator {
if (grpcHandler.getEagAttributes().get(GrpclbConstants.ATTR_LB_ADDR_AUTHORITY) != null
|| grpcHandler.getEagAttributes().get(GrpclbConstants.ATTR_LB_PROVIDED_BACKEND) != null
|| isXdsDirectPath) {
TsiHandshaker handshaker = handshakerFactory.newHandshaker(grpcHandler.getAuthority());
TsiHandshaker handshaker =
handshakerFactory.newHandshaker(grpcHandler.getAuthority(), negotiationLogger);
NettyTsiHandshaker nettyHandshaker = new NettyTsiHandshaker(handshaker);
securityHandler = new TsiHandshakeHandler(
gnh, nettyHandshaker, new AltsHandshakeValidator(), handshakeSemaphore,
@ -325,7 +330,8 @@ public final class AltsProtocolNegotiator {
}
@Override
public TsiHandshaker newHandshaker(@Nullable String authority) {
public TsiHandshaker newHandshaker(
@Nullable String authority, ChannelLogger negotiationLogger) {
AltsClientOptions handshakerOptions =
new AltsClientOptions.Builder()
.setRpcProtocolVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions())
@ -333,7 +339,9 @@ public final class AltsProtocolNegotiator {
.setTargetName(authority)
.build();
return AltsTsiHandshaker.newClient(
HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()), handshakerOptions);
HandshakerServiceGrpc.newStub(lazyHandshakerChannel.get()),
handshakerOptions,
negotiationLogger);
}
}

View File

@ -20,6 +20,8 @@ import static com.google.common.base.Preconditions.checkState;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.grpc.ChannelLogger;
import io.grpc.ChannelLogger.ChannelLogLevel;
import io.grpc.alts.internal.HandshakerServiceGrpc.HandshakerServiceStub;
import io.netty.buffer.ByteBufAllocator;
import java.nio.Buffer;
@ -27,14 +29,12 @@ import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
* Negotiates a grpc channel key to be used by the TsiFrameProtector, using ALTs handshaker service.
*/
public final class AltsTsiHandshaker implements TsiHandshaker {
private static final Logger logger = Logger.getLogger(AltsTsiHandshaker.class.getName());
private final ChannelLogger logger;
public static final String TSI_SERVICE_ACCOUNT_PEER_PROPERTY = "service_account";
@ -45,15 +45,20 @@ public final class AltsTsiHandshaker implements TsiHandshaker {
/** Starts a new TSI handshaker with client options. */
private AltsTsiHandshaker(
boolean isClient, HandshakerServiceStub stub, AltsHandshakerOptions options) {
boolean isClient,
HandshakerServiceStub stub,
AltsHandshakerOptions options,
ChannelLogger logger) {
this.isClient = isClient;
handshaker = new AltsHandshakerClient(stub, options);
this.logger = logger;
handshaker = new AltsHandshakerClient(stub, options, logger);
}
@VisibleForTesting
AltsTsiHandshaker(boolean isClient, AltsHandshakerClient handshaker) {
AltsTsiHandshaker(boolean isClient, AltsHandshakerClient handshaker, ChannelLogger logger) {
this.isClient = isClient;
this.handshaker = handshaker;
this.logger = logger;
}
/**
@ -80,6 +85,7 @@ public final class AltsTsiHandshaker implements TsiHandshaker {
checkState(!isClient, "Client handshaker should not process any frame at the beginning.");
outputFrame = handshaker.startServerHandshake(bytes);
} else {
logger.log(ChannelLogLevel.DEBUG, "Receive ALTS handshake from downstream");
outputFrame = handshaker.next(bytes);
}
// If handshake has finished or we already have bytes to write, just return true.
@ -124,13 +130,15 @@ public final class AltsTsiHandshaker implements TsiHandshaker {
}
/** Creates a new TsiHandshaker for use by the client. */
public static TsiHandshaker newClient(HandshakerServiceStub stub, AltsHandshakerOptions options) {
return new AltsTsiHandshaker(true, stub, options);
public static TsiHandshaker newClient(
HandshakerServiceStub stub, AltsHandshakerOptions options, ChannelLogger logger) {
return new AltsTsiHandshaker(true, stub, options, logger);
}
/** Creates a new TsiHandshaker for use by the server. */
public static TsiHandshaker newServer(HandshakerServiceStub stub, AltsHandshakerOptions options) {
return new AltsTsiHandshaker(false, stub, options);
public static TsiHandshaker newServer(
HandshakerServiceStub stub, AltsHandshakerOptions options, ChannelLogger logger) {
return new AltsTsiHandshaker(false, stub, options, logger);
}
/**
@ -142,12 +150,14 @@ public final class AltsTsiHandshaker implements TsiHandshaker {
public void getBytesToSendToPeer(ByteBuffer bytes) throws GeneralSecurityException {
if (outputFrame == null) { // A null outputFrame indicates we haven't started the handshake.
if (isClient) {
logger.log(ChannelLogLevel.DEBUG, "Initial ALTS handshake to downstream");
outputFrame = handshaker.startClientHandshake();
} else {
// The server needs bytes to process before it can start the handshake.
return;
}
}
logger.log(ChannelLogLevel.DEBUG, "Send ALTS request to downstream");
// Write as many bytes as we are able.
ByteBuffer outputFrameAlias = outputFrame;
if (outputFrame.remaining() > bytes.remaining()) {
@ -190,7 +200,7 @@ public final class AltsTsiHandshaker implements TsiHandshaker {
maxFrameSize = Math.min(peerMaxFrameSize, AltsTsiFrameProtector.getMaxFrameSize());
maxFrameSize = Math.max(AltsTsiFrameProtector.getMinFrameSize(), maxFrameSize);
}
logger.log(Level.FINE, "Maximum frame size value is {0}.", maxFrameSize);
logger.log(ChannelLogLevel.INFO, "Maximum frame size value is {0}.", maxFrameSize);
return new AltsTsiFrameProtector(maxFrameSize, new AltsChannelCrypter(key, isClient), alloc);
}

View File

@ -16,11 +16,12 @@
package io.grpc.alts.internal;
import io.grpc.ChannelLogger;
import javax.annotation.Nullable;
/** Factory that manufactures instances of {@link TsiHandshaker}. */
public interface TsiHandshakerFactory {
/** Creates a new handshaker. */
TsiHandshaker newHandshaker(@Nullable String authority);
TsiHandshaker newHandshaker(@Nullable String authority, ChannelLogger logger);
}

View File

@ -29,6 +29,7 @@ import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;
import io.grpc.internal.TestUtils.NoopChannelLogger;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
@ -60,7 +61,8 @@ public class AltsHandshakerClientTest {
.setTargetName(TEST_TARGET_NAME)
.setTargetServiceAccounts(ImmutableList.of(TEST_TARGET_SERVICE_ACCOUNT))
.build();
handshaker = new AltsHandshakerClient(mockStub, clientOptions);
NoopChannelLogger channelLogger = new NoopChannelLogger();
handshaker = new AltsHandshakerClient(mockStub, clientOptions, channelLogger);
}
@Test
@ -266,7 +268,8 @@ public class AltsHandshakerClientTest {
.setTargetServiceAccounts(ImmutableList.of(TEST_TARGET_SERVICE_ACCOUNT))
.setRpcProtocolVersions(rpcVersions)
.build();
handshaker = new AltsHandshakerClient(mockStub, clientOptions);
NoopChannelLogger channelLogger = new NoopChannelLogger();
handshaker = new AltsHandshakerClient(mockStub, clientOptions, channelLogger);
handshaker.startClientHandshake();

View File

@ -26,6 +26,7 @@ import static org.junit.Assert.assertTrue;
import io.grpc.Attributes;
import io.grpc.Channel;
import io.grpc.ChannelLogger;
import io.grpc.Grpc;
import io.grpc.InternalChannelz;
import io.grpc.ManagedChannel;
@ -131,8 +132,8 @@ public class AltsProtocolNegotiatorTest {
TsiHandshakerFactory handshakerFactory =
new DelegatingTsiHandshakerFactory(FakeTsiHandshaker.clientHandshakerFactory()) {
@Override
public TsiHandshaker newHandshaker(String authority) {
return new DelegatingTsiHandshaker(super.newHandshaker(authority)) {
public TsiHandshaker newHandshaker(String authority, ChannelLogger logger) {
return new DelegatingTsiHandshaker(super.newHandshaker(authority, logger)) {
@Override
public TsiPeer extractPeer() throws GeneralSecurityException {
return mockedTsiPeer;
@ -427,8 +428,8 @@ public class AltsProtocolNegotiatorTest {
}
@Override
public TsiHandshaker newHandshaker(String authority) {
return delegate.newHandshaker(authority);
public TsiHandshaker newHandshaker(String authority, ChannelLogger logger) {
return delegate.newHandshaker(authority, logger);
}
}

View File

@ -26,6 +26,7 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import io.grpc.internal.TestUtils.NoopChannelLogger;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import org.junit.Before;
@ -71,8 +72,9 @@ public class AltsTsiHandshakerTest {
public void setUp() throws Exception {
mockClient = mock(AltsHandshakerClient.class);
mockServer = mock(AltsHandshakerClient.class);
handshakerClient = new AltsTsiHandshaker(true, mockClient);
handshakerServer = new AltsTsiHandshaker(false, mockServer);
NoopChannelLogger channelLogger = new NoopChannelLogger();
handshakerClient = new AltsTsiHandshaker(true, mockClient, channelLogger);
handshakerServer = new AltsTsiHandshaker(false, mockServer, channelLogger);
}
private HandshakerResult getHandshakerResult(boolean isClient) {

View File

@ -21,6 +21,7 @@ import static org.junit.Assert.assertEquals;
import com.google.common.testing.GcFinalization;
import io.grpc.alts.internal.ByteBufTestUtils.RegisterRef;
import io.grpc.alts.internal.TsiTest.Handshakers;
import io.grpc.internal.TestUtils.NoopChannelLogger;
import io.netty.buffer.ByteBuf;
import io.netty.util.ReferenceCounted;
import io.netty.util.ResourceLeakDetector;
@ -61,8 +62,9 @@ public class AltsTsiTest {
AltsHandshakerOptions handshakerOptions = new AltsHandshakerOptions(null);
MockAltsHandshakerStub clientStub = new MockAltsHandshakerStub();
MockAltsHandshakerStub serverStub = new MockAltsHandshakerStub();
client = new AltsHandshakerClient(clientStub, handshakerOptions);
server = new AltsHandshakerClient(serverStub, handshakerOptions);
NoopChannelLogger channelLogger = new NoopChannelLogger();
client = new AltsHandshakerClient(clientStub, handshakerOptions, channelLogger);
server = new AltsHandshakerClient(serverStub, handshakerOptions, channelLogger);
}
@After
@ -76,8 +78,9 @@ public class AltsTsiTest {
}
private Handshakers newHandshakers() {
TsiHandshaker clientHandshaker = new AltsTsiHandshaker(true, client);
TsiHandshaker serverHandshaker = new AltsTsiHandshaker(false, server);
NoopChannelLogger channelLogger = new NoopChannelLogger();
TsiHandshaker clientHandshaker = new AltsTsiHandshaker(true, client, channelLogger);
TsiHandshaker serverHandshaker = new AltsTsiHandshaker(false, server, channelLogger);
return new Handshakers(clientHandshaker, serverHandshaker);
}

View File

@ -19,7 +19,9 @@ package io.grpc.alts.internal;
import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.common.base.Preconditions;
import io.grpc.ChannelLogger;
import io.grpc.alts.internal.TsiPeer.Property;
import io.grpc.internal.TestUtils.NoopChannelLogger;
import io.netty.buffer.ByteBufAllocator;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
@ -37,7 +39,7 @@ public class FakeTsiHandshaker implements TsiHandshaker {
private static final TsiHandshakerFactory clientHandshakerFactory =
new TsiHandshakerFactory() {
@Override
public TsiHandshaker newHandshaker(String authority) {
public TsiHandshaker newHandshaker(String authority, ChannelLogger logger) {
return new FakeTsiHandshaker(true);
}
};
@ -45,7 +47,7 @@ public class FakeTsiHandshaker implements TsiHandshaker {
private static final TsiHandshakerFactory serverHandshakerFactory =
new TsiHandshakerFactory() {
@Override
public TsiHandshaker newHandshaker(String authority) {
public TsiHandshaker newHandshaker(String authority, ChannelLogger logger) {
return new FakeTsiHandshaker(false);
}
};
@ -83,11 +85,13 @@ public class FakeTsiHandshaker implements TsiHandshaker {
}
public static TsiHandshaker newFakeHandshakerClient() {
return clientHandshakerFactory.newHandshaker(null);
NoopChannelLogger channelLogger = new NoopChannelLogger();
return clientHandshakerFactory.newHandshaker(null, channelLogger);
}
public static TsiHandshaker newFakeHandshakerServer() {
return serverHandshakerFactory.newHandshaker(null);
NoopChannelLogger channelLogger = new NoopChannelLogger();
return serverHandshakerFactory.newHandshaker(null, channelLogger);
}
protected FakeTsiHandshaker(boolean isClient) {