This commit is contained in:
Ashesh Vidyut 2025-06-11 15:32:33 -07:00 committed by GitHub
commit b588fade31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 156 additions and 94 deletions

View File

@ -702,7 +702,7 @@ func (s) TestResolverEmptyUpdateNotPanic(t *testing.T) {
} }
func (s) TestClientUpdatesParamsAfterGoAway(t *testing.T) { func (s) TestClientUpdatesParamsAfterGoAway(t *testing.T) {
grpctest.TLogger.ExpectError("Client received GoAway with error code ENHANCE_YOUR_CALM and debug data equal to ASCII \"too_many_pings\"") grpctest.ExpectError("Client received GoAway with error code ENHANCE_YOUR_CALM and debug data equal to ASCII \"too_many_pings\"")
lis, err := net.Listen("tcp", "localhost:0") lis, err := net.Listen("tcp", "localhost:0")
if err != nil { if err != nil {

View File

@ -112,7 +112,7 @@ func (c *errProtoCodec) Name() string {
// Tests the case where encoding fails on the server. Verifies that there is // Tests the case where encoding fails on the server. Verifies that there is
// no panic and that the encoding error is propagated to the client. // no panic and that the encoding error is propagated to the client.
func (s) TestEncodeDoesntPanicOnServer(t *testing.T) { func (s) TestEncodeDoesntPanicOnServer(t *testing.T) {
grpctest.TLogger.ExpectError("grpc: server failed to encode response") grpctest.ExpectError("grpc: server failed to encode response")
// Create a codec that errors when encoding messages. // Create a codec that errors when encoding messages.
encodingErr := errors.New("encoding failed") encodingErr := errors.New("encoding failed")

View File

@ -53,7 +53,7 @@ type Tester struct{}
// Setup updates the tlogger. // Setup updates the tlogger.
func (Tester) Setup(t *testing.T) { func (Tester) Setup(t *testing.T) {
TLogger.Update(t) Update(t)
// TODO: There is one final leak around closing connections without completely // TODO: There is one final leak around closing connections without completely
// draining the recvBuffer that has yet to be resolved. All other leaks have been // draining the recvBuffer that has yet to be resolved. All other leaks have been
// completely addressed, and this can be turned back on as soon as this issue is // completely addressed, and this can be turned back on as soon as this issue is
@ -75,7 +75,7 @@ func (Tester) Teardown(t *testing.T) {
if atomic.LoadUint32(&lcFailed) == 1 { if atomic.LoadUint32(&lcFailed) == 1 {
t.Log("Goroutine leak check disabled for future tests") t.Log("Goroutine leak check disabled for future tests")
} }
TLogger.EndTest(t) EndTest(t)
} }
// Interface defines Tester's methods for use in this package. // Interface defines Tester's methods for use in this package.

View File

@ -27,15 +27,16 @@ import (
"runtime" "runtime"
"strconv" "strconv"
"sync" "sync"
"sync/atomic"
"testing" "testing"
"time" "time"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
) )
// TLogger serves as the grpclog logger and is the interface through which // tLoggerAtomic serves as the grpclog logger and is the interface through which
// expected errors are declared in tests. // expected errors are declared in tests.
var TLogger *tLogger var tLoggerAtomic atomic.Value
const callingFrame = 4 const callingFrame = 4
@ -73,13 +74,19 @@ type tLogger struct {
} }
func init() { func init() {
TLogger = &tLogger{errors: map[*regexp.Regexp]int{}} tLoggerAtomic.Store(&tLogger{errors: map[*regexp.Regexp]int{}})
vLevel := os.Getenv("GRPC_GO_LOG_VERBOSITY_LEVEL") vLevel := os.Getenv("GRPC_GO_LOG_VERBOSITY_LEVEL")
if vl, err := strconv.Atoi(vLevel); err == nil { if vl, err := strconv.Atoi(vLevel); err == nil {
TLogger.v = vl lgr := getLogger()
lgr.v = vl
} }
} }
// getLogger returns the current logger instance.
func getLogger() *tLogger {
return tLoggerAtomic.Load().(*tLogger)
}
// getCallingPrefix returns the <file:line> at the given depth from the stack. // getCallingPrefix returns the <file:line> at the given depth from the stack.
func getCallingPrefix(depth int) (string, error) { func getCallingPrefix(depth int) (string, error) {
_, file, line, ok := runtime.Caller(depth) _, file, line, ok := runtime.Caller(depth)
@ -90,61 +97,62 @@ func getCallingPrefix(depth int) (string, error) {
} }
// log logs the message with the specified parameters to the tLogger. // log logs the message with the specified parameters to the tLogger.
func (g *tLogger) log(ltype logType, depth int, format string, args ...any) { func (tl *tLogger) log(ltype logType, depth int, format string, args ...any) {
g.mu.Lock() tl.mu.Lock()
defer g.mu.Unlock() defer tl.mu.Unlock()
prefix, err := getCallingPrefix(callingFrame + depth) prefix, err := getCallingPrefix(callingFrame + depth)
if err != nil { if err != nil {
g.t.Error(err) tl.t.Error(err)
return return
} }
args = append([]any{ltype.String() + " " + prefix}, args...) args = append([]any{ltype.String() + " " + prefix}, args...)
args = append(args, fmt.Sprintf(" (t=+%s)", time.Since(g.start))) args = append(args, fmt.Sprintf(" (t=+%s)", time.Since(tl.start)))
if format == "" { if format == "" {
switch ltype { switch ltype {
case errorLog: case errorLog:
// fmt.Sprintln is used rather than fmt.Sprint because t.Log uses fmt.Sprintln behavior. // fmt.Sprintln is used rather than fmt.Sprint because tl.Log uses fmt.Sprintln behavior.
if g.expected(fmt.Sprintln(args...)) { if tl.expected(fmt.Sprintln(args...)) {
g.t.Log(args...) tl.t.Log(args...)
} else { } else {
g.t.Error(args...) tl.t.Error(args...)
} }
case fatalLog: case fatalLog:
panic(fmt.Sprint(args...)) panic(fmt.Sprint(args...))
default: default:
g.t.Log(args...) tl.t.Log(args...)
} }
} else { } else {
// Add formatting directives for the callingPrefix and timeSuffix. // Add formatting directives for the callingPrefix and timeSuffix.
format = "%v " + format + "%s" format = "%v " + format + "%s"
switch ltype { switch ltype {
case errorLog: case errorLog:
if g.expected(fmt.Sprintf(format, args...)) { if tl.expected(fmt.Sprintf(format, args...)) {
g.t.Logf(format, args...) tl.t.Logf(format, args...)
} else { } else {
g.t.Errorf(format, args...) tl.t.Errorf(format, args...)
} }
case fatalLog: case fatalLog:
panic(fmt.Sprintf(format, args...)) panic(fmt.Sprintf(format, args...))
default: default:
g.t.Logf(format, args...) tl.t.Logf(format, args...)
} }
} }
} }
// Update updates the testing.T that the testing logger logs to. Should be done // Update updates the testing.T that the testing logger logs to. Should be done
// before every test. It also initializes the tLogger if it has not already. // before every test. It also initializes the tLogger if it has not already.
func (g *tLogger) Update(t *testing.T) { func Update(t *testing.T) {
g.mu.Lock() tl := getLogger()
defer g.mu.Unlock() tl.mu.Lock()
if !g.initialized { defer tl.mu.Unlock()
grpclog.SetLoggerV2(TLogger) if !tl.initialized {
g.initialized = true grpclog.SetLoggerV2(tl)
tl.initialized = true
} }
g.t = t tl.t = t
g.start = time.Now() tl.start = time.Now()
g.errors = map[*regexp.Regexp]int{} tl.errors = map[*regexp.Regexp]int{}
} }
// ExpectError declares an error to be expected. For the next test, the first // ExpectError declares an error to be expected. For the next test, the first
@ -152,41 +160,43 @@ func (g *tLogger) Update(t *testing.T) {
// to fail. "For the next test" includes all the time until the next call to // to fail. "For the next test" includes all the time until the next call to
// Update(). Note that if an expected error is not encountered, this will cause // Update(). Note that if an expected error is not encountered, this will cause
// the test to fail. // the test to fail.
func (g *tLogger) ExpectError(expr string) { func ExpectError(expr string) {
g.ExpectErrorN(expr, 1) ExpectErrorN(expr, 1)
} }
// ExpectErrorN declares an error to be expected n times. // ExpectErrorN declares an error to be expected n times.
func (g *tLogger) ExpectErrorN(expr string, n int) { func ExpectErrorN(expr string, n int) {
g.mu.Lock() tl := getLogger()
defer g.mu.Unlock() tl.mu.Lock()
defer tl.mu.Unlock()
re, err := regexp.Compile(expr) re, err := regexp.Compile(expr)
if err != nil { if err != nil {
g.t.Error(err) tl.t.Error(err)
return return
} }
g.errors[re] += n tl.errors[re] += n
} }
// EndTest checks if expected errors were not encountered. // EndTest checks if expected errors were not encountered.
func (g *tLogger) EndTest(t *testing.T) { func EndTest(t *testing.T) {
g.mu.Lock() tl := getLogger()
defer g.mu.Unlock() tl.mu.Lock()
for re, count := range g.errors { defer tl.mu.Unlock()
for re, count := range tl.errors {
if count > 0 { if count > 0 {
t.Errorf("Expected error '%v' not encountered", re.String()) t.Errorf("Expected error '%v' not encountered", re.String())
} }
} }
g.errors = map[*regexp.Regexp]int{} tl.errors = map[*regexp.Regexp]int{}
} }
// expected determines if the error string is protected or not. // expected determines if the error string is protected or not.
func (g *tLogger) expected(s string) bool { func (tl *tLogger) expected(s string) bool {
for re, count := range g.errors { for re, count := range tl.errors {
if re.FindStringIndex(s) != nil { if re.FindStringIndex(s) != nil {
g.errors[re]-- tl.errors[re]--
if count <= 1 { if count <= 1 {
delete(g.errors, re) delete(tl.errors, re)
} }
return true return true
} }
@ -194,70 +204,70 @@ func (g *tLogger) expected(s string) bool {
return false return false
} }
func (g *tLogger) Info(args ...any) { func (tl *tLogger) Info(args ...any) {
g.log(infoLog, 0, "", args...) tl.log(infoLog, 0, "", args...)
} }
func (g *tLogger) Infoln(args ...any) { func (tl *tLogger) Infoln(args ...any) {
g.log(infoLog, 0, "", args...) tl.log(infoLog, 0, "", args...)
} }
func (g *tLogger) Infof(format string, args ...any) { func (tl *tLogger) Infof(format string, args ...any) {
g.log(infoLog, 0, format, args...) tl.log(infoLog, 0, format, args...)
} }
func (g *tLogger) InfoDepth(depth int, args ...any) { func (tl *tLogger) InfoDepth(depth int, args ...any) {
g.log(infoLog, depth, "", args...) tl.log(infoLog, depth, "", args...)
} }
func (g *tLogger) Warning(args ...any) { func (tl *tLogger) Warning(args ...any) {
g.log(warningLog, 0, "", args...) tl.log(warningLog, 0, "", args...)
} }
func (g *tLogger) Warningln(args ...any) { func (tl *tLogger) Warningln(args ...any) {
g.log(warningLog, 0, "", args...) tl.log(warningLog, 0, "", args...)
} }
func (g *tLogger) Warningf(format string, args ...any) { func (tl *tLogger) Warningf(format string, args ...any) {
g.log(warningLog, 0, format, args...) tl.log(warningLog, 0, format, args...)
} }
func (g *tLogger) WarningDepth(depth int, args ...any) { func (tl *tLogger) WarningDepth(depth int, args ...any) {
g.log(warningLog, depth, "", args...) tl.log(warningLog, depth, "", args...)
} }
func (g *tLogger) Error(args ...any) { func (tl *tLogger) Error(args ...any) {
g.log(errorLog, 0, "", args...) tl.log(errorLog, 0, "", args...)
} }
func (g *tLogger) Errorln(args ...any) { func (tl *tLogger) Errorln(args ...any) {
g.log(errorLog, 0, "", args...) tl.log(errorLog, 0, "", args...)
} }
func (g *tLogger) Errorf(format string, args ...any) { func (tl *tLogger) Errorf(format string, args ...any) {
g.log(errorLog, 0, format, args...) tl.log(errorLog, 0, format, args...)
} }
func (g *tLogger) ErrorDepth(depth int, args ...any) { func (tl *tLogger) ErrorDepth(depth int, args ...any) {
g.log(errorLog, depth, "", args...) tl.log(errorLog, depth, "", args...)
} }
func (g *tLogger) Fatal(args ...any) { func (tl *tLogger) Fatal(args ...any) {
g.log(fatalLog, 0, "", args...) tl.log(fatalLog, 0, "", args...)
} }
func (g *tLogger) Fatalln(args ...any) { func (tl *tLogger) Fatalln(args ...any) {
g.log(fatalLog, 0, "", args...) tl.log(fatalLog, 0, "", args...)
} }
func (g *tLogger) Fatalf(format string, args ...any) { func (tl *tLogger) Fatalf(format string, args ...any) {
g.log(fatalLog, 0, format, args...) tl.log(fatalLog, 0, format, args...)
} }
func (g *tLogger) FatalDepth(depth int, args ...any) { func (tl *tLogger) FatalDepth(depth int, args ...any) {
g.log(fatalLog, depth, "", args...) tl.log(fatalLog, depth, "", args...)
} }
func (g *tLogger) V(l int) bool { func (tl *tLogger) V(l int) bool {
return l <= g.v return l <= tl.v
} }

View File

@ -19,6 +19,7 @@
package grpctest package grpctest
import ( import (
"regexp"
"testing" "testing"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
@ -66,10 +67,10 @@ func (s) TestWarningDepth(*testing.T) {
func (s) TestError(*testing.T) { func (s) TestError(*testing.T) {
const numErrors = 10 const numErrors = 10
TLogger.ExpectError("Expected error") ExpectError("Expected error")
TLogger.ExpectError("Expected ln error") ExpectError("Expected ln error")
TLogger.ExpectError("Expected formatted error") ExpectError("Expected formatted error")
TLogger.ExpectErrorN("Expected repeated error", numErrors) ExpectErrorN("Expected repeated error", numErrors)
grpclog.Error("Expected", "error") grpclog.Error("Expected", "error")
grpclog.Errorln("Expected", "ln", "error") grpclog.Errorln("Expected", "ln", "error")
grpclog.Errorf("%v %v %v", "Expected", "formatted", "error") grpclog.Errorf("%v %v %v", "Expected", "formatted", "error")
@ -77,3 +78,54 @@ func (s) TestError(*testing.T) {
grpclog.Error("Expected repeated error") grpclog.Error("Expected repeated error")
} }
} }
func (s) TestInit(t *testing.T) {
// Reset the atomic value
tLoggerAtomic.Store(&tLogger{errors: map[*regexp.Regexp]int{}})
// Test initial state
tl := getLogger()
if tl == nil {
t.Fatal("getLogger() returned nil")
}
if tl.errors == nil {
t.Error("tl.errors is nil")
}
if len(tl.errors) != 0 {
t.Errorf("tl.errors = %v; want empty map", tl.errors)
}
if tl.initialized {
t.Error("tl.initialized = true; want false")
}
if tl.t != nil {
t.Error("tl.t is not nil")
}
if !tl.start.IsZero() {
t.Error("tl.start is not zero")
}
}
func (s) TestGetLogger(t *testing.T) {
// Save original logger
origLogger := getLogger()
defer tLoggerAtomic.Store(origLogger)
// Create new logger
newLogger := &tLogger{errors: map[*regexp.Regexp]int{}}
tLoggerAtomic.Store(newLogger)
// Verify new logger is retrieved
retrievedLogger := getLogger()
if retrievedLogger != newLogger {
t.Error("getLogger() did not return the newly stored logger")
}
// Restore original logger
tLoggerAtomic.Store(origLogger)
// Verify original logger is retrieved
retrievedLogger = getLogger()
if retrievedLogger != origLogger {
t.Error("getLogger() did not return the original logger after restore")
}
}

View File

@ -398,7 +398,7 @@ func (s) TestKeepaliveClientStaysHealthyWithResponsiveServer(t *testing.T) {
// explicitly makes sure the fix works and the client sends a ping every [Time] // explicitly makes sure the fix works and the client sends a ping every [Time]
// period. // period.
func (s) TestKeepaliveClientFrequency(t *testing.T) { func (s) TestKeepaliveClientFrequency(t *testing.T) {
grpctest.TLogger.ExpectError("Client received GoAway with error code ENHANCE_YOUR_CALM and debug data equal to ASCII \"too_many_pings\"") grpctest.ExpectError("Client received GoAway with error code ENHANCE_YOUR_CALM and debug data equal to ASCII \"too_many_pings\"")
serverConfig := &ServerConfig{ serverConfig := &ServerConfig{
KeepalivePolicy: keepalive.EnforcementPolicy{ KeepalivePolicy: keepalive.EnforcementPolicy{
@ -430,7 +430,7 @@ func (s) TestKeepaliveClientFrequency(t *testing.T) {
// (when there are no active streams), based on the configured // (when there are no active streams), based on the configured
// EnforcementPolicy. // EnforcementPolicy.
func (s) TestKeepaliveServerEnforcementWithAbusiveClientNoRPC(t *testing.T) { func (s) TestKeepaliveServerEnforcementWithAbusiveClientNoRPC(t *testing.T) {
grpctest.TLogger.ExpectError("Client received GoAway with error code ENHANCE_YOUR_CALM and debug data equal to ASCII \"too_many_pings\"") grpctest.ExpectError("Client received GoAway with error code ENHANCE_YOUR_CALM and debug data equal to ASCII \"too_many_pings\"")
serverConfig := &ServerConfig{ serverConfig := &ServerConfig{
KeepalivePolicy: keepalive.EnforcementPolicy{ KeepalivePolicy: keepalive.EnforcementPolicy{
@ -461,7 +461,7 @@ func (s) TestKeepaliveServerEnforcementWithAbusiveClientNoRPC(t *testing.T) {
// (even when there is an active stream), based on the configured // (even when there is an active stream), based on the configured
// EnforcementPolicy. // EnforcementPolicy.
func (s) TestKeepaliveServerEnforcementWithAbusiveClientWithRPC(t *testing.T) { func (s) TestKeepaliveServerEnforcementWithAbusiveClientWithRPC(t *testing.T) {
grpctest.TLogger.ExpectError("Client received GoAway with error code ENHANCE_YOUR_CALM and debug data equal to ASCII \"too_many_pings\"") grpctest.ExpectError("Client received GoAway with error code ENHANCE_YOUR_CALM and debug data equal to ASCII \"too_many_pings\"")
serverConfig := &ServerConfig{ serverConfig := &ServerConfig{
KeepalivePolicy: keepalive.EnforcementPolicy{ KeepalivePolicy: keepalive.EnforcementPolicy{

View File

@ -266,7 +266,7 @@ func (f *fakeORCAService) StreamCoreMetrics(req *v3orcaservicepb.OrcaLoadReportR
// TestProducerBackoff verifies that the ORCA producer applies the proper // TestProducerBackoff verifies that the ORCA producer applies the proper
// backoff after stream failures. // backoff after stream failures.
func (s) TestProducerBackoff(t *testing.T) { func (s) TestProducerBackoff(t *testing.T) {
grpctest.TLogger.ExpectErrorN("injected error", 4) grpctest.ExpectErrorN("injected error", 4)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel() defer cancel()

View File

@ -5414,7 +5414,7 @@ func (s) TestStatusInvalidUTF8Message(t *testing.T) {
// will fail to marshal the status because of the invalid utf8 message. Details // will fail to marshal the status because of the invalid utf8 message. Details
// will be dropped when sending. // will be dropped when sending.
func (s) TestStatusInvalidUTF8Details(t *testing.T) { func (s) TestStatusInvalidUTF8Details(t *testing.T) {
grpctest.TLogger.ExpectError("Failed to marshal rpc status") grpctest.ExpectError("Failed to marshal rpc status")
var ( var (
origMsg = string([]byte{0xff, 0xfe, 0xfd}) origMsg = string([]byte{0xff, 0xfe, 0xfd})
@ -6425,7 +6425,7 @@ func (s) TestServerClosesConn(t *testing.T) {
// TestNilStatsHandler ensures we do not panic as a result of a nil stats // TestNilStatsHandler ensures we do not panic as a result of a nil stats
// handler. // handler.
func (s) TestNilStatsHandler(t *testing.T) { func (s) TestNilStatsHandler(t *testing.T) {
grpctest.TLogger.ExpectErrorN("ignoring nil parameter", 2) grpctest.ExpectErrorN("ignoring nil parameter", 2)
ss := &stubserver.StubServer{ ss := &stubserver.StubServer{
UnaryCallF: func(context.Context, *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { UnaryCallF: func(context.Context, *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
return &testpb.SimpleResponse{}, nil return &testpb.SimpleResponse{}, nil

View File

@ -121,7 +121,7 @@ func (s) TestDetailedGoAwayErrorOnGracefulClosePropagatesToRPCError(t *testing.T
} }
func (s) TestDetailedGoAwayErrorOnAbruptClosePropagatesToRPCError(t *testing.T) { func (s) TestDetailedGoAwayErrorOnAbruptClosePropagatesToRPCError(t *testing.T) {
grpctest.TLogger.ExpectError("Client received GoAway with error code ENHANCE_YOUR_CALM and debug data equal to ASCII \"too_many_pings\"") grpctest.ExpectError("Client received GoAway with error code ENHANCE_YOUR_CALM and debug data equal to ASCII \"too_many_pings\"")
// set the min keepalive time very low so that this test can take // set the min keepalive time very low so that this test can take
// a reasonable amount of time // a reasonable amount of time
prev := internal.KeepaliveMinPingTime prev := internal.KeepaliveMinPingTime

View File

@ -254,7 +254,7 @@ func (s) TestHealthCheckWatchStateChange(t *testing.T) {
// If Watch returns Unimplemented, then the ClientConn should go into READY state. // If Watch returns Unimplemented, then the ClientConn should go into READY state.
func (s) TestHealthCheckHealthServerNotRegistered(t *testing.T) { func (s) TestHealthCheckHealthServerNotRegistered(t *testing.T) {
grpctest.TLogger.ExpectError("Subchannel health check is unimplemented at server side, thus health check is disabled") grpctest.ExpectError("Subchannel health check is unimplemented at server side, thus health check is disabled")
s := grpc.NewServer() s := grpc.NewServer()
lis, err := net.Listen("tcp", "localhost:0") lis, err := net.Listen("tcp", "localhost:0")
if err != nil { if err != nil {

View File

@ -37,7 +37,7 @@ import (
) )
func (s) TestInvalidMetadata(t *testing.T) { func (s) TestInvalidMetadata(t *testing.T) {
grpctest.TLogger.ExpectErrorN("stream: failed to validate md when setting trailer", 5) grpctest.ExpectErrorN("stream: failed to validate md when setting trailer", 5)
tests := []struct { tests := []struct {
name string name string