xdsrouting: add fake headers (#3748)

This commit is contained in:
Menghan Li 2020-07-20 13:40:03 -07:00 committed by GitHub
parent 6dc7938fe8
commit 266c7b6f82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 340 additions and 136 deletions

View File

@ -0,0 +1,63 @@
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpcutil
import (
"strconv"
"time"
)
const maxTimeoutValue int64 = 100000000 - 1
// div does integer division and round-up the result. Note that this is
// equivalent to (d+r-1)/r but has less chance to overflow.
func div(d, r time.Duration) int64 {
if d%r > 0 {
return int64(d/r + 1)
}
return int64(d / r)
}
// EncodeDuration encodes the duration to the format grpc-timeout header
// accepts.
//
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests
func EncodeDuration(t time.Duration) string {
// TODO: This is simplistic and not bandwidth efficient. Improve it.
if t <= 0 {
return "0n"
}
if d := div(t, time.Nanosecond); d <= maxTimeoutValue {
return strconv.FormatInt(d, 10) + "n"
}
if d := div(t, time.Microsecond); d <= maxTimeoutValue {
return strconv.FormatInt(d, 10) + "u"
}
if d := div(t, time.Millisecond); d <= maxTimeoutValue {
return strconv.FormatInt(d, 10) + "m"
}
if d := div(t, time.Second); d <= maxTimeoutValue {
return strconv.FormatInt(d, 10) + "S"
}
if d := div(t, time.Minute); d <= maxTimeoutValue {
return strconv.FormatInt(d, 10) + "M"
}
// Note that maxTimeoutValue * time.Hour > MaxInt64.
return strconv.FormatInt(div(t, time.Hour), 10) + "H"
}

View File

@ -0,0 +1,51 @@
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpcutil
import (
"testing"
"time"
)
func TestEncodeDuration(t *testing.T) {
for _, test := range []struct {
in string
out string
}{
{"12345678ns", "12345678n"},
{"123456789ns", "123457u"},
{"12345678us", "12345678u"},
{"123456789us", "123457m"},
{"12345678ms", "12345678m"},
{"123456789ms", "123457S"},
{"12345678s", "12345678S"},
{"123456789s", "2057614M"},
{"12345678m", "12345678M"},
{"123456789m", "2057614H"},
} {
d, err := time.ParseDuration(test.in)
if err != nil {
t.Fatalf("failed to parse duration string %s: %v", test.in, err)
}
out := EncodeDuration(d)
if out != test.out {
t.Fatalf("timeoutEncode(%s) = %s, want %s", test.in, out, test.out)
}
}
}

View File

@ -0,0 +1,40 @@
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpcutil
import (
"context"
"google.golang.org/grpc/metadata"
)
type mdExtraKey struct{}
// WithExtraMetadata creates a new context with incoming md attached.
func WithExtraMetadata(ctx context.Context, md metadata.MD) context.Context {
return context.WithValue(ctx, mdExtraKey{}, md)
}
// ExtraMetadata returns the incoming metadata in ctx if it exists. The
// returned MD should not be modified. Writing to it may cause races.
// Modification should be made to copies of the returned MD.
func ExtraMetadata(ctx context.Context) (md metadata.MD, ok bool) {
md, ok = ctx.Value(mdExtraKey{}).(metadata.MD)
return
}

View File

@ -38,3 +38,47 @@ func ParseMethod(methodName string) (service, method string, _ error) {
}
return methodName[:pos], methodName[pos+1:], nil
}
const baseContentType = "application/grpc"
// ContentSubtype returns the content-subtype for the given content-type. The
// given content-type must be a valid content-type that starts with
// "application/grpc". A content-subtype will follow "application/grpc" after a
// "+" or ";". See
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details.
//
// If contentType is not a valid content-type for gRPC, the boolean
// will be false, otherwise true. If content-type == "application/grpc",
// "application/grpc+", or "application/grpc;", the boolean will be true,
// but no content-subtype will be returned.
//
// contentType is assumed to be lowercase already.
func ContentSubtype(contentType string) (string, bool) {
if contentType == baseContentType {
return "", true
}
if !strings.HasPrefix(contentType, baseContentType) {
return "", false
}
// guaranteed since != baseContentType and has baseContentType prefix
switch contentType[len(baseContentType)] {
case '+', ';':
// this will return true for "application/grpc+" or "application/grpc;"
// which the previous validContentType function tested to be valid, so we
// just say that no content-subtype is specified in this case
return contentType[len(baseContentType)+1:], true
default:
return "", false
}
}
// ContentType builds full content type with the given sub-type.
//
// contentSubtype is assumed to be lowercase
func ContentType(contentSubtype string) string {
if contentSubtype == "" {
return baseContentType
}
return baseContentType + "+" + contentSubtype
}

