332 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			332 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Go
		
	
	
	
| package grpc
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"crypto/tls"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"net"
 | |
| 	"strings"
 | |
| 	"time"
 | |
| 
 | |
| 	grpc_prometheus "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus"
 | |
| 	"github.com/jmhodges/clock"
 | |
| 	"github.com/prometheus/client_golang/prometheus"
 | |
| 	"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
 | |
| 	"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc/filters"
 | |
| 	"google.golang.org/grpc"
 | |
| 	"google.golang.org/grpc/health"
 | |
| 	healthpb "google.golang.org/grpc/health/grpc_health_v1"
 | |
| 	"google.golang.org/grpc/keepalive"
 | |
| 	"google.golang.org/grpc/status"
 | |
| 
 | |
| 	"github.com/letsencrypt/boulder/cmd"
 | |
| 	bcreds "github.com/letsencrypt/boulder/grpc/creds"
 | |
| 	blog "github.com/letsencrypt/boulder/log"
 | |
| )
 | |
| 
 | |
| // CodedError is a alias required to appease go vet
 | |
| var CodedError = status.Errorf
 | |
| 
 | |
| var errNilTLS = errors.New("boulder/grpc: received nil tls.Config")
 | |
| 
 | |
| // checker is an interface for checking the health of a grpc service
 | |
| // implementation.
 | |
| type checker interface {
 | |
| 	// Health returns nil if the service is healthy, or an error if it is not.
 | |
| 	// If the passed context is canceled, it should return immediately with an
 | |
| 	// error.
 | |
| 	Health(context.Context) error
 | |
| }
 | |
| 
 | |
| // service represents a single gRPC service that can be registered with a gRPC
 | |
| // server.
 | |
| type service struct {
 | |
| 	desc *grpc.ServiceDesc
 | |
| 	impl any
 | |
| }
 | |
| 
 | |
| // serverBuilder implements a builder pattern for constructing new gRPC servers
 | |
| // and registering gRPC services on those servers.
 | |
| type serverBuilder struct {
 | |
| 	cfg           *cmd.GRPCServerConfig
 | |
| 	services      map[string]service
 | |
| 	healthSrv     *health.Server
 | |
| 	checkInterval time.Duration
 | |
| 	logger        blog.Logger
 | |
| 	err           error
 | |
| }
 | |
| 
 | |
| // NewServer returns an object which can be used to build gRPC servers. It takes
 | |
| // the server's configuration to perform initialization and a logger for deep
 | |
| // health checks.
 | |
| func NewServer(c *cmd.GRPCServerConfig, logger blog.Logger) *serverBuilder {
 | |
| 	return &serverBuilder{cfg: c, services: make(map[string]service), logger: logger}
 | |
| }
 | |
| 
 | |
| // WithCheckInterval sets the interval at which the server will check the health
 | |
| // of its registered services. If this is not called, a default interval of 5
 | |
| // seconds will be used.
 | |
| func (sb *serverBuilder) WithCheckInterval(i time.Duration) *serverBuilder {
 | |
| 	sb.checkInterval = i
 | |
| 	return sb
 | |
| }
 | |
| 
 | |
| // Add registers a new service (consisting of its description and its
 | |
| // implementation) to the set of services which will be exposed by this server.
 | |
| // It returns the modified-in-place serverBuilder so that calls can be chained.
 | |
| // If there is an error adding this service, it will be exposed when .Build() is
 | |
| // called.
 | |
| func (sb *serverBuilder) Add(desc *grpc.ServiceDesc, impl any) *serverBuilder {
 | |
| 	if _, found := sb.services[desc.ServiceName]; found {
 | |
| 		// We've already registered a service with this same name, error out.
 | |
| 		sb.err = fmt.Errorf("attempted double-registration of gRPC service %q", desc.ServiceName)
 | |
| 		return sb
 | |
| 	}
 | |
| 	sb.services[desc.ServiceName] = service{desc: desc, impl: impl}
 | |
| 	return sb
 | |
| }
 | |
