credentials/local: implement ValidateAuthority (#8291)

This commit is contained in:
eshitachandwani 2025-05-09 02:24:49 +05:30 committed by GitHub
parent b3d63b180c
commit 4680429852
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 80 additions and 62 deletions

View File

@ -30,6 +30,7 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/credentials/local"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
@ -57,45 +58,88 @@ func authorityChecker(ctx context.Context, wantAuthority string) error {
return nil
}
// Tests the `grpc.CallAuthority` option with TLS credentials. This test verifies
// that the provided authority is correctly propagated to the server when a
// correct authority is used.
func (s) TestCorrectAuthorityWithTLSCreds(t *testing.T) {
func loadTLSCreds(t *testing.T) (grpc.ServerOption, grpc.DialOption) {
t.Helper()
cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
if err != nil {
t.Fatalf("Failed to load key pair: %s", err)
t.Fatalf("Failed to load key pair: %v", err)
return nil, nil
}
creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com")
serverCreds := grpc.Creds(credentials.NewServerTLSFromCert(&cert))
clientCreds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com")
if err != nil {
t.Fatalf("Failed to create credentials %v", err)
t.Fatalf("Failed to create client credentials: %v", err)
}
return serverCreds, grpc.WithTransportCredentials(clientCreds)
}
// Tests the scenario where the `grpc.CallAuthority` call option is used with
// different transport credentials. The test verifies that the specified
// authority is correctly propagated to the serve when a correct authority is
// used.
func (s) TestCorrectAuthorityWithCreds(t *testing.T) {
const authority = "auth.test.example.com"
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
if err := authorityChecker(ctx, authority); err != nil {
return nil, err
}
return &testpb.Empty{}, nil
tests := []struct {
name string
creds func(t *testing.T) (grpc.ServerOption, grpc.DialOption)
expectedAuth string
}{
{
name: "Insecure",
creds: func(t *testing.T) (grpc.ServerOption, grpc.DialOption) {
c := insecure.NewCredentials()
return grpc.Creds(c), grpc.WithTransportCredentials(c)
},
expectedAuth: authority,
},
{
name: "Local",
creds: func(t *testing.T) (grpc.ServerOption, grpc.DialOption) {
c := local.NewCredentials()
return grpc.Creds(c), grpc.WithTransportCredentials(c)
},
expectedAuth: authority,
},
{
name: "TLS",
creds: func(t *testing.T) (grpc.ServerOption, grpc.DialOption) {
return loadTLSCreds(t)
},
expectedAuth: authority,
},
}
if err := ss.StartServer(grpc.Creds(credentials.NewServerTLSFromCert(&cert))); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
if err := authorityChecker(ctx, tt.expectedAuth); err != nil {
return nil, err
}
return &testpb.Empty{}, nil
},
}
serverOpt, dialOpt := tt.creds(t)
if err := ss.StartServer(serverOpt); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()
cc, err := grpc.NewClient(ss.Address, dialOpt)
if err != nil {
t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err)
}
defer cc.Close()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(tt.expectedAuth)); err != nil {
t.Fatalf("EmptyCall() rpc failed: %v", err)
}
})
}
defer ss.Stop()
cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(creds))
if err != nil {
t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err)
}
defer cc.Close()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(authority)); status.Code(err) != codes.OK {
t.Fatalf("EmptyCall() returned status %v, want %v", status.Code(err), codes.OK)
}
}
// Tests the `grpc.CallAuthority` option with TLS credentials. This test verifies
@ -143,38 +187,6 @@ func (s) TestIncorrectAuthorityWithTLS(t *testing.T) {
}
}
// Tests the scenario where the `grpc.CallAuthority` call option is used with
// insecure transport credentials. The test verifies that the specified
// authority is correctly propagated to the server.
func (s) TestAuthorityCallOptionWithInsecureCreds(t *testing.T) {
const authority = "test.server.name"
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
if err := authorityChecker(ctx, authority); err != nil {
return nil, err
}
return &testpb.Empty{}, nil
},
}
if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()
cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient(%q) = %v", ss.Address, err)
}
defer cc.Close()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err = testgrpc.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(authority)); err != nil {
t.Fatalf("EmptyCall() rpc failed: %v", err)
}
}
// testAuthInfoNoValidator implements only credentials.AuthInfo and not
// credentials.AuthorityValidator.
type testAuthInfoNoValidator struct{}

View File

@ -49,6 +49,12 @@ func (info) AuthType() string {
return "local"
}
// ValidateAuthority allows any value to be overridden for the :authority
// header.
func (info) ValidateAuthority(string) error {
return nil
}
// localTC is the credentials required to establish a local connection.
type localTC struct {
info credentials.ProtocolInfo