go-sdk/client/crypto_test.go

289 lines
6.7 KiB
Go

/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package client
import (
"context"
"errors"
"fmt"
"io"
"strings"
"testing"
"time"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
commonv1 "github.com/dapr/dapr/pkg/proto/common/v1"
runtimev1pb "github.com/dapr/dapr/pkg/proto/runtime/v1"
)
func TestEncrypt(t *testing.T) {
ctx := context.Background()
t.Run("missing ComponentName", func(t *testing.T) {
out, err := testClient.Encrypt(ctx,
strings.NewReader("hello world"),
EncryptOptions{
// ComponentName: "mycomponent",
KeyName: "key",
KeyWrapAlgorithm: "algorithm",
},
)
require.Error(t, err)
require.ErrorContains(t, err, "ComponentName")
require.Nil(t, out)
})
t.Run("missing Key", func(t *testing.T) {
out, err := testClient.Encrypt(ctx,
strings.NewReader("hello world"),
EncryptOptions{
ComponentName: "mycomponent",
// KeyName: "key",
KeyWrapAlgorithm: "algorithm",
},
)
require.Error(t, err)
require.ErrorContains(t, err, "Key")
require.Nil(t, out)
})
t.Run("missing Algorithm", func(t *testing.T) {
out, err := testClient.Encrypt(ctx,
strings.NewReader("hello world"),
EncryptOptions{
ComponentName: "mycomponent",
KeyName: "key",
// Algorithm: "algorithm",
},
)
require.Error(t, err)
require.ErrorContains(t, err, "Algorithm")
require.Nil(t, out)
})
t.Run("receiving back data sent", func(t *testing.T) {
// The test server doesn't actually encrypt data
out, err := testClient.Encrypt(ctx,
strings.NewReader("hello world"),
EncryptOptions{
ComponentName: "mycomponent",
KeyName: "key",
KeyWrapAlgorithm: "algorithm",
},
)
require.NoError(t, err)
require.NotNil(t, out)
read, err := io.ReadAll(out)
require.NoError(t, err)
require.Equal(t, "hello world", string(read))
})
t.Run("error in input stream", func(t *testing.T) {
out, err := testClient.Encrypt(ctx,
&failingReader{
data: strings.NewReader("hello world"),
},
EncryptOptions{
ComponentName: "mycomponent",
KeyName: "key",
KeyWrapAlgorithm: "algorithm",
},
)
require.NoError(t, err)
require.NotNil(t, out)
_, err = io.ReadAll(out)
require.Error(t, err)
require.ErrorContains(t, err, "simulated")
})
t.Run("context canceled", func(t *testing.T) {
failingCtx, failingCancel := context.WithTimeout(ctx, time.Second)
defer failingCancel()
out, err := testClient.Encrypt(failingCtx,
&slowReader{
// Should take a lot longer than 1s
//nolint:dupword
data: strings.NewReader("soft kitty, warm kitty, little ball of fur, happy kitty, sleepy kitty, purr purr purr"),
delay: time.Second,
},
EncryptOptions{
ComponentName: "mycomponent",
KeyName: "key",
KeyWrapAlgorithm: "algorithm",
},
)
require.NoError(t, err)
require.NotNil(t, out)
_, err = io.ReadAll(out)
require.Error(t, err)
require.ErrorContains(t, err, "context deadline exceeded")
})
}
func TestDecrypt(t *testing.T) {
ctx := context.Background()
t.Run("missing ComponentName", func(t *testing.T) {
out, err := testClient.Decrypt(ctx,
strings.NewReader("hello world"),
DecryptOptions{
// ComponentName: "mycomponent",
},
)
require.Error(t, err)
require.ErrorContains(t, err, "ComponentName")
require.Nil(t, out)
})
t.Run("receiving back data sent", func(t *testing.T) {
// The test server doesn't actually decrypt data
out, err := testClient.Decrypt(ctx,
strings.NewReader("hello world"),
DecryptOptions{
ComponentName: "mycomponent",
},
)
require.NoError(t, err)
require.NotNil(t, out)
read, err := io.ReadAll(out)
require.NoError(t, err)
require.Equal(t, "hello world", string(read))
})
t.Run("error in input stream", func(t *testing.T) {
out, err := testClient.Decrypt(ctx,
&failingReader{
data: strings.NewReader("hello world"),
},
DecryptOptions{
ComponentName: "mycomponent",
},
)
require.NoError(t, err)
require.NotNil(t, out)
_, err = io.ReadAll(out)
require.Error(t, err)
require.ErrorContains(t, err, "simulated")
})
}
/* --- Server methods --- */
func (s *testDaprServer) EncryptAlpha1(stream runtimev1pb.Dapr_EncryptAlpha1Server) error {
return s.performCryptoOperation(
stream,
&runtimev1pb.EncryptRequest{},
&runtimev1pb.EncryptResponse{},
)
}
func (s *testDaprServer) DecryptAlpha1(stream runtimev1pb.Dapr_DecryptAlpha1Server) error {
return s.performCryptoOperation(
stream,
&runtimev1pb.DecryptRequest{},
&runtimev1pb.DecryptResponse{},
)
}
func (s *testDaprServer) performCryptoOperation(stream grpc.ServerStream, reqProto runtimev1pb.CryptoRequests, resProto runtimev1pb.CryptoResponses) error {
// This doesn't really encrypt or decrypt the data and just sends back whatever it receives
pr, pw := io.Pipe()
go func() {
var (
done bool
err error
expectSeq uint64
)
first := true
for !done && stream.Context().Err() == nil {
reqProto.Reset()
err = stream.RecvMsg(reqProto)
if errors.Is(err, io.EOF) {
done = true
} else if err != nil {
pw.CloseWithError(err)
return
}
if first && !reqProto.HasOptions() {
pw.CloseWithError(errors.New("first message must have options"))
return
} else if !first && reqProto.HasOptions() {
pw.CloseWithError(errors.New("messages after first must not have options"))
return
}
first = false
payload := reqProto.GetPayload()
if payload != nil {
if payload.GetSeq() != expectSeq {
pw.CloseWithError(fmt.Errorf("invalid sequence number: %d (expected: %d)", payload.GetSeq(), expectSeq))
return
}
expectSeq++
_, err = pw.Write(payload.GetData())
if err != nil {
pw.CloseWithError(err)
return
}
}
}
pw.Close()
}()
var (
done bool
n int
err error
seq uint64
)
buf := make([]byte, 2<<10)
for !done && stream.Context().Err() == nil {
resProto.Reset()
n, err = pr.Read(buf)
if errors.Is(err, io.EOF) {
done = true
} else if err != nil {
return err
}
if n > 0 {
resProto.SetPayload(&commonv1.StreamPayload{
Seq: seq,
Data: buf[:n],
})
seq++
err = stream.SendMsg(resProto)
if err != nil {
return err
}
}
}
return nil
}