tls: append h2 to tlsconfig.NextProtos (#2744)

This commit is contained in:
Menghan Li 2019-04-08 09:56:02 -07:00 committed by GitHub
parent b03f6fd5e3
commit 4abb3622b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 51 additions and 4 deletions

View File

@ -36,9 +36,6 @@ import (
"google.golang.org/grpc/credentials/internal"
)
// alpnProtoStr are the specified application level protocols for gRPC.
var alpnProtoStr = []string{"h2"}
// PerRPCCredentials defines the common interface for the credentials which need to
// attach security information to every RPC (e.g., oauth2).
type PerRPCCredentials interface {
@ -208,10 +205,23 @@ func (c *tlsCreds) OverrideServerName(serverNameOverride string) error {
return nil
}
const alpnProtoStrH2 = "h2"
func appendH2ToNextProtos(ps []string) []string {
for _, p := range ps {
if p == alpnProtoStrH2 {
return ps
}
}
ret := make([]string, 0, len(ps)+1)
ret = append(ret, ps...)
return append(ret, alpnProtoStrH2)
}
// NewTLS uses c to construct a TransportCredentials based on TLS.
func NewTLS(c *tls.Config) TransportCredentials {
tc := &tlsCreds{cloneTLSConfig(c)}
tc.config.NextProtos = alpnProtoStr
tc.config.NextProtos = appendH2ToNextProtos(tc.config.NextProtos)
return tc
}

View File

@ -22,6 +22,7 @@ import (
"context"
"crypto/tls"
"net"
"reflect"
"testing"
"google.golang.org/grpc/testdata"
@ -204,3 +205,39 @@ func tlsClientHandshake(conn net.Conn, _ string) (AuthInfo, error) {
}
return TLSInfo{State: clientConn.ConnectionState()}, nil
}
func TestAppendH2ToNextProtos(t *testing.T) {
tests := []struct {
name string
ps []string
want []string
}{
{
name: "empty",
ps: nil,
want: []string{"h2"},
},
{
name: "only h2",
ps: []string{"h2"},
want: []string{"h2"},
},
{
name: "with h2",
ps: []string{"alpn", "h2"},
want: []string{"alpn", "h2"},
},
{
name: "no h2",
ps: []string{"alpn"},
want: []string{"alpn", "h2"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := appendH2ToNextProtos(tt.ps); !reflect.DeepEqual(got, tt.want) {
t.Errorf("appendH2ToNextProtos() = %v, want %v", got, tt.want)
}
})
}
}