package grpc import ( "context" "crypto/tls" "crypto/x509" "errors" "fmt" "log" "net" "strconv" "strings" "sync" "testing" "time" "github.com/jmhodges/clock" "github.com/prometheus/client_golang/prometheus" "google.golang.org/grpc" "google.golang.org/grpc/balancer/roundrobin" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/durationpb" "github.com/letsencrypt/boulder/grpc/test_proto" "github.com/letsencrypt/boulder/metrics" "github.com/letsencrypt/boulder/test" "github.com/letsencrypt/boulder/web" ) var fc = clock.NewFake() func testHandler(_ context.Context, i interface{}) (interface{}, error) { if i != nil { return nil, errors.New("") } fc.Sleep(time.Second) return nil, nil } func testInvoker(_ context.Context, method string, _, _ interface{}, _ *grpc.ClientConn, opts ...grpc.CallOption) error { switch method { case "-service-brokeTest": return errors.New("") case "-service-requesterCanceledTest": return status.Error(1, context.Canceled.Error()) } fc.Sleep(time.Second) return nil } func TestServerInterceptor(t *testing.T) { serverMetrics, err := newServerMetrics(metrics.NoopRegisterer) test.AssertNotError(t, err, "creating server metrics") si := newServerMetadataInterceptor(serverMetrics, clock.NewFake()) md := metadata.New(map[string]string{clientRequestTimeKey: "0"}) ctxWithMetadata := metadata.NewIncomingContext(context.Background(), md) _, err = si.Unary(context.Background(), nil, nil, testHandler) test.AssertError(t, err, "si.intercept didn't fail with a context missing metadata") _, err = si.Unary(ctxWithMetadata, nil, nil, testHandler) test.AssertError(t, err, "si.intercept didn't fail with a nil grpc.UnaryServerInfo") _, err = si.Unary(ctxWithMetadata, nil, &grpc.UnaryServerInfo{FullMethod: "-service-test"}, testHandler) test.AssertNotError(t, err, "si.intercept failed with a non-nil grpc.UnaryServerInfo") _, err = si.Unary(ctxWithMetadata, 0, &grpc.UnaryServerInfo{FullMethod: "brokeTest"}, testHandler) test.AssertError(t, err, "si.intercept didn't fail when handler returned a error") } func TestClientInterceptor(t *testing.T) { clientMetrics, err := newClientMetrics(metrics.NoopRegisterer) test.AssertNotError(t, err, "creating client metrics") ci := clientMetadataInterceptor{ timeout: time.Second, metrics: clientMetrics, clk: clock.NewFake(), } err = ci.Unary(context.Background(), "-service-test", nil, nil, nil, testInvoker) test.AssertNotError(t, err, "ci.intercept failed with a non-nil grpc.UnaryServerInfo") err = ci.Unary(context.Background(), "-service-brokeTest", nil, nil, nil, testInvoker) test.AssertError(t, err, "ci.intercept didn't fail when handler returned a error") } // TestWaitForReadyTrue configures a gRPC client with waitForReady: true and // sends a request to a backend that is unavailable. It ensures that the // request doesn't error out until the timeout is reached, i.e. that // FailFast is set to false. // https://github.com/grpc/grpc/blob/main/doc/wait-for-ready.md func TestWaitForReadyTrue(t *testing.T) { clientMetrics, err := newClientMetrics(metrics.NoopRegisterer) test.AssertNotError(t, err, "creating client metrics") ci := &clientMetadataInterceptor{ timeout: 100 * time.Millisecond, metrics: clientMetrics, clk: clock.NewFake(), waitForReady: true, } conn, err := grpc.NewClient("localhost:19876", // random, probably unused port grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, roundrobin.Name)), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithUnaryInterceptor(ci.Unary)) if err != nil { t.Fatalf("did not connect: %v", err) } defer conn.Close() c := test_proto.NewChillerClient(conn) start := time.Now() _, err = c.Chill(context.Background(), &test_proto.Time{Duration: durationpb.New(time.Second)}) if err == nil { t.Errorf("Successful Chill when we expected failure.") } if time.Since(start) < 90*time.Millisecond { t.Errorf("Chill failed fast, when WaitForReady should be enabled.") } } // TestWaitForReadyFalse configures a gRPC client with waitForReady: false and // sends a request to a backend that is unavailable, and ensures that the request // errors out promptly. func TestWaitForReadyFalse(t *testing.T) { clientMetrics, err := newClientMetrics(metrics.NoopRegisterer) test.AssertNotError(t, err, "creating client metrics") ci := &clientMetadataInterceptor{ timeout: time.Second, metrics: clientMetrics, clk: clock.NewFake(), waitForReady: false, } conn, err := grpc.NewClient("localhost:19876", // random, probably unused port grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, roundrobin.Name)), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithUnaryInterceptor(ci.Unary)) if err != nil { t.Fatalf("did not connect: %v", err) } defer conn.Close() c := test_proto.NewChillerClient(conn) start := time.Now() _, err = c.Chill(context.Background(), &test_proto.Time{Duration: durationpb.New(time.Second)}) if err == nil { t.Errorf("Successful Chill when we expected failure.") } if time.Since(start) > 200*time.Millisecond { t.Errorf("Chill failed slow, when WaitForReady should be disabled.") } } // testTimeoutServer is used to implement TestTimeouts, and will attempt to sleep for // the given amount of time (unless it hits a timeout or cancel). type testTimeoutServer struct { test_proto.UnimplementedChillerServer } // Chill implements ChillerServer.Chill func (s *testTimeoutServer) Chill(ctx context.Context, in *test_proto.Time) (*test_proto.Time, error) { start := time.Now() // Sleep for either the requested amount of time, or the context times out or // is canceled. select { case <-time.After(in.Duration.AsDuration() * time.Nanosecond): spent := time.Since(start) / time.Nanosecond return &test_proto.Time{Duration: durationpb.New(spent)}, nil case <-ctx.Done(): return nil, errors.New("unique error indicating that the server's shortened context timed itself out") } } func TestTimeouts(t *testing.T) { server := new(testTimeoutServer) client, _, stop := setup(t, server, clock.NewFake()) defer stop() testCases := []struct { timeout time.Duration expectedErrorPrefix string }{ {250 * time.Millisecond, "rpc error: code = Unknown desc = unique error indicating that the server's shortened context timed itself out"}, {100 * time.Millisecond, "Chiller.Chill timed out after 0 ms"}, {10 * time.Millisecond, "Chiller.Chill timed out after 0 ms"}, } for _, tc := range testCases { t.Run(tc.timeout.String(), func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), tc.timeout) defer cancel() _, err := client.Chill(ctx, &test_proto.Time{Duration: durationpb.New(time.Second)}) if err == nil { t.Fatal("Got no error, expected a timeout") } if !strings.HasPrefix(err.Error(), tc.expectedErrorPrefix) { t.Errorf("Wrong error. Got %s, expected %s", err.Error(), tc.expectedErrorPrefix) } }) } } func TestRequestTimeTagging(t *testing.T) { server := new(testTimeoutServer) serverMetrics, err := newServerMetrics(metrics.NoopRegisterer) test.AssertNotError(t, err, "creating server metrics") client, _, stop := setup(t, server, serverMetrics) defer stop() // Make an RPC request with the ChillerClient with a timeout higher than the // requested ChillerServer delay so that the RPC completes normally ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if _, err := client.Chill(ctx, &test_proto.Time{Duration: durationpb.New(time.Second * 5)}); err != nil { t.Fatalf("Unexpected error calling Chill RPC: %s", err) } // There should be one histogram sample in the serverInterceptor rpcLag stat test.AssertMetricWithLabelsEquals(t, serverMetrics.rpcLag, prometheus.Labels{}, 1) } func TestClockSkew(t *testing.T) { // Create two separate clocks for the client and server serverClk := clock.NewFake() serverClk.Set(time.Now()) clientClk := clock.NewFake() clientClk.Set(time.Now()) _, serverPort, stop := setup(t, &testTimeoutServer{}, serverClk) defer stop() clientMetrics, err := newClientMetrics(metrics.NoopRegisterer) test.AssertNotError(t, err, "creating client metrics") ci := &clientMetadataInterceptor{ timeout: 30 * time.Second, metrics: clientMetrics, clk: clientClk, } conn, err := grpc.NewClient(net.JoinHostPort("localhost", strconv.Itoa(serverPort)), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithUnaryInterceptor(ci.Unary)) if err != nil { t.Fatalf("did not connect: %v", err) } client := test_proto.NewChillerClient(conn) // Create a context with plenty of timeout ctx, cancel := context.WithDeadline(context.Background(), clientClk.Now().Add(10*time.Second)) defer cancel() // Attempt a gRPC request which should succeed _, err = client.Chill(ctx, &test_proto.Time{Duration: durationpb.New(100 * time.Millisecond)}) test.AssertNotError(t, err, "should succeed with no skew") // Skew the client clock forward and the request should fail due to skew clientClk.Add(time.Hour) _, err = client.Chill(ctx, &test_proto.Time{Duration: durationpb.New(100 * time.Millisecond)}) test.AssertError(t, err, "should fail with positive client skew") test.AssertContains(t, err.Error(), "very different time") // Skew the server clock forward and the request should fail due to skew serverClk.Add(2 * time.Hour) _, err = client.Chill(ctx, &test_proto.Time{Duration: durationpb.New(100 * time.Millisecond)}) test.AssertError(t, err, "should fail with negative client skew") test.AssertContains(t, err.Error(), "very different time") } // blockedServer implements a ChillerServer with a Chill method that: // 1. Calls Done() on the received waitgroup when receiving an RPC // 2. Blocks the RPC on the roadblock waitgroup // // This is used by TestInFlightRPCStat to test that the gauge for in-flight RPCs // is incremented and decremented as expected. type blockedServer struct { test_proto.UnimplementedChillerServer roadblock, received sync.WaitGroup } // Chill implements ChillerServer.Chill func (s *blockedServer) Chill(_ context.Context, _ *test_proto.Time) (*test_proto.Time, error) { // Note that a client RPC arrived s.received.Done() // Wait for the roadblock to be cleared s.roadblock.Wait() // Return a dummy spent value to adhere to the chiller protocol return &test_proto.Time{Duration: durationpb.New(time.Millisecond)}, nil } func TestInFlightRPCStat(t *testing.T) { // Create a new blockedServer to act as a ChillerServer server := &blockedServer{} metrics, err := newClientMetrics(metrics.NoopRegisterer) test.AssertNotError(t, err, "creating client metrics") client, _, stop := setup(t, server, metrics) defer stop() // Increment the roadblock waitgroup - this will cause all chill RPCs to // the server to block until we call Done()! server.roadblock.Add(1) // Increment the sentRPCs waitgroup - we use this to find out when all the // RPCs we want to send have been received and we can count the in-flight // gauge numRPCs := 5 server.received.Add(numRPCs) // Fire off a few RPCs. They will block on the blockedServer's roadblock wg for range numRPCs { go func() { // Ignore errors, just chilllll. _, _ = client.Chill(context.Background(), &test_proto.Time{}) }() } // wait until all of the client RPCs have been sent and are blocking. We can // now check the gauge. server.received.Wait() // Specify the labels for the RPCs we're interested in labels := prometheus.Labels{ "service": "Chiller", "method": "Chill", } // We expect the inFlightRPCs gauge for the Chiller.Chill RPCs to be equal to numRPCs. test.AssertMetricWithLabelsEquals(t, metrics.inFlightRPCs, labels, float64(numRPCs)) // Unblock the blockedServer to let all of the Chiller.Chill RPCs complete server.roadblock.Done() // Sleep for a little bit to let all the RPCs complete time.Sleep(1 * time.Second) // Check the gauge value again test.AssertMetricWithLabelsEquals(t, metrics.inFlightRPCs, labels, 0) } func TestServiceAuthChecker(t *testing.T) { ac := authInterceptor{ map[string]map[string]struct{}{ "package.ServiceName": { "allowed.client": {}, "also.allowed": {}, }, }, } // No allowlist is a bad configuration. ctx := context.Background() err := ac.checkContextAuth(ctx, "/package.OtherService/Method/") test.AssertError(t, err, "checking empty allowlist") // Context with no peering information is disallowed. err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/") test.AssertError(t, err, "checking un-peered context") // Context with no auth info is disallowed. ctx = peer.NewContext(ctx, &peer.Peer{}) err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/") test.AssertError(t, err, "checking peer with no auth") // Context with no verified chains is disallowed. ctx = peer.NewContext(ctx, &peer.Peer{ AuthInfo: credentials.TLSInfo{ State: tls.ConnectionState{}, }, }) err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/") test.AssertError(t, err, "checking TLS with no valid chains") // Context with cert with wrong name is disallowed. ctx = peer.NewContext(ctx, &peer.Peer{ AuthInfo: credentials.TLSInfo{ State: tls.ConnectionState{ VerifiedChains: [][]*x509.Certificate{ { &x509.Certificate{ DNSNames: []string{ "disallowed.client", }, }, }, }, }, }, }) err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/") test.AssertError(t, err, "checking disallowed cert") // Context with cert with good name is allowed. ctx = peer.NewContext(ctx, &peer.Peer{ AuthInfo: credentials.TLSInfo{ State: tls.ConnectionState{ VerifiedChains: [][]*x509.Certificate{ { &x509.Certificate{ DNSNames: []string{ "disallowed.client", "also.allowed", }, }, }, }, }, }, }) err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/") test.AssertNotError(t, err, "checking allowed cert") } // testUserAgentServer stores the last value it saw in the user agent field of its context. type testUserAgentServer struct { test_proto.UnimplementedChillerServer lastSeenUA string } // Chill implements ChillerServer.Chill func (s *testUserAgentServer) Chill(ctx context.Context, in *test_proto.Time) (*test_proto.Time, error) { s.lastSeenUA = web.UserAgent(ctx) return nil, nil } func TestUserAgentMetadata(t *testing.T) { server := new(testUserAgentServer) client, _, stop := setup(t, server) defer stop() testUA := "test UA" ctx := web.WithUserAgent(context.Background(), testUA) _, err := client.Chill(ctx, &test_proto.Time{}) if err != nil { t.Fatalf("calling c.Chill: %s", err) } if server.lastSeenUA != testUA { t.Errorf("last seen User-Agent on server side was %q, want %q", server.lastSeenUA, testUA) } } // setup creates a server and client, returning the created client, the running server's port, and a stop function. func setup(t *testing.T, server test_proto.ChillerServer, opts ...any) (test_proto.ChillerClient, int, func()) { clk := clock.NewFake() serverMetricsVal, err := newServerMetrics(metrics.NoopRegisterer) test.AssertNotError(t, err, "creating server metrics") clientMetricsVal, err := newClientMetrics(metrics.NoopRegisterer) test.AssertNotError(t, err, "creating client metrics") for _, opt := range opts { switch optTyped := opt.(type) { case clock.FakeClock: clk = optTyped case clientMetrics: clientMetricsVal = optTyped case serverMetrics: serverMetricsVal = optTyped default: t.Fatalf("setup called with unrecognize option %#v", t) } } lis, err := net.Listen("tcp", ":0") if err != nil { log.Fatalf("failed to listen: %v", err) } port := lis.Addr().(*net.TCPAddr).Port si := newServerMetadataInterceptor(serverMetricsVal, clk) s := grpc.NewServer(grpc.UnaryInterceptor(si.Unary)) test_proto.RegisterChillerServer(s, server) go func() { start := time.Now() err := s.Serve(lis) if err != nil && !strings.HasSuffix(err.Error(), "use of closed network connection") { t.Logf("s.Serve: %v after %s", err, time.Since(start)) } }() ci := &clientMetadataInterceptor{ timeout: 30 * time.Second, metrics: clientMetricsVal, clk: clock.NewFake(), } conn, err := grpc.NewClient(net.JoinHostPort("localhost", strconv.Itoa(port)), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithUnaryInterceptor(ci.Unary)) if err != nil { t.Fatalf("did not connect: %v", err) } return test_proto.NewChillerClient(conn), port, s.Stop }