mirror of https://github.com/grpc/grpc-go.git
transport: allow InTapHandle to return status errors (#4365)
This commit is contained in:
parent
aff517ba8a
commit
328b1d171a
|
|
@ -20,13 +20,17 @@ package transport
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
"golang.org/x/net/http2/hpack"
|
"golang.org/x/net/http2/hpack"
|
||||||
|
"google.golang.org/grpc/internal/grpcutil"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
var updateHeaderTblSize = func(e *hpack.Encoder, v uint32) {
|
var updateHeaderTblSize = func(e *hpack.Encoder, v uint32) {
|
||||||
|
|
@ -128,6 +132,14 @@ type cleanupStream struct {
|
||||||
|
|
||||||
func (c *cleanupStream) isTransportResponseFrame() bool { return c.rst } // Results in a RST_STREAM
|
func (c *cleanupStream) isTransportResponseFrame() bool { return c.rst } // Results in a RST_STREAM
|
||||||
|
|
||||||
|
type earlyAbortStream struct {
|
||||||
|
streamID uint32
|
||||||
|
contentSubtype string
|
||||||
|
status *status.Status
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*earlyAbortStream) isTransportResponseFrame() bool { return false }
|
||||||
|
|
||||||
type dataFrame struct {
|
type dataFrame struct {
|
||||||
streamID uint32
|
streamID uint32
|
||||||
endStream bool
|
endStream bool
|
||||||
|
|
@ -749,6 +761,24 @@ func (l *loopyWriter) cleanupStreamHandler(c *cleanupStream) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *loopyWriter) earlyAbortStreamHandler(eas *earlyAbortStream) error {
|
||||||
|
if l.side == clientSide {
|
||||||
|
return errors.New("earlyAbortStream not handled on client")
|
||||||
|
}
|
||||||
|
|
||||||
|
headerFields := []hpack.HeaderField{
|
||||||
|
{Name: ":status", Value: "200"},
|
||||||
|
{Name: "content-type", Value: grpcutil.ContentType(eas.contentSubtype)},
|
||||||
|
{Name: "grpc-status", Value: strconv.Itoa(int(eas.status.Code()))},
|
||||||
|
{Name: "grpc-message", Value: encodeGrpcMessage(eas.status.Message())},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := l.writeHeader(eas.streamID, true, headerFields, nil); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (l *loopyWriter) incomingGoAwayHandler(*incomingGoAway) error {
|
func (l *loopyWriter) incomingGoAwayHandler(*incomingGoAway) error {
|
||||||
if l.side == clientSide {
|
if l.side == clientSide {
|
||||||
l.draining = true
|
l.draining = true
|
||||||
|
|
@ -787,6 +817,8 @@ func (l *loopyWriter) handle(i interface{}) error {
|
||||||
return l.registerStreamHandler(i)
|
return l.registerStreamHandler(i)
|
||||||
case *cleanupStream:
|
case *cleanupStream:
|
||||||
return l.cleanupStreamHandler(i)
|
return l.cleanupStreamHandler(i)
|
||||||
|
case *earlyAbortStream:
|
||||||
|
return l.earlyAbortStreamHandler(i)
|
||||||
case *incomingGoAway:
|
case *incomingGoAway:
|
||||||
return l.incomingGoAwayHandler(i)
|
return l.incomingGoAwayHandler(i)
|
||||||
case *dataFrame:
|
case *dataFrame:
|
||||||
|
|
|
||||||
|
|
@ -356,26 +356,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
||||||
if state.data.statsTrace != nil {
|
if state.data.statsTrace != nil {
|
||||||
s.ctx = stats.SetIncomingTrace(s.ctx, state.data.statsTrace)
|
s.ctx = stats.SetIncomingTrace(s.ctx, state.data.statsTrace)
|
||||||
}
|
}
|
||||||
if t.inTapHandle != nil {
|
|
||||||
var err error
|
|
||||||
info := &tap.Info{
|
|
||||||
FullMethodName: state.data.method,
|
|
||||||
}
|
|
||||||
s.ctx, err = t.inTapHandle(s.ctx, info)
|
|
||||||
if err != nil {
|
|
||||||
if logger.V(logLevel) {
|
|
||||||
logger.Warningf("transport: http2Server.operateHeaders got an error from InTapHandle: %v", err)
|
|
||||||
}
|
|
||||||
t.controlBuf.put(&cleanupStream{
|
|
||||||
streamID: s.id,
|
|
||||||
rst: true,
|
|
||||||
rstCode: http2.ErrCodeRefusedStream,
|
|
||||||
onWrite: func() {},
|
|
||||||
})
|
|
||||||
s.cancel()
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
if t.state != reachable {
|
if t.state != reachable {
|
||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
|
|
@ -417,6 +397,25 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
||||||
s.cancel()
|
s.cancel()
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
if t.inTapHandle != nil {
|
||||||
|
var err error
|
||||||
|
if s.ctx, err = t.inTapHandle(s.ctx, &tap.Info{FullMethodName: state.data.method}); err != nil {
|
||||||
|
t.mu.Unlock()
|
||||||
|
if logger.V(logLevel) {
|
||||||
|
logger.Infof("transport: http2Server.operateHeaders got an error from InTapHandle: %v", err)
|
||||||
|
}
|
||||||
|
stat, ok := status.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
stat = status.New(codes.PermissionDenied, err.Error())
|
||||||
|
}
|
||||||
|
t.controlBuf.put(&earlyAbortStream{
|
||||||
|
streamID: s.id,
|
||||||
|
contentSubtype: s.contentSubtype,
|
||||||
|
status: stat,
|
||||||
|
})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
t.activeStreams[streamID] = s
|
t.activeStreams[streamID] = s
|
||||||
if len(t.activeStreams) == 1 {
|
if len(t.activeStreams) == 1 {
|
||||||
t.idle = time.Time{}
|
t.idle = time.Time{}
|
||||||
|
|
|
||||||
|
|
@ -418,6 +418,11 @@ func ChainStreamInterceptor(interceptors ...StreamServerInterceptor) ServerOptio
|
||||||
|
|
||||||
// InTapHandle returns a ServerOption that sets the tap handle for all the server
|
// InTapHandle returns a ServerOption that sets the tap handle for all the server
|
||||||
// transport to be created. Only one can be installed.
|
// transport to be created. Only one can be installed.
|
||||||
|
//
|
||||||
|
// Experimental
|
||||||
|
//
|
||||||
|
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
|
||||||
|
// later release.
|
||||||
func InTapHandle(h tap.ServerInHandle) ServerOption {
|
func InTapHandle(h tap.ServerInHandle) ServerOption {
|
||||||
return newFuncServerOption(func(o *serverOptions) {
|
return newFuncServerOption(func(o *serverOptions) {
|
||||||
if o.inTapHandle != nil {
|
if o.inTapHandle != nil {
|
||||||
|
|
|
||||||
16
tap/tap.go
16
tap/tap.go
|
|
@ -37,16 +37,16 @@ type Info struct {
|
||||||
// TODO: More to be added.
|
// TODO: More to be added.
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerInHandle defines the function which runs before a new stream is created
|
// ServerInHandle defines the function which runs before a new stream is
|
||||||
// on the server side. If it returns a non-nil error, the stream will not be
|
// created on the server side. If it returns a non-nil error, the stream will
|
||||||
// created and a RST_STREAM will be sent back to the client with REFUSED_STREAM.
|
// not be created and an error will be returned to the client. If the error
|
||||||
// The client will receive an RPC error "code = Unavailable, desc = stream
|
// returned is a status error, that status code and message will be used,
|
||||||
// terminated by RST_STREAM with error code: REFUSED_STREAM".
|
// otherwise PermissionDenied will be the code and err.Error() will be the
|
||||||
|
// message.
|
||||||
//
|
//
|
||||||
// It's intended to be used in situations where you don't want to waste the
|
// It's intended to be used in situations where you don't want to waste the
|
||||||
// resources to accept the new stream (e.g. rate-limiting). And the content of
|
// resources to accept the new stream (e.g. rate-limiting). For other general
|
||||||
// the error will be ignored and won't be sent back to the client. For other
|
// usages, please use interceptors.
|
||||||
// general usages, please use interceptors.
|
|
||||||
//
|
//
|
||||||
// Note that it is executed in the per-connection I/O goroutine(s) instead of
|
// Note that it is executed in the per-connection I/O goroutine(s) instead of
|
||||||
// per-RPC goroutine. Therefore, users should NOT have any
|
// per-RPC goroutine. Therefore, users should NOT have any
|
||||||
|
|
|
||||||
|
|
@ -2507,10 +2507,13 @@ type myTap struct {
|
||||||
|
|
||||||
func (t *myTap) handle(ctx context.Context, info *tap.Info) (context.Context, error) {
|
func (t *myTap) handle(ctx context.Context, info *tap.Info) (context.Context, error) {
|
||||||
if info != nil {
|
if info != nil {
|
||||||
if info.FullMethodName == "/grpc.testing.TestService/EmptyCall" {
|
switch info.FullMethodName {
|
||||||
|
case "/grpc.testing.TestService/EmptyCall":
|
||||||
t.cnt++
|
t.cnt++
|
||||||
} else if info.FullMethodName == "/grpc.testing.TestService/UnaryCall" {
|
case "/grpc.testing.TestService/UnaryCall":
|
||||||
return nil, fmt.Errorf("tap error")
|
return nil, fmt.Errorf("tap error")
|
||||||
|
case "/grpc.testing.TestService/FullDuplexCall":
|
||||||
|
return nil, status.Errorf(codes.FailedPrecondition, "test custom error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ctx, nil
|
return ctx, nil
|
||||||
|
|
@ -2550,8 +2553,15 @@ func testTap(t *testing.T, e env) {
|
||||||
ResponseSize: 45,
|
ResponseSize: 45,
|
||||||
Payload: payload,
|
Payload: payload,
|
||||||
}
|
}
|
||||||
if _, err := tc.UnaryCall(ctx, req); status.Code(err) != codes.Unavailable {
|
if _, err := tc.UnaryCall(ctx, req); status.Code(err) != codes.PermissionDenied {
|
||||||
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, %s", err, codes.Unavailable)
|
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, %s", err, codes.PermissionDenied)
|
||||||
|
}
|
||||||
|
str, err := tc.FullDuplexCall(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unexpected error creating stream: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := str.Recv(); status.Code(err) != codes.FailedPrecondition {
|
||||||
|
t.Fatalf("FullDuplexCall Recv() = _, %v, want _, %s", err, codes.FailedPrecondition)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -3639,66 +3649,77 @@ func testMalformedHTTP2Metadata(t *testing.T, e env) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tests that the client transparently retries correctly when receiving a
|
||||||
|
// RST_STREAM with code REFUSED_STREAM.
|
||||||
func (s) TestTransparentRetry(t *testing.T) {
|
func (s) TestTransparentRetry(t *testing.T) {
|
||||||
for _, e := range listTestEnv() {
|
|
||||||
if e.name == "handler-tls" {
|
|
||||||
// Fails with RST_STREAM / FLOW_CONTROL_ERROR
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
testTransparentRetry(t, e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// This test makes sure RPCs are retried times when they receive a RST_STREAM
|
|
||||||
// with the REFUSED_STREAM error code, which the InTapHandle provokes.
|
|
||||||
func testTransparentRetry(t *testing.T, e env) {
|
|
||||||
te := newTest(t, e)
|
|
||||||
attempts := 0
|
|
||||||
successAttempt := 2
|
|
||||||
te.tapHandle = func(ctx context.Context, _ *tap.Info) (context.Context, error) {
|
|
||||||
attempts++
|
|
||||||
if attempts < successAttempt {
|
|
||||||
return nil, errors.New("not now")
|
|
||||||
}
|
|
||||||
return ctx, nil
|
|
||||||
}
|
|
||||||
te.startServer(&testServer{security: e.security})
|
|
||||||
defer te.tearDown()
|
|
||||||
|
|
||||||
cc := te.clientConn()
|
|
||||||
tsc := testpb.NewTestServiceClient(cc)
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
successAttempt int
|
|
||||||
failFast bool
|
failFast bool
|
||||||
errCode codes.Code
|
errCode codes.Code
|
||||||
}{{
|
}{{
|
||||||
successAttempt: 1,
|
// success attempt: 1, (stream ID 1)
|
||||||
}, {
|
}, {
|
||||||
successAttempt: 2,
|
// success attempt: 2, (stream IDs 3, 5)
|
||||||
}, {
|
}, {
|
||||||
successAttempt: 3,
|
// no success attempt (stream IDs 7, 9)
|
||||||
errCode: codes.Unavailable,
|
errCode: codes.Unavailable,
|
||||||
}, {
|
}, {
|
||||||
successAttempt: 1,
|
// success attempt: 1 (stream ID 11),
|
||||||
failFast: true,
|
failFast: true,
|
||||||
}, {
|
}, {
|
||||||
successAttempt: 2,
|
// success attempt: 2 (stream IDs 13, 15),
|
||||||
failFast: true,
|
failFast: true,
|
||||||
}, {
|
}, {
|
||||||
successAttempt: 3,
|
// no success attempt (stream IDs 17, 19)
|
||||||
failFast: true,
|
failFast: true,
|
||||||
errCode: codes.Unavailable,
|
errCode: codes.Unavailable,
|
||||||
}}
|
}}
|
||||||
for _, tc := range testCases {
|
|
||||||
attempts = 0
|
|
||||||
successAttempt = tc.successAttempt
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
lis, err := net.Listen("tcp", "localhost:0")
|
||||||
_, err := tsc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(!tc.failFast))
|
if err != nil {
|
||||||
cancel()
|
t.Fatalf("Failed to listen. Err: %v", err)
|
||||||
if status.Code(err) != tc.errCode {
|
|
||||||
t.Errorf("%+v: tsc.EmptyCall(_, _) = _, %v, want _, Code=%v", tc, err, tc.errCode)
|
|
||||||
}
|
}
|
||||||
|
defer lis.Close()
|
||||||
|
server := &httpServer{
|
||||||
|
headerFields: [][]string{{
|
||||||
|
":status", "200",
|
||||||
|
"content-type", "application/grpc",
|
||||||
|
"grpc-status", "0",
|
||||||
|
}},
|
||||||
|
refuseStream: func(i uint32) bool {
|
||||||
|
switch i {
|
||||||
|
case 1, 5, 11, 15: // these stream IDs succeed
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true // these are refused
|
||||||
|
},
|
||||||
|
}
|
||||||
|
server.start(t, lis)
|
||||||
|
cc, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to dial due to err: %v", err)
|
||||||
|
}
|
||||||
|
defer cc.Close()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
client := testpb.NewTestServiceClient(cc)
|
||||||
|
|
||||||
|
for i, tc := range testCases {
|
||||||
|
stream, err := client.FullDuplexCall(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error creating stream due to err: %v", err)
|
||||||
|
}
|
||||||
|
code := func(err error) codes.Code {
|
||||||
|
if err == io.EOF {
|
||||||
|
return codes.OK
|
||||||
|
}
|
||||||
|
return status.Code(err)
|
||||||
|
}
|
||||||
|
if _, err := stream.Recv(); code(err) != tc.errCode {
|
||||||
|
t.Fatalf("%v: stream.Recv() = _, %v, want error code: %v", i, err, tc.errCode)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -7191,6 +7212,7 @@ func (s) TestHTTPHeaderFrameErrorHandlingMoreThanTwoHeaders(t *testing.T) {
|
||||||
|
|
||||||
type httpServer struct {
|
type httpServer struct {
|
||||||
headerFields [][]string
|
headerFields [][]string
|
||||||
|
refuseStream func(uint32) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *httpServer) writeHeader(framer *http2.Framer, sid uint32, headerFields []string, endStream bool) error {
|
func (s *httpServer) writeHeader(framer *http2.Framer, sid uint32, headerFields []string, endStream bool) error {
|
||||||
|
|
@ -7238,17 +7260,25 @@ func (s *httpServer) start(t *testing.T, lis net.Listener) {
|
||||||
writer.Flush() // necessary since client is expecting preface before declaring connection fully setup.
|
writer.Flush() // necessary since client is expecting preface before declaring connection fully setup.
|
||||||
|
|
||||||
var sid uint32
|
var sid uint32
|
||||||
|
// Loop until conn is closed and framer returns io.EOF
|
||||||
|
for {
|
||||||
// Read frames until a header is received.
|
// Read frames until a header is received.
|
||||||
for {
|
for {
|
||||||
frame, err := framer.ReadFrame()
|
frame, err := framer.ReadFrame()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err != io.EOF {
|
||||||
t.Errorf("Error at server-side while reading frame. Err: %v", err)
|
t.Errorf("Error at server-side while reading frame. Err: %v", err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if hframe, ok := frame.(*http2.HeadersFrame); ok {
|
if hframe, ok := frame.(*http2.HeadersFrame); ok {
|
||||||
sid = hframe.Header().StreamID
|
sid = hframe.Header().StreamID
|
||||||
|
if s.refuseStream == nil || !s.refuseStream(sid) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
framer.WriteRSTStream(sid, http2.ErrCodeRefusedStream)
|
||||||
|
writer.Flush()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for i, headers := range s.headerFields {
|
for i, headers := range s.headerFields {
|
||||||
if err = s.writeHeader(framer, sid, headers, i == len(s.headerFields)-1); err != nil {
|
if err = s.writeHeader(framer, sid, headers, i == len(s.headerFields)-1); err != nil {
|
||||||
|
|
@ -7257,6 +7287,7 @@ func (s *httpServer) start(t *testing.T, lis net.Listener) {
|
||||||
}
|
}
|
||||||
writer.Flush()
|
writer.Flush()
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue