From 61a6a06b8879354998373f8bef56312ea07d6719 Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Wed, 26 Jan 2022 11:02:23 -0800 Subject: [PATCH] server: handle context errors returned by service handler (#5156) --- interceptor.go | 9 ++++++--- server.go | 11 +++++++---- stream.go | 10 ++++++---- test/server_test.go | 35 +++++++++++++++++++++++++++++++++++ 4 files changed, 54 insertions(+), 11 deletions(-) diff --git a/interceptor.go b/interceptor.go index 668e0adcf..bb96ef57b 100644 --- a/interceptor.go +++ b/interceptor.go @@ -72,9 +72,12 @@ type UnaryServerInfo struct { } // UnaryHandler defines the handler invoked by UnaryServerInterceptor to complete the normal -// execution of a unary RPC. If a UnaryHandler returns an error, it should be produced by the -// status package, or else gRPC will use codes.Unknown as the status code and err.Error() as -// the status message of the RPC. +// execution of a unary RPC. +// +// If a UnaryHandler returns an error, it should either be produced by the +// status package, or be one of the context errors. Otherwise, gRPC will use +// codes.Unknown as the status code and err.Error() as the status message of the +// RPC. type UnaryHandler func(ctx context.Context, req interface{}) (interface{}, error) // UnaryServerInterceptor provides a hook to intercept the execution of a unary RPC on the server. info diff --git a/server.go b/server.go index eadf9e05f..b24b6d539 100644 --- a/server.go +++ b/server.go @@ -1283,9 +1283,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if appErr != nil { appStatus, ok := status.FromError(appErr) if !ok { - // Convert appErr if it is not a grpc status error. - appErr = status.Error(codes.Unknown, appErr.Error()) - appStatus, _ = status.FromError(appErr) + // Convert non-status application error to a status error with code + // Unknown, but handle context errors specifically. + appStatus = status.FromContextError(appErr) + appErr = appStatus.Err() } if trInfo != nil { trInfo.tr.LazyLog(stringer(appStatus.Message()), true) @@ -1549,7 +1550,9 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp if appErr != nil { appStatus, ok := status.FromError(appErr) if !ok { - appStatus = status.New(codes.Unknown, appErr.Error()) + // Convert non-status application error to a status error with code + // Unknown, but handle context errors specifically. + appStatus = status.FromContextError(appErr) appErr = appStatus.Err() } if trInfo != nil { diff --git a/stream.go b/stream.go index 625d47b34..8cdd652e0 100644 --- a/stream.go +++ b/stream.go @@ -46,10 +46,12 @@ import ( ) // StreamHandler defines the handler called by gRPC server to complete the -// execution of a streaming RPC. If a StreamHandler returns an error, it -// should be produced by the status package, or else gRPC will use -// codes.Unknown as the status code and err.Error() as the status message -// of the RPC. +// execution of a streaming RPC. +// +// If a StreamHandler returns an error, it should either be produced by the +// status package, or be one of the context errors. Otherwise, gRPC will use +// codes.Unknown as the status code and err.Error() as the status message of the +// RPC. type StreamHandler func(srv interface{}, stream ServerStream) error // StreamDesc represents a streaming RPC service's method specification. Used diff --git a/test/server_test.go b/test/server_test.go index 97f352328..411e0aa3c 100644 --- a/test/server_test.go +++ b/test/server_test.go @@ -32,6 +32,41 @@ import ( type ctxKey string +// TestServerReturningContextError verifies that if a context error is returned +// by the service handler, the status will have the correct status code, not +// Unknown. +func (s) TestServerReturningContextError(t *testing.T) { + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return nil, context.DeadlineExceeded + }, + FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { + return context.DeadlineExceeded + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}) + if s, ok := status.FromError(err); !ok || s.Code() != codes.DeadlineExceeded { + t.Fatalf("ss.Client.EmptyCall() got error %v; want ", err) + } + + stream, err := ss.Client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("unexpected error starting the stream: %v", err) + } + _, err = stream.Recv() + if s, ok := status.FromError(err); !ok || s.Code() != codes.DeadlineExceeded { + t.Fatalf("ss.Client.FullDuplexCall().Recv() got error %v; want ", err) + } + +} + func (s) TestChainUnaryServerInterceptor(t *testing.T) { var ( firstIntKey = ctxKey("firstIntKey")