| 
 | |
| // Build creates a gRPC server that uses the provided *tls.Config and exposes
 | |
| // all of the services added to the builder. It also exposes a health check
 | |
| // service. It returns one functions, start(), which should be used to start
 | |
| // the server. It spawns a goroutine which will listen for OS signals and
 | |
| // gracefully stop the server if one is caught, causing the start() function to
 | |
| // exit.
 | |
| func (sb *serverBuilder) Build(tlsConfig *tls.Config, statsRegistry prometheus.Registerer, clk clock.Clock) (func() error, error) {
 | |
| 	// Register the health service with the server.
 | |
| 	sb.healthSrv = health.NewServer()
 | |
| 	sb.Add(&healthpb.Health_ServiceDesc, sb.healthSrv)
 | |
| 
 | |
| 	// Check to see if any of the calls to .Add() resulted in an error.
 | |
| 	if sb.err != nil {
 | |
| 		return nil, sb.err
 | |
| 	}
 | |
| 
 | |
| 	// Ensure that every configured service also got added.
 | |
| 	var registeredServices []string
 | |
| 	for r := range sb.services {
 | |
| 		registeredServices = append(registeredServices, r)
 | |
| 	}
 | |
| 	for serviceName := range sb.cfg.Services {
 | |
| 		_, ok := sb.services[serviceName]
 | |
| 		if !ok {
 | |
| 			return nil, fmt.Errorf("gRPC service %q in config does not match any service: %s", serviceName, strings.Join(registeredServices, ", "))
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if tlsConfig == nil {
 | |
| 		return nil, errNilTLS
 | |
| 	}
 | |
| 
 | |
| 	// Collect all names which should be allowed to connect to the server at all.
 | |
| 	// This is the names which are allowlisted at the server level, plus the union
 | |
| 	// of all names which are allowlisted for any individual service.
 | |
| 	acceptedSANs := make(map[string]struct{})
 | |
| 	for _, service := range sb.cfg.Services {
 | |
| 		for _, name := range service.ClientNames {
 | |
| 			acceptedSANs[name] = struct{}{}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	creds, err := bcreds.NewServerCredentials(tlsConfig, acceptedSANs)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	// Set up all of our interceptors which handle metrics, traces, error
 | |
| 	// propagation, and more.
 | |
| 	metrics, err := newServerMetrics(statsRegistry)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	var ai serverInterceptor
 | |
| 	if len(sb.cfg.Services) > 0 {
 | |
| 		ai = newServiceAuthChecker(sb.cfg)
 | |
| 	} else {
 | |
| 		ai = &noopServerInterceptor{}
 | |
| 	}
 | |
| 
 | |
| 	mi := newServerMetadataInterceptor(metrics, clk)
 | |
| 
 | |
| 	unaryInterceptors := []grpc.UnaryServerInterceptor{
 | |
| 		mi.metrics.grpcMetrics.UnaryServerInterceptor(),
 | |
| 		ai.Unary,
 | |
| 		mi.Unary,
 | |
| 	}
 | |
| 
 | |
| 	streamInterceptors := []grpc.StreamServerInterceptor{
 | |
| 		mi.metrics.grpcMetrics.StreamServerInterceptor(),
 | |
| 		ai.Stream,
 | |
| 		mi.Stream,
 | |
| 	}
 | |
| 
 | |
| 	options := []grpc.ServerOption{
 | |
| 		grpc.Creds(creds),
 | |
| 		grpc.ChainUnaryInterceptor(unaryInterceptors...),
 | |
| 		grpc.ChainStreamInterceptor(streamInterceptors...),
 | |
| 		grpc.StatsHandler(otelgrpc.NewServerHandler(otelgrpc.WithFilter(filters.Not(filters.HealthCheck())))),
 | |
| 	}
 | |
| 	if sb.cfg.MaxConnectionAge.Duration > 0 {
 | |
| 		options = append(options,
 | |
| 			grpc.KeepaliveParams(keepalive.ServerParameters{
 | |
| 				MaxConnectionAge: sb.cfg.MaxConnectionAge.Duration,
 | |
| 			}))
 | |
| 	}
 | |
| 
 | |
| 	// Create the server itself and register all of our services on it.
 | |
| 	server := grpc.NewServer(options...)
 | |
| 	for _, service := range sb.services {
 | |
| 		server.RegisterService(service.desc, service.impl)
 | |
| 	}
 | |
| 
 | |
| 	if sb.cfg.Address == "" {
 | |
| 		return nil, errors.New("GRPC listen address not configured")
 | |
| 	}
 | |
| 	sb.logger.Infof("grpc listening on %s", sb.cfg.Address)
 | |
| 
 | |
| 	// Finally return the functions which will start and stop the server.
 | |
| 	listener, err := net.Listen("tcp", sb.cfg.Address)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	start := func() error {
 | |
| 		return server.Serve(listener)
 | |
| 	}
 | |
| 
 | |
| 	// Initialize long-running health checks of all services which implement the
 | |
| 	// checker interface.
 | |
| 	if sb.checkInterval <= 0 {
 | |
| 		sb.checkInterval = 5 * time.Second
 | |
| 	}
 | |
| 	healthCtx, stopHealthChecks := context.WithCancel(context.Background())
 | |
| 	for _, s := range sb.services {
 | |
| 		check, ok := s.impl.(checker)
 | |
| 		if !ok {
 | |
| 			continue
 | |
| 		}
 | |
| 		sb.initLongRunningCheck(healthCtx, s.desc.ServiceName, check.Health)
 | |
| 	}
 | |
| 
 | |
| 	// Start a goroutine which listens for a termination signal, and then
 | |
| 	// gracefully stops the gRPC server. This in turn causes the start() function
 | |
| 	// to exit, allowing its caller (generally a main() function) to exit.
 | |
| 	go cmd.CatchSignals(func() {
 | |
| 		stopHealthChecks()
 | |
| 		sb.healthSrv.Shutdown()
 | |
| 		server.GracefulStop()
 | |
| 	})
 | |
| 
 | |
| 	return start, nil
 | |
| }
 | |
| 
 | |
| // initLongRunningCheck initializes a goroutine which will periodically check
 | |
| // the health of the provided service and update the health server accordingly.
 | |
| func (sb *serverBuilder) initLongRunningCheck(shutdownCtx context.Context, service string, checkImpl func(context.Context) error) {
 | |
| 	// Set the initial health status for the service.
 | |
| 	sb.healthSrv.SetServingStatus(service, healthpb.HealthCheckResponse_NOT_SERVING)
 | |
| 
 | |
| 	// check is a helper function that checks the health of the service and, if
 | |
| 	// necessary, updates its status in the health server.
 | |
| 	checkAndMaybeUpdate := func(checkCtx context.Context, last healthpb.HealthCheckResponse_ServingStatus) healthpb.HealthCheckResponse_ServingStatus {
 | |
| 		// Make a context with a timeout at 90% of the interval.
 | |
| 		checkImplCtx, cancel := context.WithTimeout(checkCtx, sb.checkInterval*9/10)
 | |
| 		defer cancel()
 | |
| 
 | |
| 		var next healthpb.HealthCheckResponse_ServingStatus
 | |
| 		err := checkImpl(checkImplCtx)
 | |
| 		if err != nil {
 | |
| 			next = healthpb.HealthCheckResponse_NOT_SERVING
 | |
| 		} else {
 | |
| 			next = healthpb.HealthCheckResponse_SERVING
 | |
| 		}
 | |
| 
 | |
| 		if last == next {
 | |
| 			// No change in health status.
 | |
| 			return next
 | |
| 		}
 | |
| 
 | |
| 		if next != healthpb.HealthCheckResponse_SERVING {
 | |
| 			sb.logger.Errf("transitioning health of %q from %q to %q, due to: %s", service, last, next, err)
 | |
| 		} else {
 | |
| 			sb.logger.Infof("transitioning health of %q from %q to %q", service, last, next)
 | |
| 		}
 | |
| 		sb.healthSrv.SetServingStatus(service, next)
 | |
| 		return next
 | |
| 	}
 | |
| 
 | |
| 	go func() {
 | |
| 		ticker := time.NewTicker(sb.checkInterval)
 | |
| 		defer ticker.Stop()
 | |
| 
 | |
| 		// Assume the service is not healthy to start.
 | |
| 		last := healthpb.HealthCheckResponse_NOT_SERVING
 | |
| 
 | |
| 		// Check immediately, and then at the specified interval.
 | |
| 		last = checkAndMaybeUpdate(shutdownCtx, last)
 | |
| 		for {
 | |
| 			select {
 | |
| 			case <-shutdownCtx.Done():
 | |
| 				// The server is shutting down.
 | |
| 				return
 | |
| 			case <-ticker.C:
 | |
| 				last = checkAndMaybeUpdate(shutdownCtx, last)
 | |
| 			}
 | |
| 		}
 | |
| 	}()
 | |
| }
 | |
| 
 | |
| // serverMetrics is a struct type used to return a few registered metrics from
 | |
| // `newServerMetrics`
 | |
| type serverMetrics struct {
 | |
| 	grpcMetrics *grpc_prometheus.ServerMetrics
 | |
| 	rpcLag      prometheus.Histogram
 | |
| }
 | |
| 
 | |
| // newServerMetrics registers metrics with a registry. It constructs and
 | |
| // registers a *grpc_prometheus.ServerMetrics with timing histogram enabled as
 | |
| // well as a prometheus Histogram for RPC latency. If called more than once on a
 | |
| // single registry, it will gracefully avoid registering duplicate metrics.
 | |
| func newServerMetrics(stats prometheus.Registerer) (serverMetrics, error) {
 | |
| 	// Create the grpc prometheus server metrics instance and register it
 | |
| 	grpcMetrics := grpc_prometheus.NewServerMetrics(
 | |
| 		grpc_prometheus.WithServerHandlingTimeHistogram(
 | |
| 			grpc_prometheus.WithHistogramBuckets([]float64{.01, .025, .05, .1, .5, 1, 2.5, 5, 10, 45, 90}),
 | |
| 		),
 | |
| 	)
 | |
| 	err := stats.Register(grpcMetrics)
 | |
| 	if err != nil {
 | |
| 		are := prometheus.AlreadyRegisteredError{}
 | |
| 		if errors.As(err, &are) {
 | |
| 			grpcMetrics = are.ExistingCollector.(*grpc_prometheus.ServerMetrics)
 | |
| 		} else {
 | |
| 			return serverMetrics{}, err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// rpcLag is a prometheus histogram tracking the difference between the time
 | |
| 	// the client sent an RPC and the time the server received it. Create and
 | |
| 	// register it.
 | |
| 	rpcLag := prometheus.NewHistogram(
 | |
| 		prometheus.HistogramOpts{
 | |
| 			Name: "grpc_lag",
 | |
| 			Help: "Delta between client RPC send time and server RPC receipt time",
 | |
| 		})
 | |
| 	err = stats.Register(rpcLag)
 | |
| 	if err != nil {
 | |
| 		are := prometheus.AlreadyRegisteredError{}
 | |
| 		if errors.As(err, &are) {
 | |
| 			rpcLag = are.ExistingCollector.(prometheus.Histogram)
 | |
| 		} else {
 | |
| 			return serverMetrics{}, err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return serverMetrics{
 | |
| 		grpcMetrics: grpcMetrics,
 | |
| 		rpcLag:      rpcLag,
 | |
| 	}, nil
 | |
| }
 |