From 080f9563df8b6a1eb4230cfcd500bc9d7930ecf8 Mon Sep 17 00:00:00 2001 From: eshitachandwani <59800922+eshitachandwani@users.noreply.github.com> Date: Wed, 30 Apr 2025 14:41:28 +0530 Subject: [PATCH] credentials, transport, grpc : add a call option to override the :authority header on a per-RPC basis (#8068) --- credentials/credentials.go | 14 ++ credentials/credentials_ext_test.go | 332 ++++++++++++++++++++++++++++ credentials/insecure/insecure.go | 6 + credentials/tls.go | 16 ++ credentials/tls_ext_test.go | 1 + internal/transport/http2_client.go | 19 ++ internal/transport/transport.go | 5 + rpc_util.go | 31 +++ stream.go | 1 + 9 files changed, 425 insertions(+) create mode 100644 credentials/credentials_ext_test.go diff --git a/credentials/credentials.go b/credentials/credentials.go index 6bba20f3b..a63ab606e 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -120,6 +120,20 @@ type AuthInfo interface { AuthType() string } +// AuthorityValidator validates the authority used to override the `:authority` +// header. This is an optional interface that implementations of AuthInfo can +// implement if they support per-RPC authority overrides. It is invoked when the +// application attempts to override the HTTP/2 `:authority` header using the +// CallAuthority call option. +type AuthorityValidator interface { + // ValidateAuthority checks the authority value used to override the + // `:authority` header. The authority parameter is the override value + // provided by the application via the CallAuthority option. This value + // typically corresponds to the server hostname or endpoint the RPC is + // targeting. It returns non-nil error if the validation fails. + ValidateAuthority(authority string) error +} + // ErrConnDispatched indicates that rawConn has been dispatched out of gRPC // and the caller should not close rawConn. var ErrConnDispatched = errors.New("credentials: rawConn is dispatched out of gRPC") diff --git a/credentials/credentials_ext_test.go b/credentials/credentials_ext_test.go new file mode 100644 index 000000000..27da3a265 --- /dev/null +++ b/credentials/credentials_ext_test.go @@ -0,0 +1,332 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package credentials_test + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "testing" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "google.golang.org/grpc/testdata" + + testgrpc "google.golang.org/grpc/interop/grpc_testing" + testpb "google.golang.org/grpc/interop/grpc_testing" +) + +func authorityChecker(ctx context.Context, wantAuthority string) error { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return status.Error(codes.InvalidArgument, "failed to parse metadata") + } + auths, ok := md[":authority"] + if !ok { + return status.Error(codes.InvalidArgument, "no authority header") + } + if len(auths) != 1 { + return status.Errorf(codes.InvalidArgument, "expected exactly one authority header, got %v", auths) + } + if auths[0] != wantAuthority { + return status.Errorf(codes.InvalidArgument, "invalid authority header %q, want %q", auths[0], wantAuthority) + } + return nil +} + +// Tests the `grpc.CallAuthority` option with TLS credentials. This test verifies +// that the provided authority is correctly propagated to the server when a +// correct authority is used. +func (s) TestCorrectAuthorityWithTLSCreds(t *testing.T) { + cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) + if err != nil { + t.Fatalf("Failed to load key pair: %s", err) + } + creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com") + if err != nil { + t.Fatalf("Failed to create credentials %v", err) + } + const authority = "auth.test.example.com" + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { + if err := authorityChecker(ctx, authority); err != nil { + return nil, err + } + return &testpb.Empty{}, nil + }, + } + if err := ss.StartServer(grpc.Creds(credentials.NewServerTLSFromCert(&cert))); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(creds)) + if err != nil { + t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err) + } + defer cc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(authority)); status.Code(err) != codes.OK { + t.Fatalf("EmptyCall() returned status %v, want %v", status.Code(err), codes.OK) + } + +} + +// Tests the `grpc.CallAuthority` option with TLS credentials. This test verifies +// that the RPC fails with `UNAVAILABLE` status code and doesn't reach the server +// when an incorrect authority is used. +func (s) TestIncorrectAuthorityWithTLS(t *testing.T) { + cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) + if err != nil { + t.Fatalf("Failed to load key pair: %s", err) + } + creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com") + if err != nil { + t.Fatalf("Failed to create credentials %v", err) + } + + serverCalled := make(chan struct{}) + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { + close(serverCalled) + return nil, nil + }, + } + if err := ss.StartServer(grpc.Creds(credentials.NewServerTLSFromCert(&cert))); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(creds)) + if err != nil { + t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err) + } + defer cc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + const authority = "auth.example.com" + if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(authority)); status.Code(err) != codes.Unavailable { + t.Fatalf("EmptyCall() returned status %v, want %v", status.Code(err), codes.Unavailable) + } + select { + case <-serverCalled: + t.Fatalf("Server handler should not have been called") + case <-time.After(defaultTestShortTimeout): + } +} + +// Tests the scenario where the `grpc.CallAuthority` call option is used with +// insecure transport credentials. The test verifies that the specified +// authority is correctly propagated to the server. +func (s) TestAuthorityCallOptionWithInsecureCreds(t *testing.T) { + const authority = "test.server.name" + + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { + if err := authorityChecker(ctx, authority); err != nil { + return nil, err + } + return &testpb.Empty{}, nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err) + } + defer cc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(authority)); err != nil { + t.Fatalf("EmptyCall() rpc failed: %v", err) + } +} + +// testAuthInfoNoValidator implements only credentials.AuthInfo and not +// credentials.AuthorityValidator. +type testAuthInfoNoValidator struct{} + +// AuthType returns the authentication type. +func (testAuthInfoNoValidator) AuthType() string { + return "test" +} + +// testAuthInfoWithValidator implements both credentials.AuthInfo and +// credentials.AuthorityValidator. +type testAuthInfoWithValidator struct { + validAuthority string +} + +// AuthType returns the authentication type. +func (testAuthInfoWithValidator) AuthType() string { + return "test" +} + +// ValidateAuthority implements credentials.AuthorityValidator. +func (v testAuthInfoWithValidator) ValidateAuthority(authority string) error { + if authority == v.validAuthority { + return nil + } + return fmt.Errorf("invalid authority %q, want %q", authority, v.validAuthority) +} + +// testCreds is a test TransportCredentials that can optionally support +// authority validation. +type testCreds struct { + authority string +} + +// ClientHandshake performs the client-side handshake. +func (c *testCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + if c.authority != "" { + return rawConn, testAuthInfoWithValidator{validAuthority: c.authority}, nil + } + return rawConn, testAuthInfoNoValidator{}, nil +} + +// ServerHandshake performs the server-side handshake. +func (c *testCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + if c.authority != "" { + return rawConn, testAuthInfoWithValidator{validAuthority: c.authority}, nil + } + return rawConn, testAuthInfoNoValidator{}, nil +} + +// Clone creates a copy of testCreds. +func (c *testCreds) Clone() credentials.TransportCredentials { + return &testCreds{authority: c.authority} +} + +// Info provides protocol information. +func (c *testCreds) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{} +} + +// OverrideServerName overrides the server name used for verification. +func (c *testCreds) OverrideServerName(serverName string) error { + return nil +} + +// TestAuthorityValidationFailureWithCustomCreds tests the `grpc.CallAuthority` +// call option using custom credentials. It covers two failure scenarios: +// - The credentials implement AuthorityValidator but authority used to override +// is not valid. +// - The credentials do not implement AuthorityValidator, but an authority +// override is specified. +// In both cases, the RPC is expected to fail with an `UNAVAILABLE` status code. +func (s) TestAuthorityValidationFailureWithCustomCreds(t *testing.T) { + tests := []struct { + name string + creds credentials.TransportCredentials + authority string + }{ + { + name: "IncorrectAuthorityWithFakeCreds", + authority: "auth.example.com", + creds: &testCreds{authority: "auth.test.example.com"}, + }, + { + name: "FakeCredsWithNoAuthValidator", + creds: &testCreds{}, + authority: "auth.test.example.com", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + serverCalled := make(chan struct{}) + ss := stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { + close(serverCalled) + return nil, nil + }, + } + if err := ss.StartServer(); err != nil { + t.Fatalf("Failed to start stub server: %v", err) + } + defer ss.Stop() + + cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(tt.creds)) + if err != nil { + t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err) + } + defer cc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(tt.authority)); status.Code(err) != codes.Unavailable { + t.Fatalf("EmptyCall() returned status %v, want %v", status.Code(err), codes.Unavailable) + } + select { + case <-serverCalled: + t.Fatalf("Server should not have been called") + case <-time.After(defaultTestShortTimeout): + } + }) + } + +} + +// TestCorrectAuthorityWithCustomCreds tests the `grpc.CallAuthority` call +// option using custom credentials. It verifies that the provided authority is +// correctly propagated to the server when a correct authority is used. +func (s) TestCorrectAuthorityWithCustomCreds(t *testing.T) { + const authority = "auth.test.example.com" + creds := &testCreds{authority: "auth.test.example.com"} + ss := stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { + if err := authorityChecker(ctx, authority); err != nil { + return nil, err + } + return &testpb.Empty{}, nil + }, + } + if err := ss.StartServer(); err != nil { + t.Fatalf("Failed to start stub server: %v", err) + } + defer ss.Stop() + + cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(creds)) + if err != nil { + t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err) + } + defer cc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(authority)); status.Code(err) != codes.OK { + t.Fatalf("EmptyCall() returned status %v, want %v", status.Code(err), codes.OK) + } +} diff --git a/credentials/insecure/insecure.go b/credentials/insecure/insecure.go index 4c805c644..f45c04f7c 100644 --- a/credentials/insecure/insecure.go +++ b/credentials/insecure/insecure.go @@ -71,6 +71,12 @@ func (info) AuthType() string { return "insecure" } +// ValidateAuthority allows any value to be overridden for the :authority +// header. +func (info) ValidateAuthority(string) error { + return nil +} + // insecureBundle implements an insecure bundle. // An insecure bundle provides a thin wrapper around insecureTC to support // the credentials.Bundle interface. diff --git a/credentials/tls.go b/credentials/tls.go index bd5fe22b6..20f65f7bd 100644 --- a/credentials/tls.go +++ b/credentials/tls.go @@ -22,6 +22,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "errors" "fmt" "net" "net/url" @@ -50,6 +51,21 @@ func (t TLSInfo) AuthType() string { return "tls" } +// ValidateAuthority validates the provided authority being used to override the +// :authority header by verifying it against the peer certificates. It returns a +// non-nil error if the validation fails. +func (t TLSInfo) ValidateAuthority(authority string) error { + var errs []error + for _, cert := range t.State.PeerCertificates { + var err error + if err = cert.VerifyHostname(authority); err == nil { + return nil + } + errs = append(errs, err) + } + return fmt.Errorf("credentials: invalid authority %q: %v", authority, errors.Join(errs...)) +} + // cipherSuiteLookup returns the string version of a TLS cipher suite ID. func cipherSuiteLookup(cipherSuiteID uint16) string { for _, s := range tls.CipherSuites() { diff --git a/credentials/tls_ext_test.go b/credentials/tls_ext_test.go index 22881a6f4..ceb810a4a 100644 --- a/credentials/tls_ext_test.go +++ b/credentials/tls_ext_test.go @@ -43,6 +43,7 @@ import ( ) const defaultTestTimeout = 10 * time.Second +const defaultTestShortTimeout = 10 * time.Millisecond type s struct { grpctest.Tester diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 4c2d30574..32047128c 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -749,6 +749,25 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*ClientS callHdr = &newCallHdr } + // The authority specified via the `CallAuthority` CallOption takes the + // highest precedence when determining the `:authority` header. It overrides + // any value present in the Host field of CallHdr. Before applying this + // override, the authority string is validated. If the credentials do not + // implement the AuthorityValidator interface, or if validation fails, the + // RPC is failed with a status code of `UNAVAILABLE`. + if callHdr.Authority != "" { + auth, ok := t.authInfo.(credentials.AuthorityValidator) + if !ok { + return nil, &NewStreamError{Err: status.Errorf(codes.Unavailable, "credentials type %q does not implement the AuthorityValidator interface, but authority override specified with CallAuthority call option", t.authInfo.AuthType())} + } + if err := auth.ValidateAuthority(callHdr.Authority); err != nil { + return nil, &NewStreamError{Err: status.Errorf(codes.Unavailable, "failed to validate authority %q : %v", callHdr.Authority, err)} + } + newCallHdr := *callHdr + newCallHdr.Host = callHdr.Authority + callHdr = &newCallHdr + } + headerFields, err := t.createHeaderFields(ctx, callHdr) if err != nil { return nil, &NewStreamError{Err: err, AllowTransparentRetry: false} diff --git a/internal/transport/transport.go b/internal/transport/transport.go index af4a4aeab..1730a639f 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -540,6 +540,11 @@ type CallHdr struct { PreviousAttempts int // value of grpc-previous-rpc-attempts header to set DoneFunc func() // called when the stream is finished + + // Authority is used to explicitly override the `:authority` header. If set, + // this value takes precedence over the Host field and will be used as the + // value for the `:authority` header. + Authority string } // ClientTransport is the common interface for all gRPC client-side transport diff --git a/rpc_util.go b/rpc_util.go index ad20e9dff..47ea09f5c 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -160,6 +160,7 @@ type callInfo struct { codec baseCodec maxRetryRPCBufferSize int onFinish []func(err error) + authority string } func defaultCallInfo() *callInfo { @@ -365,6 +366,36 @@ func (o MaxRecvMsgSizeCallOption) before(c *callInfo) error { } func (o MaxRecvMsgSizeCallOption) after(*callInfo, *csAttempt) {} +// CallAuthority returns a CallOption that sets the HTTP/2 :authority header of +// an RPC to the specified value. When using CallAuthority, the credentials in +// use must implement the AuthorityValidator interface. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a later +// release. +func CallAuthority(authority string) CallOption { + return AuthorityOverrideCallOption{Authority: authority} +} + +// AuthorityOverrideCallOption is a CallOption that indicates the HTTP/2 +// :authority header value to use for the call. +// +// # Experimental +// +// Notice: This type is EXPERIMENTAL and may be changed or removed in a later +// release. +type AuthorityOverrideCallOption struct { + Authority string +} + +func (o AuthorityOverrideCallOption) before(c *callInfo) error { + c.authority = o.Authority + return nil +} + +func (o AuthorityOverrideCallOption) after(*callInfo, *csAttempt) {} + // MaxCallSendMsgSize returns a CallOption which sets the maximum message size // in bytes the client can send. If this is not set, gRPC uses the default // `math.MaxInt32`. diff --git a/stream.go b/stream.go index 01e66c1ed..bbbcbc724 100644 --- a/stream.go +++ b/stream.go @@ -297,6 +297,7 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client Method: method, ContentSubtype: callInfo.contentSubtype, DoneFunc: doneFunc, + Authority: callInfo.authority, } // Set our outgoing compression according to the UseCompressor CallOption, if