diff --git a/call.go b/call.go index 504a6e18a..bf13b7d22 100644 --- a/call.go +++ b/call.go @@ -185,6 +185,6 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli if lastErr != nil { return toRPCErr(lastErr) } - return Errorf(stream.StatusCode(), stream.StatusDesc()) + return Errorf(stream.StatusCode(), "%s", stream.StatusDesc()) } } diff --git a/call_test.go b/call_test.go index 7e7f743b7..feeeb7eff 100644 --- a/call_test.go +++ b/call_test.go @@ -52,6 +52,7 @@ import ( var ( expectedRequest = "ping" expectedResponse = "pong" + weirdError = "format verbs: %v%s" sizeLargeErr = 1024 * 1024 ) @@ -95,6 +96,10 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) { t.Errorf("Failed to unmarshal the received message: %v", err) return } + if v == "weird error" { + h.t.WriteStatus(s, codes.Internal, weirdError) + return + } if v != expectedRequest { h.t.WriteStatus(s, codes.Internal, strings.Repeat("A", sizeLargeErr)) return @@ -223,3 +228,19 @@ func TestInvokeLargeErr(t *testing.T) { cc.Close() server.stop() } + +// TestInvokeErrorSpecialChars checks that error messages don't get mangled. +func TestInvokeErrorSpecialChars(t *testing.T) { + server, cc := setUp(t, 0, math.MaxUint32) + var reply string + req := "weird error" + err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc) + if _, ok := err.(rpcError); !ok { + t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.") + } + if got, want := ErrorDesc(err), weirdError; got != want { + t.Fatalf("grpc.Invoke(_, _, _, _, _) error = %q, want %q", got, want) + } + cc.Close() + server.stop() +} diff --git a/stream.go b/stream.go index dba7f6c42..16d30b435 100644 --- a/stream.go +++ b/stream.go @@ -257,7 +257,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { cs.finish(err) return nil } - return Errorf(cs.s.StatusCode(), cs.s.StatusDesc()) + return Errorf(cs.s.StatusCode(), "%s", cs.s.StatusDesc()) } return toRPCErr(err) } @@ -269,7 +269,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { // Returns io.EOF to indicate the end of the stream. return } - return Errorf(cs.s.StatusCode(), cs.s.StatusDesc()) + return Errorf(cs.s.StatusCode(), "%s", cs.s.StatusDesc()) } return toRPCErr(err) }