Add allExtensionNumbersForTypeName and handle all_extension_numbers_response

This commit is contained in:
Menghan Li 2016-06-07 13:53:51 -07:00
parent 54fd6c1ea3
commit f28f4aa4bd
1 changed files with 37 additions and 9 deletions

View File

@ -203,6 +203,19 @@ func (s *serverReflectionServer) fileDescContainingExtension(st reflect.Type, ex
return fd, nil
}
func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]int32, error) {
m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message)
if !ok {
return nil, fmt.Errorf("failed to create message from type: %v", st)
}
var out []int32
for id := range proto.RegisteredExtensions(m) {
out = append(out, id)
}
return out, nil
}
// TODO filenameContainingExtension
// fd := fileDescContainingExtension()
// return fd.GetName()
@ -271,17 +284,16 @@ func (s *serverReflectionServer) fileDescWireFormatContainingExtension(typeName
return b, nil
}
func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]int32, error) {
m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message)
if !ok {
return nil, fmt.Errorf("failed to create message from type: %v", st)
func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([]int32, error) {
st, err := s.typeForName(name)
if err != nil {
return nil, err
}
var out []int32
for id := range proto.RegisteredExtensions(m) {
out = append(out, id)
extNums, err := s.allExtensionNumbersForType(st)
if err != nil {
return nil, err
}
return out, nil
return extNums, nil
}
func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflection_ServerReflectionInfoServer) error {
@ -344,6 +356,22 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio
}
}
case *rpb.ServerReflectionRequest_AllExtensionNumbersOfType:
extNums, err := s.allExtensionNumbersForTypeName(req.AllExtensionNumbersOfType)
if err != nil {
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{
ErrorCode: int32(codes.NotFound),
ErrorMessage: err.Error(),
},
}
} else {
out.MessageResponse = &rpb.ServerReflectionResponse_AllExtensionNumbersResponse{
AllExtensionNumbersResponse: &rpb.ExtensionNumberResponse{
BaseTypeName: req.AllExtensionNumbersOfType,
ExtensionNumber: extNums,
},
}
}
case *rpb.ServerReflectionRequest_ListServices:
default:
return grpc.Errorf(codes.InvalidArgument, "invalid MessageRequest: %v", in.MessageRequest)