mirror of https://github.com/grpc/grpc-go.git
server: fix ChainUnaryInterceptor and ChainStreamInterceptor to allow retrying handlers (#5666)
This commit is contained in:
parent
e0a9f1112a
commit
adfb9155e4
50
server.go
50
server.go
|
@ -1150,21 +1150,16 @@ func chainUnaryServerInterceptors(s *Server) {
|
|||
|
||||
func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) {
|
||||
// the struct ensures the variables are allocated together, rather than separately, since we
|
||||
// know they should be garbage collected together. This saves 1 allocation and decreases
|
||||
// time/call by about 10% on the microbenchmark.
|
||||
var state struct {
|
||||
i int
|
||||
next UnaryHandler
|
||||
}
|
||||
state.next = func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
if state.i == len(interceptors)-1 {
|
||||
return interceptors[state.i](ctx, req, info, handler)
|
||||
}
|
||||
state.i++
|
||||
return interceptors[state.i-1](ctx, req, info, state.next)
|
||||
}
|
||||
return state.next(ctx, req)
|
||||
return interceptors[0](ctx, req, info, getChainUnaryHandler(interceptors, 0, info, handler))
|
||||
}
|
||||
}
|
||||
|
||||
func getChainUnaryHandler(interceptors []UnaryServerInterceptor, curr int, info *UnaryServerInfo, finalHandler UnaryHandler) UnaryHandler {
|
||||
if curr == len(interceptors)-1 {
|
||||
return finalHandler
|
||||
}
|
||||
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return interceptors[curr+1](ctx, req, info, getChainUnaryHandler(interceptors, curr+1, info, finalHandler))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1470,21 +1465,16 @@ func chainStreamServerInterceptors(s *Server) {
|
|||
|
||||
func chainStreamInterceptors(interceptors []StreamServerInterceptor) StreamServerInterceptor {
|
||||
return func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error {
|
||||
// the struct ensures the variables are allocated together, rather than separately, since we
|
||||
// know they should be garbage collected together. This saves 1 allocation and decreases
|
||||
// time/call by about 10% on the microbenchmark.
|
||||
var state struct {
|
||||
i int
|
||||
next StreamHandler
|
||||
}
|
||||
state.next = func(srv interface{}, ss ServerStream) error {
|
||||
if state.i == len(interceptors)-1 {
|
||||
return interceptors[state.i](srv, ss, info, handler)
|
||||
}
|
||||
state.i++
|
||||
return interceptors[state.i-1](srv, ss, info, state.next)
|
||||
}
|
||||
return state.next(srv, ss)
|
||||
return interceptors[0](srv, ss, info, getChainStreamHandler(interceptors, 0, info, handler))
|
||||
}
|
||||
}
|
||||
|
||||
func getChainStreamHandler(interceptors []StreamServerInterceptor, curr int, info *StreamServerInfo, finalHandler StreamHandler) StreamHandler {
|
||||
if curr == len(interceptors)-1 {
|
||||
return finalHandler
|
||||
}
|
||||
return func(srv interface{}, stream ServerStream) error {
|
||||
return interceptors[curr+1](srv, stream, info, getChainStreamHandler(interceptors, curr+1, info, finalHandler))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"google.golang.org/grpc/internal/transport"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
@ -130,6 +131,34 @@ func (s) TestGetServiceInfo(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s) TestRetryChainedInterceptor(t *testing.T) {
|
||||
var records []int
|
||||
i1 := func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error) {
|
||||
records = append(records, 1)
|
||||
// call handler twice to simulate a retry here.
|
||||
handler(ctx, req)
|
||||
return handler(ctx, req)
|
||||
}
|
||||
i2 := func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error) {
|
||||
records = append(records, 2)
|
||||
return handler(ctx, req)
|
||||
}
|
||||
i3 := func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error) {
|
||||
records = append(records, 3)
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
ii := chainUnaryInterceptors([]UnaryServerInterceptor{i1, i2, i3})
|
||||
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
}
|
||||
ii(context.Background(), nil, nil, handler)
|
||||
if !cmp.Equal(records, []int{1, 2, 3, 2, 3}) {
|
||||
t.Fatalf("retry failed on chained interceptors: %v", records)
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestStreamContext(t *testing.T) {
|
||||
expectedStream := &transport.Stream{}
|
||||
ctx := NewContextWithServerTransportStream(context.Background(), expectedStream)
|
||||
|
|
Loading…
Reference in New Issue