mirror of https://github.com/grpc/grpc-go.git
credentials/ALTS: Ensure ALTS record protocol names are consistent (#4754)
This commit is contained in:
parent
16cf65612e
commit
4e07a14b4e
|
|
@ -40,11 +40,15 @@ func Test(t *testing.T) {
|
||||||
grpctest.RunSubTests(t, s{})
|
grpctest.RunSubTests(t, s{})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
rekeyRecordProtocol = "ALTSRP_GCM_AES128_REKEY"
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
nextProtocols = []string{"ALTSRP_GCM_AES128"}
|
recordProtocols = []string{rekeyRecordProtocol}
|
||||||
altsRecordFuncs = map[string]ALTSRecordFunc{
|
altsRecordFuncs = map[string]ALTSRecordFunc{
|
||||||
// ALTS handshaker protocols.
|
// ALTS handshaker protocols.
|
||||||
"ALTSRP_GCM_AES128": func(s core.Side, keyData []byte) (ALTSRecordCrypto, error) {
|
rekeyRecordProtocol: func(s core.Side, keyData []byte) (ALTSRecordCrypto, error) {
|
||||||
return NewAES128GCM(s, keyData)
|
return NewAES128GCM(s, keyData)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -77,7 +81,7 @@ func (c *testConn) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestALTSRecordConn(in, out *bytes.Buffer, side core.Side, np string, protected []byte) *conn {
|
func newTestALTSRecordConn(in, out *bytes.Buffer, side core.Side, rp string, protected []byte) *conn {
|
||||||
key := []byte{
|
key := []byte{
|
||||||
// 16 arbitrary bytes.
|
// 16 arbitrary bytes.
|
||||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49}
|
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49}
|
||||||
|
|
@ -85,23 +89,23 @@ func newTestALTSRecordConn(in, out *bytes.Buffer, side core.Side, np string, pro
|
||||||
in: in,
|
in: in,
|
||||||
out: out,
|
out: out,
|
||||||
}
|
}
|
||||||
c, err := NewConn(&tc, side, np, key, protected)
|
c, err := NewConn(&tc, side, rp, key, protected)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("Unexpected error creating test ALTS record connection: %v", err))
|
panic(fmt.Sprintf("Unexpected error creating test ALTS record connection: %v", err))
|
||||||
}
|
}
|
||||||
return c.(*conn)
|
return c.(*conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConnPair(np string, clientProtected []byte, serverProtected []byte) (client, server *conn) {
|
func newConnPair(rp string, clientProtected []byte, serverProtected []byte) (client, server *conn) {
|
||||||
clientBuf := new(bytes.Buffer)
|
clientBuf := new(bytes.Buffer)
|
||||||
serverBuf := new(bytes.Buffer)
|
serverBuf := new(bytes.Buffer)
|
||||||
clientConn := newTestALTSRecordConn(clientBuf, serverBuf, core.ClientSide, np, clientProtected)
|
clientConn := newTestALTSRecordConn(clientBuf, serverBuf, core.ClientSide, rp, clientProtected)
|
||||||
serverConn := newTestALTSRecordConn(serverBuf, clientBuf, core.ServerSide, np, serverProtected)
|
serverConn := newTestALTSRecordConn(serverBuf, clientBuf, core.ServerSide, rp, serverProtected)
|
||||||
return clientConn, serverConn
|
return clientConn, serverConn
|
||||||
}
|
}
|
||||||
|
|
||||||
func testPingPong(t *testing.T, np string) {
|
func testPingPong(t *testing.T, rp string) {
|
||||||
clientConn, serverConn := newConnPair(np, nil, nil)
|
clientConn, serverConn := newConnPair(rp, nil, nil)
|
||||||
clientMsg := []byte("Client Message")
|
clientMsg := []byte("Client Message")
|
||||||
if n, err := clientConn.Write(clientMsg); n != len(clientMsg) || err != nil {
|
if n, err := clientConn.Write(clientMsg); n != len(clientMsg) || err != nil {
|
||||||
t.Fatalf("Client Write() = %v, %v; want %v, <nil>", n, err, len(clientMsg))
|
t.Fatalf("Client Write() = %v, %v; want %v, <nil>", n, err, len(clientMsg))
|
||||||
|
|
@ -128,13 +132,13 @@ func testPingPong(t *testing.T, np string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s) TestPingPong(t *testing.T) {
|
func (s) TestPingPong(t *testing.T) {
|
||||||
for _, np := range nextProtocols {
|
for _, rp := range recordProtocols {
|
||||||
testPingPong(t, np)
|
testPingPong(t, rp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testSmallReadBuffer(t *testing.T, np string) {
|
func testSmallReadBuffer(t *testing.T, rp string) {
|
||||||
clientConn, serverConn := newConnPair(np, nil, nil)
|
clientConn, serverConn := newConnPair(rp, nil, nil)
|
||||||
msg := []byte("Very Important Message")
|
msg := []byte("Very Important Message")
|
||||||
if n, err := clientConn.Write(msg); err != nil {
|
if n, err := clientConn.Write(msg); err != nil {
|
||||||
t.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
|
t.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
|
||||||
|
|
@ -155,13 +159,13 @@ func testSmallReadBuffer(t *testing.T, np string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s) TestSmallReadBuffer(t *testing.T) {
|
func (s) TestSmallReadBuffer(t *testing.T) {
|
||||||
for _, np := range nextProtocols {
|
for _, rp := range recordProtocols {
|
||||||
testSmallReadBuffer(t, np)
|
testSmallReadBuffer(t, rp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testLargeMsg(t *testing.T, np string) {
|
func testLargeMsg(t *testing.T, rp string) {
|
||||||
clientConn, serverConn := newConnPair(np, nil, nil)
|
clientConn, serverConn := newConnPair(rp, nil, nil)
|
||||||
// msgLen is such that the length in the framing is larger than the
|
// msgLen is such that the length in the framing is larger than the
|
||||||
// default size of one frame.
|
// default size of one frame.
|
||||||
msgLen := altsRecordDefaultLength - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1
|
msgLen := altsRecordDefaultLength - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1
|
||||||
|
|
@ -179,12 +183,12 @@ func testLargeMsg(t *testing.T, np string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s) TestLargeMsg(t *testing.T) {
|
func (s) TestLargeMsg(t *testing.T) {
|
||||||
for _, np := range nextProtocols {
|
for _, rp := range recordProtocols {
|
||||||
testLargeMsg(t, np)
|
testLargeMsg(t, rp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testIncorrectMsgType(t *testing.T, np string) {
|
func testIncorrectMsgType(t *testing.T, rp string) {
|
||||||
// framedMsg is an empty ciphertext with correct framing but wrong
|
// framedMsg is an empty ciphertext with correct framing but wrong
|
||||||
// message type.
|
// message type.
|
||||||
framedMsg := make([]byte, MsgLenFieldSize+msgTypeFieldSize)
|
framedMsg := make([]byte, MsgLenFieldSize+msgTypeFieldSize)
|
||||||
|
|
@ -193,7 +197,7 @@ func testIncorrectMsgType(t *testing.T, np string) {
|
||||||
binary.LittleEndian.PutUint32(framedMsg[MsgLenFieldSize:], wrongMsgType)
|
binary.LittleEndian.PutUint32(framedMsg[MsgLenFieldSize:], wrongMsgType)
|
||||||
|
|
||||||
in := bytes.NewBuffer(framedMsg)
|
in := bytes.NewBuffer(framedMsg)
|
||||||
c := newTestALTSRecordConn(in, nil, core.ClientSide, np, nil)
|
c := newTestALTSRecordConn(in, nil, core.ClientSide, rp, nil)
|
||||||
b := make([]byte, 1)
|
b := make([]byte, 1)
|
||||||
if n, err := c.Read(b); n != 0 || err == nil {
|
if n, err := c.Read(b); n != 0 || err == nil {
|
||||||
t.Fatalf("Read() = <nil>, want %v", fmt.Errorf("received frame with incorrect message type %v", wrongMsgType))
|
t.Fatalf("Read() = <nil>, want %v", fmt.Errorf("received frame with incorrect message type %v", wrongMsgType))
|
||||||
|
|
@ -201,15 +205,15 @@ func testIncorrectMsgType(t *testing.T, np string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s) TestIncorrectMsgType(t *testing.T) {
|
func (s) TestIncorrectMsgType(t *testing.T) {
|
||||||
for _, np := range nextProtocols {
|
for _, rp := range recordProtocols {
|
||||||
testIncorrectMsgType(t, np)
|
testIncorrectMsgType(t, rp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testFrameTooLarge(t *testing.T, np string) {
|
func testFrameTooLarge(t *testing.T, rp string) {
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
clientConn := newTestALTSRecordConn(nil, buf, core.ClientSide, np, nil)
|
clientConn := newTestALTSRecordConn(nil, buf, core.ClientSide, rp, nil)
|
||||||
serverConn := newTestALTSRecordConn(buf, nil, core.ServerSide, np, nil)
|
serverConn := newTestALTSRecordConn(buf, nil, core.ServerSide, rp, nil)
|
||||||
// payloadLen is such that the length in the framing is larger than
|
// payloadLen is such that the length in the framing is larger than
|
||||||
// allowed in one frame.
|
// allowed in one frame.
|
||||||
payloadLen := altsRecordLengthLimit - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1
|
payloadLen := altsRecordLengthLimit - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1
|
||||||
|
|
@ -234,15 +238,15 @@ func testFrameTooLarge(t *testing.T, np string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s) TestFrameTooLarge(t *testing.T) {
|
func (s) TestFrameTooLarge(t *testing.T) {
|
||||||
for _, np := range nextProtocols {
|
for _, rp := range recordProtocols {
|
||||||
testFrameTooLarge(t, np)
|
testFrameTooLarge(t, rp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testWriteLargeData(t *testing.T, np string) {
|
func testWriteLargeData(t *testing.T, rp string) {
|
||||||
// Test sending and receiving messages larger than the maximum write
|
// Test sending and receiving messages larger than the maximum write
|
||||||
// buffer size.
|
// buffer size.
|
||||||
clientConn, serverConn := newConnPair(np, nil, nil)
|
clientConn, serverConn := newConnPair(rp, nil, nil)
|
||||||
// Message size is intentionally chosen to not be multiple of
|
// Message size is intentionally chosen to not be multiple of
|
||||||
// payloadLengthLimtit.
|
// payloadLengthLimtit.
|
||||||
msgSize := altsWriteBufferMaxSize + (100 * 1024)
|
msgSize := altsWriteBufferMaxSize + (100 * 1024)
|
||||||
|
|
@ -277,25 +281,25 @@ func testWriteLargeData(t *testing.T, np string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s) TestWriteLargeData(t *testing.T) {
|
func (s) TestWriteLargeData(t *testing.T) {
|
||||||
for _, np := range nextProtocols {
|
for _, rp := range recordProtocols {
|
||||||
testWriteLargeData(t, np)
|
testWriteLargeData(t, rp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testProtectedBuffer(t *testing.T, np string) {
|
func testProtectedBuffer(t *testing.T, rp string) {
|
||||||
key := []byte{
|
key := []byte{
|
||||||
// 16 arbitrary bytes.
|
// 16 arbitrary bytes.
|
||||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49}
|
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49}
|
||||||
|
|
||||||
// Encrypt a message to be passed to NewConn as a client-side protected
|
// Encrypt a message to be passed to NewConn as a client-side protected
|
||||||
// buffer.
|
// buffer.
|
||||||
newCrypto := protocols[np]
|
newCrypto := protocols[rp]
|
||||||
if newCrypto == nil {
|
if newCrypto == nil {
|
||||||
t.Fatalf("Unknown next protocol %q", np)
|
t.Fatalf("Unknown record protocol %q", rp)
|
||||||
}
|
}
|
||||||
crypto, err := newCrypto(core.ClientSide, key)
|
crypto, err := newCrypto(core.ClientSide, key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create a crypter for protocol %q: %v", np, err)
|
t.Fatalf("Failed to create a crypter for protocol %q: %v", rp, err)
|
||||||
}
|
}
|
||||||
msg := []byte("Client Protected Message")
|
msg := []byte("Client Protected Message")
|
||||||
encryptedMsg, err := crypto.Encrypt(nil, msg)
|
encryptedMsg, err := crypto.Encrypt(nil, msg)
|
||||||
|
|
@ -307,7 +311,7 @@ func testProtectedBuffer(t *testing.T, np string) {
|
||||||
binary.LittleEndian.PutUint32(protectedMsg[4:], altsRecordMsgType)
|
binary.LittleEndian.PutUint32(protectedMsg[4:], altsRecordMsgType)
|
||||||
protectedMsg = append(protectedMsg, encryptedMsg...)
|
protectedMsg = append(protectedMsg, encryptedMsg...)
|
||||||
|
|
||||||
_, serverConn := newConnPair(np, nil, protectedMsg)
|
_, serverConn := newConnPair(rp, nil, protectedMsg)
|
||||||
rcvClientMsg := make([]byte, len(msg))
|
rcvClientMsg := make([]byte, len(msg))
|
||||||
if n, err := serverConn.Read(rcvClientMsg); n != len(rcvClientMsg) || err != nil {
|
if n, err := serverConn.Read(rcvClientMsg); n != len(rcvClientMsg) || err != nil {
|
||||||
t.Fatalf("Server Read() = %v, %v; want %v, <nil>", n, err, len(rcvClientMsg))
|
t.Fatalf("Server Read() = %v, %v; want %v, <nil>", n, err, len(rcvClientMsg))
|
||||||
|
|
@ -318,7 +322,7 @@ func testProtectedBuffer(t *testing.T, np string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s) TestProtectedBuffer(t *testing.T) {
|
func (s) TestProtectedBuffer(t *testing.T) {
|
||||||
for _, np := range nextProtocols {
|
for _, rp := range recordProtocols {
|
||||||
testProtectedBuffer(t, np)
|
testProtectedBuffer(t, rp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue