Add interceptor for per-rpc client auth (#6488)

Add a new gRPC server interceptor (both unary and streaming) which
verifies that the mTLS info set on the persistent connection has a
client cert which contains a name which is allowlisted for the
particular service being called, not just for the overall server.

This will allow us to make more services -- particularly the CA and the
SA -- more similar to the VA. We will be able to run multiple services
on the same port, while still being able to control access to those
services on a per-client basis. It will also let us split those services
(e.g. into read-only and read-write subsets) much more easily, because a
client will be able to switch which service it is calling without also
having to be reconfigured to call a different address. And finally, it
will allow us to simplify configuration for clients (such as the RA)
which maintain connections to multiple different services on the same
server, as they'll be able to re-use the same address configuration.
This commit is contained in:
Aaron Gable 2022-11-07 13:47:47 -08:00 committed by GitHub
parent 3cc84589a8
commit 257136779c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 286 additions and 74 deletions

View File

@ -28,7 +28,7 @@ import (
// It dials the remote service and returns a grpc.ClientConn if successful.
func ClientSetup(c *cmd.GRPCClientConfig, tlsConfig *tls.Config, statsRegistry prometheus.Registerer, clk clock.Clock, interceptors ...grpc.UnaryClientInterceptor) (*grpc.ClientConn, error) {
if c == nil {
return nil, errors.New("nil gRPC client config provided. JSON config is probably missing a fooService section.")
return nil, errors.New("nil gRPC client config provided: JSON config is probably missing a fooService section")
}
if tlsConfig == nil {
return nil, errNilTLS
@ -39,17 +39,17 @@ func ClientSetup(c *cmd.GRPCClientConfig, tlsConfig *tls.Config, statsRegistry p
return nil, err
}
ci := clientInterceptor{c.Timeout.Duration, metrics, clk}
cmi := clientMetadataInterceptor{c.Timeout.Duration, metrics, clk}
unaryInterceptors := append(interceptors, []grpc.UnaryClientInterceptor{
ci.interceptUnary,
ci.metrics.grpcMetrics.UnaryClientInterceptor(),
cmi.Unary,
cmi.metrics.grpcMetrics.UnaryClientInterceptor(),
hnygrpc.UnaryClientInterceptor(),
}...)
streamInterceptors := []grpc.StreamClientInterceptor{
ci.interceptStream,
ci.metrics.grpcMetrics.StreamClientInterceptor(),
cmi.Stream,
cmi.metrics.grpcMetrics.StreamClientInterceptor(),
// TODO(#6361): Get a tracing interceptor that works for gRPC streams.
}

View File

@ -31,11 +31,11 @@ func (s *errorServer) Chill(_ context.Context, _ *test_proto.Time) (*test_proto.
func TestErrorWrapping(t *testing.T) {
serverMetrics, err := newServerMetrics(metrics.NoopRegisterer)
test.AssertNotError(t, err, "creating server metrics")
si := newServerInterceptor(serverMetrics, clock.NewFake())
smi := newServerMetadataInterceptor(serverMetrics, clock.NewFake())
clientMetrics, err := newClientMetrics(metrics.NoopRegisterer)
test.AssertNotError(t, err, "creating client metrics")
ci := clientInterceptor{time.Second, clientMetrics, clock.NewFake()}
srv := grpc.NewServer(grpc.UnaryInterceptor(si.interceptUnary))
cmi := clientMetadataInterceptor{time.Second, clientMetrics, clock.NewFake()}
srv := grpc.NewServer(grpc.UnaryInterceptor(smi.Unary))
es := &errorServer{}
test_proto.RegisterChillerServer(srv, es)
lis, err := net.Listen("tcp", "127.0.0.1:")
@ -46,7 +46,7 @@ func TestErrorWrapping(t *testing.T) {
conn, err := grpc.Dial(
lis.Addr().String(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithUnaryInterceptor(ci.interceptUnary),
grpc.WithUnaryInterceptor(cmi.Unary),
)
test.AssertNotError(t, err, "Failed to dial grpc test server")
client := test_proto.NewChillerClient(conn)
@ -74,11 +74,11 @@ func TestErrorWrapping(t *testing.T) {
func TestSubErrorWrapping(t *testing.T) {
serverMetrics, err := newServerMetrics(metrics.NoopRegisterer)
test.AssertNotError(t, err, "creating server metrics")
si := newServerInterceptor(serverMetrics, clock.NewFake())
smi := newServerMetadataInterceptor(serverMetrics, clock.NewFake())
clientMetrics, err := newClientMetrics(metrics.NoopRegisterer)
test.AssertNotError(t, err, "creating client metrics")
ci := clientInterceptor{time.Second, clientMetrics, clock.NewFake()}
srv := grpc.NewServer(grpc.UnaryInterceptor(si.interceptUnary))
cmi := clientMetadataInterceptor{time.Second, clientMetrics, clock.NewFake()}
srv := grpc.NewServer(grpc.UnaryInterceptor(smi.Unary))
es := &errorServer{}
test_proto.RegisterChillerServer(srv, es)
lis, err := net.Listen("tcp", "127.0.0.1:")
@ -89,7 +89,7 @@ func TestSubErrorWrapping(t *testing.T) {
conn, err := grpc.Dial(
lis.Addr().String(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithUnaryInterceptor(ci.interceptUnary),
grpc.WithUnaryInterceptor(cmi.Unary),
)
test.AssertNotError(t, err, "Failed to dial grpc test server")
client := test_proto.NewChillerClient(conn)

View File

@ -11,9 +11,12 @@ import (
"github.com/prometheus/client_golang/prometheus"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
"github.com/letsencrypt/boulder/cmd"
berrors "github.com/letsencrypt/boulder/errors"
"github.com/letsencrypt/boulder/probs"
)
@ -24,6 +27,33 @@ const (
clientRequestTimeKey = "client-request-time"
)
type serverInterceptor interface {
Unary(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error)
Stream(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error
}
// noopServerInterceptor provides no-op interceptors. It can be substituted for
// an interceptor that has been disabled.
type noopServerInterceptor struct{}
// Unary is a gRPC unary interceptor.
func (n *noopServerInterceptor) Unary(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
return handler(ctx, req)
}
// Stream is a gRPC stream interceptor.
func (n *noopServerInterceptor) Stream(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
return handler(srv, ss)
}
// Ensure noopServerInterceptor matches the serverInterceptor interface.
var _ serverInterceptor = &noopServerInterceptor{}
type clientInterceptor interface {
Unary(ctx context.Context, method string, req interface{}, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error
Stream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error)
}
// NoCancelInterceptor is a gRPC interceptor that creates a new context,
// separate from the original context, that has the same deadline but does
// not propagate cancellation. This is used by SA.
@ -42,23 +72,23 @@ func NoCancelInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryS
return handler(ctx, req)
}
// serverInterceptor is a gRPC interceptor that adds Prometheus
// serverMetadataInterceptor is a gRPC interceptor that adds Prometheus
// metrics to requests handled by a gRPC server, and wraps Boulder-specific
// errors for transmission in a grpc/metadata trailer (see bcodes.go).
type serverInterceptor struct {
type serverMetadataInterceptor struct {
metrics serverMetrics
clk clock.Clock
}
func newServerInterceptor(metrics serverMetrics, clk clock.Clock) serverInterceptor {
return serverInterceptor{
func newServerMetadataInterceptor(metrics serverMetrics, clk clock.Clock) serverMetadataInterceptor {
return serverMetadataInterceptor{
metrics: metrics,
clk: clk,
}
}
// interceptUnary implements the grpc.UnaryServerInterceptor interface.
func (si *serverInterceptor) interceptUnary(
// Unary implements the grpc.UnaryServerInterceptor interface.
func (smi *serverMetadataInterceptor) Unary(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
@ -71,7 +101,7 @@ func (si *serverInterceptor) interceptUnary(
// a `clientRequestTimeKey` field, and it has a value, then observe the RPC
// latency with Prometheus.
if md, ok := metadata.FromIncomingContext(ctx); ok && len(md[clientRequestTimeKey]) > 0 {
err := si.observeLatency(md[clientRequestTimeKey][0])
err := smi.observeLatency(md[clientRequestTimeKey][0])
if err != nil {
return nil, err
}
@ -118,8 +148,8 @@ func (iss interceptedServerStream) Context() context.Context {
return iss.ctx
}
// interceptStream implements the grpc.StreamServerInterceptor interface.
func (si *serverInterceptor) interceptStream(
// Stream implements the grpc.StreamServerInterceptor interface.
func (smi *serverMetadataInterceptor) Stream(
srv interface{},
ss grpc.ServerStream,
info *grpc.StreamServerInfo,
@ -130,7 +160,7 @@ func (si *serverInterceptor) interceptStream(
// a `clientRequestTimeKey` field, and it has a value, then observe the RPC
// latency with Prometheus.
if md, ok := metadata.FromIncomingContext(ctx); ok && len(md[clientRequestTimeKey]) > 0 {
err := si.observeLatency(md[clientRequestTimeKey][0])
err := smi.observeLatency(md[clientRequestTimeKey][0])
if err != nil {
return err
}
@ -184,7 +214,7 @@ func splitMethodName(fullMethodName string) (string, string) {
// used to calculate the latency between send and receive time. The latency is
// published to the server interceptor's rpcLag prometheus histogram. An error
// is returned if the `clientReqTime` string is not a valid timestamp.
func (si *serverInterceptor) observeLatency(clientReqTime string) error {
func (smi *serverMetadataInterceptor) observeLatency(clientReqTime string) error {
// Convert the metadata request time into an int64
reqTimeUnixNanos, err := strconv.ParseInt(clientReqTime, 10, 64)
if err != nil {
@ -193,27 +223,30 @@ func (si *serverInterceptor) observeLatency(clientReqTime string) error {
}
// Calculate the elapsed time since the client sent the RPC
reqTime := time.Unix(0, reqTimeUnixNanos)
elapsed := si.clk.Since(reqTime)
elapsed := smi.clk.Since(reqTime)
// Publish an RPC latency observation to the histogram
si.metrics.rpcLag.Observe(elapsed.Seconds())
smi.metrics.rpcLag.Observe(elapsed.Seconds())
return nil
}
// clientInterceptor is a gRPC interceptor that adds Prometheus
// Ensure serverMetadataInterceptor matches the serverInterceptor interface.
var _ serverInterceptor = (*serverMetadataInterceptor)(nil)
// clientMetadataInterceptor is a gRPC interceptor that adds Prometheus
// metrics to sent requests, and disables FailFast. We disable FailFast because
// non-FailFast mode is most similar to the old AMQP RPC layer: If a client
// makes a request while all backends are briefly down (e.g. for a restart), the
// request doesn't necessarily fail. A backend can service the request if it
// comes back up within the timeout. Under gRPC the same effect is achieved by
// retries up to the Context deadline.
type clientInterceptor struct {
type clientMetadataInterceptor struct {
timeout time.Duration
metrics clientMetrics
clk clock.Clock
}
// interceptUnary implements the grpc.UnaryClientInterceptor interface.
func (ci *clientInterceptor) interceptUnary(
// Unary implements the grpc.UnaryClientInterceptor interface.
func (cmi *clientMetadataInterceptor) Unary(
ctx context.Context,
fullMethod string,
req,
@ -223,16 +256,16 @@ func (ci *clientInterceptor) interceptUnary(
opts ...grpc.CallOption) error {
// This should not occur but fail fast with a clear error if it does (e.g.
// because of buggy unit test code) instead of a generic nil panic later!
if ci.metrics.inFlightRPCs == nil {
if cmi.metrics.inFlightRPCs == nil {
return berrors.InternalServerError("clientInterceptor has nil inFlightRPCs gauge")
}
// Ensure that the context has a deadline set.
localCtx, cancel := context.WithTimeout(ctx, ci.timeout)
localCtx, cancel := context.WithTimeout(ctx, cmi.timeout)
defer cancel()
// Convert the current unix nano timestamp to a string for embedding in the grpc metadata
nowTS := strconv.FormatInt(ci.clk.Now().UnixNano(), 10)
nowTS := strconv.FormatInt(cmi.clk.Now().UnixNano(), 10)
// Create a grpc/metadata.Metadata instance for the request metadata.
// Initialize it with the request time.
reqMD := metadata.New(map[string]string{clientRequestTimeKey: nowTS})
@ -259,12 +292,12 @@ func (ci *clientInterceptor) interceptUnary(
"service": service,
}
// Increment the inFlightRPCs gauge for this method/service
ci.metrics.inFlightRPCs.With(labels).Inc()
cmi.metrics.inFlightRPCs.With(labels).Inc()
// And defer decrementing it when we're done
defer ci.metrics.inFlightRPCs.With(labels).Dec()
defer cmi.metrics.inFlightRPCs.With(labels).Dec()
// Handle the RPC
begin := ci.clk.Now()
begin := cmi.clk.Now()
err := invoker(localCtx, fullMethod, req, reply, cc, opts...)
if err != nil {
err = unwrapError(err, respMD)
@ -272,7 +305,7 @@ func (ci *clientInterceptor) interceptUnary(
return deadlineDetails{
service: service,
method: method,
latency: ci.clk.Since(begin),
latency: cmi.clk.Since(begin),
}
}
}
@ -322,8 +355,8 @@ func (ics interceptedClientStream) CloseSend() error {
return err
}
// interceptUnary implements the grpc.StreamClientInterceptor interface.
func (ci *clientInterceptor) interceptStream(
// Stream implements the grpc.StreamClientInterceptor interface.
func (cmi *clientMetadataInterceptor) Stream(
ctx context.Context,
desc *grpc.StreamDesc,
cc *grpc.ClientConn,
@ -332,16 +365,16 @@ func (ci *clientInterceptor) interceptStream(
opts ...grpc.CallOption) (grpc.ClientStream, error) {
// This should not occur but fail fast with a clear error if it does (e.g.
// because of buggy unit test code) instead of a generic nil panic later!
if ci.metrics.inFlightRPCs == nil {
if cmi.metrics.inFlightRPCs == nil {
return nil, berrors.InternalServerError("clientInterceptor has nil inFlightRPCs gauge")
}
// We don't defer cancel() here, because this function is going to return
// immediately. Instead we store it in the interceptedClientStream.
localCtx, cancel := context.WithTimeout(ctx, ci.timeout)
localCtx, cancel := context.WithTimeout(ctx, cmi.timeout)
// Convert the current unix nano timestamp to a string for embedding in the grpc metadata
nowTS := strconv.FormatInt(ci.clk.Now().UnixNano(), 10)
nowTS := strconv.FormatInt(cmi.clk.Now().UnixNano(), 10)
// Create a grpc/metadata.Metadata instance for the request metadata.
// Initialize it with the request time.
reqMD := metadata.New(map[string]string{clientRequestTimeKey: nowTS})
@ -368,21 +401,21 @@ func (ci *clientInterceptor) interceptStream(
"service": service,
}
// Increment the inFlightRPCs gauge for this method/service
ci.metrics.inFlightRPCs.With(labels).Inc()
begin := ci.clk.Now()
cmi.metrics.inFlightRPCs.With(labels).Inc()
begin := cmi.clk.Now()
// Cancel the local context and decrement the metric when we're done. Also
// transform the error into a more usable form, if necessary.
finish := func(err error) error {
cancel()
ci.metrics.inFlightRPCs.With(labels).Dec()
cmi.metrics.inFlightRPCs.With(labels).Dec()
if err != nil {
err = unwrapError(err, respMD)
if status.Code(err) == codes.DeadlineExceeded {
return deadlineDetails{
service: service,
method: method,
latency: ci.clk.Since(begin),
latency: cmi.clk.Since(begin),
}
}
}
@ -395,6 +428,8 @@ func (ci *clientInterceptor) interceptStream(
return ics, err
}
var _ clientInterceptor = (*clientMetadataInterceptor)(nil)
// CancelTo408Interceptor calls the underlying invoker, checks to see if the
// resulting error was a gRPC Canceled error (because this client cancelled
// the request, likely because the ACME client itself canceled the HTTP
@ -421,3 +456,94 @@ func (dd deadlineDetails) Error() string {
return fmt.Sprintf("%s.%s timed out after %d ms",
dd.service, dd.method, int64(dd.latency/time.Millisecond))
}
// authInterceptor provides two server interceptors (Unary and Stream) which can
// check that every request for a given gRPC service is being made over an mTLS
// connection from a client which is allow-listed for that particular service.
type authInterceptor struct {
// serviceClientNames is a map of gRPC service names (e.g. "ca.CertificateAuthority")
// to allowed client certificate SANs (e.g. "ra.boulder") which are allowed to
// make RPCs to that service. The set of client names is implemented as a map
// of names to empty structs for easy lookup.
serviceClientNames map[string]map[string]struct{}
}
// newServiceAuthChecker takes a GRPCServerConfig and uses its Service stanzas
// to construct a serviceAuthChecker which enforces the service/client mappings
// contained in the config.
func newServiceAuthChecker(c *cmd.GRPCServerConfig) *authInterceptor {
names := make(map[string]map[string]struct{})
for serviceName, service := range c.Services {
names[serviceName] = make(map[string]struct{})
for _, clientName := range service.ClientNames {
names[serviceName][clientName] = struct{}{}
}
}
return &authInterceptor{names}
}
// Unary is a gRPC unary interceptor.
func (ac *authInterceptor) Unary(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
err := ac.checkContextAuth(ctx, info.FullMethod)
if err != nil {
return nil, err
}
return handler(ctx, req)
}
// Stream is a gRPC stream interceptor.
func (ac *authInterceptor) Stream(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
err := ac.checkContextAuth(ss.Context(), info.FullMethod)
if err != nil {
return err
}
return handler(srv, ss)
}
// checkContextAuth does most of the heavy lifting. It extracts TLS information
// from the incoming context, gets the set of DNS names contained in the client
// mTLS cert, and returns nil if at least one of those names appears in the set
// of allowed client names for given service (or if the set of allowed client
// names is empty).
func (ac *authInterceptor) checkContextAuth(ctx context.Context, fullMethod string) error {
serviceName, _ := splitMethodName(fullMethod)
allowedClientNames, ok := ac.serviceClientNames[serviceName]
if !ok || len(allowedClientNames) == 0 {
return fmt.Errorf("service %q has no allowed client names", serviceName)
}
p, ok := peer.FromContext(ctx)
if !ok {
return fmt.Errorf("unable to fetch peer info from grpc context")
}
if p.AuthInfo == nil {
return fmt.Errorf("grpc connection appears to be plaintext")
}
tlsAuth, ok := p.AuthInfo.(credentials.TLSInfo)
if !ok {
return fmt.Errorf("connection is not TLS authed")
}
if len(tlsAuth.State.VerifiedChains) == 0 || len(tlsAuth.State.VerifiedChains[0]) == 0 {
return fmt.Errorf("connection auth not verified")
}
cert := tlsAuth.State.VerifiedChains[0][0]
for _, clientName := range cert.DNSNames {
_, ok := allowedClientNames[clientName]
if ok {
return nil
}
}
return fmt.Errorf(
"client names %v are not authorized for service %q (%v)",
cert.DNSNames, serviceName, allowedClientNames)
}
// Ensure authInterceptor matches the serverInterceptor interface.
var _ serverInterceptor = (*authInterceptor)(nil)

View File

@ -2,6 +2,8 @@ package grpc
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"log"
@ -18,8 +20,10 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/balancer/roundrobin"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
"github.com/letsencrypt/boulder/grpc/test_proto"
@ -52,37 +56,37 @@ func testInvoker(_ context.Context, method string, _, _ interface{}, _ *grpc.Cli
func TestServerInterceptor(t *testing.T) {
serverMetrics, err := newServerMetrics(metrics.NoopRegisterer)
test.AssertNotError(t, err, "creating server metrics")
si := newServerInterceptor(serverMetrics, clock.NewFake())
si := newServerMetadataInterceptor(serverMetrics, clock.NewFake())
md := metadata.New(map[string]string{clientRequestTimeKey: "0"})
ctxWithMetadata := metadata.NewIncomingContext(context.Background(), md)
_, err = si.interceptUnary(context.Background(), nil, nil, testHandler)
_, err = si.Unary(context.Background(), nil, nil, testHandler)
test.AssertError(t, err, "si.intercept didn't fail with a context missing metadata")
_, err = si.interceptUnary(ctxWithMetadata, nil, nil, testHandler)
_, err = si.Unary(ctxWithMetadata, nil, nil, testHandler)
test.AssertError(t, err, "si.intercept didn't fail with a nil grpc.UnaryServerInfo")
_, err = si.interceptUnary(ctxWithMetadata, nil, &grpc.UnaryServerInfo{FullMethod: "-service-test"}, testHandler)
_, err = si.Unary(ctxWithMetadata, nil, &grpc.UnaryServerInfo{FullMethod: "-service-test"}, testHandler)
test.AssertNotError(t, err, "si.intercept failed with a non-nil grpc.UnaryServerInfo")
_, err = si.interceptUnary(ctxWithMetadata, 0, &grpc.UnaryServerInfo{FullMethod: "brokeTest"}, testHandler)
_, err = si.Unary(ctxWithMetadata, 0, &grpc.UnaryServerInfo{FullMethod: "brokeTest"}, testHandler)
test.AssertError(t, err, "si.intercept didn't fail when handler returned a error")
}
func TestClientInterceptor(t *testing.T) {
clientMetrics, err := newClientMetrics(metrics.NoopRegisterer)
test.AssertNotError(t, err, "creating client metrics")
ci := clientInterceptor{
ci := clientMetadataInterceptor{
timeout: time.Second,
metrics: clientMetrics,
clk: clock.NewFake(),
}
err = ci.interceptUnary(context.Background(), "-service-test", nil, nil, nil, testInvoker)
err = ci.Unary(context.Background(), "-service-test", nil, nil, nil, testInvoker)
test.AssertNotError(t, err, "ci.intercept failed with a non-nil grpc.UnaryServerInfo")
err = ci.interceptUnary(context.Background(), "-service-brokeTest", nil, nil, nil, testInvoker)
err = ci.Unary(context.Background(), "-service-brokeTest", nil, nil, nil, testInvoker)
test.AssertError(t, err, "ci.intercept didn't fail when handler returned a error")
}
@ -106,7 +110,7 @@ func TestCancelTo408Interceptor(t *testing.T) {
func TestFailFastFalse(t *testing.T) {
clientMetrics, err := newClientMetrics(metrics.NoopRegisterer)
test.AssertNotError(t, err, "creating client metrics")
ci := &clientInterceptor{
ci := &clientMetadataInterceptor{
timeout: 100 * time.Millisecond,
metrics: clientMetrics,
clk: clock.NewFake(),
@ -114,7 +118,7 @@ func TestFailFastFalse(t *testing.T) {
conn, err := grpc.Dial("localhost:19876", // random, probably unused port
grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, roundrobin.Name)),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithUnaryInterceptor(ci.interceptUnary))
grpc.WithUnaryInterceptor(ci.Unary))
if err != nil {
t.Fatalf("did not connect: %v", err)
}
@ -161,8 +165,8 @@ func TestTimeouts(t *testing.T) {
serverMetrics, err := newServerMetrics(metrics.NoopRegisterer)
test.AssertNotError(t, err, "creating server metrics")
si := newServerInterceptor(serverMetrics, clock.NewFake())
s := grpc.NewServer(grpc.UnaryInterceptor(si.interceptUnary))
si := newServerMetadataInterceptor(serverMetrics, clock.NewFake())
s := grpc.NewServer(grpc.UnaryInterceptor(si.Unary))
test_proto.RegisterChillerServer(s, &testServer{})
go func() {
start := time.Now()
@ -176,14 +180,14 @@ func TestTimeouts(t *testing.T) {
// make client
clientMetrics, err := newClientMetrics(metrics.NoopRegisterer)
test.AssertNotError(t, err, "creating client metrics")
ci := &clientInterceptor{
ci := &clientMetadataInterceptor{
timeout: 30 * time.Second,
metrics: clientMetrics,
clk: clock.NewFake(),
}
conn, err := grpc.Dial(net.JoinHostPort("localhost", strconv.Itoa(port)),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithUnaryInterceptor(ci.interceptUnary))
grpc.WithUnaryInterceptor(ci.Unary))
if err != nil {
t.Fatalf("did not connect: %v", err)
}
@ -225,8 +229,8 @@ func TestRequestTimeTagging(t *testing.T) {
// Create a new ChillerServer
serverMetrics, err := newServerMetrics(metrics.NoopRegisterer)
test.AssertNotError(t, err, "creating server metrics")
si := newServerInterceptor(serverMetrics, clk)
s := grpc.NewServer(grpc.UnaryInterceptor(si.interceptUnary))
si := newServerMetadataInterceptor(serverMetrics, clk)
s := grpc.NewServer(grpc.UnaryInterceptor(si.Unary))
test_proto.RegisterChillerServer(s, &testServer{})
// Chill until ill
go func() {
@ -241,14 +245,14 @@ func TestRequestTimeTagging(t *testing.T) {
// Dial the ChillerServer
clientMetrics, err := newClientMetrics(metrics.NoopRegisterer)
test.AssertNotError(t, err, "creating client metrics")
ci := &clientInterceptor{
ci := &clientMetadataInterceptor{
timeout: 30 * time.Second,
metrics: clientMetrics,
clk: clk,
}
conn, err := grpc.Dial(net.JoinHostPort("localhost", strconv.Itoa(port)),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithUnaryInterceptor(ci.interceptUnary))
grpc.WithUnaryInterceptor(ci.Unary))
if err != nil {
t.Fatalf("did not connect: %v", err)
}
@ -314,8 +318,8 @@ func TestInFlightRPCStat(t *testing.T) {
serverMetrics, err := newServerMetrics(metrics.NoopRegisterer)
test.AssertNotError(t, err, "creating server metrics")
si := newServerInterceptor(serverMetrics, clk)
s := grpc.NewServer(grpc.UnaryInterceptor(si.interceptUnary))
si := newServerMetadataInterceptor(serverMetrics, clk)
s := grpc.NewServer(grpc.UnaryInterceptor(si.Unary))
test_proto.RegisterChillerServer(s, server)
// Chill until ill
go func() {
@ -330,14 +334,14 @@ func TestInFlightRPCStat(t *testing.T) {
// Dial the ChillerServer
clientMetrics, err := newClientMetrics(metrics.NoopRegisterer)
test.AssertNotError(t, err, "creating client metrics")
ci := &clientInterceptor{
ci := &clientMetadataInterceptor{
timeout: 30 * time.Second,
metrics: clientMetrics,
clk: clk,
}
conn, err := grpc.Dial(net.JoinHostPort("localhost", strconv.Itoa(port)),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithUnaryInterceptor(ci.interceptUnary))
grpc.WithUnaryInterceptor(ci.Unary))
if err != nil {
t.Fatalf("did not connect: %v", err)
}
@ -395,3 +399,76 @@ func TestNoCancelInterceptor(t *testing.T) {
t.Error(err)
}
}
func TestServiceAuthChecker(t *testing.T) {
ac := authInterceptor{
map[string]map[string]struct{}{
"package.ServiceName": {
"allowed.client": {},
"also.allowed": {},
},
},
}
// No allowlist is a bad configuration.
ctx := context.Background()
err := ac.checkContextAuth(ctx, "/package.OtherService/Method/")
test.AssertError(t, err, "checking empty allowlist")
// Context with no peering information is disallowed.
err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/")
test.AssertError(t, err, "checking un-peered context")
// Context with no auth info is disallowed.
ctx = peer.NewContext(ctx, &peer.Peer{})
err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/")
test.AssertError(t, err, "checking peer with no auth")
// Context with no verified chains is disallowed.
ctx = peer.NewContext(ctx, &peer.Peer{
AuthInfo: credentials.TLSInfo{
State: tls.ConnectionState{},
},
})
err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/")
test.AssertError(t, err, "checking TLS with no valid chains")
// Context with cert with wrong name is disallowed.
ctx = peer.NewContext(ctx, &peer.Peer{
AuthInfo: credentials.TLSInfo{
State: tls.ConnectionState{
VerifiedChains: [][]*x509.Certificate{
{
&x509.Certificate{
DNSNames: []string{
"disallowed.client",
},
},
},
},
},
},
})
err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/")
test.AssertError(t, err, "checking disallowed cert")
// Context with cert with good name is allowed.
ctx = peer.NewContext(ctx, &peer.Peer{
AuthInfo: credentials.TLSInfo{
State: tls.ConnectionState{
VerifiedChains: [][]*x509.Certificate{
{
&x509.Certificate{
DNSNames: []string{
"disallowed.client",
"also.allowed",
},
},
},
},
},
},
})
err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/")
test.AssertNotError(t, err, "checking allowed cert")
}

View File

@ -118,17 +118,26 @@ func (sb *serverBuilder) Build(tlsConfig *tls.Config, statsRegistry prometheus.R
return nil, nil, err
}
si := newServerInterceptor(metrics, clk)
var ai serverInterceptor
if len(sb.cfg.Services) > 0 {
ai = newServiceAuthChecker(sb.cfg)
} else {
ai = &noopServerInterceptor{}
}
mi := newServerMetadataInterceptor(metrics, clk)
unaryInterceptors := append([]grpc.UnaryServerInterceptor{
si.interceptUnary,
si.metrics.grpcMetrics.UnaryServerInterceptor(),
mi.metrics.grpcMetrics.UnaryServerInterceptor(),
ai.Unary,
mi.Unary,
hnygrpc.UnaryServerInterceptor(),
}, interceptors...)
streamInterceptors := []grpc.StreamServerInterceptor{
si.interceptStream,
si.metrics.grpcMetrics.StreamServerInterceptor(),
mi.metrics.grpcMetrics.StreamServerInterceptor(),
ai.Stream,
mi.Stream,
// TODO(#6361): Get a tracing interceptor that works for gRPC streams.
}