mirror of https://github.com/grpc/grpc-go.git
gcp/observability: update method name validation (#5951)
This commit is contained in:
parent
4075ef07c5
commit
52a8392f37
|
@ -24,19 +24,14 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
gcplogging "cloud.google.com/go/logging"
|
||||
"golang.org/x/oauth2/google"
|
||||
"google.golang.org/grpc/internal/envconfig"
|
||||
)
|
||||
|
||||
const (
|
||||
envProjectID = "GOOGLE_CLOUD_PROJECT"
|
||||
methodStringRegexpStr = `^([\w./]+)/((?:\w+)|[*])$`
|
||||
)
|
||||
|
||||
var methodStringRegexp = regexp.MustCompile(methodStringRegexpStr)
|
||||
const envProjectID = "GOOGLE_CLOUD_PROJECT"
|
||||
|
||||
// fetchDefaultProjectID fetches the default GCP project id from environment.
|
||||
func fetchDefaultProjectID(ctx context.Context) string {
|
||||
|
@ -59,6 +54,25 @@ func fetchDefaultProjectID(ctx context.Context) string {
|
|||
return credentials.ProjectID
|
||||
}
|
||||
|
||||
// validateMethodString validates whether the string passed in is a valid
|
||||
// pattern.
|
||||
func validateMethodString(method string) error {
|
||||
if strings.HasPrefix(method, "/") {
|
||||
return errors.New("cannot have a leading slash")
|
||||
}
|
||||
serviceMethod := strings.Split(method, "/")
|
||||
if len(serviceMethod) != 2 {
|
||||
return errors.New("/ must come in between service and method, only one /")
|
||||
}
|
||||
if serviceMethod[1] == "" {
|
||||
return errors.New("method name must be non empty")
|
||||
}
|
||||
if serviceMethod[0] == "*" {
|
||||
return errors.New("cannot have service wildcard * i.e. (*/m)")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateLogEventMethod(methods []string, exclude bool) error {
|
||||
for _, method := range methods {
|
||||
if method == "*" {
|
||||
|
@ -67,9 +81,8 @@ func validateLogEventMethod(methods []string, exclude bool) error {
|
|||
}
|
||||
continue
|
||||
}
|
||||
match := methodStringRegexp.FindStringSubmatch(method)
|
||||
if match == nil {
|
||||
return fmt.Errorf("invalid method string: %v", method)
|
||||
if err := validateMethodString(method); err != nil {
|
||||
return fmt.Errorf("invalid method string: %v, err: %v", method, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -322,6 +323,7 @@ func (bml *binaryMethodLogger) Log(c iblog.LogEntryConfig) {
|
|||
}
|
||||
|
||||
type eventConfig struct {
|
||||
// ServiceMethod has /s/m syntax for fast matching.
|
||||
ServiceMethod map[string]bool
|
||||
Services map[string]bool
|
||||
MatchAll bool
|
||||
|
@ -364,6 +366,17 @@ func (bl *binaryLogger) GetMethodLogger(methodName string) iblog.MethodLogger {
|
|||
return nil
|
||||
}
|
||||
|
||||
// parseMethod splits service and method from the input. It expects format
|
||||
// "service/method".
|
||||
func parseMethod(method string) (string, string, error) {
|
||||
pos := strings.Index(method, "/")
|
||||
if pos < 0 {
|
||||
// Shouldn't happen, config already validated.
|
||||
return "", "", errors.New("invalid method name: no / found")
|
||||
}
|
||||
return method[:pos], method[pos+1:], nil
|
||||
}
|
||||
|
||||
func registerClientRPCEvents(clientRPCEvents []clientRPCEvents, exporter loggingExporter) {
|
||||
if len(clientRPCEvents) == 0 {
|
||||
return
|
||||
|
@ -382,7 +395,7 @@ func registerClientRPCEvents(clientRPCEvents []clientRPCEvents, exporter logging
|
|||
eventConfig.MatchAll = true
|
||||
continue
|
||||
}
|
||||
s, m, err := grpcutil.ParseMethod(method)
|
||||
s, m, err := parseMethod(method)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
@ -390,7 +403,7 @@ func registerClientRPCEvents(clientRPCEvents []clientRPCEvents, exporter logging
|
|||
eventConfig.Services[s] = true
|
||||
continue
|
||||
}
|
||||
eventConfig.ServiceMethod[method] = true
|
||||
eventConfig.ServiceMethod["/"+method] = true
|
||||
}
|
||||
eventConfigs = append(eventConfigs, eventConfig)
|
||||
}
|
||||
|
@ -419,15 +432,15 @@ func registerServerRPCEvents(serverRPCEvents []serverRPCEvents, exporter logging
|
|||
eventConfig.MatchAll = true
|
||||
continue
|
||||
}
|
||||
s, m, err := grpcutil.ParseMethod(method)
|
||||
if err != nil { // Shouldn't happen, already validated at this point.
|
||||
s, m, err := parseMethod(method)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if m == "*" {
|
||||
eventConfig.Services[s] = true
|
||||
continue
|
||||
}
|
||||
eventConfig.ServiceMethod[method] = true
|
||||
eventConfig.ServiceMethod["/"+method] = true
|
||||
}
|
||||
eventConfigs = append(eventConfigs, eventConfig)
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
|
@ -99,13 +100,14 @@ func setupObservabilitySystemWithConfig(cfg *config) (func(), error) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
err = Start(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error in Start: %v", err)
|
||||
}
|
||||
return func() {
|
||||
cleanup := func() {
|
||||
End()
|
||||
envconfig.ObservabilityConfig = oldObservabilityConfig
|
||||
}, nil
|
||||
}
|
||||
if err != nil {
|
||||
return cleanup, fmt.Errorf("error in Start: %v", err)
|
||||
}
|
||||
return cleanup, nil
|
||||
}
|
||||
|
||||
// TestClientRPCEventsLogAll tests the observability system configured with a
|
||||
|
@ -777,18 +779,18 @@ func (s) TestPrecedenceOrderingInConfiguration(t *testing.T) {
|
|||
CloudLogging: &cloudLogging{
|
||||
ClientRPCEvents: []clientRPCEvents{
|
||||
{
|
||||
Methods: []string{"/grpc.testing.TestService/UnaryCall"},
|
||||
Methods: []string{"grpc.testing.TestService/UnaryCall"},
|
||||
MaxMetadataBytes: 30,
|
||||
MaxMessageBytes: 30,
|
||||
},
|
||||
{
|
||||
Methods: []string{"/grpc.testing.TestService/EmptyCall"},
|
||||
Methods: []string{"grpc.testing.TestService/EmptyCall"},
|
||||
Exclude: true,
|
||||
MaxMetadataBytes: 30,
|
||||
MaxMessageBytes: 30,
|
||||
},
|
||||
{
|
||||
Methods: []string{"/grpc.testing.TestService/*"},
|
||||
Methods: []string{"grpc.testing.TestService/*"},
|
||||
MaxMetadataBytes: 30,
|
||||
MaxMessageBytes: 30,
|
||||
},
|
||||
|
@ -1273,3 +1275,111 @@ func (s) TestMetadataTruncationAccountsKey(t *testing.T) {
|
|||
}
|
||||
fle.mu.Unlock()
|
||||
}
|
||||
|
||||
// TestMethodInConfiguration tests different method names with an expectation on
|
||||
// whether they should error or not.
|
||||
func (s) TestMethodInConfiguration(t *testing.T) {
|
||||
// To skip creating a stackdriver exporter.
|
||||
fle := &fakeLoggingExporter{
|
||||
t: t,
|
||||
}
|
||||
|
||||
defer func(ne func(ctx context.Context, config *config) (loggingExporter, error)) {
|
||||
newLoggingExporter = ne
|
||||
}(newLoggingExporter)
|
||||
|
||||
newLoggingExporter = func(ctx context.Context, config *config) (loggingExporter, error) {
|
||||
return fle, nil
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config *config
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "leading-slash",
|
||||
config: &config{
|
||||
ProjectID: "fake",
|
||||
CloudLogging: &cloudLogging{
|
||||
ClientRPCEvents: []clientRPCEvents{
|
||||
{
|
||||
Methods: []string{"/service/method"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: "cannot have a leading slash",
|
||||
},
|
||||
{
|
||||
name: "wildcard service/method",
|
||||
config: &config{
|
||||
ProjectID: "fake",
|
||||
CloudLogging: &cloudLogging{
|
||||
ClientRPCEvents: []clientRPCEvents{
|
||||
{
|
||||
Methods: []string{"*/method"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: "cannot have service wildcard *",
|
||||
},
|
||||
{
|
||||
name: "/ in service name",
|
||||
config: &config{
|
||||
ProjectID: "fake",
|
||||
CloudLogging: &cloudLogging{
|
||||
ClientRPCEvents: []clientRPCEvents{
|
||||
{
|
||||
Methods: []string{"ser/vice/method"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: "only one /",
|
||||
},
|
||||
{
|
||||
name: "empty method name",
|
||||
config: &config{
|
||||
ProjectID: "fake",
|
||||
CloudLogging: &cloudLogging{
|
||||
ClientRPCEvents: []clientRPCEvents{
|
||||
{
|
||||
Methods: []string{"service/"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: "method name must be non empty",
|
||||
},
|
||||
{
|
||||
name: "normal",
|
||||
config: &config{
|
||||
ProjectID: "fake",
|
||||
CloudLogging: &cloudLogging{
|
||||
ClientRPCEvents: []clientRPCEvents{
|
||||
{
|
||||
Methods: []string{"service/method"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: "",
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
cleanup, gotErr := setupObservabilitySystemWithConfig(test.config)
|
||||
if cleanup != nil {
|
||||
defer cleanup()
|
||||
}
|
||||
if gotErr != nil && !strings.Contains(gotErr.Error(), test.wantErr) {
|
||||
t.Fatalf("Start(%v) = %v, wantErr %v", test.config, gotErr, test.wantErr)
|
||||
}
|
||||
if (gotErr != nil) != (test.wantErr != "") {
|
||||
t.Fatalf("Start(%v) = %v, wantErr %v", test.config, gotErr, test.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue