mirror of https://github.com/grpc/grpc-go.git
xdsrouting: add fake headers (#3748)
This commit is contained in:
parent
6dc7938fe8
commit
266c7b6f82
|
@ -0,0 +1,63 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2020 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package grpcutil
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
const maxTimeoutValue int64 = 100000000 - 1
|
||||
|
||||
// div does integer division and round-up the result. Note that this is
|
||||
// equivalent to (d+r-1)/r but has less chance to overflow.
|
||||
func div(d, r time.Duration) int64 {
|
||||
if d%r > 0 {
|
||||
return int64(d/r + 1)
|
||||
}
|
||||
return int64(d / r)
|
||||
}
|
||||
|
||||
// EncodeDuration encodes the duration to the format grpc-timeout header
|
||||
// accepts.
|
||||
//
|
||||
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests
|
||||
func EncodeDuration(t time.Duration) string {
|
||||
// TODO: This is simplistic and not bandwidth efficient. Improve it.
|
||||
if t <= 0 {
|
||||
return "0n"
|
||||
}
|
||||
if d := div(t, time.Nanosecond); d <= maxTimeoutValue {
|
||||
return strconv.FormatInt(d, 10) + "n"
|
||||
}
|
||||
if d := div(t, time.Microsecond); d <= maxTimeoutValue {
|
||||
return strconv.FormatInt(d, 10) + "u"
|
||||
}
|
||||
if d := div(t, time.Millisecond); d <= maxTimeoutValue {
|
||||
return strconv.FormatInt(d, 10) + "m"
|
||||
}
|
||||
if d := div(t, time.Second); d <= maxTimeoutValue {
|
||||
return strconv.FormatInt(d, 10) + "S"
|
||||
}
|
||||
if d := div(t, time.Minute); d <= maxTimeoutValue {
|
||||
return strconv.FormatInt(d, 10) + "M"
|
||||
}
|
||||
// Note that maxTimeoutValue * time.Hour > MaxInt64.
|
||||
return strconv.FormatInt(div(t, time.Hour), 10) + "H"
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2020 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package grpcutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestEncodeDuration(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
in string
|
||||
out string
|
||||
}{
|
||||
{"12345678ns", "12345678n"},
|
||||
{"123456789ns", "123457u"},
|
||||
{"12345678us", "12345678u"},
|
||||
{"123456789us", "123457m"},
|
||||
{"12345678ms", "12345678m"},
|
||||
{"123456789ms", "123457S"},
|
||||
{"12345678s", "12345678S"},
|
||||
{"123456789s", "2057614M"},
|
||||
{"12345678m", "12345678M"},
|
||||
{"123456789m", "2057614H"},
|
||||
} {
|
||||
d, err := time.ParseDuration(test.in)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse duration string %s: %v", test.in, err)
|
||||
}
|
||||
out := EncodeDuration(d)
|
||||
if out != test.out {
|
||||
t.Fatalf("timeoutEncode(%s) = %s, want %s", test.in, out, test.out)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2020 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package grpcutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
type mdExtraKey struct{}
|
||||
|
||||
// WithExtraMetadata creates a new context with incoming md attached.
|
||||
func WithExtraMetadata(ctx context.Context, md metadata.MD) context.Context {
|
||||
return context.WithValue(ctx, mdExtraKey{}, md)
|
||||
}
|
||||
|
||||
// ExtraMetadata returns the incoming metadata in ctx if it exists. The
|
||||
// returned MD should not be modified. Writing to it may cause races.
|
||||
// Modification should be made to copies of the returned MD.
|
||||
func ExtraMetadata(ctx context.Context) (md metadata.MD, ok bool) {
|
||||
md, ok = ctx.Value(mdExtraKey{}).(metadata.MD)
|
||||
return
|
||||
}
|
|
@ -38,3 +38,47 @@ func ParseMethod(methodName string) (service, method string, _ error) {
|
|||
}
|
||||
return methodName[:pos], methodName[pos+1:], nil
|
||||
}
|
||||
|
||||
const baseContentType = "application/grpc"
|
||||
|
||||
// ContentSubtype returns the content-subtype for the given content-type. The
|
||||
// given content-type must be a valid content-type that starts with
|
||||
// "application/grpc". A content-subtype will follow "application/grpc" after a
|
||||
// "+" or ";". See
|
||||
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
|
||||
// more details.
|
||||
//
|
||||
// If contentType is not a valid content-type for gRPC, the boolean
|
||||
// will be false, otherwise true. If content-type == "application/grpc",
|
||||
// "application/grpc+", or "application/grpc;", the boolean will be true,
|
||||
// but no content-subtype will be returned.
|
||||
//
|
||||
// contentType is assumed to be lowercase already.
|
||||
func ContentSubtype(contentType string) (string, bool) {
|
||||
if contentType == baseContentType {
|
||||
return "", true
|
||||
}
|
||||
if !strings.HasPrefix(contentType, baseContentType) {
|
||||
return "", false
|
||||
}
|
||||
// guaranteed since != baseContentType and has baseContentType prefix
|
||||
switch contentType[len(baseContentType)] {
|
||||
case '+', ';':
|
||||
// this will return true for "application/grpc+" or "application/grpc;"
|
||||
// which the previous validContentType function tested to be valid, so we
|
||||
// just say that no content-subtype is specified in this case
|
||||
return contentType[len(baseContentType)+1:], true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
// ContentType builds full content type with the given sub-type.
|
||||
//
|
||||
// contentSubtype is assumed to be lowercase
|
||||
func ContentType(contentSubtype string) string {
|
||||
if contentSubtype == "" {
|
||||
return baseContentType
|
||||
}
|
||||
return baseContentType + "+" + contentSubtype
|
||||
}
|
||||
|
|
|
@ -44,3 +44,26 @@ func TestParseMethod(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestContentSubtype(t *testing.T) {
|
||||
tests := []struct {
|
||||
contentType string
|
||||
want string
|
||||
wantValid bool
|
||||
}{
|
||||
{"application/grpc", "", true},
|
||||
{"application/grpc+", "", true},
|
||||
{"application/grpc+blah", "blah", true},
|
||||
{"application/grpc;", "", true},
|
||||
{"application/grpc;blah", "blah", true},
|
||||
{"application/grpcd", "", false},
|
||||
{"application/grpd", "", false},
|
||||
{"application/grp", "", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got, gotValid := ContentSubtype(tt.contentType)
|
||||
if got != tt.want || gotValid != tt.wantValid {
|
||||
t.Errorf("contentSubtype(%q) = (%v, %v); want (%v, %v)", tt.contentType, got, gotValid, tt.want, tt.wantValid)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,6 +39,7 @@ import (
|
|||
"golang.org/x/net/http2"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/internal/grpcutil"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/stats"
|
||||
|
@ -57,7 +58,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats sta
|
|||
}
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
// TODO: do we assume contentType is lowercase? we did before
|
||||
contentSubtype, validContentType := contentSubtype(contentType)
|
||||
contentSubtype, validContentType := grpcutil.ContentSubtype(contentType)
|
||||
if !validContentType {
|
||||
return nil, errors.New("invalid gRPC request content-type")
|
||||
}
|
||||
|
|
|
@ -32,6 +32,7 @@ import (
|
|||
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
"google.golang.org/grpc/internal/grpcutil"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
|
@ -436,7 +437,7 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr)
|
|||
headerFields = append(headerFields, hpack.HeaderField{Name: ":scheme", Value: t.scheme})
|
||||
headerFields = append(headerFields, hpack.HeaderField{Name: ":path", Value: callHdr.Method})
|
||||
headerFields = append(headerFields, hpack.HeaderField{Name: ":authority", Value: callHdr.Host})
|
||||
headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(callHdr.ContentSubtype)})
|
||||
headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: grpcutil.ContentType(callHdr.ContentSubtype)})
|
||||
headerFields = append(headerFields, hpack.HeaderField{Name: "user-agent", Value: t.userAgent})
|
||||
headerFields = append(headerFields, hpack.HeaderField{Name: "te", Value: "trailers"})
|
||||
if callHdr.PreviousAttempts > 0 {
|
||||
|
@ -451,7 +452,7 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr)
|
|||
// Send out timeout regardless its value. The server can detect timeout context by itself.
|
||||
// TODO(mmukhi): Perhaps this field should be updated when actually writing out to the wire.
|
||||
timeout := time.Until(dl)
|
||||
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)})
|
||||
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-timeout", Value: grpcutil.EncodeDuration(timeout)})
|
||||
}
|
||||
for k, v := range authData {
|
||||
headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
|
||||
|
|
|
@ -34,6 +34,7 @@ import (
|
|||
"github.com/golang/protobuf/proto"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
"google.golang.org/grpc/internal/grpcutil"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
|
@ -804,7 +805,7 @@ func (t *http2Server) writeHeaderLocked(s *Stream) error {
|
|||
// first and create a slice of that exact size.
|
||||
headerFields := make([]hpack.HeaderField, 0, 2) // at least :status, content-type will be there if none else.
|
||||
headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"})
|
||||
headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(s.contentSubtype)})
|
||||
headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: grpcutil.ContentType(s.contentSubtype)})
|
||||
if s.sendCompress != "" {
|
||||
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress})
|
||||
}
|
||||
|
@ -854,7 +855,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
|
|||
}
|
||||
} else { // Send a trailer only response.
|
||||
headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"})
|
||||
headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(s.contentSubtype)})
|
||||
headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: grpcutil.ContentType(s.contentSubtype)})
|
||||
}
|
||||
}
|
||||
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status", Value: strconv.Itoa(int(st.Code()))})
|
||||
|
|
|
@ -38,6 +38,7 @@ import (
|
|||
spb "google.golang.org/genproto/googleapis/rpc/status"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/grpclog"
|
||||
"google.golang.org/grpc/internal/grpcutil"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
|
@ -51,7 +52,7 @@ const (
|
|||
// "proto" as a suffix after "+" or ";". See
|
||||
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests
|
||||
// for more details.
|
||||
baseContentType = "application/grpc"
|
||||
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -184,46 +185,6 @@ func isWhitelistedHeader(hdr string) bool {
|
|||
}
|
||||
}
|
||||
|
||||
// contentSubtype returns the content-subtype for the given content-type. The
|
||||
// given content-type must be a valid content-type that starts with
|
||||
// "application/grpc". A content-subtype will follow "application/grpc" after a
|
||||
// "+" or ";". See
|
||||
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
|
||||
// more details.
|
||||
//
|
||||
// If contentType is not a valid content-type for gRPC, the boolean
|
||||
// will be false, otherwise true. If content-type == "application/grpc",
|
||||
// "application/grpc+", or "application/grpc;", the boolean will be true,
|
||||
// but no content-subtype will be returned.
|
||||
//
|
||||
// contentType is assumed to be lowercase already.
|
||||
func contentSubtype(contentType string) (string, bool) {
|
||||
if contentType == baseContentType {
|
||||
return "", true
|
||||
}
|
||||
if !strings.HasPrefix(contentType, baseContentType) {
|
||||
return "", false
|
||||
}
|
||||
// guaranteed since != baseContentType and has baseContentType prefix
|
||||
switch contentType[len(baseContentType)] {
|
||||
case '+', ';':
|
||||
// this will return true for "application/grpc+" or "application/grpc;"
|
||||
// which the previous validContentType function tested to be valid, so we
|
||||
// just say that no content-subtype is specified in this case
|
||||
return contentType[len(baseContentType)+1:], true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
// contentSubtype is assumed to be lowercase
|
||||
func contentType(contentSubtype string) string {
|
||||
if contentSubtype == "" {
|
||||
return baseContentType
|
||||
}
|
||||
return baseContentType + "+" + contentSubtype
|
||||
}
|
||||
|
||||
func (d *decodeState) status() *status.Status {
|
||||
if d.data.statusGen == nil {
|
||||
// No status-details were provided; generate status using code/msg.
|
||||
|
@ -342,7 +303,7 @@ func (d *decodeState) addMetadata(k, v string) {
|
|||
func (d *decodeState) processHeaderField(f hpack.HeaderField) {
|
||||
switch f.Name {
|
||||
case "content-type":
|
||||
contentSubtype, validContentType := contentSubtype(f.Value)
|
||||
contentSubtype, validContentType := grpcutil.ContentSubtype(f.Value)
|
||||
if !validContentType {
|
||||
d.data.contentTypeErr = fmt.Sprintf("transport: received the unexpected content-type %q", f.Value)
|
||||
return
|
||||
|
@ -453,41 +414,6 @@ func timeoutUnitToDuration(u timeoutUnit) (d time.Duration, ok bool) {
|
|||
return
|
||||
}
|
||||
|
||||
const maxTimeoutValue int64 = 100000000 - 1
|
||||
|
||||
// div does integer division and round-up the result. Note that this is
|
||||
// equivalent to (d+r-1)/r but has less chance to overflow.
|
||||
func div(d, r time.Duration) int64 {
|
||||
if m := d % r; m > 0 {
|
||||
return int64(d/r + 1)
|
||||
}
|
||||
return int64(d / r)
|
||||
}
|
||||
|
||||
// TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it.
|
||||
func encodeTimeout(t time.Duration) string {
|
||||
if t <= 0 {
|
||||
return "0n"
|
||||
}
|
||||
if d := div(t, time.Nanosecond); d <= maxTimeoutValue {
|
||||
return strconv.FormatInt(d, 10) + "n"
|
||||
}
|
||||
if d := div(t, time.Microsecond); d <= maxTimeoutValue {
|
||||
return strconv.FormatInt(d, 10) + "u"
|
||||
}
|
||||
if d := div(t, time.Millisecond); d <= maxTimeoutValue {
|
||||
return strconv.FormatInt(d, 10) + "m"
|
||||
}
|
||||
if d := div(t, time.Second); d <= maxTimeoutValue {
|
||||
return strconv.FormatInt(d, 10) + "S"
|
||||
}
|
||||
if d := div(t, time.Minute); d <= maxTimeoutValue {
|
||||
return strconv.FormatInt(d, 10) + "M"
|
||||
}
|
||||
// Note that maxTimeoutValue * time.Hour > MaxInt64.
|
||||
return strconv.FormatInt(div(t, time.Hour), 10) + "H"
|
||||
}
|
||||
|
||||
func decodeTimeout(s string) (time.Duration, error) {
|
||||
size := len(s)
|
||||
if size < 2 {
|
||||
|
|
|
@ -25,33 +25,6 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
func (s) TestTimeoutEncode(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
in string
|
||||
out string
|
||||
}{
|
||||
{"12345678ns", "12345678n"},
|
||||
{"123456789ns", "123457u"},
|
||||
{"12345678us", "12345678u"},
|
||||
{"123456789us", "123457m"},
|
||||
{"12345678ms", "12345678m"},
|
||||
{"123456789ms", "123457S"},
|
||||
{"12345678s", "12345678S"},
|
||||
{"123456789s", "2057614M"},
|
||||
{"12345678m", "12345678M"},
|
||||
{"123456789m", "2057614H"},
|
||||
} {
|
||||
d, err := time.ParseDuration(test.in)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse duration string %s: %v", test.in, err)
|
||||
}
|
||||
out := encodeTimeout(d)
|
||||
if out != test.out {
|
||||
t.Fatalf("timeoutEncode(%s) = %s, want %s", test.in, out, test.out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestTimeoutDecode(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
// input
|
||||
|
@ -72,29 +45,6 @@ func (s) TestTimeoutDecode(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s) TestContentSubtype(t *testing.T) {
|
||||
tests := []struct {
|
||||
contentType string
|
||||
want string
|
||||
wantValid bool
|
||||
}{
|
||||
{"application/grpc", "", true},
|
||||
{"application/grpc+", "", true},
|
||||
{"application/grpc+blah", "blah", true},
|
||||
{"application/grpc;", "", true},
|
||||
{"application/grpc;blah", "blah", true},
|
||||
{"application/grpcd", "", false},
|
||||
{"application/grpd", "", false},
|
||||
{"application/grp", "", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got, gotValid := contentSubtype(tt.contentType)
|
||||
if got != tt.want || gotValid != tt.wantValid {
|
||||
t.Errorf("contentSubtype(%q) = (%v, %v); want (%v, %v)", tt.contentType, got, gotValid, tt.want, tt.wantValid)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestEncodeGrpcMessage(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
input string
|
||||
|
|
12
stream.go
12
stream.go
|
@ -35,6 +35,7 @@ import (
|
|||
"google.golang.org/grpc/internal/binarylog"
|
||||
"google.golang.org/grpc/internal/channelz"
|
||||
"google.golang.org/grpc/internal/grpcrand"
|
||||
"google.golang.org/grpc/internal/grpcutil"
|
||||
"google.golang.org/grpc/internal/transport"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/peer"
|
||||
|
@ -346,7 +347,16 @@ func (cs *clientStream) newAttemptLocked(sh stats.Handler, trInfo *traceInfo) (r
|
|||
if err := cs.ctx.Err(); err != nil {
|
||||
return toRPCErr(err)
|
||||
}
|
||||
t, done, err := cs.cc.getTransport(cs.ctx, cs.callInfo.failFast, cs.callHdr.Method)
|
||||
|
||||
ctx := cs.ctx
|
||||
if cs.cc.parsedTarget.Scheme == "xds" {
|
||||
// Add extra metadata (metadata that will be added by transport) to context
|
||||
// so the balancer can see them.
|
||||
ctx = grpcutil.WithExtraMetadata(cs.ctx, metadata.Pairs(
|
||||
"content-type", grpcutil.ContentType(cs.callHdr.ContentSubtype),
|
||||
))
|
||||
}
|
||||
t, done, err := cs.cc.getTransport(ctx, cs.callInfo.failFast, cs.callHdr.Method)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -39,6 +39,7 @@ import (
|
|||
"google.golang.org/grpc/internal/balancer/stub"
|
||||
"google.golang.org/grpc/internal/balancerload"
|
||||
"google.golang.org/grpc/internal/grpcsync"
|
||||
"google.golang.org/grpc/internal/grpcutil"
|
||||
"google.golang.org/grpc/internal/testutils"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/resolver"
|
||||
|
@ -60,6 +61,7 @@ type testBalancer struct {
|
|||
|
||||
newSubConnOptions balancer.NewSubConnOptions
|
||||
pickInfos []balancer.PickInfo
|
||||
pickExtraMDs []metadata.MD
|
||||
doneInfo []balancer.DoneInfo
|
||||
}
|
||||
|
||||
|
@ -124,8 +126,10 @@ func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
|
|||
if p.err != nil {
|
||||
return balancer.PickResult{}, p.err
|
||||
}
|
||||
extraMD, _ := grpcutil.ExtraMetadata(info.Ctx)
|
||||
info.Ctx = nil // Do not validate context.
|
||||
p.bal.pickInfos = append(p.bal.pickInfos, info)
|
||||
p.bal.pickExtraMDs = append(p.bal.pickExtraMDs, extraMD)
|
||||
return balancer.PickResult{SubConn: p.sc, Done: func(d balancer.DoneInfo) { p.bal.doneInfo = append(p.bal.doneInfo, d) }}, nil
|
||||
}
|
||||
|
||||
|
@ -157,6 +161,60 @@ func (s) TestCredsBundleFromBalancer(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s) TestPickExtraMetadata(t *testing.T) {
|
||||
for _, e := range listTestEnv() {
|
||||
testPickExtraMetadata(t, e)
|
||||
}
|
||||
}
|
||||
|
||||
func testPickExtraMetadata(t *testing.T, e env) {
|
||||
te := newTest(t, e)
|
||||
b := &testBalancer{}
|
||||
balancer.Register(b)
|
||||
const (
|
||||
testUserAgent = "test-user-agent"
|
||||
testSubContentType = "proto"
|
||||
)
|
||||
|
||||
te.customDialOptions = []grpc.DialOption{
|
||||
grpc.WithBalancerName(testBalancerName),
|
||||
grpc.WithUserAgent(testUserAgent),
|
||||
}
|
||||
te.startServer(&testServer{security: e.security})
|
||||
defer te.tearDown()
|
||||
|
||||
// Set resolver to xds to trigger the extra metadata code path.
|
||||
r := manual.NewBuilderWithScheme("xds")
|
||||
resolver.Register(r)
|
||||
defer func() {
|
||||
resolver.UnregisterForTesting("xds")
|
||||
}()
|
||||
r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: te.srvAddr}}})
|
||||
te.resolverScheme = "xds"
|
||||
cc := te.clientConn()
|
||||
tc := testpb.NewTestServiceClient(cc)
|
||||
|
||||
// The RPCs will fail, but we don't care. We just need the pick to happen.
|
||||
ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel1()
|
||||
tc.EmptyCall(ctx1, &testpb.Empty{})
|
||||
|
||||
ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel2()
|
||||
tc.EmptyCall(ctx2, &testpb.Empty{}, grpc.CallContentSubtype(testSubContentType))
|
||||
|
||||
want := []metadata.MD{
|
||||
// First RPC doesn't have sub-content-type.
|
||||
{"content-type": []string{"application/grpc"}},
|
||||
// Second RPC has sub-content-type "proto".
|
||||
{"content-type": []string{"application/grpc+proto"}},
|
||||
}
|
||||
|
||||
if !cmp.Equal(b.pickExtraMDs, want) {
|
||||
t.Fatalf("%s", cmp.Diff(b.pickExtraMDs, want))
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestDoneInfo(t *testing.T) {
|
||||
for _, e := range listTestEnv() {
|
||||
testDoneInfo(t, e)
|
||||
|
|
|
@ -20,9 +20,11 @@ package xdsrouting
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/internal/grpcrand"
|
||||
"google.golang.org/grpc/internal/grpcutil"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
|
@ -47,6 +49,16 @@ func (a *compositeMatcher) match(info balancer.PickInfo) bool {
|
|||
var md metadata.MD
|
||||
if info.Ctx != nil {
|
||||
md, _ = metadata.FromOutgoingContext(info.Ctx)
|
||||
if extraMD, ok := grpcutil.ExtraMetadata(info.Ctx); ok {
|
||||
md = metadata.Join(md, extraMD)
|
||||
// Remove all binary headers. They are hard to match with. May need
|
||||
// to add back if asked by users.
|
||||
for k := range md {
|
||||
if strings.HasSuffix(k, "-bin") {
|
||||
delete(md, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, m := range a.hms {
|
||||
if !m.match(md) {
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
|
||||
"google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/internal/grpcrand"
|
||||
"google.golang.org/grpc/internal/grpcutil"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
|
@ -65,6 +66,32 @@ func TestAndMatcherMatch(t *testing.T) {
|
|||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "fake header",
|
||||
pm: newPathPrefixMatcher("/"),
|
||||
hm: newHeaderExactMatcher("content-type", "fake"),
|
||||
info: balancer.PickInfo{
|
||||
FullMethodName: "/a/b",
|
||||
Ctx: grpcutil.WithExtraMetadata(context.Background(), metadata.Pairs(
|
||||
"content-type", "fake",
|
||||
)),
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "binary header",
|
||||
pm: newPathPrefixMatcher("/"),
|
||||
hm: newHeaderPresentMatcher("t-bin", true),
|
||||
info: balancer.PickInfo{
|
||||
FullMethodName: "/a/b",
|
||||
Ctx: grpcutil.WithExtraMetadata(
|
||||
metadata.NewOutgoingContext(context.Background(), metadata.Pairs("t-bin", "123")), metadata.Pairs(
|
||||
"content-type", "fake",
|
||||
)),
|
||||
},
|
||||
// Shouldn't match binary header, even though it's in metadata.
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
|
|
@ -100,9 +100,6 @@ func (jc lbConfigJSON) toLBConfig() *lbConfig {
|
|||
tempR.regex = *r.Regex
|
||||
}
|
||||
for _, h := range r.Headers {
|
||||
if h.RangeMatch != nil {
|
||||
fmt.Println("range not nil", *h.RangeMatch)
|
||||
}
|
||||
var tempHeader headerMatcher
|
||||
switch {
|
||||
case h.ExactMatch != nil:
|
||||
|
|
Loading…
Reference in New Issue