View File

@ -44,3 +44,26 @@ func TestParseMethod(t *testing.T) {
}
}
}
func TestContentSubtype(t *testing.T) {
tests := []struct {
contentType string
want string
wantValid bool
}{
{"application/grpc", "", true},
{"application/grpc+", "", true},
{"application/grpc+blah", "blah", true},
{"application/grpc;", "", true},
{"application/grpc;blah", "blah", true},
{"application/grpcd", "", false},
{"application/grpd", "", false},
{"application/grp", "", false},
}
for _, tt := range tests {
got, gotValid := ContentSubtype(tt.contentType)
if got != tt.want || gotValid != tt.wantValid {
t.Errorf("contentSubtype(%q) = (%v, %v); want (%v, %v)", tt.contentType, got, gotValid, tt.want, tt.wantValid)
}
}
}

View File

@ -39,6 +39,7 @@ import (
"golang.org/x/net/http2"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats"
@ -57,7 +58,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats sta
}
contentType := r.Header.Get("Content-Type")
// TODO: do we assume contentType is lowercase? we did before
contentSubtype, validContentType := contentSubtype(contentType)
contentSubtype, validContentType := grpcutil.ContentSubtype(contentType)
if !validContentType {
return nil, errors.New("invalid gRPC request content-type")
}

View File

@ -32,6 +32,7 @@ import (
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
@ -436,7 +437,7 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr)
headerFields = append(headerFields, hpack.HeaderField{Name: ":scheme", Value: t.scheme})
headerFields = append(headerFields, hpack.HeaderField{Name: ":path", Value: callHdr.Method})
headerFields = append(headerFields, hpack.HeaderField{Name: ":authority", Value: callHdr.Host})
headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(callHdr.ContentSubtype)})
headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: grpcutil.ContentType(callHdr.ContentSubtype)})
headerFields = append(headerFields, hpack.HeaderField{Name: "user-agent", Value: t.userAgent})
headerFields = append(headerFields, hpack.HeaderField{Name: "te", Value: "trailers"})
if callHdr.PreviousAttempts > 0 {
@ -451,7 +452,7 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr)
// Send out timeout regardless its value. The server can detect timeout context by itself.
// TODO(mmukhi): Perhaps this field should be updated when actually writing out to the wire.
timeout := time.Until(dl)
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)})
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-timeout", Value: grpcutil.EncodeDuration(timeout)})
}
for k, v := range authData {
headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})

View File

@ -34,6 +34,7 @@ import (
"github.com/golang/protobuf/proto"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
@ -804,7 +805,7 @@ func (t *http2Server) writeHeaderLocked(s *Stream) error {
// first and create a slice of that exact size.
headerFields := make([]hpack.HeaderField, 0, 2) // at least :status, content-type will be there if none else.
headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"})
headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(s.contentSubtype)})
headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: grpcutil.ContentType(s.contentSubtype)})
if s.sendCompress != "" {
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress})
}
@ -854,7 +855,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
}
} else { // Send a trailer only response.
headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"})
headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(s.contentSubtype)})
headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: grpcutil.ContentType(s.contentSubtype)})
}
}
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status", Value: strconv.Itoa(int(st.Code()))})

