Remove the 2nd RecvProto in the generated code for client streaming

This commit is contained in:
iamqizhao 2015-02-18 22:15:13 -08:00
parent 43c0bbeb1f
commit 634392a1c6
8 changed files with 96 additions and 76 deletions

View File

@ -38,9 +38,9 @@ import (
"sync" "sync"
"time" "time"
"golang.org/x/net/context"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/transport" "google.golang.org/grpc/transport"
"golang.org/x/net/context"
) )
type dialOptions struct { type dialOptions struct {

View File

@ -151,7 +151,8 @@ void PrintClientMethodImpl(google::protobuf::io::Printer* printer,
const google::protobuf::MethodDescriptor* method, const google::protobuf::MethodDescriptor* method,
map<string, string>* vars, map<string, string>* vars,
const set<string>& imports, const set<string>& imports,
const map<string, string>& import_alias) { const map<string, string>& import_alias,
int* stream_ind) {
(*vars)["Method"] = method->name(); (*vars)["Method"] = method->name();
(*vars)["Request"] = (*vars)["Request"] =
GetFullMessageQualifiedName(method->input_type(), imports, import_alias); GetFullMessageQualifiedName(method->input_type(), imports, import_alias);
@ -171,12 +172,15 @@ void PrintClientMethodImpl(google::protobuf::io::Printer* printer,
printer->Print("\t}\n"); printer->Print("\t}\n");
printer->Print("\treturn out, nil\n"); printer->Print("\treturn out, nil\n");
printer->Print("}\n\n"); printer->Print("}\n\n");
} else if (BidiStreaming(method)) { return;
}
(*vars)["StreamInd"] = std::to_string(*stream_ind);
if (BidiStreaming(method)) {
printer->Print( printer->Print(
*vars, *vars,
"func (c *$ServiceStruct$Client) $Method$(ctx context.Context, opts " "func (c *$ServiceStruct$Client) $Method$(ctx context.Context, opts "
"...grpc.CallOption) ($Service$_$Method$Client, error) {\n" "...grpc.CallOption) ($Service$_$Method$Client, error) {\n"
"\tstream, err := grpc.NewClientStream(ctx, c.cc, " "\tstream, err := grpc.NewClientStream(ctx, &_$Service$_serviceDesc.Streams[$StreamInd$], c.cc, "
"\"/$Package$$Service$/$Method$\", opts...)\n" "\"/$Package$$Service$/$Method$\", opts...)\n"
"\tif err != nil {\n" "\tif err != nil {\n"
"\t\treturn nil, err\n" "\t\treturn nil, err\n"
@ -214,7 +218,7 @@ void PrintClientMethodImpl(google::protobuf::io::Printer* printer,
"func (c *$ServiceStruct$Client) $Method$(ctx context.Context, m " "func (c *$ServiceStruct$Client) $Method$(ctx context.Context, m "
"*$Request$, " "*$Request$, "
"opts ...grpc.CallOption) ($Service$_$Method$Client, error) {\n" "opts ...grpc.CallOption) ($Service$_$Method$Client, error) {\n"
"\tstream, err := grpc.NewClientStream(ctx, c.cc, " "\tstream, err := grpc.NewClientStream(ctx, &_$Service$_serviceDesc.Streams[$StreamInd$], c.cc, "
"\"/$Package$$Service$/$Method$\", opts...)\n" "\"/$Package$$Service$/$Method$\", opts...)\n"
"\tif err != nil {\n" "\tif err != nil {\n"
"\t\treturn nil, err\n" "\t\treturn nil, err\n"
@ -252,7 +256,7 @@ void PrintClientMethodImpl(google::protobuf::io::Printer* printer,
*vars, *vars,
"func (c *$ServiceStruct$Client) $Method$(ctx context.Context, opts " "func (c *$ServiceStruct$Client) $Method$(ctx context.Context, opts "
"...grpc.CallOption) ($Service$_$Method$Client, error) {\n" "...grpc.CallOption) ($Service$_$Method$Client, error) {\n"
"\tstream, err := grpc.NewClientStream(ctx, c.cc, " "\tstream, err := grpc.NewClientStream(ctx, &_$Service$_serviceDesc.Streams[$StreamInd$], c.cc, "
"\"/$Package$$Service$/$Method$\", opts...)\n" "\"/$Package$$Service$/$Method$\", opts...)\n"
"\tif err != nil {\n" "\tif err != nil {\n"
"\t\treturn nil, err\n" "\t\treturn nil, err\n"
@ -282,18 +286,13 @@ void PrintClientMethodImpl(google::protobuf::io::Printer* printer,
"\t\treturn nil, err\n" "\t\treturn nil, err\n"
"\t}\n" "\t}\n"
"\tm := new($Response$)\n" "\tm := new($Response$)\n"
"\tif err := x.ClientStream.RecvProto(m); err != nil {\n" "\tif err := x.ClientStream.RecvProto(m); err != io.EOF {\n"
"\t\treturn nil, err\n" "\t\treturn nil, err\n"
"\t}\n" "\t}\n"
"\t// Read EOF.\n" "\treturn m, nil\n"
"\tif err := x.ClientStream.RecvProto(m); err == io.EOF {\n"
"\t\treturn m, nil\n"
"\t}\n"
"\t// gRPC protocol violation.\n"
"\treturn m, fmt.Errorf(\"Violate gRPC client streaming protocol: no "
"EOF after the response.\")\n"
"}\n\n"); "}\n\n");
} }
(*stream_ind)++;
} }
void PrintClient(google::protobuf::io::Printer* printer, void PrintClient(google::protobuf::io::Printer* printer,
@ -318,8 +317,10 @@ void PrintClient(google::protobuf::io::Printer* printer,
"func New$Service$Client(cc *grpc.ClientConn) $Service$Client {\n" "func New$Service$Client(cc *grpc.ClientConn) $Service$Client {\n"
"\treturn &$ServiceStruct$Client{cc}\n" "\treturn &$ServiceStruct$Client{cc}\n"
"}\n\n"); "}\n\n");
int stream_ind = 0;
for (int i = 0; i < service->method_count(); ++i) { for (int i = 0; i < service->method_count(); ++i) {
PrintClientMethodImpl(printer, service->method(i), vars, imports, import_alias); PrintClientMethodImpl(
printer, service->method(i), vars, imports, import_alias, &stream_ind);
} }
} }
@ -489,6 +490,12 @@ void PrintServerStreamingMethodDesc(
printer->Print("\t\t{\n"); printer->Print("\t\t{\n");
printer->Print(*vars, "\t\t\tStreamName:\t\"$Method$\",\n"); printer->Print(*vars, "\t\t\tStreamName:\t\"$Method$\",\n");
printer->Print(*vars, "\t\t\tHandler:\t_$Service$_$Method$_Handler,\n"); printer->Print(*vars, "\t\t\tHandler:\t_$Service$_$Method$_Handler,\n");
if (method->client_streaming()) {
printer->Print(*vars, "\t\t\tClientStreams:\ttrue,\n");
}
if (method->server_streaming()) {
printer->Print(*vars, "\t\t\tServerStreams:\ttrue,\n");
}
printer->Print("\t\t},\n"); printer->Print("\t\t},\n");
} }
@ -505,7 +512,7 @@ void PrintServer(google::protobuf::io::Printer* printer,
printer->Print("}\n\n"); printer->Print("}\n\n");
printer->Print(*vars, printer->Print(*vars,
"func RegisterService(s *grpc.Server, srv $Service$Server) {\n" "func Register$Service$Server(s *grpc.Server, srv $Service$Server) {\n"
"\ts.RegisterService(&_$Service$_serviceDesc, srv)\n" "\ts.RegisterService(&_$Service$_serviceDesc, srv)\n"
"}\n\n"); "}\n\n");
@ -613,7 +620,6 @@ string GetServices(const google::protobuf::FileDescriptor* file,
printer.Print("import (\n"); printer.Print("import (\n");
if (HasClientOnlyStreaming(file)) { if (HasClientOnlyStreaming(file)) {
printer.Print( printer.Print(
"\t\"fmt\"\n"
"\t\"io\"\n"); "\t\"io\"\n");
} }
printer.Print( printer.Print(

View File

@ -59,9 +59,9 @@ import math "math"
import ( import (
errors "errors" errors "errors"
io "io"
context "golang.org/x/net/context" context "golang.org/x/net/context"
grpc "google.golang.org/grpc" grpc "google.golang.org/grpc"
io "io"
) )
// Reference imports to suppress errors if they are not otherwise used. // Reference imports to suppress errors if they are not otherwise used.
@ -430,7 +430,7 @@ func (c *testServiceClient) UnaryCall(ctx context.Context, in *SimpleRequest, op
} }
func (c *testServiceClient) StreamingOutputCall(ctx context.Context, in *StreamingOutputCallRequest, opts ...grpc.CallOption) (TestService_StreamingOutputCallClient, error) { func (c *testServiceClient) StreamingOutputCall(ctx context.Context, in *StreamingOutputCallRequest, opts ...grpc.CallOption) (TestService_StreamingOutputCallClient, error) {
stream, err := grpc.NewClientStream(ctx, c.cc, "/grpc.testing.TestService/StreamingOutputCall", opts...) stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[0], c.cc, "/grpc.testing.TestService/StreamingOutputCall", opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -462,7 +462,7 @@ func (x *testServiceStreamingOutputCallClient) Recv() (*StreamingOutputCallRespo
} }
func (c *testServiceClient) StreamingInputCall(ctx context.Context, opts ...grpc.CallOption) (TestService_StreamingInputCallClient, error) { func (c *testServiceClient) StreamingInputCall(ctx context.Context, opts ...grpc.CallOption) (TestService_StreamingInputCallClient, error) {
stream, err := grpc.NewClientStream(ctx, c.cc, "/grpc.testing.TestService/StreamingInputCall", opts...) stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[1], c.cc, "/grpc.testing.TestService/StreamingInputCall", opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -489,20 +489,14 @@ func (x *testServiceStreamingInputCallClient) CloseAndRecv() (*StreamingInputCal
return nil, err return nil, err
} }
m := new(StreamingInputCallResponse) m := new(StreamingInputCallResponse)
if err := x.ClientStream.RecvProto(m); err != nil { if err := x.ClientStream.RecvProto(m); err != io.EOF {
return nil, err return nil, err
} }
// Read EOF.
dummy := new(StreamingInputCallResponse)
if err := x.ClientStream.RecvProto(dummy); err != io.EOF {
// gRPC protocol violation.
return nil, errors.New("gRPC client streaming protocol violation: no EOF after final response")
}
return m, nil return m, nil
} }
func (c *testServiceClient) FullDuplexCall(ctx context.Context, opts ...grpc.CallOption) (TestService_FullDuplexCallClient, error) { func (c *testServiceClient) FullDuplexCall(ctx context.Context, opts ...grpc.CallOption) (TestService_FullDuplexCallClient, error) {
stream, err := grpc.NewClientStream(ctx, c.cc, "/grpc.testing.TestService/FullDuplexCall", opts...) stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[2], c.cc, "/grpc.testing.TestService/FullDuplexCall", opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -533,7 +527,7 @@ func (x *testServiceFullDuplexCallClient) Recv() (*StreamingOutputCallResponse,
} }
func (c *testServiceClient) HalfDuplexCall(ctx context.Context, opts ...grpc.CallOption) (TestService_HalfDuplexCallClient, error) { func (c *testServiceClient) HalfDuplexCall(ctx context.Context, opts ...grpc.CallOption) (TestService_HalfDuplexCallClient, error) {
stream, err := grpc.NewClientStream(ctx, c.cc, "/grpc.testing.TestService/HalfDuplexCall", opts...) stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[3], c.cc, "/grpc.testing.TestService/HalfDuplexCall", opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -732,18 +726,24 @@ var _TestService_serviceDesc = grpc.ServiceDesc{
{ {
StreamName: "StreamingOutputCall", StreamName: "StreamingOutputCall",
Handler: _TestService_StreamingOutputCall_Handler, Handler: _TestService_StreamingOutputCall_Handler,
ServerStreams: true,
}, },
{ {
StreamName: "StreamingInputCall", StreamName: "StreamingInputCall",
Handler: _TestService_StreamingInputCall_Handler, Handler: _TestService_StreamingInputCall_Handler,
ClientStreams: true,
}, },
{ {
StreamName: "FullDuplexCall", StreamName: "FullDuplexCall",
Handler: _TestService_FullDuplexCall_Handler, Handler: _TestService_FullDuplexCall_Handler,
ClientStreams: true,
ServerStreams: true,
}, },
{ {
StreamName: "HalfDuplexCall", StreamName: "HalfDuplexCall",
Handler: _TestService_HalfDuplexCall_Handler, Handler: _TestService_HalfDuplexCall_Handler,
ClientStreams: true,
ServerStreams: true,
}, },
}, },
} }

View File

@ -43,10 +43,10 @@ import (
"time" "time"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"golang.org/x/net/context"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/transport" "google.golang.org/grpc/transport"
"golang.org/x/net/context"
) )
// CallOption configures a Call before it starts or extracts information from // CallOption configures a Call before it starts or extracts information from

View File

@ -42,9 +42,9 @@ import (
"time" "time"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"golang.org/x/net/context"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/transport" "google.golang.org/grpc/transport"
"golang.org/x/net/context"
) )
func TestSimpleParsing(t *testing.T) { func TestSimpleParsing(t *testing.T) {

View File

@ -43,10 +43,10 @@ import (
"sync" "sync"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"golang.org/x/net/context"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/transport" "google.golang.org/grpc/transport"
"golang.org/x/net/context"
) )
type methodHandler func(srv interface{}, ctx context.Context, buf []byte) (proto.Message, error) type methodHandler func(srv interface{}, ctx context.Context, buf []byte) (proto.Message, error)
@ -57,14 +57,6 @@ type MethodDesc struct {
Handler methodHandler Handler methodHandler
} }
type streamHandler func(srv interface{}, stream ServerStream) error
// StreamDesc represents a streaming RPC service's method specification.
type StreamDesc struct {
StreamName string
Handler streamHandler
}
// ServiceDesc represents an RPC service's specification. // ServiceDesc represents an RPC service's specification.
type ServiceDesc struct { type ServiceDesc struct {
ServiceName string ServiceName string

View File

@ -34,6 +34,7 @@
package grpc package grpc
import ( import (
"fmt"
"io" "io"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
@ -43,6 +44,18 @@ import (
"google.golang.org/grpc/transport" "google.golang.org/grpc/transport"
) )
type streamHandler func(srv interface{}, stream ServerStream) error
// StreamDesc represents a streaming RPC service's method specification.
type StreamDesc struct {
StreamName string
Handler streamHandler
// At least one of these is true.
ServerStreams bool
ClientStreams bool
}
// Stream defines the common interface a client or server stream has to satisfy. // Stream defines the common interface a client or server stream has to satisfy.
type Stream interface { type Stream interface {
// Context returns the context for this stream. // Context returns the context for this stream.
@ -80,7 +93,7 @@ type ClientStream interface {
// NewClientStream creates a new Stream for the client side. This is called // NewClientStream creates a new Stream for the client side. This is called
// by generated code. // by generated code.
func NewClientStream(ctx context.Context, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) { func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) {
// TODO(zhaoq): CallOption is omitted. Add support when it is needed. // TODO(zhaoq): CallOption is omitted. Add support when it is needed.
callHdr := &transport.CallHdr{ callHdr := &transport.CallHdr{
Host: cc.target, Host: cc.target,
@ -98,6 +111,7 @@ func NewClientStream(ctx context.Context, cc *ClientConn, method string, opts ..
t: t, t: t,
s: s, s: s,
p: &parser{s: s}, p: &parser{s: s},
desc: desc,
}, nil }, nil
} }
@ -106,6 +120,7 @@ type clientStream struct {
t transport.ClientTransport t transport.ClientTransport
s *transport.Stream s *transport.Stream
p *parser p *parser
desc *StreamDesc
} }
func (cs *clientStream) Context() context.Context { func (cs *clientStream) Context() context.Context {
@ -146,8 +161,15 @@ func (cs *clientStream) SendProto(m proto.Message) (err error) {
func (cs *clientStream) RecvProto(m proto.Message) (err error) { func (cs *clientStream) RecvProto(m proto.Message) (err error) {
err = recvProto(cs.p, m) err = recvProto(cs.p, m)
if err == nil { if err == nil {
if !cs.desc.ClientStreams || cs.desc.ServerStreams {
return return
} }
// Special handling for client streaming rpc.
if err = recvProto(cs.p, m); err != io.EOF {
cs.t.CloseStream(cs.s, err)
return fmt.Errorf("gRPC client streaming protocol violation: %v, want <EOF>", err)
}
}
if _, ok := err.(transport.ConnectionError); !ok { if _, ok := err.(transport.ConnectionError); !ok {
cs.t.CloseStream(cs.s, err) cs.t.CloseStream(cs.s, err)
} }

View File

@ -59,9 +59,9 @@ import math "math"
import ( import (
errors "errors" errors "errors"
io "io"
context "golang.org/x/net/context" context "golang.org/x/net/context"
grpc "google.golang.org/grpc" grpc "google.golang.org/grpc"
io "io"
) )
// Reference imports to suppress errors if they are not otherwise used. // Reference imports to suppress errors if they are not otherwise used.
@ -430,7 +430,7 @@ func (c *testServiceClient) UnaryCall(ctx context.Context, in *SimpleRequest, op
} }
func (c *testServiceClient) StreamingOutputCall(ctx context.Context, in *StreamingOutputCallRequest, opts ...grpc.CallOption) (TestService_StreamingOutputCallClient, error) { func (c *testServiceClient) StreamingOutputCall(ctx context.Context, in *StreamingOutputCallRequest, opts ...grpc.CallOption) (TestService_StreamingOutputCallClient, error) {
stream, err := grpc.NewClientStream(ctx, c.cc, "/grpc.testing.TestService/StreamingOutputCall", opts...) stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[0], c.cc, "/grpc.testing.TestService/StreamingOutputCall", opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -462,7 +462,7 @@ func (x *testServiceStreamingOutputCallClient) Recv() (*StreamingOutputCallRespo
} }
func (c *testServiceClient) StreamingInputCall(ctx context.Context, opts ...grpc.CallOption) (TestService_StreamingInputCallClient, error) { func (c *testServiceClient) StreamingInputCall(ctx context.Context, opts ...grpc.CallOption) (TestService_StreamingInputCallClient, error) {
stream, err := grpc.NewClientStream(ctx, c.cc, "/grpc.testing.TestService/StreamingInputCall", opts...) stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[1], c.cc, "/grpc.testing.TestService/StreamingInputCall", opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -489,20 +489,14 @@ func (x *testServiceStreamingInputCallClient) CloseAndRecv() (*StreamingInputCal
return nil, err return nil, err
} }
m := new(StreamingInputCallResponse) m := new(StreamingInputCallResponse)
if err := x.ClientStream.RecvProto(m); err != nil { if err := x.ClientStream.RecvProto(m); err != io.EOF {
return nil, err return nil, err
} }
// Read EOF.
dummy := new(StreamingInputCallResponse)
if err := x.ClientStream.RecvProto(dummy); err != io.EOF {
// gRPC protocol violation.
return nil, errors.New("gRPC client streaming protocol violation: no EOF after final response")
}
return m, nil return m, nil
} }
func (c *testServiceClient) FullDuplexCall(ctx context.Context, opts ...grpc.CallOption) (TestService_FullDuplexCallClient, error) { func (c *testServiceClient) FullDuplexCall(ctx context.Context, opts ...grpc.CallOption) (TestService_FullDuplexCallClient, error) {
stream, err := grpc.NewClientStream(ctx, c.cc, "/grpc.testing.TestService/FullDuplexCall", opts...) stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[2], c.cc, "/grpc.testing.TestService/FullDuplexCall", opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -533,7 +527,7 @@ func (x *testServiceFullDuplexCallClient) Recv() (*StreamingOutputCallResponse,
} }
func (c *testServiceClient) HalfDuplexCall(ctx context.Context, opts ...grpc.CallOption) (TestService_HalfDuplexCallClient, error) { func (c *testServiceClient) HalfDuplexCall(ctx context.Context, opts ...grpc.CallOption) (TestService_HalfDuplexCallClient, error) {
stream, err := grpc.NewClientStream(ctx, c.cc, "/grpc.testing.TestService/HalfDuplexCall", opts...) stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[3], c.cc, "/grpc.testing.TestService/HalfDuplexCall", opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -732,18 +726,24 @@ var _TestService_serviceDesc = grpc.ServiceDesc{
{ {
StreamName: "StreamingOutputCall", StreamName: "StreamingOutputCall",
Handler: _TestService_StreamingOutputCall_Handler, Handler: _TestService_StreamingOutputCall_Handler,
ServerStreams: true,
}, },
{ {
StreamName: "StreamingInputCall", StreamName: "StreamingInputCall",
Handler: _TestService_StreamingInputCall_Handler, Handler: _TestService_StreamingInputCall_Handler,
ClientStreams: true,
}, },
{ {
StreamName: "FullDuplexCall", StreamName: "FullDuplexCall",
Handler: _TestService_FullDuplexCall_Handler, Handler: _TestService_FullDuplexCall_Handler,
ClientStreams: true,
ServerStreams: true,
}, },
{ {
StreamName: "HalfDuplexCall", StreamName: "HalfDuplexCall",
Handler: _TestService_HalfDuplexCall_Handler, Handler: _TestService_HalfDuplexCall_Handler,
ClientStreams: true,
ServerStreams: true,
}, },
}, },
} }