mirror of https://github.com/grpc/grpc-go.git
tls: append h2 to tlsconfig.NextProtos (#2744)
This commit is contained in:
parent
b03f6fd5e3
commit
4abb3622b0
|
|
@ -36,9 +36,6 @@ import (
|
||||||
"google.golang.org/grpc/credentials/internal"
|
"google.golang.org/grpc/credentials/internal"
|
||||||
)
|
)
|
||||||
|
|
||||||
// alpnProtoStr are the specified application level protocols for gRPC.
|
|
||||||
var alpnProtoStr = []string{"h2"}
|
|
||||||
|
|
||||||
// PerRPCCredentials defines the common interface for the credentials which need to
|
// PerRPCCredentials defines the common interface for the credentials which need to
|
||||||
// attach security information to every RPC (e.g., oauth2).
|
// attach security information to every RPC (e.g., oauth2).
|
||||||
type PerRPCCredentials interface {
|
type PerRPCCredentials interface {
|
||||||
|
|
@ -208,10 +205,23 @@ func (c *tlsCreds) OverrideServerName(serverNameOverride string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const alpnProtoStrH2 = "h2"
|
||||||
|
|
||||||
|
func appendH2ToNextProtos(ps []string) []string {
|
||||||
|
for _, p := range ps {
|
||||||
|
if p == alpnProtoStrH2 {
|
||||||
|
return ps
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ret := make([]string, 0, len(ps)+1)
|
||||||
|
ret = append(ret, ps...)
|
||||||
|
return append(ret, alpnProtoStrH2)
|
||||||
|
}
|
||||||
|
|
||||||
// NewTLS uses c to construct a TransportCredentials based on TLS.
|
// NewTLS uses c to construct a TransportCredentials based on TLS.
|
||||||
func NewTLS(c *tls.Config) TransportCredentials {
|
func NewTLS(c *tls.Config) TransportCredentials {
|
||||||
tc := &tlsCreds{cloneTLSConfig(c)}
|
tc := &tlsCreds{cloneTLSConfig(c)}
|
||||||
tc.config.NextProtos = alpnProtoStr
|
tc.config.NextProtos = appendH2ToNextProtos(tc.config.NextProtos)
|
||||||
return tc
|
return tc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"net"
|
"net"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"google.golang.org/grpc/testdata"
|
"google.golang.org/grpc/testdata"
|
||||||
|
|
@ -204,3 +205,39 @@ func tlsClientHandshake(conn net.Conn, _ string) (AuthInfo, error) {
|
||||||
}
|
}
|
||||||
return TLSInfo{State: clientConn.ConnectionState()}, nil
|
return TLSInfo{State: clientConn.ConnectionState()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAppendH2ToNextProtos(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ps []string
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
ps: nil,
|
||||||
|
want: []string{"h2"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only h2",
|
||||||
|
ps: []string{"h2"},
|
||||||
|
want: []string{"h2"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with h2",
|
||||||
|
ps: []string{"alpn", "h2"},
|
||||||
|
want: []string{"alpn", "h2"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no h2",
|
||||||
|
ps: []string{"alpn"},
|
||||||
|
want: []string{"alpn", "h2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := appendH2ToNextProtos(tt.ps); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("appendH2ToNextProtos() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue