diff --git a/status/status.go b/status/status.go index 3a42dc6de..95420ba9b 100644 --- a/status/status.go +++ b/status/status.go @@ -46,7 +46,7 @@ func (se *statusError) Error() string { return fmt.Sprintf("rpc error: code = %s desc = %s", codes.Code(p.GetCode()), p.GetMessage()) } -func (se *statusError) status() *Status { +func (se *statusError) Status() *Status { return &Status{s: (*spb.Status)(se)} } @@ -120,14 +120,14 @@ func FromProto(s *spb.Status) *Status { } // FromError returns a Status representing err if it was produced from this -// package. Otherwise, ok is false and a Status is returned with codes.Unknown -// and the original error message. +// package or has a method Status() *Status. Otherwise, ok is false and a +// Status is returned with codes.Unknown and the original error message. func FromError(err error) (s *Status, ok bool) { if err == nil { return &Status{s: &spb.Status{Code: int32(codes.OK)}}, true } - if se, ok := err.(*statusError); ok { - return se.status(), true + if se, ok := err.(interface{ Status() *Status }); ok { + return se.Status(), true } return New(codes.Unknown, err.Error()), false } @@ -182,8 +182,8 @@ func Code(err error) codes.Code { if err == nil { return codes.OK } - if se, ok := err.(*statusError); ok { - return se.status().Code() + if se, ok := err.(interface{ Status() *Status }); ok { + return se.Status().Code() } return codes.Unknown } diff --git a/status/status_test.go b/status/status_test.go index 8b74c27d6..b9db5372e 100644 --- a/status/status_test.go +++ b/status/status_test.go @@ -24,6 +24,8 @@ import ( "reflect" "testing" + "github.com/golang/protobuf/ptypes/any" + "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" apb "github.com/golang/protobuf/ptypes/any" @@ -119,6 +121,47 @@ func TestFromErrorOK(t *testing.T) { } } +type customError struct { + Code codes.Code + Message string + Details []*any.Any +} + +func (c customError) Error() string { + return fmt.Sprintf("rpc error: code = %s desc = %s", c.Code, c.Message) +} + +func (c customError) Status() *Status { + return &Status{ + s: &spb.Status{ + Code: int32(c.Code), + Message: c.Message, + Details: c.Details, + }, + } +} + +func TestFromErrorImplementsInterface(t *testing.T) { + code, message := codes.Internal, "test description" + details := []*any.Any{{ + TypeUrl: "testUrl", + Value: []byte("testValue"), + }} + err := customError{ + Code: code, + Message: message, + Details: details, + } + s, ok := FromError(err) + if !ok || s.Code() != code || s.Message() != message || s.Err() == nil { + t.Fatalf("FromError(%v) = %v, %v; want , true", err, s, ok, code, message) + } + pd := s.Proto().GetDetails() + if len(pd) != 1 || !reflect.DeepEqual(pd[0], details[0]) { + t.Fatalf("s.Proto.GetDetails() = %v; want ", pd, details) + } +} + func TestFromErrorUnknownError(t *testing.T) { code, message := codes.Unknown, "unknown error" err := errors.New("unknown error")