mirror of https://github.com/grpc/grpc-go.git
grpc: Add a pointer of server to ctx passed into stats handler (#6750)
This commit is contained in:
parent
8190d883e0
commit
8cb98464e5
|
|
@ -73,6 +73,11 @@ var (
|
|||
// xDS-enabled server invokes this method on a grpc.Server when a particular
|
||||
// listener moves to "not-serving" mode.
|
||||
DrainServerTransports any // func(*grpc.Server, string)
|
||||
// IsRegisteredMethod returns whether the passed in method is registered as
|
||||
// a method on the server.
|
||||
IsRegisteredMethod any // func(*grpc.Server, string) bool
|
||||
// ServerFromContext returns the server from the context.
|
||||
ServerFromContext any // func(context.Context) *grpc.Server
|
||||
// AddGlobalServerOptions adds an array of ServerOption that will be
|
||||
// effective globally for newly created servers. The priority will be: 1.
|
||||
// user-provided; 2. this method; 3. default values.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,65 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2023 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package testutils
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc/stats"
|
||||
)
|
||||
|
||||
// StubStatsHandler is a stats handler that is easy to customize within
|
||||
// individual test cases. It is a stubbable implementation of
|
||||
// google.golang.org/grpc/stats.Handler for testing purposes.
|
||||
type StubStatsHandler struct {
|
||||
TagRPCF func(ctx context.Context, info *stats.RPCTagInfo) context.Context
|
||||
HandleRPCF func(ctx context.Context, info stats.RPCStats)
|
||||
TagConnF func(ctx context.Context, info *stats.ConnTagInfo) context.Context
|
||||
HandleConnF func(ctx context.Context, info stats.ConnStats)
|
||||
}
|
||||
|
||||
// TagRPC calls the StubStatsHandler's TagRPCF, if set.
|
||||
func (ssh *StubStatsHandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context {
|
||||
if ssh.TagRPCF != nil {
|
||||
return ssh.TagRPCF(ctx, info)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
// HandleRPC calls the StubStatsHandler's HandleRPCF, if set.
|
||||
func (ssh *StubStatsHandler) HandleRPC(ctx context.Context, rs stats.RPCStats) {
|
||||
if ssh.HandleRPCF != nil {
|
||||
ssh.HandleRPCF(ctx, rs)
|
||||
}
|
||||
}
|
||||
|
||||
// TagConn calls the StubStatsHandler's TagConnF, if set.
|
||||
func (ssh *StubStatsHandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context {
|
||||
if ssh.TagConnF != nil {
|
||||
return ssh.TagConnF(ctx, info)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
// HandleConn calls the StubStatsHandler's HandleConnF, if set.
|
||||
func (ssh *StubStatsHandler) HandleConn(ctx context.Context, cs stats.ConnStats) {
|
||||
if ssh.HandleConnF != nil {
|
||||
ssh.HandleConnF(ctx, cs)
|
||||
}
|
||||
}
|
||||
43
server.go
43
server.go
|
|
@ -70,6 +70,10 @@ func init() {
|
|||
internal.GetServerCredentials = func(srv *Server) credentials.TransportCredentials {
|
||||
return srv.opts.creds
|
||||
}
|
||||
internal.IsRegisteredMethod = func(srv *Server, method string) bool {
|
||||
return srv.isRegisteredMethod(method)
|
||||
}
|
||||
internal.ServerFromContext = serverFromContext
|
||||
internal.DrainServerTransports = func(srv *Server, addr string) {
|
||||
srv.drainServerTransports(addr)
|
||||
}
|
||||
|
|
@ -1707,6 +1711,7 @@ func (s *Server) processStreamingRPC(ctx context.Context, t transport.ServerTran
|
|||
|
||||
func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream) {
|
||||
ctx := stream.Context()
|
||||
ctx = contextWithServer(ctx, s)
|
||||
var ti *traceInfo
|
||||
if EnableTracing {
|
||||
tr := trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method())
|
||||
|
|
@ -1953,6 +1958,44 @@ func (s *Server) getCodec(contentSubtype string) baseCodec {
|
|||
return codec
|
||||
}
|
||||
|
||||
type serverKey struct{}
|
||||
|
||||
// serverFromContext gets the Server from the context.
|
||||
func serverFromContext(ctx context.Context) *Server {
|
||||
s, _ := ctx.Value(serverKey{}).(*Server)
|
||||
return s
|
||||
}
|
||||
|
||||
// contextWithServer sets the Server in the context.
|
||||
func contextWithServer(ctx context.Context, server *Server) context.Context {
|
||||
return context.WithValue(ctx, serverKey{}, server)
|
||||
}
|
||||
|
||||
// isRegisteredMethod returns whether the passed in method is registered as a
|
||||
// method on the server. /service/method and service/method will match if the
|
||||
// service and method are registered on the server.
|
||||
func (s *Server) isRegisteredMethod(serviceMethod string) bool {
|
||||
if serviceMethod != "" && serviceMethod[0] == '/' {
|
||||
serviceMethod = serviceMethod[1:]
|
||||
}
|
||||
pos := strings.LastIndex(serviceMethod, "/")
|
||||
if pos == -1 { // Invalid method name syntax.
|
||||
return false
|
||||
}
|
||||
service := serviceMethod[:pos]
|
||||
method := serviceMethod[pos+1:]
|
||||
srv, knownService := s.services[service]
|
||||
if knownService {
|
||||
if _, ok := srv.methods[method]; ok {
|
||||
return true
|
||||
}
|
||||
if _, ok := srv.streams[method]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// SetHeader sets the header metadata to be sent from the server to the client.
|
||||
// The context provided must be the context passed to the server's handler.
|
||||
//
|
||||
|
|
|
|||
|
|
@ -31,7 +31,10 @@ import (
|
|||
"github.com/golang/protobuf/proto"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/internal"
|
||||
"google.golang.org/grpc/internal/grpctest"
|
||||
"google.golang.org/grpc/internal/stubserver"
|
||||
"google.golang.org/grpc/internal/testutils"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/stats"
|
||||
"google.golang.org/grpc/status"
|
||||
|
|
@ -1457,3 +1460,61 @@ func (s) TestMultipleServerStatsHandler(t *testing.T) {
|
|||
t.Fatalf("h.gotConn: unexpected amount of ConnStats: %v != %v", len(h.gotConn), 4)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStatsHandlerCallsServerIsRegisteredMethod tests whether a stats handler
|
||||
// gets access to a Server on the server side, and thus the method that the
|
||||
// server owns which specifies whether a method is made or not. The test sets up
|
||||
// a server with a unary call and full duplex call configured, and makes an RPC.
|
||||
// Within the stats handler, asking the server whether unary or duplex method
|
||||
// names are registered should return true, and any other query should return
|
||||
// false.
|
||||
func (s) TestStatsHandlerCallsServerIsRegisteredMethod(t *testing.T) {
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
stubStatsHandler := &testutils.StubStatsHandler{
|
||||
TagRPCF: func(ctx context.Context, _ *stats.RPCTagInfo) context.Context {
|
||||
// OpenTelemetry instrumentation needs the passed in Server to determine if
|
||||
// methods are registered in different handle calls in to record metrics.
|
||||
// This tag RPC call context gets passed into every handle call, so can
|
||||
// assert once here, since it maps to all the handle RPC calls that come
|
||||
// after. These internal calls will be how the OpenTelemetry instrumentation
|
||||
// component accesses this server and the subsequent helper on the server.
|
||||
server := internal.ServerFromContext.(func(context.Context) *grpc.Server)(ctx)
|
||||
if server == nil {
|
||||
t.Errorf("stats handler received ctx has no server present")
|
||||
}
|
||||
isRegisteredMethod := internal.IsRegisteredMethod.(func(*grpc.Server, string) bool)
|
||||
// /s/m and s/m are valid.
|
||||
if !isRegisteredMethod(server, "/grpc.testing.TestService/UnaryCall") {
|
||||
t.Errorf("UnaryCall should be a registered method according to server")
|
||||
}
|
||||
if !isRegisteredMethod(server, "grpc.testing.TestService/FullDuplexCall") {
|
||||
t.Errorf("FullDuplexCall should be a registered method according to server")
|
||||
}
|
||||
if isRegisteredMethod(server, "/grpc.testing.TestService/DoesNotExistCall") {
|
||||
t.Errorf("DoesNotExistCall should not be a registered method according to server")
|
||||
}
|
||||
if isRegisteredMethod(server, "/unknownService/UnaryCall") {
|
||||
t.Errorf("/unknownService/UnaryCall should not be a registered method according to server")
|
||||
}
|
||||
wg.Done()
|
||||
return ctx
|
||||
},
|
||||
}
|
||||
ss := &stubserver.StubServer{
|
||||
UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
|
||||
return &testpb.SimpleResponse{}, nil
|
||||
},
|
||||
}
|
||||
if err := ss.Start([]grpc.ServerOption{grpc.StatsHandler(stubStatsHandler)}); err != nil {
|
||||
t.Fatalf("Error starting endpoint server: %v", err)
|
||||
}
|
||||
defer ss.Stop()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{Payload: &testpb.Payload{}}); err != nil {
|
||||
t.Fatalf("Unexpected error from UnaryCall: %v", err)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue