mirror of https://github.com/grpc/grpc-go.git
client side, use user context, and change sent time
This commit is contained in:
parent
1896a21fb3
commit
aa5b5c7e2a
13
call.go
13
call.go
|
@ -50,7 +50,8 @@ import (
|
|||
// On error, it returns the error and indicates whether the call should be retried.
|
||||
//
|
||||
// TODO(zhaoq): Check whether the received message sequence is valid.
|
||||
func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) (err error) {
|
||||
// TODO ctx is userCtx, not stream.Context. It is used for stats handling. Change this later if necessary.
|
||||
func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) (err error) {
|
||||
// Try to acquire header metadata from the server if there is any.
|
||||
defer func() {
|
||||
if err != nil {
|
||||
|
@ -81,7 +82,7 @@ func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, s
|
|||
if inPayload != nil && err == io.EOF && stream.StatusCode() == codes.OK {
|
||||
// TODO in the current implementation, inTrailer may be handled before inStats in some cases.
|
||||
// Fix the order if necessary.
|
||||
stats.Handle(stream.Context(), inPayload)
|
||||
stats.Handle(ctx, inPayload)
|
||||
}
|
||||
c.trailerMD = stream.Trailer()
|
||||
return nil
|
||||
|
@ -117,10 +118,12 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
|
|||
if err != nil {
|
||||
return nil, Errorf(codes.Internal, "grpc: %v", err)
|
||||
}
|
||||
err = t.Write(stream, outBuf, opts)
|
||||
if outPayload != nil {
|
||||
outPayload.SentTime = time.Now()
|
||||
stats.Handle(stream.Context(), outPayload)
|
||||
}
|
||||
err = t.Write(stream, outBuf, opts)
|
||||
if outPayload != nil {
|
||||
stats.Handle(ctx, outPayload)
|
||||
}
|
||||
// t.NewStream(...) could lead to an early rejection of the RPC (e.g., the service/method
|
||||
// does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following
|
||||
|
@ -247,7 +250,7 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
|
|||
}
|
||||
return toRPCErr(err)
|
||||
}
|
||||
err = recvResponse(cc.dopts, t, &c, stream, reply)
|
||||
err = recvResponse(ctx, cc.dopts, t, &c, stream, reply)
|
||||
if err != nil {
|
||||
if put != nil {
|
||||
put()
|
||||
|
|
17
server.go
17
server.go
|
@ -552,16 +552,16 @@ func (s *Server) removeConn(c io.Closer) {
|
|||
|
||||
func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options) error {
|
||||
var (
|
||||
cbuf *bytes.Buffer
|
||||
outStats *stats.OutPayload
|
||||
cbuf *bytes.Buffer
|
||||
outPayload *stats.OutPayload
|
||||
)
|
||||
if cp != nil {
|
||||
cbuf = new(bytes.Buffer)
|
||||
}
|
||||
if stats.On() {
|
||||
outStats = &stats.OutPayload{}
|
||||
outPayload = &stats.OutPayload{}
|
||||
}
|
||||
p, err := encode(s.opts.codec, msg, cp, cbuf, outStats)
|
||||
p, err := encode(s.opts.codec, msg, cp, cbuf, outPayload)
|
||||
if err != nil {
|
||||
// This typically indicates a fatal issue (e.g., memory
|
||||
// corruption or hardware faults) the application program
|
||||
|
@ -572,11 +572,12 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str
|
|||
// the optimal option.
|
||||
grpclog.Fatalf("grpc: Server failed to encode response %v", err)
|
||||
}
|
||||
if outPayload != nil {
|
||||
outPayload.SentTime = time.Now()
|
||||
}
|
||||
err = t.Write(stream, p, opts)
|
||||
if outStats != nil {
|
||||
outStats.SentTime = time.Now()
|
||||
|
||||
stats.Handle(stream.Context(), outStats)
|
||||
if outPayload != nil {
|
||||
stats.Handle(stream.Context(), outPayload)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
|
42
stream.go
42
stream.go
|
@ -213,6 +213,8 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
|
|||
|
||||
tracing: EnableTracing,
|
||||
trInfo: trInfo,
|
||||
|
||||
userCtx: ctx,
|
||||
}
|
||||
if cc.dopts.cp != nil {
|
||||
cs.cbuf = new(bytes.Buffer)
|
||||
|
@ -265,6 +267,10 @@ type clientStream struct {
|
|||
// trInfo.tr is set when the clientStream is created (if EnableTracing is true),
|
||||
// and is set to nil when the clientStream's finish method is called.
|
||||
trInfo traceInfo
|
||||
|
||||
// Keep the user context for stats handling.
|
||||
// All stats handling should use the user context instead of the stream context.
|
||||
userCtx context.Context
|
||||
}
|
||||
|
||||
func (cs *clientStream) Context() context.Context {
|
||||
|
@ -280,7 +286,7 @@ func (cs *clientStream) Header() (_ metadata.MD, err error) {
|
|||
EndTime: time.Now(),
|
||||
Error: err,
|
||||
}
|
||||
stats.Handle(cs.s.Context(), end)
|
||||
stats.Handle(cs.userCtx, end)
|
||||
}
|
||||
}()
|
||||
m, err := cs.s.Header()
|
||||
|
@ -311,7 +317,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
|
|||
Client: true,
|
||||
Error: err,
|
||||
}
|
||||
stats.Handle(cs.s.Context(), end)
|
||||
stats.Handle(cs.userCtx, end)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
|
@ -336,13 +342,13 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
|
|||
}
|
||||
err = toRPCErr(err)
|
||||
}()
|
||||
var outStats *stats.OutPayload
|
||||
var outPayload *stats.OutPayload
|
||||
if stats.On() {
|
||||
outStats = &stats.OutPayload{
|
||||
outPayload = &stats.OutPayload{
|
||||
Client: true,
|
||||
}
|
||||
}
|
||||
out, err := encode(cs.codec, m, cs.cp, cs.cbuf, outStats)
|
||||
out, err := encode(cs.codec, m, cs.cp, cs.cbuf, outPayload)
|
||||
defer func() {
|
||||
if cs.cbuf != nil {
|
||||
cs.cbuf.Reset()
|
||||
|
@ -351,10 +357,12 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
|
|||
if err != nil {
|
||||
return Errorf(codes.Internal, "grpc: %v", err)
|
||||
}
|
||||
if outPayload != nil {
|
||||
outPayload.SentTime = time.Now()
|
||||
}
|
||||
err = cs.t.Write(cs.s, out, &transport.Options{Last: false})
|
||||
if outStats != nil {
|
||||
outStats.SentTime = time.Now()
|
||||
stats.Handle(cs.s.Context(), outStats)
|
||||
if outPayload != nil {
|
||||
stats.Handle(cs.userCtx, outPayload)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -371,7 +379,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
|
|||
EndTime: time.Now(),
|
||||
Error: e,
|
||||
}
|
||||
stats.Handle(cs.s.Context(), end)
|
||||
stats.Handle(cs.userCtx, end)
|
||||
}
|
||||
}()
|
||||
var inStats *stats.InPayload
|
||||
|
@ -396,7 +404,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
|
|||
cs.mu.Unlock()
|
||||
}
|
||||
if inStats != nil {
|
||||
stats.Handle(cs.s.Context(), inStats)
|
||||
stats.Handle(cs.userCtx, inStats)
|
||||
}
|
||||
if !cs.desc.ClientStreams || cs.desc.ServerStreams {
|
||||
return
|
||||
|
@ -557,11 +565,11 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
|
|||
ss.mu.Unlock()
|
||||
}
|
||||
}()
|
||||
var outStats *stats.OutPayload
|
||||
var outPayload *stats.OutPayload
|
||||
if stats.On() {
|
||||
outStats = &stats.OutPayload{}
|
||||
outPayload = &stats.OutPayload{}
|
||||
}
|
||||
out, err := encode(ss.codec, m, ss.cp, ss.cbuf, outStats)
|
||||
out, err := encode(ss.codec, m, ss.cp, ss.cbuf, outPayload)
|
||||
defer func() {
|
||||
if ss.cbuf != nil {
|
||||
ss.cbuf.Reset()
|
||||
|
@ -571,12 +579,14 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
|
|||
err = Errorf(codes.Internal, "grpc: %v", err)
|
||||
return err
|
||||
}
|
||||
if outPayload != nil {
|
||||
outPayload.SentTime = time.Now()
|
||||
}
|
||||
if err := ss.t.Write(ss.s, out, &transport.Options{Last: false}); err != nil {
|
||||
return toRPCErr(err)
|
||||
}
|
||||
if outStats != nil {
|
||||
outStats.SentTime = time.Now()
|
||||
stats.Handle(ss.s.Context(), outStats)
|
||||
if outPayload != nil {
|
||||
stats.Handle(ss.s.Context(), outPayload)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -277,6 +277,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|||
if t.authInfo != nil {
|
||||
pr.AuthInfo = t.authInfo
|
||||
}
|
||||
userCtx := ctx
|
||||
ctx = peer.NewContext(ctx, pr)
|
||||
authData := make(map[string]string)
|
||||
for _, c := range t.creds {
|
||||
|
@ -348,6 +349,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|||
return nil, ErrConnClosing
|
||||
}
|
||||
s := t.newStream(ctx, callHdr)
|
||||
s.userCtx = userCtx
|
||||
t.activeStreams[s.id] = s
|
||||
|
||||
// This stream is not counted when applySetings(...) initialize t.streamsQuota.
|
||||
|
@ -459,7 +461,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|||
Encryption: callHdr.SendCompress,
|
||||
FailFast: callHdr.FailFast,
|
||||
}
|
||||
stats.Handle(s.Context(), outHeader)
|
||||
stats.Handle(s.userCtx, outHeader)
|
||||
}
|
||||
t.writableChan <- 0
|
||||
return s, nil
|
||||
|
@ -896,13 +898,13 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
|
|||
Client: true,
|
||||
WireLength: int(frame.Header().Length),
|
||||
}
|
||||
stats.Handle(s.ctx, inHeader)
|
||||
stats.Handle(s.userCtx, inHeader)
|
||||
} else {
|
||||
inTrailer := &stats.InTrailer{
|
||||
Client: true,
|
||||
WireLength: int(frame.Header().Length),
|
||||
}
|
||||
stats.Handle(s.ctx, inTrailer)
|
||||
stats.Handle(s.userCtx, inTrailer)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
|
|
@ -168,6 +168,9 @@ type Stream struct {
|
|||
id uint32
|
||||
// nil for client side Stream.
|
||||
st ServerTransport
|
||||
// Keep the user context for stats handling.
|
||||
// All stats handling should use the user context instead of the stream context.
|
||||
userCtx context.Context
|
||||
// ctx is the associated context of the stream.
|
||||
ctx context.Context
|
||||
// cancel is always nil for client side Stream.
|
||||
|
|
Loading…
Reference in New Issue