diff --git a/internal/grpcutil/encode_duration.go b/internal/grpcutil/encode_duration.go new file mode 100644 index 000000000..b25b0baec --- /dev/null +++ b/internal/grpcutil/encode_duration.go @@ -0,0 +1,63 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpcutil + +import ( + "strconv" + "time" +) + +const maxTimeoutValue int64 = 100000000 - 1 + +// div does integer division and round-up the result. Note that this is +// equivalent to (d+r-1)/r but has less chance to overflow. +func div(d, r time.Duration) int64 { + if d%r > 0 { + return int64(d/r + 1) + } + return int64(d / r) +} + +// EncodeDuration encodes the duration to the format grpc-timeout header +// accepts. +// +// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests +func EncodeDuration(t time.Duration) string { + // TODO: This is simplistic and not bandwidth efficient. Improve it. + if t <= 0 { + return "0n" + } + if d := div(t, time.Nanosecond); d <= maxTimeoutValue { + return strconv.FormatInt(d, 10) + "n" + } + if d := div(t, time.Microsecond); d <= maxTimeoutValue { + return strconv.FormatInt(d, 10) + "u" + } + if d := div(t, time.Millisecond); d <= maxTimeoutValue { + return strconv.FormatInt(d, 10) + "m" + } + if d := div(t, time.Second); d <= maxTimeoutValue { + return strconv.FormatInt(d, 10) + "S" + } + if d := div(t, time.Minute); d <= maxTimeoutValue { + return strconv.FormatInt(d, 10) + "M" + } + // Note that maxTimeoutValue * time.Hour > MaxInt64. + return strconv.FormatInt(div(t, time.Hour), 10) + "H" +} diff --git a/internal/grpcutil/encode_duration_test.go b/internal/grpcutil/encode_duration_test.go new file mode 100644 index 000000000..eea49e2e7 --- /dev/null +++ b/internal/grpcutil/encode_duration_test.go @@ -0,0 +1,51 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpcutil + +import ( + "testing" + "time" +) + +func TestEncodeDuration(t *testing.T) { + for _, test := range []struct { + in string + out string + }{ + {"12345678ns", "12345678n"}, + {"123456789ns", "123457u"}, + {"12345678us", "12345678u"}, + {"123456789us", "123457m"}, + {"12345678ms", "12345678m"}, + {"123456789ms", "123457S"}, + {"12345678s", "12345678S"}, + {"123456789s", "2057614M"}, + {"12345678m", "12345678M"}, + {"123456789m", "2057614H"}, + } { + d, err := time.ParseDuration(test.in) + if err != nil { + t.Fatalf("failed to parse duration string %s: %v", test.in, err) + } + out := EncodeDuration(d) + if out != test.out { + t.Fatalf("timeoutEncode(%s) = %s, want %s", test.in, out, test.out) + } + } +} diff --git a/internal/grpcutil/metadata.go b/internal/grpcutil/metadata.go new file mode 100644 index 000000000..6f22bd891 --- /dev/null +++ b/internal/grpcutil/metadata.go @@ -0,0 +1,40 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpcutil + +import ( + "context" + + "google.golang.org/grpc/metadata" +) + +type mdExtraKey struct{} + +// WithExtraMetadata creates a new context with incoming md attached. +func WithExtraMetadata(ctx context.Context, md metadata.MD) context.Context { + return context.WithValue(ctx, mdExtraKey{}, md) +} + +// ExtraMetadata returns the incoming metadata in ctx if it exists. The +// returned MD should not be modified. Writing to it may cause races. +// Modification should be made to copies of the returned MD. +func ExtraMetadata(ctx context.Context) (md metadata.MD, ok bool) { + md, ok = ctx.Value(mdExtraKey{}).(metadata.MD) + return +} diff --git a/internal/grpcutil/method.go b/internal/grpcutil/method.go index 2c2ff7732..4e7475060 100644 --- a/internal/grpcutil/method.go +++ b/internal/grpcutil/method.go @@ -38,3 +38,47 @@ func ParseMethod(methodName string) (service, method string, _ error) { } return methodName[:pos], methodName[pos+1:], nil } + +const baseContentType = "application/grpc" + +// ContentSubtype returns the content-subtype for the given content-type. The +// given content-type must be a valid content-type that starts with +// "application/grpc". A content-subtype will follow "application/grpc" after a +// "+" or ";". See +// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for +// more details. +// +// If contentType is not a valid content-type for gRPC, the boolean +// will be false, otherwise true. If content-type == "application/grpc", +// "application/grpc+", or "application/grpc;", the boolean will be true, +// but no content-subtype will be returned. +// +// contentType is assumed to be lowercase already. +func ContentSubtype(contentType string) (string, bool) { + if contentType == baseContentType { + return "", true + } + if !strings.HasPrefix(contentType, baseContentType) { + return "", false + } + // guaranteed since != baseContentType and has baseContentType prefix + switch contentType[len(baseContentType)] { + case '+', ';': + // this will return true for "application/grpc+" or "application/grpc;" + // which the previous validContentType function tested to be valid, so we + // just say that no content-subtype is specified in this case + return contentType[len(baseContentType)+1:], true + default: + return "", false + } +} + +// ContentType builds full content type with the given sub-type. +// +// contentSubtype is assumed to be lowercase +func ContentType(contentSubtype string) string { + if contentSubtype == "" { + return baseContentType + } + return baseContentType + "+" + contentSubtype +} diff --git a/internal/grpcutil/method_test.go b/internal/grpcutil/method_test.go index 9880e1c43..36c786cff 100644 --- a/internal/grpcutil/method_test.go +++ b/internal/grpcutil/method_test.go @@ -44,3 +44,26 @@ func TestParseMethod(t *testing.T) { } } } + +func TestContentSubtype(t *testing.T) { + tests := []struct { + contentType string + want string + wantValid bool + }{ + {"application/grpc", "", true}, + {"application/grpc+", "", true}, + {"application/grpc+blah", "blah", true}, + {"application/grpc;", "", true}, + {"application/grpc;blah", "blah", true}, + {"application/grpcd", "", false}, + {"application/grpd", "", false}, + {"application/grp", "", false}, + } + for _, tt := range tests { + got, gotValid := ContentSubtype(tt.contentType) + if got != tt.want || gotValid != tt.wantValid { + t.Errorf("contentSubtype(%q) = (%v, %v); want (%v, %v)", tt.contentType, got, gotValid, tt.want, tt.wantValid) + } + } +} diff --git a/internal/transport/handler_server.go b/internal/transport/handler_server.go index fc44e9761..05d3871e6 100644 --- a/internal/transport/handler_server.go +++ b/internal/transport/handler_server.go @@ -39,6 +39,7 @@ import ( "golang.org/x/net/http2" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/stats" @@ -57,7 +58,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats sta } contentType := r.Header.Get("Content-Type") // TODO: do we assume contentType is lowercase? we did before - contentSubtype, validContentType := contentSubtype(contentType) + contentSubtype, validContentType := grpcutil.ContentSubtype(contentType) if !validContentType { return nil, errors.New("invalid gRPC request content-type") } diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 47af2ebea..e7f232113 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -32,6 +32,7 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" + "google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" @@ -436,7 +437,7 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr) headerFields = append(headerFields, hpack.HeaderField{Name: ":scheme", Value: t.scheme}) headerFields = append(headerFields, hpack.HeaderField{Name: ":path", Value: callHdr.Method}) headerFields = append(headerFields, hpack.HeaderField{Name: ":authority", Value: callHdr.Host}) - headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(callHdr.ContentSubtype)}) + headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: grpcutil.ContentType(callHdr.ContentSubtype)}) headerFields = append(headerFields, hpack.HeaderField{Name: "user-agent", Value: t.userAgent}) headerFields = append(headerFields, hpack.HeaderField{Name: "te", Value: "trailers"}) if callHdr.PreviousAttempts > 0 { @@ -451,7 +452,7 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr) // Send out timeout regardless its value. The server can detect timeout context by itself. // TODO(mmukhi): Perhaps this field should be updated when actually writing out to the wire. timeout := time.Until(dl) - headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)}) + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-timeout", Value: grpcutil.EncodeDuration(timeout)}) } for k, v := range authData { headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 788cf1e4d..04cbedf79 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -34,6 +34,7 @@ import ( "github.com/golang/protobuf/proto" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" + "google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" @@ -804,7 +805,7 @@ func (t *http2Server) writeHeaderLocked(s *Stream) error { // first and create a slice of that exact size. headerFields := make([]hpack.HeaderField, 0, 2) // at least :status, content-type will be there if none else. headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"}) - headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(s.contentSubtype)}) + headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: grpcutil.ContentType(s.contentSubtype)}) if s.sendCompress != "" { headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress}) } @@ -854,7 +855,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { } } else { // Send a trailer only response. headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"}) - headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(s.contentSubtype)}) + headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: grpcutil.ContentType(s.contentSubtype)}) } } headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status", Value: strconv.Itoa(int(st.Code()))}) diff --git a/internal/transport/http_util.go b/internal/transport/http_util.go index e68cdcb6c..5e1e7a65d 100644 --- a/internal/transport/http_util.go +++ b/internal/transport/http_util.go @@ -38,6 +38,7 @@ import ( spb "google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/grpc/codes" "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/status" ) @@ -51,7 +52,7 @@ const ( // "proto" as a suffix after "+" or ";". See // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests // for more details. - baseContentType = "application/grpc" + ) var ( @@ -184,46 +185,6 @@ func isWhitelistedHeader(hdr string) bool { } } -// contentSubtype returns the content-subtype for the given content-type. The -// given content-type must be a valid content-type that starts with -// "application/grpc". A content-subtype will follow "application/grpc" after a -// "+" or ";". See -// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for -// more details. -// -// If contentType is not a valid content-type for gRPC, the boolean -// will be false, otherwise true. If content-type == "application/grpc", -// "application/grpc+", or "application/grpc;", the boolean will be true, -// but no content-subtype will be returned. -// -// contentType is assumed to be lowercase already. -func contentSubtype(contentType string) (string, bool) { - if contentType == baseContentType { - return "", true - } - if !strings.HasPrefix(contentType, baseContentType) { - return "", false - } - // guaranteed since != baseContentType and has baseContentType prefix - switch contentType[len(baseContentType)] { - case '+', ';': - // this will return true for "application/grpc+" or "application/grpc;" - // which the previous validContentType function tested to be valid, so we - // just say that no content-subtype is specified in this case - return contentType[len(baseContentType)+1:], true - default: - return "", false - } -} - -// contentSubtype is assumed to be lowercase -func contentType(contentSubtype string) string { - if contentSubtype == "" { - return baseContentType - } - return baseContentType + "+" + contentSubtype -} - func (d *decodeState) status() *status.Status { if d.data.statusGen == nil { // No status-details were provided; generate status using code/msg. @@ -342,7 +303,7 @@ func (d *decodeState) addMetadata(k, v string) { func (d *decodeState) processHeaderField(f hpack.HeaderField) { switch f.Name { case "content-type": - contentSubtype, validContentType := contentSubtype(f.Value) + contentSubtype, validContentType := grpcutil.ContentSubtype(f.Value) if !validContentType { d.data.contentTypeErr = fmt.Sprintf("transport: received the unexpected content-type %q", f.Value) return @@ -453,41 +414,6 @@ func timeoutUnitToDuration(u timeoutUnit) (d time.Duration, ok bool) { return } -const maxTimeoutValue int64 = 100000000 - 1 - -// div does integer division and round-up the result. Note that this is -// equivalent to (d+r-1)/r but has less chance to overflow. -func div(d, r time.Duration) int64 { - if m := d % r; m > 0 { - return int64(d/r + 1) - } - return int64(d / r) -} - -// TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it. -func encodeTimeout(t time.Duration) string { - if t <= 0 { - return "0n" - } - if d := div(t, time.Nanosecond); d <= maxTimeoutValue { - return strconv.FormatInt(d, 10) + "n" - } - if d := div(t, time.Microsecond); d <= maxTimeoutValue { - return strconv.FormatInt(d, 10) + "u" - } - if d := div(t, time.Millisecond); d <= maxTimeoutValue { - return strconv.FormatInt(d, 10) + "m" - } - if d := div(t, time.Second); d <= maxTimeoutValue { - return strconv.FormatInt(d, 10) + "S" - } - if d := div(t, time.Minute); d <= maxTimeoutValue { - return strconv.FormatInt(d, 10) + "M" - } - // Note that maxTimeoutValue * time.Hour > MaxInt64. - return strconv.FormatInt(div(t, time.Hour), 10) + "H" -} - func decodeTimeout(s string) (time.Duration, error) { size := len(s) if size < 2 { diff --git a/internal/transport/http_util_test.go b/internal/transport/http_util_test.go index dd13bcc2f..80b1c094a 100644 --- a/internal/transport/http_util_test.go +++ b/internal/transport/http_util_test.go @@ -25,33 +25,6 @@ import ( "time" ) -func (s) TestTimeoutEncode(t *testing.T) { - for _, test := range []struct { - in string - out string - }{ - {"12345678ns", "12345678n"}, - {"123456789ns", "123457u"}, - {"12345678us", "12345678u"}, - {"123456789us", "123457m"}, - {"12345678ms", "12345678m"}, - {"123456789ms", "123457S"}, - {"12345678s", "12345678S"}, - {"123456789s", "2057614M"}, - {"12345678m", "12345678M"}, - {"123456789m", "2057614H"}, - } { - d, err := time.ParseDuration(test.in) - if err != nil { - t.Fatalf("failed to parse duration string %s: %v", test.in, err) - } - out := encodeTimeout(d) - if out != test.out { - t.Fatalf("timeoutEncode(%s) = %s, want %s", test.in, out, test.out) - } - } -} - func (s) TestTimeoutDecode(t *testing.T) { for _, test := range []struct { // input @@ -72,29 +45,6 @@ func (s) TestTimeoutDecode(t *testing.T) { } } -func (s) TestContentSubtype(t *testing.T) { - tests := []struct { - contentType string - want string - wantValid bool - }{ - {"application/grpc", "", true}, - {"application/grpc+", "", true}, - {"application/grpc+blah", "blah", true}, - {"application/grpc;", "", true}, - {"application/grpc;blah", "blah", true}, - {"application/grpcd", "", false}, - {"application/grpd", "", false}, - {"application/grp", "", false}, - } - for _, tt := range tests { - got, gotValid := contentSubtype(tt.contentType) - if got != tt.want || gotValid != tt.wantValid { - t.Errorf("contentSubtype(%q) = (%v, %v); want (%v, %v)", tt.contentType, got, gotValid, tt.want, tt.wantValid) - } - } -} - func (s) TestEncodeGrpcMessage(t *testing.T) { for _, tt := range []struct { input string diff --git a/stream.go b/stream.go index 34e4010d0..fbc3fb11c 100644 --- a/stream.go +++ b/stream.go @@ -35,6 +35,7 @@ import ( "google.golang.org/grpc/internal/binarylog" "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/grpcrand" + "google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/internal/transport" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" @@ -346,7 +347,16 @@ func (cs *clientStream) newAttemptLocked(sh stats.Handler, trInfo *traceInfo) (r if err := cs.ctx.Err(); err != nil { return toRPCErr(err) } - t, done, err := cs.cc.getTransport(cs.ctx, cs.callInfo.failFast, cs.callHdr.Method) + + ctx := cs.ctx + if cs.cc.parsedTarget.Scheme == "xds" { + // Add extra metadata (metadata that will be added by transport) to context + // so the balancer can see them. + ctx = grpcutil.WithExtraMetadata(cs.ctx, metadata.Pairs( + "content-type", grpcutil.ContentType(cs.callHdr.ContentSubtype), + )) + } + t, done, err := cs.cc.getTransport(ctx, cs.callInfo.failFast, cs.callHdr.Method) if err != nil { return err } diff --git a/test/balancer_test.go b/test/balancer_test.go index 7fa96d868..9f8bc1a6e 100644 --- a/test/balancer_test.go +++ b/test/balancer_test.go @@ -39,6 +39,7 @@ import ( "google.golang.org/grpc/internal/balancer/stub" "google.golang.org/grpc/internal/balancerload" "google.golang.org/grpc/internal/grpcsync" + "google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/metadata" "google.golang.org/grpc/resolver" @@ -60,6 +61,7 @@ type testBalancer struct { newSubConnOptions balancer.NewSubConnOptions pickInfos []balancer.PickInfo + pickExtraMDs []metadata.MD doneInfo []balancer.DoneInfo } @@ -124,8 +126,10 @@ func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { if p.err != nil { return balancer.PickResult{}, p.err } + extraMD, _ := grpcutil.ExtraMetadata(info.Ctx) info.Ctx = nil // Do not validate context. p.bal.pickInfos = append(p.bal.pickInfos, info) + p.bal.pickExtraMDs = append(p.bal.pickExtraMDs, extraMD) return balancer.PickResult{SubConn: p.sc, Done: func(d balancer.DoneInfo) { p.bal.doneInfo = append(p.bal.doneInfo, d) }}, nil } @@ -157,6 +161,60 @@ func (s) TestCredsBundleFromBalancer(t *testing.T) { } } +func (s) TestPickExtraMetadata(t *testing.T) { + for _, e := range listTestEnv() { + testPickExtraMetadata(t, e) + } +} + +func testPickExtraMetadata(t *testing.T, e env) { + te := newTest(t, e) + b := &testBalancer{} + balancer.Register(b) + const ( + testUserAgent = "test-user-agent" + testSubContentType = "proto" + ) + + te.customDialOptions = []grpc.DialOption{ + grpc.WithBalancerName(testBalancerName), + grpc.WithUserAgent(testUserAgent), + } + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + // Set resolver to xds to trigger the extra metadata code path. + r := manual.NewBuilderWithScheme("xds") + resolver.Register(r) + defer func() { + resolver.UnregisterForTesting("xds") + }() + r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: te.srvAddr}}}) + te.resolverScheme = "xds" + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + + // The RPCs will fail, but we don't care. We just need the pick to happen. + ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second) + defer cancel1() + tc.EmptyCall(ctx1, &testpb.Empty{}) + + ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second) + defer cancel2() + tc.EmptyCall(ctx2, &testpb.Empty{}, grpc.CallContentSubtype(testSubContentType)) + + want := []metadata.MD{ + // First RPC doesn't have sub-content-type. + {"content-type": []string{"application/grpc"}}, + // Second RPC has sub-content-type "proto". + {"content-type": []string{"application/grpc+proto"}}, + } + + if !cmp.Equal(b.pickExtraMDs, want) { + t.Fatalf("%s", cmp.Diff(b.pickExtraMDs, want)) + } +} + func (s) TestDoneInfo(t *testing.T) { for _, e := range listTestEnv() { testDoneInfo(t, e) diff --git a/xds/internal/balancer/xdsrouting/matcher.go b/xds/internal/balancer/xdsrouting/matcher.go index 79fab828a..196aefae5 100644 --- a/xds/internal/balancer/xdsrouting/matcher.go +++ b/xds/internal/balancer/xdsrouting/matcher.go @@ -20,9 +20,11 @@ package xdsrouting import ( "fmt" + "strings" "google.golang.org/grpc/balancer" "google.golang.org/grpc/internal/grpcrand" + "google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/metadata" ) @@ -47,6 +49,16 @@ func (a *compositeMatcher) match(info balancer.PickInfo) bool { var md metadata.MD if info.Ctx != nil { md, _ = metadata.FromOutgoingContext(info.Ctx) + if extraMD, ok := grpcutil.ExtraMetadata(info.Ctx); ok { + md = metadata.Join(md, extraMD) + // Remove all binary headers. They are hard to match with. May need + // to add back if asked by users. + for k := range md { + if strings.HasSuffix(k, "-bin") { + delete(md, k) + } + } + } } for _, m := range a.hms { if !m.match(md) { diff --git a/xds/internal/balancer/xdsrouting/matcher_test.go b/xds/internal/balancer/xdsrouting/matcher_test.go index 8c8a4b52e..e7d76e274 100644 --- a/xds/internal/balancer/xdsrouting/matcher_test.go +++ b/xds/internal/balancer/xdsrouting/matcher_test.go @@ -24,6 +24,7 @@ import ( "google.golang.org/grpc/balancer" "google.golang.org/grpc/internal/grpcrand" + "google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/metadata" ) @@ -65,6 +66,32 @@ func TestAndMatcherMatch(t *testing.T) { }, want: false, }, + { + name: "fake header", + pm: newPathPrefixMatcher("/"), + hm: newHeaderExactMatcher("content-type", "fake"), + info: balancer.PickInfo{ + FullMethodName: "/a/b", + Ctx: grpcutil.WithExtraMetadata(context.Background(), metadata.Pairs( + "content-type", "fake", + )), + }, + want: true, + }, + { + name: "binary header", + pm: newPathPrefixMatcher("/"), + hm: newHeaderPresentMatcher("t-bin", true), + info: balancer.PickInfo{ + FullMethodName: "/a/b", + Ctx: grpcutil.WithExtraMetadata( + metadata.NewOutgoingContext(context.Background(), metadata.Pairs("t-bin", "123")), metadata.Pairs( + "content-type", "fake", + )), + }, + // Shouldn't match binary header, even though it's in metadata. + want: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/xds/internal/balancer/xdsrouting/routing_config.go b/xds/internal/balancer/xdsrouting/routing_config.go index 6daea224a..c8cc7a18c 100644 --- a/xds/internal/balancer/xdsrouting/routing_config.go +++ b/xds/internal/balancer/xdsrouting/routing_config.go @@ -100,9 +100,6 @@ func (jc lbConfigJSON) toLBConfig() *lbConfig { tempR.regex = *r.Regex } for _, h := range r.Headers { - if h.RangeMatch != nil { - fmt.Println("range not nil", *h.RangeMatch) - } var tempHeader headerMatcher switch { case h.ExactMatch != nil: