diff --git a/clientconn_test.go b/clientconn_test.go index 9cca5a8eb..68c224140 100644 --- a/clientconn_test.go +++ b/clientconn_test.go @@ -702,7 +702,7 @@ func (s) TestResolverEmptyUpdateNotPanic(t *testing.T) { } func (s) TestClientUpdatesParamsAfterGoAway(t *testing.T) { - grpctest.TLogger.ExpectError("Client received GoAway with error code ENHANCE_YOUR_CALM and debug data equal to ASCII \"too_many_pings\"") + grpctest.ExpectError("Client received GoAway with error code ENHANCE_YOUR_CALM and debug data equal to ASCII \"too_many_pings\"") lis, err := net.Listen("tcp", "localhost:0") if err != nil { diff --git a/encoding/encoding_test.go b/encoding/encoding_test.go index 2dce42ddd..ab22429b6 100644 --- a/encoding/encoding_test.go +++ b/encoding/encoding_test.go @@ -112,7 +112,7 @@ func (c *errProtoCodec) Name() string { // Tests the case where encoding fails on the server. Verifies that there is // no panic and that the encoding error is propagated to the client. func (s) TestEncodeDoesntPanicOnServer(t *testing.T) { - grpctest.TLogger.ExpectError("grpc: server failed to encode response") + grpctest.ExpectError("grpc: server failed to encode response") // Create a codec that errors when encoding messages. encodingErr := errors.New("encoding failed") diff --git a/internal/grpctest/grpctest.go b/internal/grpctest/grpctest.go index b8bc38580..be62c2a3c 100644 --- a/internal/grpctest/grpctest.go +++ b/internal/grpctest/grpctest.go @@ -53,7 +53,7 @@ type Tester struct{} // Setup updates the tlogger. func (Tester) Setup(t *testing.T) { - TLogger.Update(t) + Update(t) // TODO: There is one final leak around closing connections without completely // draining the recvBuffer that has yet to be resolved. All other leaks have been // completely addressed, and this can be turned back on as soon as this issue is @@ -75,7 +75,7 @@ func (Tester) Teardown(t *testing.T) { if atomic.LoadUint32(&lcFailed) == 1 { t.Log("Goroutine leak check disabled for future tests") } - TLogger.EndTest(t) + EndTest(t) } // Interface defines Tester's methods for use in this package. diff --git a/internal/grpctest/tlogger.go b/internal/grpctest/tlogger.go index f7f6da152..ae8af3fa9 100644 --- a/internal/grpctest/tlogger.go +++ b/internal/grpctest/tlogger.go @@ -27,15 +27,16 @@ import ( "runtime" "strconv" "sync" + "sync/atomic" "testing" "time" "google.golang.org/grpc/grpclog" ) -// TLogger serves as the grpclog logger and is the interface through which +// tLoggerAtomic serves as the grpclog logger and is the interface through which // expected errors are declared in tests. -var TLogger *tLogger +var tLoggerAtomic atomic.Value const callingFrame = 4 @@ -73,13 +74,19 @@ type tLogger struct { } func init() { - TLogger = &tLogger{errors: map[*regexp.Regexp]int{}} + tLoggerAtomic.Store(&tLogger{errors: map[*regexp.Regexp]int{}}) vLevel := os.Getenv("GRPC_GO_LOG_VERBOSITY_LEVEL") if vl, err := strconv.Atoi(vLevel); err == nil { - TLogger.v = vl + lgr := getLogger() + lgr.v = vl } } +// getLogger returns the current logger instance. +func getLogger() *tLogger { + return tLoggerAtomic.Load().(*tLogger) +} + // getCallingPrefix returns the at the given depth from the stack. func getCallingPrefix(depth int) (string, error) { _, file, line, ok := runtime.Caller(depth) @@ -90,61 +97,62 @@ func getCallingPrefix(depth int) (string, error) { } // log logs the message with the specified parameters to the tLogger. -func (g *tLogger) log(ltype logType, depth int, format string, args ...any) { - g.mu.Lock() - defer g.mu.Unlock() +func (tl *tLogger) log(ltype logType, depth int, format string, args ...any) { + tl.mu.Lock() + defer tl.mu.Unlock() prefix, err := getCallingPrefix(callingFrame + depth) if err != nil { - g.t.Error(err) + tl.t.Error(err) return } args = append([]any{ltype.String() + " " + prefix}, args...) - args = append(args, fmt.Sprintf(" (t=+%s)", time.Since(g.start))) + args = append(args, fmt.Sprintf(" (t=+%s)", time.Since(tl.start))) if format == "" { switch ltype { case errorLog: - // fmt.Sprintln is used rather than fmt.Sprint because t.Log uses fmt.Sprintln behavior. - if g.expected(fmt.Sprintln(args...)) { - g.t.Log(args...) + // fmt.Sprintln is used rather than fmt.Sprint because tl.Log uses fmt.Sprintln behavior. + if tl.expected(fmt.Sprintln(args...)) { + tl.t.Log(args...) } else { - g.t.Error(args...) + tl.t.Error(args...) } case fatalLog: panic(fmt.Sprint(args...)) default: - g.t.Log(args...) + tl.t.Log(args...) } } else { // Add formatting directives for the callingPrefix and timeSuffix. format = "%v " + format + "%s" switch ltype { case errorLog: - if g.expected(fmt.Sprintf(format, args...)) { - g.t.Logf(format, args...) + if tl.expected(fmt.Sprintf(format, args...)) { + tl.t.Logf(format, args...) } else { - g.t.Errorf(format, args...) + tl.t.Errorf(format, args...) } case fatalLog: panic(fmt.Sprintf(format, args...)) default: - g.t.Logf(format, args...) + tl.t.Logf(format, args...) } } } // Update updates the testing.T that the testing logger logs to. Should be done // before every test. It also initializes the tLogger if it has not already. -func (g *tLogger) Update(t *testing.T) { - g.mu.Lock() - defer g.mu.Unlock() - if !g.initialized { - grpclog.SetLoggerV2(TLogger) - g.initialized = true +func Update(t *testing.T) { + tl := getLogger() + tl.mu.Lock() + defer tl.mu.Unlock() + if !tl.initialized { + grpclog.SetLoggerV2(tl) + tl.initialized = true } - g.t = t - g.start = time.Now() - g.errors = map[*regexp.Regexp]int{} + tl.t = t + tl.start = time.Now() + tl.errors = map[*regexp.Regexp]int{} } // ExpectError declares an error to be expected. For the next test, the first @@ -152,41 +160,43 @@ func (g *tLogger) Update(t *testing.T) { // to fail. "For the next test" includes all the time until the next call to // Update(). Note that if an expected error is not encountered, this will cause // the test to fail. -func (g *tLogger) ExpectError(expr string) { - g.ExpectErrorN(expr, 1) +func ExpectError(expr string) { + ExpectErrorN(expr, 1) } // ExpectErrorN declares an error to be expected n times. -func (g *tLogger) ExpectErrorN(expr string, n int) { - g.mu.Lock() - defer g.mu.Unlock() +func ExpectErrorN(expr string, n int) { + tl := getLogger() + tl.mu.Lock() + defer tl.mu.Unlock() re, err := regexp.Compile(expr) if err != nil { - g.t.Error(err) + tl.t.Error(err) return } - g.errors[re] += n + tl.errors[re] += n } // EndTest checks if expected errors were not encountered. -func (g *tLogger) EndTest(t *testing.T) { - g.mu.Lock() - defer g.mu.Unlock() - for re, count := range g.errors { +func EndTest(t *testing.T) { + tl := getLogger() + tl.mu.Lock() + defer tl.mu.Unlock() + for re, count := range tl.errors { if count > 0 { t.Errorf("Expected error '%v' not encountered", re.String()) } } - g.errors = map[*regexp.Regexp]int{} + tl.errors = map[*regexp.Regexp]int{} } // expected determines if the error string is protected or not. -func (g *tLogger) expected(s string) bool { - for re, count := range g.errors { +func (tl *tLogger) expected(s string) bool { + for re, count := range tl.errors { if re.FindStringIndex(s) != nil { - g.errors[re]-- + tl.errors[re]-- if count <= 1 { - delete(g.errors, re) + delete(tl.errors, re) } return true } @@ -194,70 +204,70 @@ func (g *tLogger) expected(s string) bool { return false } -func (g *tLogger) Info(args ...any) { - g.log(infoLog, 0, "", args...) +func (tl *tLogger) Info(args ...any) { + tl.log(infoLog, 0, "", args...) } -func (g *tLogger) Infoln(args ...any) { - g.log(infoLog, 0, "", args...) +func (tl *tLogger) Infoln(args ...any) { + tl.log(infoLog, 0, "", args...) } -func (g *tLogger) Infof(format string, args ...any) { - g.log(infoLog, 0, format, args...) +func (tl *tLogger) Infof(format string, args ...any) { + tl.log(infoLog, 0, format, args...) } -func (g *tLogger) InfoDepth(depth int, args ...any) { - g.log(infoLog, depth, "", args...) +func (tl *tLogger) InfoDepth(depth int, args ...any) { + tl.log(infoLog, depth, "", args...) } -func (g *tLogger) Warning(args ...any) { - g.log(warningLog, 0, "", args...) +func (tl *tLogger) Warning(args ...any) { + tl.log(warningLog, 0, "", args...) } -func (g *tLogger) Warningln(args ...any) { - g.log(warningLog, 0, "", args...) +func (tl *tLogger) Warningln(args ...any) { + tl.log(warningLog, 0, "", args...) } -func (g *tLogger) Warningf(format string, args ...any) { - g.log(warningLog, 0, format, args...) +func (tl *tLogger) Warningf(format string, args ...any) { + tl.log(warningLog, 0, format, args...) } -func (g *tLogger) WarningDepth(depth int, args ...any) { - g.log(warningLog, depth, "", args...) +func (tl *tLogger) WarningDepth(depth int, args ...any) { + tl.log(warningLog, depth, "", args...) } -func (g *tLogger) Error(args ...any) { - g.log(errorLog, 0, "", args...) +func (tl *tLogger) Error(args ...any) { + tl.log(errorLog, 0, "", args...) } -func (g *tLogger) Errorln(args ...any) { - g.log(errorLog, 0, "", args...) +func (tl *tLogger) Errorln(args ...any) { + tl.log(errorLog, 0, "", args...) } -func (g *tLogger) Errorf(format string, args ...any) { - g.log(errorLog, 0, format, args...) +func (tl *tLogger) Errorf(format string, args ...any) { + tl.log(errorLog, 0, format, args...) } -func (g *tLogger) ErrorDepth(depth int, args ...any) { - g.log(errorLog, depth, "", args...) +func (tl *tLogger) ErrorDepth(depth int, args ...any) { + tl.log(errorLog, depth, "", args...) } -func (g *tLogger) Fatal(args ...any) { - g.log(fatalLog, 0, "", args...) +func (tl *tLogger) Fatal(args ...any) { + tl.log(fatalLog, 0, "", args...) } -func (g *tLogger) Fatalln(args ...any) { - g.log(fatalLog, 0, "", args...) +func (tl *tLogger) Fatalln(args ...any) { + tl.log(fatalLog, 0, "", args...) } -func (g *tLogger) Fatalf(format string, args ...any) { - g.log(fatalLog, 0, format, args...) +func (tl *tLogger) Fatalf(format string, args ...any) { + tl.log(fatalLog, 0, format, args...) } -func (g *tLogger) FatalDepth(depth int, args ...any) { - g.log(fatalLog, depth, "", args...) +func (tl *tLogger) FatalDepth(depth int, args ...any) { + tl.log(fatalLog, depth, "", args...) } -func (g *tLogger) V(l int) bool { - return l <= g.v +func (tl *tLogger) V(l int) bool { + return l <= tl.v } diff --git a/internal/grpctest/tlogger_test.go b/internal/grpctest/tlogger_test.go index 364f1432e..475ca7844 100644 --- a/internal/grpctest/tlogger_test.go +++ b/internal/grpctest/tlogger_test.go @@ -19,6 +19,7 @@ package grpctest import ( + "regexp" "testing" "google.golang.org/grpc/grpclog" @@ -66,10 +67,10 @@ func (s) TestWarningDepth(*testing.T) { func (s) TestError(*testing.T) { const numErrors = 10 - TLogger.ExpectError("Expected error") - TLogger.ExpectError("Expected ln error") - TLogger.ExpectError("Expected formatted error") - TLogger.ExpectErrorN("Expected repeated error", numErrors) + ExpectError("Expected error") + ExpectError("Expected ln error") + ExpectError("Expected formatted error") + ExpectErrorN("Expected repeated error", numErrors) grpclog.Error("Expected", "error") grpclog.Errorln("Expected", "ln", "error") grpclog.Errorf("%v %v %v", "Expected", "formatted", "error") @@ -77,3 +78,54 @@ func (s) TestError(*testing.T) { grpclog.Error("Expected repeated error") } } + +func (s) TestInit(t *testing.T) { + // Reset the atomic value + tLoggerAtomic.Store(&tLogger{errors: map[*regexp.Regexp]int{}}) + + // Test initial state + tl := getLogger() + if tl == nil { + t.Fatal("getLogger() returned nil") + } + if tl.errors == nil { + t.Error("tl.errors is nil") + } + if len(tl.errors) != 0 { + t.Errorf("tl.errors = %v; want empty map", tl.errors) + } + if tl.initialized { + t.Error("tl.initialized = true; want false") + } + if tl.t != nil { + t.Error("tl.t is not nil") + } + if !tl.start.IsZero() { + t.Error("tl.start is not zero") + } +} + +func (s) TestGetLogger(t *testing.T) { + // Save original logger + origLogger := getLogger() + defer tLoggerAtomic.Store(origLogger) + + // Create new logger + newLogger := &tLogger{errors: map[*regexp.Regexp]int{}} + tLoggerAtomic.Store(newLogger) + + // Verify new logger is retrieved + retrievedLogger := getLogger() + if retrievedLogger != newLogger { + t.Error("getLogger() did not return the newly stored logger") + } + + // Restore original logger + tLoggerAtomic.Store(origLogger) + + // Verify original logger is retrieved + retrievedLogger = getLogger() + if retrievedLogger != origLogger { + t.Error("getLogger() did not return the original logger after restore") + } +} diff --git a/internal/transport/keepalive_test.go b/internal/transport/keepalive_test.go index 037b0b1c1..0bd7ba356 100644 --- a/internal/transport/keepalive_test.go +++ b/internal/transport/keepalive_test.go @@ -398,7 +398,7 @@ func (s) TestKeepaliveClientStaysHealthyWithResponsiveServer(t *testing.T) { // explicitly makes sure the fix works and the client sends a ping every [Time] // period. func (s) TestKeepaliveClientFrequency(t *testing.T) { - grpctest.TLogger.ExpectError("Client received GoAway with error code ENHANCE_YOUR_CALM and debug data equal to ASCII \"too_many_pings\"") + grpctest.ExpectError("Client received GoAway with error code ENHANCE_YOUR_CALM and debug data equal to ASCII \"too_many_pings\"") serverConfig := &ServerConfig{ KeepalivePolicy: keepalive.EnforcementPolicy{ @@ -430,7 +430,7 @@ func (s) TestKeepaliveClientFrequency(t *testing.T) { // (when there are no active streams), based on the configured // EnforcementPolicy. func (s) TestKeepaliveServerEnforcementWithAbusiveClientNoRPC(t *testing.T) { - grpctest.TLogger.ExpectError("Client received GoAway with error code ENHANCE_YOUR_CALM and debug data equal to ASCII \"too_many_pings\"") + grpctest.ExpectError("Client received GoAway with error code ENHANCE_YOUR_CALM and debug data equal to ASCII \"too_many_pings\"") serverConfig := &ServerConfig{ KeepalivePolicy: keepalive.EnforcementPolicy{ @@ -461,7 +461,7 @@ func (s) TestKeepaliveServerEnforcementWithAbusiveClientNoRPC(t *testing.T) { // (even when there is an active stream), based on the configured // EnforcementPolicy. func (s) TestKeepaliveServerEnforcementWithAbusiveClientWithRPC(t *testing.T) { - grpctest.TLogger.ExpectError("Client received GoAway with error code ENHANCE_YOUR_CALM and debug data equal to ASCII \"too_many_pings\"") + grpctest.ExpectError("Client received GoAway with error code ENHANCE_YOUR_CALM and debug data equal to ASCII \"too_many_pings\"") serverConfig := &ServerConfig{ KeepalivePolicy: keepalive.EnforcementPolicy{ diff --git a/orca/producer_test.go b/orca/producer_test.go index 130b21a72..41758b867 100644 --- a/orca/producer_test.go +++ b/orca/producer_test.go @@ -266,7 +266,7 @@ func (f *fakeORCAService) StreamCoreMetrics(req *v3orcaservicepb.OrcaLoadReportR // TestProducerBackoff verifies that the ORCA producer applies the proper // backoff after stream failures. func (s) TestProducerBackoff(t *testing.T) { - grpctest.TLogger.ExpectErrorN("injected error", 4) + grpctest.ExpectErrorN("injected error", 4) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() diff --git a/test/end2end_test.go b/test/end2end_test.go index ab2517d5f..7fa45ebcd 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -5414,7 +5414,7 @@ func (s) TestStatusInvalidUTF8Message(t *testing.T) { // will fail to marshal the status because of the invalid utf8 message. Details // will be dropped when sending. func (s) TestStatusInvalidUTF8Details(t *testing.T) { - grpctest.TLogger.ExpectError("Failed to marshal rpc status") + grpctest.ExpectError("Failed to marshal rpc status") var ( origMsg = string([]byte{0xff, 0xfe, 0xfd}) @@ -6425,7 +6425,7 @@ func (s) TestServerClosesConn(t *testing.T) { // TestNilStatsHandler ensures we do not panic as a result of a nil stats // handler. func (s) TestNilStatsHandler(t *testing.T) { - grpctest.TLogger.ExpectErrorN("ignoring nil parameter", 2) + grpctest.ExpectErrorN("ignoring nil parameter", 2) ss := &stubserver.StubServer{ UnaryCallF: func(context.Context, *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { return &testpb.SimpleResponse{}, nil diff --git a/test/goaway_test.go b/test/goaway_test.go index 076412da7..6f90eccef 100644 --- a/test/goaway_test.go +++ b/test/goaway_test.go @@ -121,7 +121,7 @@ func (s) TestDetailedGoAwayErrorOnGracefulClosePropagatesToRPCError(t *testing.T } func (s) TestDetailedGoAwayErrorOnAbruptClosePropagatesToRPCError(t *testing.T) { - grpctest.TLogger.ExpectError("Client received GoAway with error code ENHANCE_YOUR_CALM and debug data equal to ASCII \"too_many_pings\"") + grpctest.ExpectError("Client received GoAway with error code ENHANCE_YOUR_CALM and debug data equal to ASCII \"too_many_pings\"") // set the min keepalive time very low so that this test can take // a reasonable amount of time prev := internal.KeepaliveMinPingTime diff --git a/test/healthcheck_test.go b/test/healthcheck_test.go index 0f7ec54a8..29fa3228b 100644 --- a/test/healthcheck_test.go +++ b/test/healthcheck_test.go @@ -254,7 +254,7 @@ func (s) TestHealthCheckWatchStateChange(t *testing.T) { // If Watch returns Unimplemented, then the ClientConn should go into READY state. func (s) TestHealthCheckHealthServerNotRegistered(t *testing.T) { - grpctest.TLogger.ExpectError("Subchannel health check is unimplemented at server side, thus health check is disabled") + grpctest.ExpectError("Subchannel health check is unimplemented at server side, thus health check is disabled") s := grpc.NewServer() lis, err := net.Listen("tcp", "localhost:0") if err != nil { diff --git a/test/metadata_test.go b/test/metadata_test.go index 57139a8d9..6c469d2f7 100644 --- a/test/metadata_test.go +++ b/test/metadata_test.go @@ -37,7 +37,7 @@ import ( ) func (s) TestInvalidMetadata(t *testing.T) { - grpctest.TLogger.ExpectErrorN("stream: failed to validate md when setting trailer", 5) + grpctest.ExpectErrorN("stream: failed to validate md when setting trailer", 5) tests := []struct { name string