diff --git a/internal/internal.go b/internal/internal.go index 64cdbe50c..386527ba2 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -40,6 +40,11 @@ var ( // ParseServiceConfig is a function to parse JSON service configs into // opaque data structures. ParseServiceConfig func(sc string) (interface{}, error) + // StatusRawProto is exported by status/status.go. This func returns a + // pointer to the wrapped Status proto for a given status.Status without a + // call to proto.Clone(). The returned Status proto should not be mutated by + // the caller. + StatusRawProto interface{} // func (*status.Status) *spb.Status ) // HealthChecker defines the signature of the client-side LB channel health checking function. diff --git a/internal/testutils/status_equal.go b/internal/testutils/status_equal.go new file mode 100644 index 000000000..dfd647336 --- /dev/null +++ b/internal/testutils/status_equal.go @@ -0,0 +1,38 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package testutils + +import ( + "github.com/golang/protobuf/proto" + "google.golang.org/grpc/status" +) + +// StatusErrEqual returns true iff both err1 and err2 wrap status.Status errors +// and their underlying status protos are equal. +func StatusErrEqual(err1, err2 error) bool { + status1, ok := status.FromError(err1) + if !ok { + return false + } + status2, ok := status.FromError(err2) + if !ok { + return false + } + return proto.Equal(status1.Proto(), status2.Proto()) +} diff --git a/internal/testutils/status_equal_test.go b/internal/testutils/status_equal_test.go new file mode 100644 index 000000000..b3b412ec9 --- /dev/null +++ b/internal/testutils/status_equal_test.go @@ -0,0 +1,57 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package testutils + +import ( + "testing" + + anypb "github.com/golang/protobuf/ptypes/any" + spb "google.golang.org/genproto/googleapis/rpc/status" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +var statusErr = status.ErrorProto(&spb.Status{ + Code: int32(codes.DataLoss), + Message: "error for testing", + Details: []*anypb.Any{{ + TypeUrl: "url", + Value: []byte{6, 0, 0, 6, 1, 3}, + }}, +}) + +func TestStatusErrEqual(t *testing.T) { + tests := []struct { + name string + err1 error + err2 error + wantEqual bool + }{ + {"nil errors", nil, nil, true}, + {"equal OK status", status.New(codes.OK, "").Err(), status.New(codes.OK, "").Err(), true}, + {"equal status errors", statusErr, statusErr, true}, + {"different status errors", statusErr, status.New(codes.OK, "").Err(), false}, + } + + for _, test := range tests { + if gotEqual := StatusErrEqual(test.err1, test.err2); gotEqual != test.wantEqual { + t.Errorf("%v: StatusErrEqual(%v, %v) = %v, want %v", test.name, test.err1, test.err2, gotEqual, test.wantEqual) + } + } +} diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 435092e5c..0756a6b52 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -35,9 +35,11 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" + spb "google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/grpcrand" "google.golang.org/grpc/keepalive" @@ -55,6 +57,9 @@ var ( // ErrHeaderListSizeLimitViolation indicates that the header list size is larger // than the limit set by peer. ErrHeaderListSizeLimitViolation = errors.New("transport: trying to send header list size larger than the limit set by peer") + // statusRawProto is a function to get to the raw status proto wrapped in a + // status.Status without a proto.Clone(). + statusRawProto = internal.StatusRawProto.(func(*status.Status) *spb.Status) ) // http2Server implements the ServerTransport interface with HTTP2. @@ -817,7 +822,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status", Value: strconv.Itoa(int(st.Code()))}) headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())}) - if p := st.Proto(); p != nil && len(p.Details) > 0 { + if p := statusRawProto(st); p != nil && len(p.Details) > 0 { stBytes, err := proto.Marshal(p) if err != nil { // TODO: return error instead, when callers are able to handle it. diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index ba4174d3c..8f58b393f 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -27,7 +27,6 @@ import ( "io" "math" "net" - "reflect" "runtime" "strconv" "strings" @@ -40,6 +39,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/internal/leakcheck" "google.golang.org/grpc/internal/syscall" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/status" ) @@ -1690,7 +1690,7 @@ func TestEncodingRequiredStatus(t *testing.T) { if _, err := s.trReader.(*transportReader).Read(p); err != io.EOF { t.Fatalf("Read got error %v, want %v", err, io.EOF) } - if !reflect.DeepEqual(s.Status(), encodingTestStatus) { + if !testutils.StatusErrEqual(s.Status().Err(), encodingTestStatus.Err()) { t.Fatalf("stream with status %v, want %v", s.Status(), encodingTestStatus) } ct.Close() diff --git a/rpc_util_test.go b/rpc_util_test.go index ee641f51a..2449c2381 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -30,6 +30,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/encoding" protoenc "google.golang.org/grpc/encoding/proto" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/internal/transport" "google.golang.org/grpc/status" perfpb "google.golang.org/grpc/test/codec_perf" @@ -182,10 +183,10 @@ func (s) TestToRPCErr(t *testing.T) { } { err := toRPCErr(test.errIn) if _, ok := status.FromError(err); !ok { - t.Fatalf("toRPCErr{%v} returned type %T, want %T", test.errIn, err, status.Error(codes.Unknown, "")) + t.Errorf("toRPCErr{%v} returned type %T, want %T", test.errIn, err, status.Error) } - if !reflect.DeepEqual(err, test.errOut) { - t.Fatalf("toRPCErr{%v} = %v \nwant %v", test.errIn, err, test.errOut) + if !testutils.StatusErrEqual(err, test.errOut) { + t.Errorf("toRPCErr{%v} = %v \nwant %v", test.errIn, err, test.errOut) } } } diff --git a/status/status.go b/status/status.go index ed36681bb..641c45c6f 100644 --- a/status/status.go +++ b/status/status.go @@ -36,8 +36,15 @@ import ( "github.com/golang/protobuf/ptypes" spb "google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal" ) +func init() { + internal.StatusRawProto = statusRawProto +} + +func statusRawProto(s *Status) *spb.Status { return s.s } + // statusError is an alias of a status proto. It implements error and Status, // and a nil statusError should never be returned by this package. type statusError spb.Status diff --git a/status/status_test.go b/status/status_test.go index 011cb0aea..b7db2d3d8 100644 --- a/status/status_test.go +++ b/status/status_test.go @@ -35,11 +35,25 @@ import ( "google.golang.org/grpc/codes" ) +// errEqual is essentially a copy of testutils.StatusErrEqual(), to avoid a +// cyclic dependency. +func errEqual(err1, err2 error) bool { + status1, ok := FromError(err1) + if !ok { + return false + } + status2, ok := FromError(err2) + if !ok { + return false + } + return proto.Equal(status1.Proto(), status2.Proto()) +} + func TestErrorsWithSameParameters(t *testing.T) { const description = "some description" e1 := Errorf(codes.AlreadyExists, description) e2 := Errorf(codes.AlreadyExists, description) - if e1 == e2 || !reflect.DeepEqual(e1, e2) { + if e1 == e2 || !errEqual(e1, e2) { t.Fatalf("Errors should be equivalent but unique - e1: %v, %v e2: %p, %v", e1.(*statusError), e1, e2.(*statusError), e2) } } @@ -156,7 +170,7 @@ func TestFromErrorImplementsInterface(t *testing.T) { t.Fatalf("FromError(%v) = %v, %v; want , true", err, s, ok, code, message) } pd := s.Proto().GetDetails() - if len(pd) != 1 || !reflect.DeepEqual(pd[0], details[0]) { + if len(pd) != 1 || !proto.Equal(pd[0], details[0]) { t.Fatalf("s.Proto.GetDetails() = %v; want ", pd, details) } } diff --git a/test/balancer_test.go b/test/balancer_test.go index b01dd2b73..f877bfc00 100644 --- a/test/balancer_test.go +++ b/test/balancer_test.go @@ -30,6 +30,7 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/internal/balancerload" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/metadata" "google.golang.org/grpc/resolver" testpb "google.golang.org/grpc/test/grpc_testing" @@ -162,14 +163,14 @@ func testDoneInfo(t *testing.T, e env) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() wantErr := detailedError - if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); !reflect.DeepEqual(err, wantErr) { + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); !testutils.StatusErrEqual(err, wantErr) { t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, wantErr) } if _, err := tc.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil { t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, ", ctx, err) } - if len(b.doneInfo) < 1 || !reflect.DeepEqual(b.doneInfo[0].Err, wantErr) { + if len(b.doneInfo) < 1 || !testutils.StatusErrEqual(b.doneInfo[0].Err, wantErr) { t.Fatalf("b.doneInfo = %v; want b.doneInfo[0].Err = %v", b.doneInfo, wantErr) } if len(b.doneInfo) < 2 || !reflect.DeepEqual(b.doneInfo[1].Trailer, testTrailerMetadata) { diff --git a/test/end2end_test.go b/test/end2end_test.go index 70e60870d..a7f0fd470 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -2496,7 +2496,7 @@ func testHealthCheckOnFailure(t *testing.T, e env) { cc := te.clientConn() wantErr := status.Error(codes.DeadlineExceeded, "context deadline exceeded") - if _, err := healthCheck(0*time.Second, cc, "grpc.health.v1.Health"); !reflect.DeepEqual(err, wantErr) { + if _, err := healthCheck(0*time.Second, cc, "grpc.health.v1.Health"); !testutils.StatusErrEqual(err, wantErr) { t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %s", err, codes.DeadlineExceeded) } awaitNewConnLogOutput() @@ -2517,7 +2517,7 @@ func testHealthCheckOff(t *testing.T, e env) { te.startServer(&testServer{security: e.security}) defer te.tearDown() want := status.Error(codes.Unimplemented, "unknown service grpc.health.v1.Health") - if _, err := healthCheck(1*time.Second, te.clientConn(), ""); !reflect.DeepEqual(err, want) { + if _, err := healthCheck(1*time.Second, te.clientConn(), ""); !testutils.StatusErrEqual(err, want) { t.Fatalf("Health/Check(_, _) = _, %v, want _, %v", err, want) } } @@ -2791,7 +2791,7 @@ func testUnknownHandler(t *testing.T, e env, unknownHandler grpc.StreamHandler) te.startServer(&testServer{security: e.security}) defer te.tearDown() want := status.Error(codes.Unauthenticated, "user unauthenticated") - if _, err := healthCheck(1*time.Second, te.clientConn(), ""); !reflect.DeepEqual(err, want) { + if _, err := healthCheck(1*time.Second, te.clientConn(), ""); !testutils.StatusErrEqual(err, want) { t.Fatalf("Health/Check(_, _) = _, %v, want _, %v", err, want) } } @@ -2818,7 +2818,7 @@ func testHealthCheckServingStatus(t *testing.T, e env) { t.Fatalf("Got the serving status %v, want SERVING", out.Status) } wantErr := status.Error(codes.NotFound, "unknown service") - if _, err := healthCheck(1*time.Second, cc, "grpc.health.v1.Health"); !reflect.DeepEqual(err, wantErr) { + if _, err := healthCheck(1*time.Second, cc, "grpc.health.v1.Health"); !testutils.StatusErrEqual(err, wantErr) { t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %s", err, codes.NotFound) } hs.SetServingStatus("grpc.health.v1.Health", healthpb.HealthCheckResponse_SERVING) @@ -2886,7 +2886,7 @@ func testFailedEmptyUnary(t *testing.T, e env) { ctx := metadata.NewOutgoingContext(context.Background(), testMetadata) wantErr := detailedError - if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); !reflect.DeepEqual(err, wantErr) { + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); !testutils.StatusErrEqual(err, wantErr) { t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, wantErr) } }