grpc-go/reflection/serverreflection.go

304 lines
8.0 KiB
Go

package reflection
import (
"bytes"
"compress/gzip"
"fmt"
"io"
"io/ioutil"
"reflect"
"github.com/golang/protobuf/proto"
dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
)
type serverReflectionServer struct {
s *grpc.Server
// TODO add cache if necessary
}
// InstallOnServer installs server reflection service on the given grpc server.
func InstallOnServer(s *grpc.Server) {
rpb.RegisterServerReflectionServer(s, &serverReflectionServer{
s: s,
})
}
type protoMessage interface {
Descriptor() ([]byte, []int)
}
func (s *serverReflectionServer) fileDescForType(st reflect.Type) (*dpb.FileDescriptorProto, []int, error) {
// Indexes list is not stored in cache.
// So this step is needed to get idxs.
m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(protoMessage)
if !ok {
return nil, nil, fmt.Errorf("failed to create message from type: %v", st)
}
enc, idxs := m.Descriptor()
// Cache missed, try to decode.
fd, err := s.decodeFileDesc(enc)
if err != nil {
return nil, nil, err
}
return fd, idxs, nil
}
func (s *serverReflectionServer) decodeFileDesc(enc []byte) (*dpb.FileDescriptorProto, error) {
raw := decompress(enc)
if raw == nil {
return nil, fmt.Errorf("failed to decompress enc")
}
fd := new(dpb.FileDescriptorProto)
if err := proto.Unmarshal(raw, fd); err != nil {
return nil, fmt.Errorf("bad descriptor: %v", err)
}
return fd, nil
}
func decompress(b []byte) []byte {
r, err := gzip.NewReader(bytes.NewReader(b))
if err != nil {
fmt.Printf("bad gzipped descriptor: %v\n", err)
return nil
}
out, err := ioutil.ReadAll(r)
if err != nil {
fmt.Printf("bad gzipped descriptor: %v\n", err)
return nil
}
return out
}
func (s *serverReflectionServer) typeForName(name string) (reflect.Type, error) {
pt := proto.MessageType(name)
if pt == nil {
return nil, fmt.Errorf("unknown type: %q", name)
}
st := pt.Elem()
return st, nil
}
func (s *serverReflectionServer) fileDescContainingExtension(st reflect.Type, ext int32) (*dpb.FileDescriptorProto, error) {
m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message)
if !ok {
return nil, fmt.Errorf("failed to create message from type: %v", st)
}
var extDesc *proto.ExtensionDesc
for id, desc := range proto.RegisteredExtensions(m) {
if id == ext {
extDesc = desc
break
}
}
if extDesc == nil {
return nil, fmt.Errorf("failed to find registered extension for extension number %v", ext)
}
extT := reflect.TypeOf(extDesc.ExtensionType).Elem()
fd, _, err := s.fileDescForType(extT)
if err != nil {
return nil, err
}
return fd, nil
}
func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]int32, error) {
m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message)
if !ok {
return nil, fmt.Errorf("failed to create message from type: %v", st)
}
var out []int32
for id := range proto.RegisteredExtensions(m) {
out = append(out, id)
}
return out, nil
}
func (s *serverReflectionServer) fileDescWireFormatByFilename(name string) ([]byte, error) {
enc := proto.FileDescriptor(name)
if enc == nil {
return nil, fmt.Errorf("unknown file: %v", name)
}
fd, err := s.decodeFileDesc(enc)
if err != nil {
return nil, err
}
b, err := proto.Marshal(fd)
if err != nil {
return nil, err
}
return b, nil
}
func (s *serverReflectionServer) fileDescWireFormatContainingSymbol(name string) ([]byte, error) {
var (
fd *dpb.FileDescriptorProto
)
// Check if it's a type name.
if st, err := s.typeForName(name); err == nil {
fd, _, err = s.fileDescForType(st)
if err != nil {
return nil, err
}
} else {
// Check if it's a service name or method name.
meta := s.s.Metadata(name)
if meta != nil {
if enc, ok := meta.([]byte); ok {
fd, err = s.decodeFileDesc(enc)
if err != nil {
return nil, err
}
}
}
}
// Marshal to wire format.
if fd != nil {
b, err := proto.Marshal(fd)
if err != nil {
return nil, err
}
return b, nil
}
return nil, fmt.Errorf("unknown symbol: %v", name)
}
func (s *serverReflectionServer) fileDescWireFormatContainingExtension(typeName string, extNum int32) ([]byte, error) {
st, err := s.typeForName(typeName)
if err != nil {
return nil, err
}
fd, err := s.fileDescContainingExtension(st, extNum)
if err != nil {
return nil, err
}
b, err := proto.Marshal(fd)
if err != nil {
return nil, err
}
return b, nil
}
func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([]int32, error) {
st, err := s.typeForName(name)
if err != nil {
return nil, err
}
extNums, err := s.allExtensionNumbersForType(st)
if err != nil {
return nil, err
}
return extNums, nil
}
func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflection_ServerReflectionInfoServer) error {
for {
in, err := stream.Recv()
if err == io.EOF {
return nil
}
if err != nil {
return err
}
out := &rpb.ServerReflectionResponse{
ValidHost: in.Host,
OriginalRequest: in,
}
switch req := in.MessageRequest.(type) {
case *rpb.ServerReflectionRequest_FileByFilename:
b, err := s.fileDescWireFormatByFilename(req.FileByFilename)
if err != nil {
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{
ErrorCode: int32(codes.NotFound),
ErrorMessage: err.Error(),
},
}
} else {
out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}},
}
}
case *rpb.ServerReflectionRequest_FileContainingSymbol:
b, err := s.fileDescWireFormatContainingSymbol(req.FileContainingSymbol)
if err != nil {
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{
ErrorCode: int32(codes.NotFound),
ErrorMessage: err.Error(),
},
}
} else {
out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}},
}
}
case *rpb.ServerReflectionRequest_FileContainingExtension:
typeName := req.FileContainingExtension.ContainingType
extNum := req.FileContainingExtension.ExtensionNumber
b, err := s.fileDescWireFormatContainingExtension(typeName, extNum)
if err != nil {
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{
ErrorCode: int32(codes.NotFound),
ErrorMessage: err.Error(),
},
}
} else {
out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}},
}
}
case *rpb.ServerReflectionRequest_AllExtensionNumbersOfType:
extNums, err := s.allExtensionNumbersForTypeName(req.AllExtensionNumbersOfType)
if err != nil {
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{
ErrorCode: int32(codes.NotFound),
ErrorMessage: err.Error(),
},
}
} else {
out.MessageResponse = &rpb.ServerReflectionResponse_AllExtensionNumbersResponse{
AllExtensionNumbersResponse: &rpb.ExtensionNumberResponse{
BaseTypeName: req.AllExtensionNumbersOfType,
ExtensionNumber: extNums,
},
}
}
case *rpb.ServerReflectionRequest_ListServices:
services := s.s.AllServiceNames()
serviceResponses := make([]*rpb.ServiceResponse, len(services))
for i, s := range services {
serviceResponses[i] = &rpb.ServiceResponse{
Name: s,
}
}
out.MessageResponse = &rpb.ServerReflectionResponse_ListServicesResponse{
ListServicesResponse: &rpb.ListServiceResponse{
Service: serviceResponses,
},
}
default:
return grpc.Errorf(codes.InvalidArgument, "invalid MessageRequest: %v", in.MessageRequest)
}
if err := stream.Send(out); err != nil {
return err
}
}
}