diff --git a/reflection/README.md b/reflection/README.md index 519a0240c..04b6371af 100644 --- a/reflection/README.md +++ b/reflection/README.md @@ -4,15 +4,15 @@ Package reflection implements server reflection service. The service implemented is defined in: https://github.com/grpc/grpc/blob/master/src/proto/grpc/reflection/v1alpha/reflection.proto. -To install server reflection on a gRPC server: +To register server reflection on a gRPC server: ```go import "google.golang.org/grpc/reflection" s := grpc.NewServer() pb.RegisterYourOwnServer(s, &server{}) -// Install reflection service on gRPC server. -reflection.InstallOnServer(s) +// Register reflection service on gRPC server. +reflection.Register(s) s.Serve(lis) ``` diff --git a/reflection/serverreflection.go b/reflection/serverreflection.go index 02853a7e8..c2d613f80 100644 --- a/reflection/serverreflection.go +++ b/reflection/serverreflection.go @@ -37,14 +37,14 @@ Package reflection implements server reflection service. The service implemented is defined in: https://github.com/grpc/grpc/blob/master/src/proto/grpc/reflection/v1alpha/reflection.proto. -To install server reflection on a gRPC server: +To register server reflection on a gRPC server: import "google.golang.org/grpc/reflection" s := grpc.NewServer() pb.RegisterYourOwnServer(s, &server{}) - // Install reflection service on gRPC server. - reflection.InstallOnServer(s) + // Register reflection service on gRPC server. + reflection.Register(s) s.Serve(lis) @@ -71,14 +71,17 @@ type serverReflectionServer struct { // TODO add cache if necessary } -// InstallOnServer installs server reflection service on the given gRPC server. -func InstallOnServer(s *grpc.Server) { +// Register registers the server reflection service on the given gRPC server. +func Register(s *grpc.Server) { rpb.RegisterServerReflectionServer(s, &serverReflectionServer{ s: s, }) } -// protoMessage is the interface representing objects with function Descriptor(). +// protoMessage is used for type assertion on proto messages. +// Generated proto message implements function Descriptor(), but Descriptor() +// is not part of interface proto.Message. This interface is needed to +// call Descriptor(). type protoMessage interface { Descriptor() ([]byte, []int) } @@ -92,19 +95,15 @@ func (s *serverReflectionServer) fileDescForType(st reflect.Type) (*dpb.FileDesc } enc, _ := m.Descriptor() - fd, err := s.decodeFileDesc(enc) - if err != nil { - return nil, err - } - return fd, nil + return s.decodeFileDesc(enc) } // decodeFileDesc does decompression and unmarshalling on the given // file descriptor byte slice. func (s *serverReflectionServer) decodeFileDesc(enc []byte) (*dpb.FileDescriptorProto, error) { - raw := decompress(enc) - if raw == nil { - return nil, fmt.Errorf("failed to decompress enc") + raw, err := decompress(enc) + if err != nil { + return nil, fmt.Errorf("failed to decompress enc: %v", err) } fd := new(dpb.FileDescriptorProto) @@ -115,18 +114,16 @@ func (s *serverReflectionServer) decodeFileDesc(enc []byte) (*dpb.FileDescriptor } // decompress does gzip decompression. -func decompress(b []byte) []byte { +func decompress(b []byte) ([]byte, error) { r, err := gzip.NewReader(bytes.NewReader(b)) if err != nil { - fmt.Printf("bad gzipped descriptor: %v\n", err) - return nil + return nil, fmt.Errorf("bad gzipped descriptor: %v\n", err) } out, err := ioutil.ReadAll(r) if err != nil { - fmt.Printf("bad gzipped descriptor: %v\n", err) - return nil + return nil, fmt.Errorf("bad gzipped descriptor: %v\n", err) } - return out + return out, nil } func (s *serverReflectionServer) typeForName(name string) (reflect.Type, error) { @@ -159,11 +156,7 @@ func (s *serverReflectionServer) fileDescContainingExtension(st reflect.Type, ex extT := reflect.TypeOf(extDesc.ExtensionType).Elem() - fd, err := s.fileDescForType(extT) - if err != nil { - return nil, err - } - return fd, nil + return s.fileDescForType(extT) } func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]int32, error) { @@ -173,20 +166,16 @@ func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([] } exts := proto.RegisteredExtensions(m) - out := make([]int32, len(exts)) - i := 0 + out := make([]int32, 0, len(exts)) for id := range exts { - out[i] = id - i++ + out = append(out, id) } return out, nil } -// Following are helper functions for reflection service handler. - -// fileDescWireFormatByFilename finds the file descriptor for given filename, +// fileDescEncodingByFilename finds the file descriptor for given filename, // does marshalling on it and returns the marshalled result. -func (s *serverReflectionServer) fileDescWireFormatByFilename(name string) ([]byte, error) { +func (s *serverReflectionServer) fileDescEncodingByFilename(name string) ([]byte, error) { enc := proto.FileDescriptor(name) if enc == nil { return nil, fmt.Errorf("unknown file: %v", name) @@ -202,10 +191,10 @@ func (s *serverReflectionServer) fileDescWireFormatByFilename(name string) ([]by return b, nil } -// fileDescWireFormatContainingSymbol finds the file descriptor containing the given symbol, +// fileDescEncodingContainingSymbol finds the file descriptor containing the given symbol, // does marshalling on it and returns the marshalled result. // The given symbol can be a type, a service or a method. -func (s *serverReflectionServer) fileDescWireFormatContainingSymbol(name string) ([]byte, error) { +func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) ([]byte, error) { var ( fd *dpb.FileDescriptorProto ) @@ -239,9 +228,9 @@ func (s *serverReflectionServer) fileDescWireFormatContainingSymbol(name string) return nil, fmt.Errorf("unknown symbol: %v", name) } -// fileDescWireFormatContainingExtension finds the file descriptor containing given extension, +// fileDescEncodingContainingExtension finds the file descriptor containing given extension, // does marshalling on it and returns the marshalled result. -func (s *serverReflectionServer) fileDescWireFormatContainingExtension(typeName string, extNum int32) ([]byte, error) { +func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32) ([]byte, error) { st, err := s.typeForName(typeName) if err != nil { return nil, err @@ -288,7 +277,7 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio } switch req := in.MessageRequest.(type) { case *rpb.ServerReflectionRequest_FileByFilename: - b, err := s.fileDescWireFormatByFilename(req.FileByFilename) + b, err := s.fileDescEncodingByFilename(req.FileByFilename) if err != nil { out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ ErrorResponse: &rpb.ErrorResponse{ @@ -302,7 +291,7 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio } } case *rpb.ServerReflectionRequest_FileContainingSymbol: - b, err := s.fileDescWireFormatContainingSymbol(req.FileContainingSymbol) + b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol) if err != nil { out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ ErrorResponse: &rpb.ErrorResponse{ @@ -318,7 +307,7 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio case *rpb.ServerReflectionRequest_FileContainingExtension: typeName := req.FileContainingExtension.ContainingType extNum := req.FileContainingExtension.ExtensionNumber - b, err := s.fileDescWireFormatContainingExtension(typeName, extNum) + b, err := s.fileDescEncodingContainingExtension(typeName, extNum) if err != nil { out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ ErrorResponse: &rpb.ErrorResponse{ diff --git a/reflection/serverreflection_test.go b/reflection/serverreflection_test.go index b06c4d52a..7b66911f3 100644 --- a/reflection/serverreflection_test.go +++ b/reflection/serverreflection_test.go @@ -128,10 +128,6 @@ func TestAllExtensionNumbersForType(t *testing.T) { // Do end2end tests. -var ( - port = ":35764" -) - type server struct{} func (s *server) Search(ctx context.Context, in *pb.SearchRequest) (*pb.SearchResponse, error) { @@ -144,18 +140,18 @@ func (s *server) StreamingSearch(stream pb.SearchService_StreamingSearchServer) func TestEnd2end(t *testing.T) { // Start server. - lis, err := net.Listen("tcp", port) + lis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("failed to listen: %v", err) } s := grpc.NewServer() pb.RegisterSearchServiceServer(s, &server{}) - // Install reflection service on s. - InstallOnServer(s) + // Register reflection service on s. + Register(s) go s.Serve(lis) // Create client. - conn, err := grpc.Dial("localhost"+port, grpc.WithInsecure()) + conn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) if err != nil { t.Fatalf("cannot connect to server: %v", err) }