Add server.GetServiceInfo().

To replace server.ServiceMetadata() and server.AllServiceNames().
This commit is contained in:
Menghan Li 2016-06-23 16:37:55 -07:00
parent 439f11e63d
commit 26d2db5487
4 changed files with 254 additions and 154 deletions

View File

@ -69,7 +69,8 @@ import (
type serverReflectionServer struct {
s *grpc.Server
// TODO add cache if necessary
// TODO add more cache if necessary
serviceInfo map[string]*grpc.ServiceInfo // cache for s.GetServiceInfo()
}
// Register registers the server reflection service on the given gRPC server.
@ -188,6 +189,46 @@ func (s *serverReflectionServer) fileDescEncodingByFilename(name string) ([]byte
return proto.Marshal(fd)
}
// serviceMetadataForSymbol finds the metadata for name in s.serviceInfo.
// name should be a service name or a method name.
func (s *serverReflectionServer) serviceMetadataForSymbol(name string) (interface{}, error) {
if s.serviceInfo == nil {
s.serviceInfo = s.s.GetServiceInfo()
}
// Check if it's a service name.
if info, ok := s.serviceInfo[name]; ok {
return info.Metadata, nil
}
// Check if it's a method name.
pos := strings.LastIndex(name, ".")
// Not a valid method name.
if pos == -1 {
return nil, fmt.Errorf("unknown symbol: %v", name)
}
info, ok := s.serviceInfo[name[:pos]]
// Substring before last "." is not a service name.
if !ok {
return nil, fmt.Errorf("unknown symbol: %v", name)
}
// Search for method in info.
var found bool
for _, m := range info.Methods {
if m == name[pos+1:] {
found = true
break
}
}
if !found {
return nil, fmt.Errorf("unknown symbol: %v", name)
}
return info.Metadata, nil
}
// fileDescEncodingContainingSymbol finds the file descriptor containing the given symbol,
// does marshalling on it and returns the marshalled result.
// The given symbol can be a type, a service or a method.
@ -201,28 +242,26 @@ func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) (
if err != nil {
return nil, err
}
} else {
// Check if it's a service name.
meta := s.s.ServiceMetadata(name, "")
// Check if it's a method name.
if meta == nil {
if pos := strings.LastIndex(name, "."); pos != -1 {
meta = s.s.ServiceMetadata(name[:pos], name[pos+1:])
}
} else { // Check if it's a service name or a method name.
meta, err := s.serviceMetadataForSymbol(name)
// Metadata not found.
if err != nil {
return nil, err
}
if meta != nil {
if enc, ok := meta.([]byte); ok {
fd, err = s.decodeFileDesc(enc)
if err != nil {
return nil, err
}
}
// Metadata not valid.
enc, ok := meta.([]byte)
if !ok {
return nil, fmt.Errorf("invalid file descriptor for symbol: %v")
}
fd, err = s.decodeFileDesc(enc)
if err != nil {
return nil, err
}
}
if fd == nil {
return nil, fmt.Errorf("unknown symbol: %v", name)
}
return proto.Marshal(fd)
}
@ -331,12 +370,14 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio
}
}
case *rpb.ServerReflectionRequest_ListServices:
services := s.s.AllServiceNames()
serviceResponses := make([]*rpb.ServiceResponse, len(services))
for i, s := range services {
serviceResponses[i] = &rpb.ServiceResponse{
Name: s,
}
if s.serviceInfo == nil {
s.serviceInfo = s.s.GetServiceInfo()
}
serviceResponses := make([]*rpb.ServiceResponse, 0, len(s.serviceInfo))
for n := range s.serviceInfo {
serviceResponses = append(serviceResponses, &rpb.ServiceResponse{
Name: n,
})
}
out.MessageResponse = &rpb.ServerReflectionResponse_ListServicesResponse{
ListServicesResponse: &rpb.ListServiceResponse{

View File

@ -92,7 +92,7 @@ func TestFileDescForType(t *testing.T) {
} {
fd, err := s.fileDescForType(test.st)
if err != nil || !reflect.DeepEqual(fd, test.wantFd) {
t.Fatalf("fileDescForType(%q) = %q, %v, want %q, <nil>", test.st, fd, err, test.wantFd)
t.Errorf("fileDescForType(%q) = %q, %v, want %q, <nil>", test.st, fd, err, test.wantFd)
}
}
}
@ -106,7 +106,7 @@ func TestTypeForName(t *testing.T) {
} {
r, err := s.typeForName(test.name)
if err != nil || r != test.want {
t.Fatalf("typeForName(%q) = %q, %v, want %q, <nil>", test.name, r, err, test.want)
t.Errorf("typeForName(%q) = %q, %v, want %q, <nil>", test.name, r, err, test.want)
}
}
}
@ -117,7 +117,7 @@ func TestTypeForNameNotFound(t *testing.T) {
} {
_, err := s.typeForName(test)
if err == nil {
t.Fatalf("typeForName(%q) = _, %v, want _, <non-nil>", test, err)
t.Errorf("typeForName(%q) = _, %v, want _, <non-nil>", test, err)
}
}
}
@ -132,7 +132,7 @@ func TestFileDescContainingExtension(t *testing.T) {
} {
fd, err := s.fileDescContainingExtension(test.st, test.extNum)
if err != nil || !reflect.DeepEqual(fd, test.want) {
t.Fatalf("fileDescContainingExtension(%q) = %q, %v, want %q, <nil>", test.st, fd, err, test.want)
t.Errorf("fileDescContainingExtension(%q) = %q, %v, want %q, <nil>", test.st, fd, err, test.want)
}
}
}
@ -154,7 +154,7 @@ func TestAllExtensionNumbersForType(t *testing.T) {
r, err := s.allExtensionNumbersForType(test.st)
sort.Sort(intArray(r))
if err != nil || !reflect.DeepEqual(r, test.want) {
t.Fatalf("allExtensionNumbersForType(%q) = %v, %v, want %v, <nil>", test.st, r, err, test.want)
t.Errorf("allExtensionNumbersForType(%q) = %v, %v, want %v, <nil>", test.st, r, err, test.want)
}
}
}
@ -194,9 +194,13 @@ func TestReflectionEnd2end(t *testing.T) {
stream, err := c.ServerReflectionInfo(context.Background())
testFileByFilename(t, stream)
testFileByFilenameError(t, stream)
testFileContainingSymbol(t, stream)
testFileContainingSymbolError(t, stream)
testFileContainingExtension(t, stream)
testFileContainingExtensionError(t, stream)
testAllExtensionNumbersOfType(t, stream)
testAllExtensionNumbersOfTypeError(t, stream)
testListServices(t, stream)
s.Stop()
@ -227,10 +231,37 @@ func testFileByFilename(t *testing.T, stream rpb.ServerReflection_ServerReflecti
switch r.MessageResponse.(type) {
case *rpb.ServerReflectionResponse_FileDescriptorResponse:
if !reflect.DeepEqual(r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want) {
t.Fatalf("FileByFilename\nreceived: %q,\nwant: %q", r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want)
t.Errorf("FileByFilename(%v)\nreceived: %q,\nwant: %q", test.filename, r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want)
}
default:
t.Fatalf("FileByFilename = %v, want type <ServerReflectionResponse_FileDescriptorResponse>", r.MessageResponse)
t.Errorf("FileByFilename(%v) = %v, want type <ServerReflectionResponse_FileDescriptorResponse>", test.filename, r.MessageResponse)
}
}
}
func testFileByFilenameError(t *testing.T, stream rpb.ServerReflection_ServerReflectionInfoClient) {
for _, test := range []string{
"test.poto",
"proo2.proto",
"proto2_et.proto",
} {
if err := stream.Send(&rpb.ServerReflectionRequest{
MessageRequest: &rpb.ServerReflectionRequest_FileByFilename{
FileByFilename: test,
},
}); err != nil {
t.Fatalf("failed to send request: %v", err)
}
r, err := stream.Recv()
if err != nil {
// io.EOF is not ok.
t.Fatalf("failed to recv response: %v", err)
}
switch r.MessageResponse.(type) {
case *rpb.ServerReflectionResponse_ErrorResponse:
default:
t.Errorf("FileByFilename(%v) = %v, want type <ServerReflectionResponse_ErrorResponse>", test, r.MessageResponse)
}
}
}
@ -261,10 +292,38 @@ func testFileContainingSymbol(t *testing.T, stream rpb.ServerReflection_ServerRe
switch r.MessageResponse.(type) {
case *rpb.ServerReflectionResponse_FileDescriptorResponse:
if !reflect.DeepEqual(r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want) {
t.Fatalf("FileContainingSymbol\nreceived: %q,\nwant: %q", r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want)
t.Errorf("FileContainingSymbol(%v)\nreceived: %q,\nwant: %q", test.symbol, r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want)
}
default:
t.Fatalf("FileContainingSymbol = %v, want type <ServerReflectionResponse_FileDescriptorResponse>", r.MessageResponse)
t.Errorf("FileContainingSymbol(%v) = %v, want type <ServerReflectionResponse_FileDescriptorResponse>", test.symbol, r.MessageResponse)
}
}
}
func testFileContainingSymbolError(t *testing.T, stream rpb.ServerReflection_ServerReflectionInfoClient) {
for _, test := range []string{
"grpc.testing.SerchService",
"grpc.testing.SearchService.SearchE",
"grpc.tesing.SearchResponse",
"gpc.testing.ToBeExtened",
} {
if err := stream.Send(&rpb.ServerReflectionRequest{
MessageRequest: &rpb.ServerReflectionRequest_FileContainingSymbol{
FileContainingSymbol: test,
},
}); err != nil {
t.Fatalf("failed to send request: %v", err)
}
r, err := stream.Recv()
if err != nil {
// io.EOF is not ok.
t.Fatalf("failed to recv response: %v", err)
}
switch r.MessageResponse.(type) {
case *rpb.ServerReflectionResponse_ErrorResponse:
default:
t.Errorf("FileContainingSymbol(%v) = %v, want type <ServerReflectionResponse_ErrorResponse>", test, r.MessageResponse)
}
}
}
@ -296,10 +355,42 @@ func testFileContainingExtension(t *testing.T, stream rpb.ServerReflection_Serve
switch r.MessageResponse.(type) {
case *rpb.ServerReflectionResponse_FileDescriptorResponse:
if !reflect.DeepEqual(r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want) {
t.Fatalf("FileContainingExtension\nreceived: %q,\nwant: %q", r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want)
t.Errorf("FileContainingExtension(%v, %v)\nreceived: %q,\nwant: %q", test.typeName, test.extNum, r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want)
}
default:
t.Fatalf("FileContainingExtension = %v, want type <ServerReflectionResponse_FileDescriptorResponse>", r.MessageResponse)
t.Errorf("FileContainingExtension(%v, %v) = %v, want type <ServerReflectionResponse_FileDescriptorResponse>", test.typeName, test.extNum, r.MessageResponse)
}
}
}
func testFileContainingExtensionError(t *testing.T, stream rpb.ServerReflection_ServerReflectionInfoClient) {
for _, test := range []struct {
typeName string
extNum int32
}{
{"grpc.testing.ToBExtened", 17},
{"grpc.testing.ToBeExtened", 15},
} {
if err := stream.Send(&rpb.ServerReflectionRequest{
MessageRequest: &rpb.ServerReflectionRequest_FileContainingExtension{
FileContainingExtension: &rpb.ExtensionRequest{
ContainingType: test.typeName,
ExtensionNumber: test.extNum,
},
},
}); err != nil {
t.Fatalf("failed to send request: %v", err)
}
r, err := stream.Recv()
if err != nil {
// io.EOF is not ok.
t.Fatalf("failed to recv response: %v", err)
}
switch r.MessageResponse.(type) {
case *rpb.ServerReflectionResponse_ErrorResponse:
default:
t.Errorf("FileContainingExtension(%v, %v) = %v, want type <ServerReflectionResponse_FileDescriptorResponse>", test.typeName, test.extNum, r.MessageResponse)
}
}
}
@ -330,10 +421,35 @@ func testAllExtensionNumbersOfType(t *testing.T, stream rpb.ServerReflection_Ser
sort.Sort(intArray(extNum))
if r.GetAllExtensionNumbersResponse().BaseTypeName != test.typeName ||
!reflect.DeepEqual(extNum, test.want) {
t.Fatalf("AllExtensionNumbersOfType\nreceived: %v,\nwant: {%q %v}", r.GetAllExtensionNumbersResponse(), test.typeName, test.want)
t.Errorf("AllExtensionNumbersOfType(%v)\nreceived: %v,\nwant: {%q %v}", r.GetAllExtensionNumbersResponse(), test.typeName, test.typeName, test.want)
}
default:
t.Fatalf("AllExtensionNumbersOfType = %v, want type <ServerReflectionResponse_AllExtensionNumbersResponse>", r.MessageResponse)
t.Errorf("AllExtensionNumbersOfType(%v) = %v, want type <ServerReflectionResponse_AllExtensionNumbersResponse>", test.typeName, r.MessageResponse)
}
}
}
func testAllExtensionNumbersOfTypeError(t *testing.T, stream rpb.ServerReflection_ServerReflectionInfoClient) {
for _, test := range []string{
"grpc.testing.ToBeExtenedE",
} {
if err := stream.Send(&rpb.ServerReflectionRequest{
MessageRequest: &rpb.ServerReflectionRequest_AllExtensionNumbersOfType{
AllExtensionNumbersOfType: test,
},
}); err != nil {
t.Fatalf("failed to send request: %v", err)
}
r, err := stream.Recv()
if err != nil {
// io.EOF is not ok.
t.Fatalf("failed to recv response: %v", err)
}
switch r.MessageResponse.(type) {
case *rpb.ServerReflectionResponse_ErrorResponse:
default:
t.Errorf("AllExtensionNumbersOfType(%v) = %v, want type <ServerReflectionResponse_ErrorResponse>", test, r.MessageResponse)
}
}
}
@ -356,7 +472,7 @@ func testListServices(t *testing.T, stream rpb.ServerReflection_ServerReflection
want := []string{"grpc.testing.SearchService", "grpc.reflection.v1alpha.ServerReflection"}
// Compare service names in response with want.
if len(services) != len(want) {
t.Fatalf("= %v, want service names: %v", services, want)
t.Errorf("= %v, want service names: %v", services, want)
}
m := make(map[string]int)
for _, e := range services {
@ -367,9 +483,9 @@ func testListServices(t *testing.T, stream rpb.ServerReflection_ServerReflection
m[e]--
continue
}
t.Fatalf("ListService\nreceived: %v,\nwant: %q", services, want)
t.Errorf("ListService\nreceived: %v,\nwant: %q", services, want)
}
default:
t.Fatalf("ListServices = %v, want type <ServerReflectionResponse_ListServicesResponse>", r.MessageResponse)
t.Errorf("ListServices = %v, want type <ServerReflectionResponse_ListServicesResponse>", r.MessageResponse)
}
}

View File

@ -245,32 +245,29 @@ func (s *Server) register(sd *ServiceDesc, ss interface{}) {
s.m[sd.ServiceName] = srv
}
// ServiceMetadata returns the metadata for a service or method.
// service should be the full service name with package, in the form of <package>.<service>.
// method should be the method name only.
// If only service is important, method should be an empty string.
func (s *Server) ServiceMetadata(service, method string) interface{} {
// Check if service is registered.
if srv, ok := s.m[service]; ok {
if method == "" {
return srv.meta
}
// Check if method is part of service.
if _, ok := srv.md[method]; ok {
return srv.meta
}
if _, ok := srv.sd[method]; ok {
return srv.meta
}
}
return nil
// ServiceInfo contains method names and metadata for a service.
type ServiceInfo struct {
Methods []string
Metadata interface{}
}
// AllServiceNames returns all the registered service names.
func (s *Server) AllServiceNames() []string {
ret := make([]string, 0, len(s.m))
for k := range s.m {
ret = append(ret, k)
// GetServiceInfo returns a map from service name to ServiceInfo.
// Service name includes the package name, in the form of <package>.<service>.
func (s *Server) GetServiceInfo() map[string]*ServiceInfo {
ret := make(map[string]*ServiceInfo)
for n, srv := range s.m {
methods := make([]string, 0, len(srv.md)+len(srv.sd))
for m := range srv.md {
methods = append(methods, m)
}
for m := range srv.sd {
methods = append(methods, m)
}
ret[n] = &ServiceInfo{
Methods: methods,
Metadata: srv.meta,
}
}
return ret
}

View File

@ -44,29 +44,6 @@ type emptyServiceServer interface{}
type testServer struct{}
var (
testSd = ServiceDesc{
ServiceName: "grpc.testing.EmptyService",
HandlerType: (*emptyServiceServer)(nil),
Methods: []MethodDesc{
{
MethodName: "EmptyCall",
Handler: nil,
},
},
Streams: []StreamDesc{
{
StreamName: "EmptyStream",
Handler: nil,
ServerStreams: true,
ClientStreams: true,
},
},
Metadata: testFd,
}
testFd = []byte{0, 1, 2, 3}
)
func TestStopBeforeServe(t *testing.T) {
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
@ -88,73 +65,42 @@ func TestStopBeforeServe(t *testing.T) {
}
}
func TestServiceMetadata(t *testing.T) {
server := NewServer()
server.RegisterService(&testSd, &testServer{})
for _, test := range []struct {
service string
method string
want []byte
}{
{"grpc.testing.EmptyService", "", testFd},
{"grpc.testing.EmptyService", "EmptyCall", testFd},
{"grpc.testing.EmptyService", "EmptyStream", testFd},
} {
meta := server.ServiceMetadata(test.service, test.method)
var (
fd []byte
ok bool
)
if fd, ok = meta.([]byte); !ok {
t.Errorf("ServiceMetadata(%q, %q) = %v, want %v", test.service, test.method, meta, test.want)
}
if !reflect.DeepEqual(fd, test.want) {
t.Errorf("ServiceMetadata(%q, %q) = %v, want %v", test.service, test.method, fd, test.want)
}
}
}
func TestServiceMetadataNotFound(t *testing.T) {
server := NewServer()
server.RegisterService(&testSd, &testServer{})
for _, test := range []struct {
service string
method string
}{
{"", "EmptyCall"},
{"grpc.EmptyService", ""},
{"grpc.EmptyService", "EmptyCall"},
{"grpc.testing.EmptyService", "EmptyCallWrong"},
{"grpc.testing.EmptyService", "EmptyStreamWrong"},
} {
meta := server.ServiceMetadata(test.service, test.method)
if meta != nil {
t.Errorf("ServiceMetadata(%q, %q) = %v, want <nil>", test.service, test.method, meta)
}
}
}
func TestAllServiceNames(t *testing.T) {
server := NewServer()
server.RegisterService(&testSd, &testServer{})
server.RegisterService(&ServiceDesc{
ServiceName: "another.EmptyService",
func TestGetServiceInfo(t *testing.T) {
testSd := ServiceDesc{
ServiceName: "grpc.testing.EmptyService",
HandlerType: (*emptyServiceServer)(nil),
}, &testServer{})
services := server.AllServiceNames()
want := []string{"grpc.testing.EmptyService", "another.EmptyService"}
// Compare string slices.
m := make(map[string]int)
for _, s := range services {
m[s]++
Methods: []MethodDesc{
{
MethodName: "EmptyCall",
Handler: nil,
},
},
Streams: []StreamDesc{
{
StreamName: "EmptyStream",
Handler: nil,
ServerStreams: true,
ClientStreams: true,
},
},
Metadata: []int{0, 2, 1, 3},
}
for _, s := range want {
if m[s] > 0 {
m[s]--
continue
}
t.Fatalf("AllServiceNames() = %q, want: %q", services, want)
server := NewServer()
server.RegisterService(&testSd, &testServer{})
info := server.GetServiceInfo()
want := map[string]*ServiceInfo{
"grpc.testing.EmptyService": &ServiceInfo{
Methods: []string{
"EmptyCall",
"EmptyStream",
},
Metadata: []int{0, 2, 1, 3},
},
}
if !reflect.DeepEqual(info, want) {
t.Errorf("GetServiceInfo() = %q, want %q", info, want)
}
}