diff --git a/grpc/client.go b/grpc/client.go index 5c245b57e..13314f281 100644 --- a/grpc/client.go +++ b/grpc/client.go @@ -28,7 +28,7 @@ import ( // It dials the remote service and returns a grpc.ClientConn if successful. func ClientSetup(c *cmd.GRPCClientConfig, tlsConfig *tls.Config, statsRegistry prometheus.Registerer, clk clock.Clock, interceptors ...grpc.UnaryClientInterceptor) (*grpc.ClientConn, error) { if c == nil { - return nil, errors.New("nil gRPC client config provided. JSON config is probably missing a fooService section.") + return nil, errors.New("nil gRPC client config provided: JSON config is probably missing a fooService section") } if tlsConfig == nil { return nil, errNilTLS @@ -39,17 +39,17 @@ func ClientSetup(c *cmd.GRPCClientConfig, tlsConfig *tls.Config, statsRegistry p return nil, err } - ci := clientInterceptor{c.Timeout.Duration, metrics, clk} + cmi := clientMetadataInterceptor{c.Timeout.Duration, metrics, clk} unaryInterceptors := append(interceptors, []grpc.UnaryClientInterceptor{ - ci.interceptUnary, - ci.metrics.grpcMetrics.UnaryClientInterceptor(), + cmi.Unary, + cmi.metrics.grpcMetrics.UnaryClientInterceptor(), hnygrpc.UnaryClientInterceptor(), }...) streamInterceptors := []grpc.StreamClientInterceptor{ - ci.interceptStream, - ci.metrics.grpcMetrics.StreamClientInterceptor(), + cmi.Stream, + cmi.metrics.grpcMetrics.StreamClientInterceptor(), // TODO(#6361): Get a tracing interceptor that works for gRPC streams. } diff --git a/grpc/errors_test.go b/grpc/errors_test.go index 15494bc22..2bc987472 100644 --- a/grpc/errors_test.go +++ b/grpc/errors_test.go @@ -31,11 +31,11 @@ func (s *errorServer) Chill(_ context.Context, _ *test_proto.Time) (*test_proto. func TestErrorWrapping(t *testing.T) { serverMetrics, err := newServerMetrics(metrics.NoopRegisterer) test.AssertNotError(t, err, "creating server metrics") - si := newServerInterceptor(serverMetrics, clock.NewFake()) + smi := newServerMetadataInterceptor(serverMetrics, clock.NewFake()) clientMetrics, err := newClientMetrics(metrics.NoopRegisterer) test.AssertNotError(t, err, "creating client metrics") - ci := clientInterceptor{time.Second, clientMetrics, clock.NewFake()} - srv := grpc.NewServer(grpc.UnaryInterceptor(si.interceptUnary)) + cmi := clientMetadataInterceptor{time.Second, clientMetrics, clock.NewFake()} + srv := grpc.NewServer(grpc.UnaryInterceptor(smi.Unary)) es := &errorServer{} test_proto.RegisterChillerServer(srv, es) lis, err := net.Listen("tcp", "127.0.0.1:") @@ -46,7 +46,7 @@ func TestErrorWrapping(t *testing.T) { conn, err := grpc.Dial( lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithUnaryInterceptor(ci.interceptUnary), + grpc.WithUnaryInterceptor(cmi.Unary), ) test.AssertNotError(t, err, "Failed to dial grpc test server") client := test_proto.NewChillerClient(conn) @@ -74,11 +74,11 @@ func TestErrorWrapping(t *testing.T) { func TestSubErrorWrapping(t *testing.T) { serverMetrics, err := newServerMetrics(metrics.NoopRegisterer) test.AssertNotError(t, err, "creating server metrics") - si := newServerInterceptor(serverMetrics, clock.NewFake()) + smi := newServerMetadataInterceptor(serverMetrics, clock.NewFake()) clientMetrics, err := newClientMetrics(metrics.NoopRegisterer) test.AssertNotError(t, err, "creating client metrics") - ci := clientInterceptor{time.Second, clientMetrics, clock.NewFake()} - srv := grpc.NewServer(grpc.UnaryInterceptor(si.interceptUnary)) + cmi := clientMetadataInterceptor{time.Second, clientMetrics, clock.NewFake()} + srv := grpc.NewServer(grpc.UnaryInterceptor(smi.Unary)) es := &errorServer{} test_proto.RegisterChillerServer(srv, es) lis, err := net.Listen("tcp", "127.0.0.1:") @@ -89,7 +89,7 @@ func TestSubErrorWrapping(t *testing.T) { conn, err := grpc.Dial( lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithUnaryInterceptor(ci.interceptUnary), + grpc.WithUnaryInterceptor(cmi.Unary), ) test.AssertNotError(t, err, "Failed to dial grpc test server") client := test_proto.NewChillerClient(conn) diff --git a/grpc/interceptors.go b/grpc/interceptors.go index 5835cd7a8..091536faf 100644 --- a/grpc/interceptors.go +++ b/grpc/interceptors.go @@ -11,9 +11,12 @@ import ( "github.com/prometheus/client_golang/prometheus" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" "google.golang.org/grpc/status" + "github.com/letsencrypt/boulder/cmd" berrors "github.com/letsencrypt/boulder/errors" "github.com/letsencrypt/boulder/probs" ) @@ -24,6 +27,33 @@ const ( clientRequestTimeKey = "client-request-time" ) +type serverInterceptor interface { + Unary(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) + Stream(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error +} + +// noopServerInterceptor provides no-op interceptors. It can be substituted for +// an interceptor that has been disabled. +type noopServerInterceptor struct{} + +// Unary is a gRPC unary interceptor. +func (n *noopServerInterceptor) Unary(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + return handler(ctx, req) +} + +// Stream is a gRPC stream interceptor. +func (n *noopServerInterceptor) Stream(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + return handler(srv, ss) +} + +// Ensure noopServerInterceptor matches the serverInterceptor interface. +var _ serverInterceptor = &noopServerInterceptor{} + +type clientInterceptor interface { + Unary(ctx context.Context, method string, req interface{}, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error + Stream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) +} + // NoCancelInterceptor is a gRPC interceptor that creates a new context, // separate from the original context, that has the same deadline but does // not propagate cancellation. This is used by SA. @@ -42,23 +72,23 @@ func NoCancelInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryS return handler(ctx, req) } -// serverInterceptor is a gRPC interceptor that adds Prometheus +// serverMetadataInterceptor is a gRPC interceptor that adds Prometheus // metrics to requests handled by a gRPC server, and wraps Boulder-specific // errors for transmission in a grpc/metadata trailer (see bcodes.go). -type serverInterceptor struct { +type serverMetadataInterceptor struct { metrics serverMetrics clk clock.Clock } -func newServerInterceptor(metrics serverMetrics, clk clock.Clock) serverInterceptor { - return serverInterceptor{ +func newServerMetadataInterceptor(metrics serverMetrics, clk clock.Clock) serverMetadataInterceptor { + return serverMetadataInterceptor{ metrics: metrics, clk: clk, } } -// interceptUnary implements the grpc.UnaryServerInterceptor interface. -func (si *serverInterceptor) interceptUnary( +// Unary implements the grpc.UnaryServerInterceptor interface. +func (smi *serverMetadataInterceptor) Unary( ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, @@ -71,7 +101,7 @@ func (si *serverInterceptor) interceptUnary( // a `clientRequestTimeKey` field, and it has a value, then observe the RPC // latency with Prometheus. if md, ok := metadata.FromIncomingContext(ctx); ok && len(md[clientRequestTimeKey]) > 0 { - err := si.observeLatency(md[clientRequestTimeKey][0]) + err := smi.observeLatency(md[clientRequestTimeKey][0]) if err != nil { return nil, err } @@ -118,8 +148,8 @@ func (iss interceptedServerStream) Context() context.Context { return iss.ctx } -// interceptStream implements the grpc.StreamServerInterceptor interface. -func (si *serverInterceptor) interceptStream( +// Stream implements the grpc.StreamServerInterceptor interface. +func (smi *serverMetadataInterceptor) Stream( srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, @@ -130,7 +160,7 @@ func (si *serverInterceptor) interceptStream( // a `clientRequestTimeKey` field, and it has a value, then observe the RPC // latency with Prometheus. if md, ok := metadata.FromIncomingContext(ctx); ok && len(md[clientRequestTimeKey]) > 0 { - err := si.observeLatency(md[clientRequestTimeKey][0]) + err := smi.observeLatency(md[clientRequestTimeKey][0]) if err != nil { return err } @@ -184,7 +214,7 @@ func splitMethodName(fullMethodName string) (string, string) { // used to calculate the latency between send and receive time. The latency is // published to the server interceptor's rpcLag prometheus histogram. An error // is returned if the `clientReqTime` string is not a valid timestamp. -func (si *serverInterceptor) observeLatency(clientReqTime string) error { +func (smi *serverMetadataInterceptor) observeLatency(clientReqTime string) error { // Convert the metadata request time into an int64 reqTimeUnixNanos, err := strconv.ParseInt(clientReqTime, 10, 64) if err != nil { @@ -193,27 +223,30 @@ func (si *serverInterceptor) observeLatency(clientReqTime string) error { } // Calculate the elapsed time since the client sent the RPC reqTime := time.Unix(0, reqTimeUnixNanos) - elapsed := si.clk.Since(reqTime) + elapsed := smi.clk.Since(reqTime) // Publish an RPC latency observation to the histogram - si.metrics.rpcLag.Observe(elapsed.Seconds()) + smi.metrics.rpcLag.Observe(elapsed.Seconds()) return nil } -// clientInterceptor is a gRPC interceptor that adds Prometheus +// Ensure serverMetadataInterceptor matches the serverInterceptor interface. +var _ serverInterceptor = (*serverMetadataInterceptor)(nil) + +// clientMetadataInterceptor is a gRPC interceptor that adds Prometheus // metrics to sent requests, and disables FailFast. We disable FailFast because // non-FailFast mode is most similar to the old AMQP RPC layer: If a client // makes a request while all backends are briefly down (e.g. for a restart), the // request doesn't necessarily fail. A backend can service the request if it // comes back up within the timeout. Under gRPC the same effect is achieved by // retries up to the Context deadline. -type clientInterceptor struct { +type clientMetadataInterceptor struct { timeout time.Duration metrics clientMetrics clk clock.Clock } -// interceptUnary implements the grpc.UnaryClientInterceptor interface. -func (ci *clientInterceptor) interceptUnary( +// Unary implements the grpc.UnaryClientInterceptor interface. +func (cmi *clientMetadataInterceptor) Unary( ctx context.Context, fullMethod string, req, @@ -223,16 +256,16 @@ func (ci *clientInterceptor) interceptUnary( opts ...grpc.CallOption) error { // This should not occur but fail fast with a clear error if it does (e.g. // because of buggy unit test code) instead of a generic nil panic later! - if ci.metrics.inFlightRPCs == nil { + if cmi.metrics.inFlightRPCs == nil { return berrors.InternalServerError("clientInterceptor has nil inFlightRPCs gauge") } // Ensure that the context has a deadline set. - localCtx, cancel := context.WithTimeout(ctx, ci.timeout) + localCtx, cancel := context.WithTimeout(ctx, cmi.timeout) defer cancel() // Convert the current unix nano timestamp to a string for embedding in the grpc metadata - nowTS := strconv.FormatInt(ci.clk.Now().UnixNano(), 10) + nowTS := strconv.FormatInt(cmi.clk.Now().UnixNano(), 10) // Create a grpc/metadata.Metadata instance for the request metadata. // Initialize it with the request time. reqMD := metadata.New(map[string]string{clientRequestTimeKey: nowTS}) @@ -259,12 +292,12 @@ func (ci *clientInterceptor) interceptUnary( "service": service, } // Increment the inFlightRPCs gauge for this method/service - ci.metrics.inFlightRPCs.With(labels).Inc() + cmi.metrics.inFlightRPCs.With(labels).Inc() // And defer decrementing it when we're done - defer ci.metrics.inFlightRPCs.With(labels).Dec() + defer cmi.metrics.inFlightRPCs.With(labels).Dec() // Handle the RPC - begin := ci.clk.Now() + begin := cmi.clk.Now() err := invoker(localCtx, fullMethod, req, reply, cc, opts...) if err != nil { err = unwrapError(err, respMD) @@ -272,7 +305,7 @@ func (ci *clientInterceptor) interceptUnary( return deadlineDetails{ service: service, method: method, - latency: ci.clk.Since(begin), + latency: cmi.clk.Since(begin), } } } @@ -322,8 +355,8 @@ func (ics interceptedClientStream) CloseSend() error { return err } -// interceptUnary implements the grpc.StreamClientInterceptor interface. -func (ci *clientInterceptor) interceptStream( +// Stream implements the grpc.StreamClientInterceptor interface. +func (cmi *clientMetadataInterceptor) Stream( ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, @@ -332,16 +365,16 @@ func (ci *clientInterceptor) interceptStream( opts ...grpc.CallOption) (grpc.ClientStream, error) { // This should not occur but fail fast with a clear error if it does (e.g. // because of buggy unit test code) instead of a generic nil panic later! - if ci.metrics.inFlightRPCs == nil { + if cmi.metrics.inFlightRPCs == nil { return nil, berrors.InternalServerError("clientInterceptor has nil inFlightRPCs gauge") } // We don't defer cancel() here, because this function is going to return // immediately. Instead we store it in the interceptedClientStream. - localCtx, cancel := context.WithTimeout(ctx, ci.timeout) + localCtx, cancel := context.WithTimeout(ctx, cmi.timeout) // Convert the current unix nano timestamp to a string for embedding in the grpc metadata - nowTS := strconv.FormatInt(ci.clk.Now().UnixNano(), 10) + nowTS := strconv.FormatInt(cmi.clk.Now().UnixNano(), 10) // Create a grpc/metadata.Metadata instance for the request metadata. // Initialize it with the request time. reqMD := metadata.New(map[string]string{clientRequestTimeKey: nowTS}) @@ -368,21 +401,21 @@ func (ci *clientInterceptor) interceptStream( "service": service, } // Increment the inFlightRPCs gauge for this method/service - ci.metrics.inFlightRPCs.With(labels).Inc() - begin := ci.clk.Now() + cmi.metrics.inFlightRPCs.With(labels).Inc() + begin := cmi.clk.Now() // Cancel the local context and decrement the metric when we're done. Also // transform the error into a more usable form, if necessary. finish := func(err error) error { cancel() - ci.metrics.inFlightRPCs.With(labels).Dec() + cmi.metrics.inFlightRPCs.With(labels).Dec() if err != nil { err = unwrapError(err, respMD) if status.Code(err) == codes.DeadlineExceeded { return deadlineDetails{ service: service, method: method, - latency: ci.clk.Since(begin), + latency: cmi.clk.Since(begin), } } } @@ -395,6 +428,8 @@ func (ci *clientInterceptor) interceptStream( return ics, err } +var _ clientInterceptor = (*clientMetadataInterceptor)(nil) + // CancelTo408Interceptor calls the underlying invoker, checks to see if the // resulting error was a gRPC Canceled error (because this client cancelled // the request, likely because the ACME client itself canceled the HTTP @@ -421,3 +456,94 @@ func (dd deadlineDetails) Error() string { return fmt.Sprintf("%s.%s timed out after %d ms", dd.service, dd.method, int64(dd.latency/time.Millisecond)) } + +// authInterceptor provides two server interceptors (Unary and Stream) which can +// check that every request for a given gRPC service is being made over an mTLS +// connection from a client which is allow-listed for that particular service. +type authInterceptor struct { + // serviceClientNames is a map of gRPC service names (e.g. "ca.CertificateAuthority") + // to allowed client certificate SANs (e.g. "ra.boulder") which are allowed to + // make RPCs to that service. The set of client names is implemented as a map + // of names to empty structs for easy lookup. + serviceClientNames map[string]map[string]struct{} +} + +// newServiceAuthChecker takes a GRPCServerConfig and uses its Service stanzas +// to construct a serviceAuthChecker which enforces the service/client mappings +// contained in the config. +func newServiceAuthChecker(c *cmd.GRPCServerConfig) *authInterceptor { + names := make(map[string]map[string]struct{}) + for serviceName, service := range c.Services { + names[serviceName] = make(map[string]struct{}) + for _, clientName := range service.ClientNames { + names[serviceName][clientName] = struct{}{} + } + } + return &authInterceptor{names} +} + +// Unary is a gRPC unary interceptor. +func (ac *authInterceptor) Unary(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + err := ac.checkContextAuth(ctx, info.FullMethod) + if err != nil { + return nil, err + } + return handler(ctx, req) +} + +// Stream is a gRPC stream interceptor. +func (ac *authInterceptor) Stream(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + err := ac.checkContextAuth(ss.Context(), info.FullMethod) + if err != nil { + return err + } + return handler(srv, ss) +} + +// checkContextAuth does most of the heavy lifting. It extracts TLS information +// from the incoming context, gets the set of DNS names contained in the client +// mTLS cert, and returns nil if at least one of those names appears in the set +// of allowed client names for given service (or if the set of allowed client +// names is empty). +func (ac *authInterceptor) checkContextAuth(ctx context.Context, fullMethod string) error { + serviceName, _ := splitMethodName(fullMethod) + + allowedClientNames, ok := ac.serviceClientNames[serviceName] + if !ok || len(allowedClientNames) == 0 { + return fmt.Errorf("service %q has no allowed client names", serviceName) + } + + p, ok := peer.FromContext(ctx) + if !ok { + return fmt.Errorf("unable to fetch peer info from grpc context") + } + + if p.AuthInfo == nil { + return fmt.Errorf("grpc connection appears to be plaintext") + } + + tlsAuth, ok := p.AuthInfo.(credentials.TLSInfo) + if !ok { + return fmt.Errorf("connection is not TLS authed") + } + + if len(tlsAuth.State.VerifiedChains) == 0 || len(tlsAuth.State.VerifiedChains[0]) == 0 { + return fmt.Errorf("connection auth not verified") + } + + cert := tlsAuth.State.VerifiedChains[0][0] + + for _, clientName := range cert.DNSNames { + _, ok := allowedClientNames[clientName] + if ok { + return nil + } + } + + return fmt.Errorf( + "client names %v are not authorized for service %q (%v)", + cert.DNSNames, serviceName, allowedClientNames) +} + +// Ensure authInterceptor matches the serverInterceptor interface. +var _ serverInterceptor = (*authInterceptor)(nil) diff --git a/grpc/interceptors_test.go b/grpc/interceptors_test.go index 3fbc25c8d..4cf103356 100644 --- a/grpc/interceptors_test.go +++ b/grpc/interceptors_test.go @@ -2,6 +2,8 @@ package grpc import ( "context" + "crypto/tls" + "crypto/x509" "errors" "fmt" "log" @@ -18,8 +20,10 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/balancer/roundrobin" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" "google.golang.org/grpc/status" "github.com/letsencrypt/boulder/grpc/test_proto" @@ -52,37 +56,37 @@ func testInvoker(_ context.Context, method string, _, _ interface{}, _ *grpc.Cli func TestServerInterceptor(t *testing.T) { serverMetrics, err := newServerMetrics(metrics.NoopRegisterer) test.AssertNotError(t, err, "creating server metrics") - si := newServerInterceptor(serverMetrics, clock.NewFake()) + si := newServerMetadataInterceptor(serverMetrics, clock.NewFake()) md := metadata.New(map[string]string{clientRequestTimeKey: "0"}) ctxWithMetadata := metadata.NewIncomingContext(context.Background(), md) - _, err = si.interceptUnary(context.Background(), nil, nil, testHandler) + _, err = si.Unary(context.Background(), nil, nil, testHandler) test.AssertError(t, err, "si.intercept didn't fail with a context missing metadata") - _, err = si.interceptUnary(ctxWithMetadata, nil, nil, testHandler) + _, err = si.Unary(ctxWithMetadata, nil, nil, testHandler) test.AssertError(t, err, "si.intercept didn't fail with a nil grpc.UnaryServerInfo") - _, err = si.interceptUnary(ctxWithMetadata, nil, &grpc.UnaryServerInfo{FullMethod: "-service-test"}, testHandler) + _, err = si.Unary(ctxWithMetadata, nil, &grpc.UnaryServerInfo{FullMethod: "-service-test"}, testHandler) test.AssertNotError(t, err, "si.intercept failed with a non-nil grpc.UnaryServerInfo") - _, err = si.interceptUnary(ctxWithMetadata, 0, &grpc.UnaryServerInfo{FullMethod: "brokeTest"}, testHandler) + _, err = si.Unary(ctxWithMetadata, 0, &grpc.UnaryServerInfo{FullMethod: "brokeTest"}, testHandler) test.AssertError(t, err, "si.intercept didn't fail when handler returned a error") } func TestClientInterceptor(t *testing.T) { clientMetrics, err := newClientMetrics(metrics.NoopRegisterer) test.AssertNotError(t, err, "creating client metrics") - ci := clientInterceptor{ + ci := clientMetadataInterceptor{ timeout: time.Second, metrics: clientMetrics, clk: clock.NewFake(), } - err = ci.interceptUnary(context.Background(), "-service-test", nil, nil, nil, testInvoker) + err = ci.Unary(context.Background(), "-service-test", nil, nil, nil, testInvoker) test.AssertNotError(t, err, "ci.intercept failed with a non-nil grpc.UnaryServerInfo") - err = ci.interceptUnary(context.Background(), "-service-brokeTest", nil, nil, nil, testInvoker) + err = ci.Unary(context.Background(), "-service-brokeTest", nil, nil, nil, testInvoker) test.AssertError(t, err, "ci.intercept didn't fail when handler returned a error") } @@ -106,7 +110,7 @@ func TestCancelTo408Interceptor(t *testing.T) { func TestFailFastFalse(t *testing.T) { clientMetrics, err := newClientMetrics(metrics.NoopRegisterer) test.AssertNotError(t, err, "creating client metrics") - ci := &clientInterceptor{ + ci := &clientMetadataInterceptor{ timeout: 100 * time.Millisecond, metrics: clientMetrics, clk: clock.NewFake(), @@ -114,7 +118,7 @@ func TestFailFastFalse(t *testing.T) { conn, err := grpc.Dial("localhost:19876", // random, probably unused port grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, roundrobin.Name)), grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithUnaryInterceptor(ci.interceptUnary)) + grpc.WithUnaryInterceptor(ci.Unary)) if err != nil { t.Fatalf("did not connect: %v", err) } @@ -161,8 +165,8 @@ func TestTimeouts(t *testing.T) { serverMetrics, err := newServerMetrics(metrics.NoopRegisterer) test.AssertNotError(t, err, "creating server metrics") - si := newServerInterceptor(serverMetrics, clock.NewFake()) - s := grpc.NewServer(grpc.UnaryInterceptor(si.interceptUnary)) + si := newServerMetadataInterceptor(serverMetrics, clock.NewFake()) + s := grpc.NewServer(grpc.UnaryInterceptor(si.Unary)) test_proto.RegisterChillerServer(s, &testServer{}) go func() { start := time.Now() @@ -176,14 +180,14 @@ func TestTimeouts(t *testing.T) { // make client clientMetrics, err := newClientMetrics(metrics.NoopRegisterer) test.AssertNotError(t, err, "creating client metrics") - ci := &clientInterceptor{ + ci := &clientMetadataInterceptor{ timeout: 30 * time.Second, metrics: clientMetrics, clk: clock.NewFake(), } conn, err := grpc.Dial(net.JoinHostPort("localhost", strconv.Itoa(port)), grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithUnaryInterceptor(ci.interceptUnary)) + grpc.WithUnaryInterceptor(ci.Unary)) if err != nil { t.Fatalf("did not connect: %v", err) } @@ -225,8 +229,8 @@ func TestRequestTimeTagging(t *testing.T) { // Create a new ChillerServer serverMetrics, err := newServerMetrics(metrics.NoopRegisterer) test.AssertNotError(t, err, "creating server metrics") - si := newServerInterceptor(serverMetrics, clk) - s := grpc.NewServer(grpc.UnaryInterceptor(si.interceptUnary)) + si := newServerMetadataInterceptor(serverMetrics, clk) + s := grpc.NewServer(grpc.UnaryInterceptor(si.Unary)) test_proto.RegisterChillerServer(s, &testServer{}) // Chill until ill go func() { @@ -241,14 +245,14 @@ func TestRequestTimeTagging(t *testing.T) { // Dial the ChillerServer clientMetrics, err := newClientMetrics(metrics.NoopRegisterer) test.AssertNotError(t, err, "creating client metrics") - ci := &clientInterceptor{ + ci := &clientMetadataInterceptor{ timeout: 30 * time.Second, metrics: clientMetrics, clk: clk, } conn, err := grpc.Dial(net.JoinHostPort("localhost", strconv.Itoa(port)), grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithUnaryInterceptor(ci.interceptUnary)) + grpc.WithUnaryInterceptor(ci.Unary)) if err != nil { t.Fatalf("did not connect: %v", err) } @@ -314,8 +318,8 @@ func TestInFlightRPCStat(t *testing.T) { serverMetrics, err := newServerMetrics(metrics.NoopRegisterer) test.AssertNotError(t, err, "creating server metrics") - si := newServerInterceptor(serverMetrics, clk) - s := grpc.NewServer(grpc.UnaryInterceptor(si.interceptUnary)) + si := newServerMetadataInterceptor(serverMetrics, clk) + s := grpc.NewServer(grpc.UnaryInterceptor(si.Unary)) test_proto.RegisterChillerServer(s, server) // Chill until ill go func() { @@ -330,14 +334,14 @@ func TestInFlightRPCStat(t *testing.T) { // Dial the ChillerServer clientMetrics, err := newClientMetrics(metrics.NoopRegisterer) test.AssertNotError(t, err, "creating client metrics") - ci := &clientInterceptor{ + ci := &clientMetadataInterceptor{ timeout: 30 * time.Second, metrics: clientMetrics, clk: clk, } conn, err := grpc.Dial(net.JoinHostPort("localhost", strconv.Itoa(port)), grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithUnaryInterceptor(ci.interceptUnary)) + grpc.WithUnaryInterceptor(ci.Unary)) if err != nil { t.Fatalf("did not connect: %v", err) } @@ -395,3 +399,76 @@ func TestNoCancelInterceptor(t *testing.T) { t.Error(err) } } + +func TestServiceAuthChecker(t *testing.T) { + ac := authInterceptor{ + map[string]map[string]struct{}{ + "package.ServiceName": { + "allowed.client": {}, + "also.allowed": {}, + }, + }, + } + + // No allowlist is a bad configuration. + ctx := context.Background() + err := ac.checkContextAuth(ctx, "/package.OtherService/Method/") + test.AssertError(t, err, "checking empty allowlist") + + // Context with no peering information is disallowed. + err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/") + test.AssertError(t, err, "checking un-peered context") + + // Context with no auth info is disallowed. + ctx = peer.NewContext(ctx, &peer.Peer{}) + err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/") + test.AssertError(t, err, "checking peer with no auth") + + // Context with no verified chains is disallowed. + ctx = peer.NewContext(ctx, &peer.Peer{ + AuthInfo: credentials.TLSInfo{ + State: tls.ConnectionState{}, + }, + }) + err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/") + test.AssertError(t, err, "checking TLS with no valid chains") + + // Context with cert with wrong name is disallowed. + ctx = peer.NewContext(ctx, &peer.Peer{ + AuthInfo: credentials.TLSInfo{ + State: tls.ConnectionState{ + VerifiedChains: [][]*x509.Certificate{ + { + &x509.Certificate{ + DNSNames: []string{ + "disallowed.client", + }, + }, + }, + }, + }, + }, + }) + err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/") + test.AssertError(t, err, "checking disallowed cert") + + // Context with cert with good name is allowed. + ctx = peer.NewContext(ctx, &peer.Peer{ + AuthInfo: credentials.TLSInfo{ + State: tls.ConnectionState{ + VerifiedChains: [][]*x509.Certificate{ + { + &x509.Certificate{ + DNSNames: []string{ + "disallowed.client", + "also.allowed", + }, + }, + }, + }, + }, + }, + }) + err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/") + test.AssertNotError(t, err, "checking allowed cert") +} diff --git a/grpc/server.go b/grpc/server.go index d9fec6f42..a8b639d84 100644 --- a/grpc/server.go +++ b/grpc/server.go @@ -118,17 +118,26 @@ func (sb *serverBuilder) Build(tlsConfig *tls.Config, statsRegistry prometheus.R return nil, nil, err } - si := newServerInterceptor(metrics, clk) + var ai serverInterceptor + if len(sb.cfg.Services) > 0 { + ai = newServiceAuthChecker(sb.cfg) + } else { + ai = &noopServerInterceptor{} + } + + mi := newServerMetadataInterceptor(metrics, clk) unaryInterceptors := append([]grpc.UnaryServerInterceptor{ - si.interceptUnary, - si.metrics.grpcMetrics.UnaryServerInterceptor(), + mi.metrics.grpcMetrics.UnaryServerInterceptor(), + ai.Unary, + mi.Unary, hnygrpc.UnaryServerInterceptor(), }, interceptors...) streamInterceptors := []grpc.StreamServerInterceptor{ - si.interceptStream, - si.metrics.grpcMetrics.StreamServerInterceptor(), + mi.metrics.grpcMetrics.StreamServerInterceptor(), + ai.Stream, + mi.Stream, // TODO(#6361): Get a tracing interceptor that works for gRPC streams. }