View File

@ -38,6 +38,7 @@ import (
spb "google.golang.org/genproto/googleapis/rpc/status"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/status"
)
@ -51,7 +52,7 @@ const (
// "proto" as a suffix after "+" or ";". See
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests
// for more details.
baseContentType = "application/grpc"
)
var (
@ -184,46 +185,6 @@ func isWhitelistedHeader(hdr string) bool {
}
}
// contentSubtype returns the content-subtype for the given content-type. The
// given content-type must be a valid content-type that starts with
// "application/grpc". A content-subtype will follow "application/grpc" after a
// "+" or ";". See
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details.
//
// If contentType is not a valid content-type for gRPC, the boolean
// will be false, otherwise true. If content-type == "application/grpc",
// "application/grpc+", or "application/grpc;", the boolean will be true,
// but no content-subtype will be returned.
//
// contentType is assumed to be lowercase already.
func contentSubtype(contentType string) (string, bool) {
if contentType == baseContentType {
return "", true
}
if !strings.HasPrefix(contentType, baseContentType) {
return "", false
}
// guaranteed since != baseContentType and has baseContentType prefix
switch contentType[len(baseContentType)] {
case '+', ';':
// this will return true for "application/grpc+" or "application/grpc;"
// which the previous validContentType function tested to be valid, so we
// just say that no content-subtype is specified in this case
return contentType[len(baseContentType)+1:], true
default:
return "", false
}
}
// contentSubtype is assumed to be lowercase
func contentType(contentSubtype string) string {
if contentSubtype == "" {
return baseContentType
}
return baseContentType + "+" + contentSubtype
}
func (d *decodeState) status() *status.Status {
if d.data.statusGen == nil {
// No status-details were provided; generate status using code/msg.
@ -342,7 +303,7 @@ func (d *decodeState) addMetadata(k, v string) {
func (d *decodeState) processHeaderField(f hpack.HeaderField) {
switch f.Name {
case "content-type":
contentSubtype, validContentType := contentSubtype(f.Value)
contentSubtype, validContentType := grpcutil.ContentSubtype(f.Value)
if !validContentType {
d.data.contentTypeErr = fmt.Sprintf("transport: received the unexpected content-type %q", f.Value)
return
@ -453,41 +414,6 @@ func timeoutUnitToDuration(u timeoutUnit) (d time.Duration, ok bool) {
return
}
const maxTimeoutValue int64 = 100000000 - 1
// div does integer division and round-up the result. Note that this is
// equivalent to (d+r-1)/r but has less chance to overflow.
func div(d, r time.Duration) int64 {
if m := d % r; m > 0 {
return int64(d/r + 1)
}
return int64(d / r)
}
// TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it.
func encodeTimeout(t time.Duration) string {
if t <= 0 {
return "0n"
}
if d := div(t, time.Nanosecond); d <= maxTimeoutValue {
return strconv.FormatInt(d, 10) + "n"
}
if d := div(t, time.Microsecond); d <= maxTimeoutValue {
return strconv.FormatInt(d, 10) + "u"
}
if d := div(t, time.Millisecond); d <= maxTimeoutValue {
return strconv.FormatInt(d, 10) + "m"
}
if d := div(t, time.Second); d <= maxTimeoutValue {
return strconv.FormatInt(d, 10) + "S"
}
if d := div(t, time.Minute); d <= maxTimeoutValue {
return strconv.FormatInt(d, 10) + "M"
}
// Note that maxTimeoutValue * time.Hour > MaxInt64.
return strconv.FormatInt(div(t, time.Hour), 10) + "H"
}
func decodeTimeout(s string) (time.Duration, error) {
size := len(s)
if size < 2 {

View File

@ -25,33 +25,6 @@ import (
"time"
)
func (s) TestTimeoutEncode(t *testing.T) {
for _, test := range []struct {
in string
out string
}{
{"12345678ns", "12345678n"},
{"123456789ns", "123457u"},
{"12345678us", "12345678u"},
{"123456789us", "123457m"},
{"12345678ms", "12345678m"},
{"123456789ms", "123457S"},
{"12345678s", "12345678S"},
{"123456789s", "2057614M"},
{"12345678m", "12345678M"},
{"123456789m", "2057614H"},
} {
d, err := time.ParseDuration(test.in)
if err != nil {
t.Fatalf("failed to parse duration string %s: %v", test.in, err)
}
out := encodeTimeout(d)
if out != test.out {
t.Fatalf("timeoutEncode(%s) = %s, want %s", test.in, out, test.out)
}
}
}
func (s) TestTimeoutDecode(t *testing.T) {
for _, test := range []struct {
// input
@ -72,29 +45,6 @@ func (s) TestTimeoutDecode(t *testing.T) {
}
}
func (s) TestContentSubtype(t *testing.T) {
tests := []struct {
contentType string
want string
wantValid bool
}{
{"application/grpc", "", true},
{"application/grpc+", "", true},
{"application/grpc+blah", "blah", true},
{"application/grpc;", "", true},
{"application/grpc;blah", "blah", true},
{"application/grpcd", "", false},
{"application/grpd", "", false},
{"application/grp", "", false},
}
for _, tt := range tests {
got, gotValid := contentSubtype(tt.contentType)
if got != tt.want || gotValid != tt.wantValid {
t.Errorf("contentSubtype(%q) = (%v, %v); want (%v, %v)", tt.contentType, got, gotValid, tt.want, tt.wantValid)
}
}
}
func (s) TestEncodeGrpcMessage(t *testing.T) {
for _, tt := range []struct {
input string

View File

@ -35,6 +35,7 @@ import (
"google.golang.org/grpc/internal/binarylog"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
@ -346,7 +347,16 @@ func (cs *clientStream) newAttemptLocked(sh stats.Handler, trInfo *traceInfo) (r
if err := cs.ctx.Err(); err != nil {
return toRPCErr(err)
}
t, done, err := cs.cc.getTransport(cs.ctx, cs.callInfo.failFast, cs.callHdr.Method)
ctx := cs.ctx
if cs.cc.parsedTarget.Scheme == "xds" {
// Add extra metadata (metadata that will be added by transport) to context
// so the balancer can see them.
ctx = grpcutil.WithExtraMetadata(cs.ctx, metadata.Pairs(
"content-type", grpcutil.ContentType(cs.callHdr.ContentSubtype),
))
}
t, done, err := cs.cc.getTransport(ctx, cs.callInfo.failFast, cs.callHdr.Method)
if err != nil {
return err
}

View File

@ -39,6 +39,7 @@ import (
"google.golang.org/grpc/internal/balancer/stub"
"google.golang.org/grpc/internal/balancerload"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver"
@ -60,6 +61,7 @@ type testBalancer struct {
newSubConnOptions balancer.NewSubConnOptions
pickInfos []balancer.PickInfo
pickExtraMDs []metadata.MD
doneInfo []balancer.DoneInfo
}
@ -124,8 +126,10 @@ func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
if p.err != nil {
return balancer.PickResult{}, p.err
}
extraMD, _ := grpcutil.ExtraMetadata(info.Ctx)
info.Ctx = nil // Do not validate context.
p.bal.pickInfos = append(p.bal.pickInfos, info)
p.bal.pickExtraMDs = append(p.bal.pickExtraMDs, extraMD)
return balancer.PickResult{SubConn: p.sc, Done: func(d balancer.DoneInfo) { p.bal.doneInfo = append(p.bal.doneInfo, d) }}, nil
}
@ -157,6 +161,60 @@ func (s) TestCredsBundleFromBalancer(t *testing.T) {
}
}
func (s) TestPickExtraMetadata(t *testing.T) {
for _, e := range listTestEnv() {
testPickExtraMetadata(t, e)
}
}
func testPickExtraMetadata(t *testing.T, e env) {
te := newTest(t, e)
b := &testBalancer{}
balancer.Register(b)
const (
testUserAgent = "test-user-agent"
testSubContentType = "proto"
)
te.customDialOptions = []grpc.DialOption{
grpc.WithBalancerName(testBalancerName),
grpc.WithUserAgent(testUserAgent),
}
te.startServer(&testServer{security: e.security})
defer te.tearDown()
// Set resolver to xds to trigger the extra metadata code path.
r := manual.NewBuilderWithScheme("xds")
resolver.Register(r)
defer func() {
resolver.UnregisterForTesting("xds")
}()
r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: te.srvAddr}}})
te.resolverScheme = "xds"
cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc)
// The RPCs will fail, but we don't care. We just need the pick to happen.
ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second)
defer cancel1()
tc.EmptyCall(ctx1, &testpb.Empty{})
ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second)
defer cancel2()
tc.EmptyCall(ctx2, &testpb.Empty{}, grpc.CallContentSubtype(testSubContentType))
want := []metadata.MD{
// First RPC doesn't have sub-content-type.
{"content-type": []string{"application/grpc"}},
// Second RPC has sub-content-type "proto".
{"content-type": []string{"application/grpc+proto"}},
}
if !cmp.Equal(b.pickExtraMDs, want) {
t.Fatalf("%s", cmp.Diff(b.pickExtraMDs, want))
}
}
func (s) TestDoneInfo(t *testing.T) {
for _, e := range listTestEnv() {
testDoneInfo(t, e)

View File

@ -20,9 +20,11 @@ package xdsrouting
import (
"fmt"
"strings"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/metadata"
)
@ -47,6 +49,16 @@ func (a *compositeMatcher) match(info balancer.PickInfo) bool {
var md metadata.MD
if info.Ctx != nil {
md, _ = metadata.FromOutgoingContext(info.Ctx)
if extraMD, ok := grpcutil.ExtraMetadata(info.Ctx); ok {
md = metadata.Join(md, extraMD)
// Remove all binary headers. They are hard to match with. May need
// to add back if asked by users.
for k := range md {
if strings.HasSuffix(k, "-bin") {
delete(md, k)
}
}
}
}
for _, m := range a.hms {
if !m.match(md) {

View File

@ -24,6 +24,7 @@ import (
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/metadata"
)
@ -65,6 +66,32 @@ func TestAndMatcherMatch(t *testing.T) {
},
want: false,
},
{
name: "fake header",
pm: newPathPrefixMatcher("/"),
hm: newHeaderExactMatcher("content-type", "fake"),
info: balancer.PickInfo{
FullMethodName: "/a/b",
Ctx: grpcutil.WithExtraMetadata(context.Background(), metadata.Pairs(
"content-type", "fake",
)),
},
want: true,
},
{
name: "binary header",
pm: newPathPrefixMatcher("/"),
hm: newHeaderPresentMatcher("t-bin", true),
info: balancer.PickInfo{
FullMethodName: "/a/b",
Ctx: grpcutil.WithExtraMetadata(
metadata.NewOutgoingContext(context.Background(), metadata.Pairs("t-bin", "123")), metadata.Pairs(
"content-type", "fake",
)),
},
// Shouldn't match binary header, even though it's in metadata.
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View File

@ -100,9 +100,6 @@ func (jc lbConfigJSON) toLBConfig() *lbConfig {
tempR.regex = *r.Regex
}
for _, h := range r.Headers {
if h.RangeMatch != nil {
fmt.Println("range not nil", *h.RangeMatch)
}
var tempHeader headerMatcher
switch {
case h.ExactMatch != nil: