mirror of https://github.com/grpc/grpc-go.git
rls: service field in RLS request must not contain slashes (#5168)
This commit is contained in:
parent
e2fc510d57
commit
593ff8d017
|
|
@ -120,6 +120,10 @@ type KeyMap struct {
|
||||||
// RLSKey builds the RLS keys to be used for the given request, identified by
|
// RLSKey builds the RLS keys to be used for the given request, identified by
|
||||||
// the request path and the request headers stored in metadata.
|
// the request path and the request headers stored in metadata.
|
||||||
func (bm BuilderMap) RLSKey(md metadata.MD, host, path string) KeyMap {
|
func (bm BuilderMap) RLSKey(md metadata.MD, host, path string) KeyMap {
|
||||||
|
// The path passed in is of the form "/service/method". The keyBuilderMap is
|
||||||
|
// indexed with keys of the form "/service/" or "/service/method". The service
|
||||||
|
// that we set in the keyMap (to be sent out in the RLS request) should not
|
||||||
|
// include any slashes though.
|
||||||
i := strings.LastIndex(path, "/")
|
i := strings.LastIndex(path, "/")
|
||||||
service, method := path[:i+1], path[i+1:]
|
service, method := path[:i+1], path[i+1:]
|
||||||
b, ok := bm[path]
|
b, ok := bm[path]
|
||||||
|
|
@ -135,7 +139,7 @@ func (bm BuilderMap) RLSKey(md metadata.MD, host, path string) KeyMap {
|
||||||
kvMap[b.hostKey] = host
|
kvMap[b.hostKey] = host
|
||||||
}
|
}
|
||||||
if b.serviceKey != "" {
|
if b.serviceKey != "" {
|
||||||
kvMap[b.serviceKey] = service
|
kvMap[b.serviceKey] = strings.Trim(service, "/")
|
||||||
}
|
}
|
||||||
if b.methodKey != "" {
|
if b.methodKey != "" {
|
||||||
kvMap[b.methodKey] = method
|
kvMap[b.methodKey] = method
|
||||||
|
|
|
||||||
|
|
@ -341,12 +341,12 @@ func TestRLSKey(t *testing.T) {
|
||||||
"const-key-1": "const-val-1",
|
"const-key-1": "const-val-1",
|
||||||
"const-key-2": "const-val-2",
|
"const-key-2": "const-val-2",
|
||||||
"host": "dummy-host",
|
"host": "dummy-host",
|
||||||
"service": "/gFoo/",
|
"service": "gFoo",
|
||||||
"method": "method1",
|
"method": "method1",
|
||||||
"k1": "v1",
|
"k1": "v1",
|
||||||
"k2": "v1",
|
"k2": "v1",
|
||||||
},
|
},
|
||||||
Str: "const-key-1=const-val-1,const-key-2=const-val-2,host=dummy-host,k1=v1,k2=v1,method=method1,service=/gFoo/",
|
Str: "const-key-1=const-val-1,const-key-2=const-val-2,host=dummy-host,k1=v1,k2=v1,method=method1,service=gFoo",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,8 @@ package rls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -65,8 +67,17 @@ type rlsPicker struct {
|
||||||
logger *internalgrpclog.PrefixLogger
|
logger *internalgrpclog.PrefixLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isFullMethodNameValid return true if name is of the form `/service/method`.
|
||||||
|
func isFullMethodNameValid(name string) bool {
|
||||||
|
return strings.HasPrefix(name, "/") && strings.Count(name, "/") == 2
|
||||||
|
}
|
||||||
|
|
||||||
// Pick makes the routing decision for every outbound RPC.
|
// Pick makes the routing decision for every outbound RPC.
|
||||||
func (p *rlsPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
|
func (p *rlsPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
|
||||||
|
if name := info.FullMethodName; !isFullMethodNameValid(name) {
|
||||||
|
return balancer.PickResult{}, fmt.Errorf("rls: method name %q is not of the form '/service/method", name)
|
||||||
|
}
|
||||||
|
|
||||||
// Build the request's keys using the key builders from LB config.
|
// Build the request's keys using the key builders from LB config.
|
||||||
md, _ := metadata.FromOutgoingContext(info.Ctx)
|
md, _ := metadata.FromOutgoingContext(info.Ctx)
|
||||||
reqKeys := p.kbm.RLSKey(md, p.origEndpoint, info.FullMethodName)
|
reqKeys := p.kbm.RLSKey(md, p.origEndpoint, info.FullMethodName)
|
||||||
|
|
|
||||||
|
|
@ -720,3 +720,40 @@ func (s) TestPick_DataCacheHit_PendingEntryExists_ExpiredEntry(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsFullMethodNameValid(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
desc string
|
||||||
|
methodName string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "does not start with a slash",
|
||||||
|
methodName: "service/method",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "does not contain a method",
|
||||||
|
methodName: "/service",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "path has more elements",
|
||||||
|
methodName: "/service/path/to/method",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "valid",
|
||||||
|
methodName: "/service/method",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
if got := isFullMethodNameValid(test.methodName); got != test.want {
|
||||||
|
t.Fatalf("isFullMethodNameValid(%q) = %v, want %v", test.methodName, got, test.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue