server: fix ChainUnaryInterceptor and ChainStreamInterceptor to allow retrying handlers (#5666)

This commit is contained in:
Yimin Chen 2022-11-22 12:58:04 -08:00 committed by GitHub
parent e0a9f1112a
commit adfb9155e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 49 additions and 30 deletions

View File

@ -1150,21 +1150,16 @@ func chainUnaryServerInterceptors(s *Server) {
func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) {
// the struct ensures the variables are allocated together, rather than separately, since we
// know they should be garbage collected together. This saves 1 allocation and decreases
// time/call by about 10% on the microbenchmark.
var state struct {
i int
next UnaryHandler
}
state.next = func(ctx context.Context, req interface{}) (interface{}, error) {
if state.i == len(interceptors)-1 {
return interceptors[state.i](ctx, req, info, handler)
}
state.i++
return interceptors[state.i-1](ctx, req, info, state.next)
}
return state.next(ctx, req)
return interceptors[0](ctx, req, info, getChainUnaryHandler(interceptors, 0, info, handler))
}
}
func getChainUnaryHandler(interceptors []UnaryServerInterceptor, curr int, info *UnaryServerInfo, finalHandler UnaryHandler) UnaryHandler {
if curr == len(interceptors)-1 {
return finalHandler
}
return func(ctx context.Context, req interface{}) (interface{}, error) {
return interceptors[curr+1](ctx, req, info, getChainUnaryHandler(interceptors, curr+1, info, finalHandler))
}
}
@ -1470,21 +1465,16 @@ func chainStreamServerInterceptors(s *Server) {
func chainStreamInterceptors(interceptors []StreamServerInterceptor) StreamServerInterceptor {
return func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error {
// the struct ensures the variables are allocated together, rather than separately, since we
// know they should be garbage collected together. This saves 1 allocation and decreases
// time/call by about 10% on the microbenchmark.
var state struct {
i int
next StreamHandler
}
state.next = func(srv interface{}, ss ServerStream) error {
if state.i == len(interceptors)-1 {
return interceptors[state.i](srv, ss, info, handler)
}
state.i++
return interceptors[state.i-1](srv, ss, info, state.next)
}
return state.next(srv, ss)
return interceptors[0](srv, ss, info, getChainStreamHandler(interceptors, 0, info, handler))
}
}
func getChainStreamHandler(interceptors []StreamServerInterceptor, curr int, info *StreamServerInfo, finalHandler StreamHandler) StreamHandler {
if curr == len(interceptors)-1 {
return finalHandler
}
return func(srv interface{}, stream ServerStream) error {
return interceptors[curr+1](srv, stream, info, getChainStreamHandler(interceptors, curr+1, info, finalHandler))
}
}

View File

@ -27,6 +27,7 @@ import (
"testing"
"time"
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/status"
)
@ -130,6 +131,34 @@ func (s) TestGetServiceInfo(t *testing.T) {
}
}
func (s) TestRetryChainedInterceptor(t *testing.T) {
var records []int
i1 := func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error) {
records = append(records, 1)
// call handler twice to simulate a retry here.
handler(ctx, req)
return handler(ctx, req)
}
i2 := func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error) {
records = append(records, 2)
return handler(ctx, req)
}
i3 := func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (resp interface{}, err error) {
records = append(records, 3)
return handler(ctx, req)
}
ii := chainUnaryInterceptors([]UnaryServerInterceptor{i1, i2, i3})
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
}
ii(context.Background(), nil, nil, handler)
if !cmp.Equal(records, []int{1, 2, 3, 2, 3}) {
t.Fatalf("retry failed on chained interceptors: %v", records)
}
}
func (s) TestStreamContext(t *testing.T) {
expectedStream := &transport.Stream{}
ctx := NewContextWithServerTransportStream(context.Background(), expectedStream)