mirror of https://github.com/dapr/kit.git
896 lines
26 KiB
Go
896 lines
26 KiB
Go
/*
|
||
Copyright 2023 The Dapr Authors
|
||
Licensed under the Apache License, Version 2.0 (the "License");
|
||
you may not use this file except in compliance with the License.
|
||
You may obtain a copy of the License at
|
||
http://www.apache.org/licenses/LICENSE-2.0
|
||
Unless required by applicable law or agreed to in writing, software
|
||
distributed under the License is distributed on an "AS IS" BASIS,
|
||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
See the License for the specific language governing permissions and
|
||
limitations under the License.
|
||
*/
|
||
|
||
package v1
|
||
|
||
import (
|
||
"bytes"
|
||
"encoding/json"
|
||
"errors"
|
||
"io"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
"sync"
|
||
"testing"
|
||
|
||
"github.com/stretchr/testify/require"
|
||
)
|
||
|
||
var (
|
||
errSimulatedStream = errors.New("simulated stream error")
|
||
errSimulated = errors.New("simulated")
|
||
)
|
||
|
||
func TestScheme(t *testing.T) {
|
||
// Fake wrapKeyFn and unwrapKeyFn, which just return the plaintext key
|
||
//nolint:stylecheck,revive
|
||
var wrapKeyFn WrapKeyFn = func(plaintextKey []byte, algorithm, keyName string, nonce []byte) (wrappedKey []byte, tag []byte, err error) {
|
||
return plaintextKey, nil, nil
|
||
}
|
||
//nolint:stylecheck,revive
|
||
var unwrapKeyFn UnwrapKeyFn = func(wrappedKey []byte, algorithm, keyName string, nonce, tag []byte) (plaintextKey []byte, err error) {
|
||
return wrappedKey, nil
|
||
}
|
||
|
||
// In all these tests, the key name and wrapping algorithms don't matter as we don't actually wrap/unwrap keys
|
||
const keyName = "mykey"
|
||
const algorithm = KeyAlgorithmAES
|
||
|
||
testData := map[string][]byte{
|
||
// Data is short and fits in a single segment
|
||
"single-segment": []byte("hello world"),
|
||
// Data is larger than a single segment (120KB)
|
||
"multi-segment": bytes.Repeat([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0}, 12<<10),
|
||
// Data is exactly the size of a segment (64KB)
|
||
"one-full-segment": bytes.Repeat([]byte{1, 2, 3, 4, 5, 6, 7, 8}, 8<<10),
|
||
// Data is exactly the size of two segments (128KB)
|
||
"two-full-segments": bytes.Repeat([]byte{1, 2, 3, 4, 5, 6, 7, 8}, 16<<10),
|
||
// Large file (300KB)
|
||
"large-file": bytes.Repeat([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0}, 30<<10),
|
||
// Empty message - this should succeed
|
||
"empty-message": {},
|
||
}
|
||
|
||
t.Run("encrypt and decrypt", func(t *testing.T) {
|
||
testFn := func(message []byte, cipher Cipher) func(t *testing.T) {
|
||
return func(t *testing.T) {
|
||
// Encrypt the message
|
||
enc, err := Encrypt(
|
||
bytes.NewReader(message),
|
||
EncryptOptions{
|
||
WrapKeyFn: wrapKeyFn,
|
||
KeyName: keyName,
|
||
Algorithm: algorithm,
|
||
Cipher: &cipher,
|
||
},
|
||
)
|
||
require.NoError(t, err)
|
||
|
||
// Read the encrypted data
|
||
encData, err := io.ReadAll(enc)
|
||
require.NoError(t, err)
|
||
require.NotEmpty(t, encData)
|
||
|
||
// Sanity check of the header
|
||
// First, ensure the scheme name is there
|
||
idx := bytes.IndexByte(encData, '\n')
|
||
require.Equal(t, 14, idx)
|
||
require.Equal(t, SchemeName, string(encData[0:idx]))
|
||
|
||
// Second, check that the JSON manifest is present and valid
|
||
start := idx + 1
|
||
idx = bytes.IndexByte(encData[start:], '\n')
|
||
require.Greater(t, idx, 0)
|
||
var manifest Manifest
|
||
err = json.Unmarshal(encData[start:(start+idx)], &manifest)
|
||
require.NoError(t, err)
|
||
require.NoError(t, manifest.Validate())
|
||
require.Equal(t, keyName, manifest.KeyName)
|
||
require.Equal(t, algorithm.ID(), manifest.KeyWrappingAlgorithm.ID())
|
||
require.Equal(t, cipher.ID(), manifest.Cipher.ID())
|
||
require.Len(t, manifest.WFK, 32)
|
||
require.Len(t, manifest.NoncePrefix, 7)
|
||
|
||
// Third, check that we have the MAC
|
||
// We are not validating the MAC here as the decryption code will do it; we'll just check it's present and 44-byte long (when encoded as base64)
|
||
start += idx + 1
|
||
idx = bytes.IndexByte(encData[start:], '\n')
|
||
require.Greater(t, idx, 0)
|
||
require.Len(t, encData[start:(start+idx)], 44)
|
||
|
||
// Decrypt the encrypted data
|
||
dec, err := Decrypt(
|
||
bytes.NewReader(encData),
|
||
DecryptOptions{
|
||
UnwrapKeyFn: unwrapKeyFn,
|
||
},
|
||
)
|
||
require.NoError(t, err)
|
||
|
||
// The encrypted data should match
|
||
decData, err := io.ReadAll(dec)
|
||
require.NoError(t, err)
|
||
require.Equal(t, message, decData)
|
||
}
|
||
}
|
||
|
||
testFnAllCiphers := func(message []byte) func(t *testing.T) {
|
||
return func(t *testing.T) {
|
||
t.Run("with AES-GCM", testFn(message, CipherAESGCM))
|
||
t.Run("with ChaCha20-Poly1305", testFn(message, CipherChaCha20Poly1305))
|
||
}
|
||
}
|
||
|
||
t.Run("single-segment", testFnAllCiphers(testData["single-segment"]))
|
||
t.Run("multi-segment", testFnAllCiphers(testData["multi-segment"]))
|
||
t.Run("one-full-segment", testFnAllCiphers(testData["one-full-segment"]))
|
||
t.Run("two-full-segments", testFnAllCiphers(testData["two-full-segments"]))
|
||
t.Run("large-file", testFnAllCiphers(testData["large-file"]))
|
||
t.Run("empty-message", testFnAllCiphers(testData["empty-message"]))
|
||
})
|
||
|
||
t.Run("decrypt test data", func(t *testing.T) {
|
||
testFn := func(fileName string, expectData []byte) func(t *testing.T) {
|
||
return func(t *testing.T) {
|
||
enc, err := os.Open(filepath.Join("testdata", fileName))
|
||
require.NoError(t, err)
|
||
defer enc.Close()
|
||
|
||
// Decrypt the encrypted data
|
||
dec, err := Decrypt(
|
||
enc,
|
||
DecryptOptions{
|
||
UnwrapKeyFn: unwrapKeyFn,
|
||
},
|
||
)
|
||
require.NoError(t, err)
|
||
|
||
// The encrypted data should match
|
||
decData, err := io.ReadAll(dec)
|
||
require.NoError(t, err)
|
||
require.Equal(t, expectData, decData)
|
||
}
|
||
}
|
||
|
||
t.Run("single-segment", testFn("single-segment.enc", testData["single-segment"]))
|
||
t.Run("multi-segment", testFn("multi-segment.enc", testData["multi-segment"]))
|
||
t.Run("one-full-segment", testFn("one-full-segment.enc", testData["one-full-segment"]))
|
||
t.Run("two-full-segments", testFn("two-full-segments.enc", testData["two-full-segments"]))
|
||
t.Run("empty-message", testFn("empty-message.enc", testData["empty-message"]))
|
||
})
|
||
|
||
t.Run("default cipher in encryption is AES-GCM", func(t *testing.T) {
|
||
// Encrypt the message
|
||
enc, err := Encrypt(
|
||
strings.NewReader("hello world"),
|
||
EncryptOptions{
|
||
WrapKeyFn: wrapKeyFn,
|
||
KeyName: keyName,
|
||
Algorithm: algorithm,
|
||
// Explicitly set to nil
|
||
Cipher: nil,
|
||
},
|
||
)
|
||
require.NoError(t, err)
|
||
|
||
// Read the encrypted data
|
||
encData, err := io.ReadAll(enc)
|
||
require.NoError(t, err)
|
||
require.NotEmpty(t, encData)
|
||
|
||
// Get the JSON manifest
|
||
start := bytes.IndexByte(encData, '{')
|
||
require.Greater(t, start, 14)
|
||
end := start + bytes.IndexByte(encData[start:], '\n')
|
||
require.Greater(t, end, start)
|
||
var manifest Manifest
|
||
err = json.Unmarshal(encData[start:end], &manifest)
|
||
require.NoError(t, err)
|
||
require.NoError(t, manifest.Validate())
|
||
require.Equal(t, CipherAESGCM.ID(), manifest.Cipher.ID())
|
||
require.Len(t, manifest.WFK, 32)
|
||
require.Len(t, manifest.NoncePrefix, 7)
|
||
})
|
||
|
||
t.Run("encryption option DecryptionKeyName", func(t *testing.T) {
|
||
// Encrypt the message
|
||
enc, err := Encrypt(
|
||
bytes.NewReader(testData["single-segment"]),
|
||
EncryptOptions{
|
||
WrapKeyFn: wrapKeyFn,
|
||
KeyName: keyName,
|
||
Algorithm: algorithm,
|
||
DecryptionKeyName: "dec-key",
|
||
},
|
||
)
|
||
require.NoError(t, err)
|
||
|
||
// Read the encrypted data
|
||
encData, err := io.ReadAll(enc)
|
||
require.NoError(t, err)
|
||
require.NotEmpty(t, encData)
|
||
|
||
// Get the JSON manifest
|
||
start := bytes.IndexByte(encData, '{')
|
||
require.Greater(t, start, 14)
|
||
end := start + bytes.IndexByte(encData[start:], '\n')
|
||
require.Greater(t, end, start)
|
||
var manifest Manifest
|
||
err = json.Unmarshal(encData[start:end], &manifest)
|
||
require.NoError(t, err)
|
||
require.NoError(t, manifest.Validate())
|
||
require.Equal(t, "dec-key", manifest.KeyName)
|
||
})
|
||
|
||
t.Run("encryption option OmitKeyName", func(t *testing.T) {
|
||
// Encrypt the message
|
||
enc, err := Encrypt(
|
||
bytes.NewReader(testData["single-segment"]),
|
||
EncryptOptions{
|
||
WrapKeyFn: wrapKeyFn,
|
||
KeyName: keyName,
|
||
Algorithm: algorithm,
|
||
DecryptionKeyName: "dec-key", // Should be ignored
|
||
OmitKeyName: true,
|
||
},
|
||
)
|
||
require.NoError(t, err)
|
||
|
||
// Read the encrypted data
|
||
encData, err := io.ReadAll(enc)
|
||
require.NoError(t, err)
|
||
require.NotEmpty(t, encData)
|
||
|
||
// Get the JSON manifest
|
||
start := bytes.IndexByte(encData, '{')
|
||
require.Greater(t, start, 14)
|
||
end := start + bytes.IndexByte(encData[start:], '\n')
|
||
require.Greater(t, end, start)
|
||
var manifest Manifest
|
||
err = json.Unmarshal(encData[start:end], &manifest)
|
||
require.NoError(t, err)
|
||
require.NoError(t, manifest.Validate())
|
||
require.Empty(t, manifest.KeyName)
|
||
})
|
||
|
||
t.Run("decryption of a message created with OmitKeyName requires passing a key name", func(t *testing.T) {
|
||
enc, err := os.Open(filepath.Join("testdata", "single-segment-no-key-name.enc"))
|
||
require.NoError(t, err)
|
||
defer enc.Close()
|
||
|
||
// Decryption requires passing the key name
|
||
dec, err := Decrypt(
|
||
enc,
|
||
DecryptOptions{
|
||
KeyName: "mykey",
|
||
UnwrapKeyFn: unwrapKeyFn,
|
||
},
|
||
)
|
||
require.NoError(t, err)
|
||
|
||
// The encrypted data should match
|
||
decData, err := io.ReadAll(dec)
|
||
require.NoError(t, err)
|
||
require.Equal(t, testData["single-segment"], decData)
|
||
})
|
||
|
||
t.Run("decryption of a message created with OmitKeyName fails without a key name", func(t *testing.T) {
|
||
enc, err := os.Open(filepath.Join("testdata", "single-segment-no-key-name.enc"))
|
||
require.NoError(t, err)
|
||
defer enc.Close()
|
||
|
||
// Do not pass a key name
|
||
dec, err := Decrypt(
|
||
enc,
|
||
DecryptOptions{
|
||
UnwrapKeyFn: unwrapKeyFn,
|
||
},
|
||
)
|
||
require.Error(t, err)
|
||
require.ErrorIs(t, err, ErrDecryptionKeyMissing)
|
||
require.Nil(t, dec)
|
||
})
|
||
|
||
t.Run("wrapKeyFn receives the key name and algorithm", func(t *testing.T) {
|
||
var (
|
||
gotKeyName string
|
||
gotAlgorithm string
|
||
)
|
||
_, err := Encrypt(
|
||
strings.NewReader("hello world"),
|
||
EncryptOptions{
|
||
WrapKeyFn: func(plaintextKey []byte, algorithm, keyName string, nonce []byte) (wrappedKey []byte, tag []byte, err error) {
|
||
gotAlgorithm = algorithm
|
||
gotKeyName = keyName
|
||
return wrapKeyFn(plaintextKey, algorithm, keyName, nonce)
|
||
},
|
||
// The actual values don't matter in this test
|
||
KeyName: "fakekey",
|
||
Algorithm: KeyAlgorithmRSAOAEP256,
|
||
// Explicitly set to nil
|
||
Cipher: nil,
|
||
},
|
||
)
|
||
require.NoError(t, err)
|
||
|
||
require.Equal(t, "fakekey", gotKeyName)
|
||
require.Equal(t, string(KeyAlgorithmRSAOAEP256), gotAlgorithm)
|
||
})
|
||
|
||
t.Run("override key name in decryption", func(t *testing.T) {
|
||
enc, err := os.Open(filepath.Join("testdata", "single-segment.enc"))
|
||
require.NoError(t, err)
|
||
defer enc.Close()
|
||
|
||
// Decrypt the encrypted data
|
||
var gotKeyName string
|
||
dec, err := Decrypt(
|
||
enc,
|
||
DecryptOptions{
|
||
// Although we're passing a different value for keyName, we still return the same key so decryption will work
|
||
KeyName: "anotherkey",
|
||
UnwrapKeyFn: func(wrappedKey []byte, algorithm, keyName string, nonce, tag []byte) (plaintextKey []byte, err error) {
|
||
gotKeyName = keyName
|
||
return unwrapKeyFn(wrappedKey, algorithm, keyName, nonce, tag)
|
||
},
|
||
},
|
||
)
|
||
require.NoError(t, err)
|
||
|
||
// The encrypted data should match
|
||
decData, err := io.ReadAll(dec)
|
||
require.NoError(t, err)
|
||
require.Equal(t, testData["single-segment"], decData)
|
||
|
||
// The key name should be "anotherkey"
|
||
require.Equal(t, "anotherkey", gotKeyName)
|
||
})
|
||
|
||
t.Run("encryption fails with input stream error", func(t *testing.T) {
|
||
enc, err := Encrypt(
|
||
&failingReader{},
|
||
EncryptOptions{
|
||
WrapKeyFn: wrapKeyFn,
|
||
KeyName: keyName,
|
||
Algorithm: algorithm,
|
||
},
|
||
)
|
||
require.NoError(t, err)
|
||
|
||
// Read the encrypted data
|
||
_, err = io.ReadAll(enc)
|
||
require.Error(t, err)
|
||
require.ErrorIs(t, err, errSimulatedStream)
|
||
})
|
||
|
||
t.Run("wrapping key fails in Encrypt", func(t *testing.T) {
|
||
enc, err := Encrypt(
|
||
&bytes.Buffer{},
|
||
EncryptOptions{
|
||
WrapKeyFn: func(plaintextKey []byte, algorithm, keyName string, nonce []byte) (wrappedKey []byte, tag []byte, err error) {
|
||
return nil, nil, errSimulated
|
||
},
|
||
KeyName: keyName,
|
||
Algorithm: algorithm,
|
||
},
|
||
)
|
||
require.Error(t, err)
|
||
require.ErrorContains(t, err, "failed to wrap the file key")
|
||
require.Nil(t, enc)
|
||
})
|
||
|
||
t.Run("unwrapping key fails in Decrypt", func(t *testing.T) {
|
||
f, err := os.Open(filepath.Join("testdata", "single-segment.enc"))
|
||
require.NoError(t, err)
|
||
defer f.Close()
|
||
|
||
// When the wrapping function returns an error, that is swallowed and the user will only see "failed to validate the document's signature"
|
||
// That's by design
|
||
dec, err := Decrypt(
|
||
f,
|
||
DecryptOptions{
|
||
UnwrapKeyFn: func(wrappedKey []byte, algorithm, keyName string, nonce, tag []byte) (plaintextKey []byte, err error) {
|
||
return nil, errSimulated
|
||
},
|
||
},
|
||
)
|
||
require.Error(t, err)
|
||
require.ErrorContains(t, err, "failed to validate the document's signature")
|
||
require.Nil(t, dec)
|
||
})
|
||
|
||
t.Run("unwrapping key returns different key in Decrypt", func(t *testing.T) {
|
||
f, err := os.Open(filepath.Join("testdata", "single-segment.enc"))
|
||
require.NoError(t, err)
|
||
defer f.Close()
|
||
|
||
dec, err := Decrypt(
|
||
f,
|
||
DecryptOptions{
|
||
UnwrapKeyFn: func(wrappedKey []byte, algorithm, keyName string, nonce, tag []byte) (plaintextKey []byte, err error) {
|
||
return bytes.Repeat([]byte{1, 2, 3, 4, 5, 6, 7, 8}, 4), nil
|
||
},
|
||
},
|
||
)
|
||
require.Error(t, err)
|
||
require.ErrorContains(t, err, "failed to validate the document's signature")
|
||
require.Nil(t, dec)
|
||
})
|
||
|
||
t.Run("decryption fails with scheme name not found", func(t *testing.T) {
|
||
dec, err := Decrypt(
|
||
strings.NewReader("foo"),
|
||
DecryptOptions{
|
||
UnwrapKeyFn: unwrapKeyFn,
|
||
},
|
||
)
|
||
require.Error(t, err)
|
||
require.ErrorContains(t, err, "invalid header: scheme name not found")
|
||
require.Nil(t, dec)
|
||
})
|
||
|
||
t.Run("decryption fails with scheme name not matching", func(t *testing.T) {
|
||
dec, err := Decrypt(
|
||
strings.NewReader("invalidscheme\nfoo"),
|
||
DecryptOptions{
|
||
UnwrapKeyFn: unwrapKeyFn,
|
||
},
|
||
)
|
||
require.Error(t, err)
|
||
require.ErrorContains(t, err, "invalid header: unsupported scheme")
|
||
require.Nil(t, dec)
|
||
})
|
||
|
||
t.Run("decryption fails with manifest not found", func(t *testing.T) {
|
||
f, err := os.Open(filepath.Join("testdata", "single-segment.enc"))
|
||
require.NoError(t, err)
|
||
defer f.Close()
|
||
|
||
dec, err := Decrypt(
|
||
&failingReader{data: io.LimitReader(f, 20)},
|
||
DecryptOptions{
|
||
UnwrapKeyFn: unwrapKeyFn,
|
||
},
|
||
)
|
||
require.Error(t, err)
|
||
require.ErrorContains(t, err, "invalid header: manifest not found")
|
||
require.Nil(t, dec)
|
||
})
|
||
|
||
t.Run("decryption fails with manifest zero bytes", func(t *testing.T) {
|
||
dec, err := Decrypt(
|
||
strings.NewReader("dapr.io/enc/v1\n\n"),
|
||
DecryptOptions{
|
||
UnwrapKeyFn: unwrapKeyFn,
|
||
},
|
||
)
|
||
require.Error(t, err)
|
||
require.ErrorContains(t, err, "invalid header: invalid format")
|
||
require.Nil(t, dec)
|
||
})
|
||
|
||
t.Run("decryption fails with manifest not valid JSON", func(t *testing.T) {
|
||
f, err := os.Open(filepath.Join("testdata", "single-segment.enc"))
|
||
require.NoError(t, err)
|
||
defer f.Close()
|
||
|
||
// This manifest will not unmarshal as JSON
|
||
rr := newReplaceReader(f, 15, 116, strings.NewReader("notjson"))
|
||
|
||
dec, err := Decrypt(
|
||
rr,
|
||
DecryptOptions{
|
||
UnwrapKeyFn: unwrapKeyFn,
|
||
},
|
||
)
|
||
require.Error(t, err)
|
||
require.ErrorContains(t, err, "invalid header: invalid manifest")
|
||
require.Nil(t, dec)
|
||
})
|
||
|
||
t.Run("decryption fails with manifest not validating", func(t *testing.T) {
|
||
f, err := os.Open(filepath.Join("testdata", "single-segment.enc"))
|
||
require.NoError(t, err)
|
||
defer f.Close()
|
||
|
||
// This manifest will unmarshal into the Manifest struct, but will fail the Validate() method
|
||
// We won't test all possible violations here because they're tested in the manifest_test.go file
|
||
rr := newReplaceReader(f, 15, 116, strings.NewReader(`{"wk":"foo"}`))
|
||
|
||
dec, err := Decrypt(
|
||
rr,
|
||
DecryptOptions{
|
||
UnwrapKeyFn: unwrapKeyFn,
|
||
},
|
||
)
|
||
require.Error(t, err)
|
||
require.ErrorContains(t, err, "invalid header: invalid manifest")
|
||
require.Nil(t, dec)
|
||
})
|
||
|
||
t.Run("decryption fails with MAC not found", func(t *testing.T) {
|
||
f, err := os.Open(filepath.Join("testdata", "single-segment.enc"))
|
||
require.NoError(t, err)
|
||
defer f.Close()
|
||
|
||
dec, err := Decrypt(
|
||
&failingReader{data: io.LimitReader(f, 120)},
|
||
DecryptOptions{
|
||
UnwrapKeyFn: unwrapKeyFn,
|
||
},
|
||
)
|
||
require.Error(t, err)
|
||
require.ErrorContains(t, err, "invalid header: message authentication code not found")
|
||
require.Nil(t, dec)
|
||
})
|
||
|
||
t.Run("decryption fails with MAC zero bytes", func(t *testing.T) {
|
||
f, err := os.Open(filepath.Join("testdata", "single-segment.enc"))
|
||
require.NoError(t, err)
|
||
defer f.Close()
|
||
|
||
in := io.MultiReader(
|
||
io.LimitReader(f, 117),
|
||
bytes.NewReader([]byte{'\n'}),
|
||
)
|
||
dec, err := Decrypt(
|
||
&failingReader{data: in},
|
||
DecryptOptions{
|
||
UnwrapKeyFn: unwrapKeyFn,
|
||
},
|
||
)
|
||
require.Error(t, err)
|
||
require.ErrorContains(t, err, "invalid header: invalid format")
|
||
require.Nil(t, dec)
|
||
})
|
||
|
||
t.Run("decryption fails with MAC not valid Base64", func(t *testing.T) {
|
||
f, err := os.Open(filepath.Join("testdata", "single-segment.enc"))
|
||
require.NoError(t, err)
|
||
defer f.Close()
|
||
|
||
// Replace some bytes in the MAC
|
||
rr := newReplaceReader(f, 120, 121, strings.NewReader("*"))
|
||
|
||
dec, err := Decrypt(
|
||
rr,
|
||
DecryptOptions{
|
||
UnwrapKeyFn: unwrapKeyFn,
|
||
},
|
||
)
|
||
require.Error(t, err)
|
||
require.ErrorContains(t, err, "failed to decode header's signature: illegal base64 data")
|
||
require.Nil(t, dec)
|
||
})
|
||
|
||
t.Run("decryption fails with header too long", func(t *testing.T) {
|
||
f, err := os.Open(filepath.Join("testdata", "single-segment.enc"))
|
||
require.NoError(t, err)
|
||
defer f.Close()
|
||
|
||
// After the manifest (included in the first 120 bytes), add 80KB of zeros
|
||
in := io.MultiReader(
|
||
io.LimitReader(f, 120),
|
||
bytes.NewReader(bytes.Repeat([]byte{0}, 120<<10)),
|
||
)
|
||
dec, err := Decrypt(
|
||
in,
|
||
DecryptOptions{
|
||
UnwrapKeyFn: unwrapKeyFn,
|
||
},
|
||
)
|
||
require.Error(t, err)
|
||
require.ErrorContains(t, err, "invalid header: message authentication code not found")
|
||
require.Nil(t, dec)
|
||
})
|
||
|
||
t.Run("decryption fails with input stream error after header", func(t *testing.T) {
|
||
f, err := os.Open(filepath.Join("testdata", "single-segment.enc"))
|
||
require.NoError(t, err)
|
||
defer f.Close()
|
||
|
||
dec, err := Decrypt(
|
||
&failingReader{data: f},
|
||
DecryptOptions{
|
||
UnwrapKeyFn: unwrapKeyFn,
|
||
},
|
||
)
|
||
require.NoError(t, err)
|
||
|
||
_, err = io.ReadAll(dec)
|
||
require.Error(t, err)
|
||
require.ErrorIs(t, err, errSimulatedStream)
|
||
})
|
||
|
||
t.Run("decryption fails when a byte is changed in the ciphertext", func(t *testing.T) {
|
||
f, err := os.Open(filepath.Join("testdata", "large-file.enc"))
|
||
require.NoError(t, err)
|
||
defer f.Close()
|
||
|
||
// Replace a byte in the second segment (segment 1)
|
||
rr := newReplaceReader(f, 100_000, 100_001, bytes.NewReader([]byte{'A'}))
|
||
|
||
dec, err := Decrypt(
|
||
rr,
|
||
DecryptOptions{
|
||
UnwrapKeyFn: unwrapKeyFn,
|
||
},
|
||
)
|
||
require.NoError(t, err)
|
||
|
||
_, err = io.ReadAll(dec)
|
||
require.Error(t, err)
|
||
require.ErrorIs(t, err, ErrDecryptionFailed)
|
||
require.ErrorContains(t, err, "error processing segment 1")
|
||
})
|
||
|
||
t.Run("decryption fails when a segment is removed from the ciphertext", func(t *testing.T) {
|
||
f, err := os.Open(filepath.Join("testdata", "large-file.enc"))
|
||
require.NoError(t, err)
|
||
defer f.Close()
|
||
|
||
// Remove the third segment (segment 2)
|
||
rr := newReplaceReader(f, 162+(SegmentSize+SegmentOverhead)*2, 162+(SegmentSize+SegmentOverhead)*3, &bytes.Buffer{})
|
||
|
||
dec, err := Decrypt(
|
||
rr,
|
||
DecryptOptions{
|
||
UnwrapKeyFn: unwrapKeyFn,
|
||
},
|
||
)
|
||
require.NoError(t, err)
|
||
|
||
_, err = io.ReadAll(dec)
|
||
require.Error(t, err)
|
||
require.ErrorIs(t, err, ErrDecryptionFailed)
|
||
require.ErrorContains(t, err, "error processing segment 2")
|
||
})
|
||
|
||
t.Run("decryption fails when the last segment is removed from the ciphertext", func(t *testing.T) {
|
||
f, err := os.Open(filepath.Join("testdata", "large-file.enc"))
|
||
require.NoError(t, err)
|
||
defer f.Close()
|
||
|
||
// Remove the last segment (segment 4)
|
||
// This will fail on segment 3 because at that point it becomes the last
|
||
rr := newReplaceReader(f, 162+(SegmentSize+SegmentOverhead)*4, -1, &bytes.Buffer{})
|
||
|
||
dec, err := Decrypt(
|
||
rr,
|
||
DecryptOptions{
|
||
UnwrapKeyFn: unwrapKeyFn,
|
||
},
|
||
)
|
||
require.NoError(t, err)
|
||
|
||
_, err = io.ReadAll(dec)
|
||
require.Error(t, err)
|
||
require.ErrorIs(t, err, ErrDecryptionFailed)
|
||
require.ErrorContains(t, err, "error processing segment 3")
|
||
})
|
||
|
||
t.Run("init errors for Encrypt", func(t *testing.T) {
|
||
t.Run("input stream is nil", func(t *testing.T) {
|
||
out, err := Encrypt(nil, EncryptOptions{
|
||
WrapKeyFn: wrapKeyFn,
|
||
KeyName: keyName,
|
||
Algorithm: algorithm,
|
||
})
|
||
require.Error(t, err)
|
||
require.ErrorContains(t, err, "in stream is nil")
|
||
require.Nil(t, out)
|
||
})
|
||
|
||
t.Run("option WrapKeyFn is empty", func(t *testing.T) {
|
||
out, err := Encrypt(&bytes.Buffer{}, EncryptOptions{
|
||
KeyName: keyName,
|
||
Algorithm: algorithm,
|
||
})
|
||
require.Error(t, err)
|
||
require.ErrorContains(t, err, "option WrapKeyFn is required")
|
||
require.Nil(t, out)
|
||
})
|
||
|
||
t.Run("option KeyName is empty", func(t *testing.T) {
|
||
out, err := Encrypt(&bytes.Buffer{}, EncryptOptions{
|
||
WrapKeyFn: wrapKeyFn,
|
||
Algorithm: algorithm,
|
||
})
|
||
require.Error(t, err)
|
||
require.ErrorContains(t, err, "option KeyName is required")
|
||
require.Nil(t, out)
|
||
})
|
||
|
||
t.Run("option Algorithm is empty", func(t *testing.T) {
|
||
out, err := Encrypt(&bytes.Buffer{}, EncryptOptions{
|
||
WrapKeyFn: wrapKeyFn,
|
||
KeyName: keyName,
|
||
})
|
||
require.Error(t, err)
|
||
require.ErrorContains(t, err, "option Algorithm is required")
|
||
require.Nil(t, out)
|
||
})
|
||
|
||
t.Run("option Algorithm is invalid", func(t *testing.T) {
|
||
out, err := Encrypt(&bytes.Buffer{}, EncryptOptions{
|
||
WrapKeyFn: wrapKeyFn,
|
||
KeyName: keyName,
|
||
Algorithm: "invalid",
|
||
})
|
||
require.Error(t, err)
|
||
require.ErrorContains(t, err, "option Algorithm is not valid")
|
||
require.Nil(t, out)
|
||
})
|
||
|
||
t.Run("option Cipher is invalid", func(t *testing.T) {
|
||
invalidCipher := Cipher("invalid")
|
||
out, err := Encrypt(&bytes.Buffer{}, EncryptOptions{
|
||
WrapKeyFn: wrapKeyFn,
|
||
KeyName: keyName,
|
||
Algorithm: algorithm,
|
||
Cipher: &invalidCipher,
|
||
})
|
||
require.Error(t, err)
|
||
require.ErrorContains(t, err, "option Cipher is not valid")
|
||
require.Nil(t, out)
|
||
})
|
||
})
|
||
|
||
t.Run("init errors for Decrypt", func(t *testing.T) {
|
||
t.Run("input stream is nil", func(t *testing.T) {
|
||
out, err := Decrypt(nil, DecryptOptions{
|
||
UnwrapKeyFn: unwrapKeyFn,
|
||
})
|
||
require.Error(t, err)
|
||
require.ErrorContains(t, err, "in stream is nil")
|
||
require.Nil(t, out)
|
||
})
|
||
|
||
t.Run("option UnwrapKeyFn is empty", func(t *testing.T) {
|
||
out, err := Decrypt(&bytes.Buffer{}, DecryptOptions{})
|
||
require.Error(t, err)
|
||
require.ErrorContains(t, err, "option UnwrapKeyFn is required")
|
||
require.Nil(t, out)
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestReplaceReader(t *testing.T) {
|
||
const message = "Ho sceso, dandoti il braccio, almeno un milione di scale e ora che non ci sei è il vuoto ad ogni gradino."
|
||
|
||
t.Run("replace bytes", func(t *testing.T) {
|
||
const replace = "✂️"
|
||
const expect = "Ho sceso, dandoti il braccio, almeno un milione di scale e ora✂️è il vuoto ad ogni gradino."
|
||
|
||
rr := newReplaceReader(strings.NewReader(message), 62, 78, strings.NewReader(replace))
|
||
read, err := io.ReadAll(rr)
|
||
require.NoError(t, err)
|
||
require.Equal(t, expect, string(read))
|
||
})
|
||
|
||
t.Run("remove bytes", func(t *testing.T) {
|
||
const expect = "Ho sceso, dandoti il braccio, almeno un milione di scale e ora è il vuoto ad ogni gradino."
|
||
|
||
rr := newReplaceReader(strings.NewReader(message), 63, 78, &bytes.Buffer{})
|
||
read, err := io.ReadAll(rr)
|
||
require.NoError(t, err)
|
||
require.Equal(t, expect, string(read))
|
||
})
|
||
|
||
t.Run("remove at the end", func(t *testing.T) {
|
||
const expect = "Ho sceso, dandoti il braccio, almeno un milione di scale e ora che non ci sei è il vuoto"
|
||
|
||
rr := newReplaceReader(strings.NewReader(message), 89, -1, &bytes.Buffer{})
|
||
read, err := io.ReadAll(rr)
|
||
require.NoError(t, err)
|
||
require.Equal(t, expect, string(read))
|
||
})
|
||
}
|
||
|
||
// Implements an io.Reader that replaces a segment in the stream with custom data
|
||
type replaceReader struct {
|
||
stream io.Reader
|
||
cutStart int
|
||
cutEnd int // If -1, removes till the ned
|
||
replace io.Reader
|
||
|
||
// Internal properties
|
||
read int
|
||
replacing bool
|
||
l sync.Mutex
|
||
}
|
||
|
||
func newReplaceReader(stream io.Reader, cutStart, cutEnd int, replace io.Reader) io.Reader {
|
||
return &replaceReader{
|
||
stream: stream,
|
||
cutStart: cutStart,
|
||
cutEnd: cutEnd,
|
||
replace: replace,
|
||
}
|
||
}
|
||
|
||
func (r *replaceReader) Read(p []byte) (int, error) {
|
||
if r.cutEnd == 0 || (r.cutEnd > 0 && r.cutStart > r.cutEnd) {
|
||
panic("cutStart and/or cutEnd are not valid")
|
||
}
|
||
|
||
r.l.Lock()
|
||
defer r.l.Unlock()
|
||
|
||
// If we've already replaced the data and there's no more data left, just read from the rest of the stream
|
||
if r.replacing && r.replace == nil {
|
||
return r.stream.Read(p)
|
||
}
|
||
|
||
// If we're in the replacement section, read from the replace stream
|
||
if r.replacing {
|
||
n, err := r.replace.Read(p)
|
||
if errors.Is(err, io.EOF) {
|
||
err = nil
|
||
r.replace = nil
|
||
}
|
||
return n, err
|
||
}
|
||
|
||
max := len(p)
|
||
if (max + r.read) > r.cutStart {
|
||
max = r.cutStart - r.read
|
||
}
|
||
n, err := r.stream.Read(p[:max])
|
||
r.read += n
|
||
|
||
if r.read >= r.cutStart {
|
||
// Advance the stream till the cut end, ignoring errors
|
||
if r.cutEnd < 0 {
|
||
io.Copy(io.Discard, r.stream)
|
||
} else {
|
||
io.CopyN(io.Discard, r.stream, int64(r.cutEnd-r.cutStart))
|
||
}
|
||
r.replacing = true
|
||
}
|
||
|
||
return n, err
|
||
}
|
||
|
||
// Implements an io.Reader that simulates failures (after optionally reading from a stream in full)
|
||
type failingReader struct {
|
||
// Data to return before returning an error
|
||
data io.Reader
|
||
l sync.Mutex
|
||
}
|
||
|
||
func (f *failingReader) Read(p []byte) (n int, err error) {
|
||
f.l.Lock()
|
||
defer f.l.Unlock()
|
||
|
||
if f.data != nil {
|
||
n, err := f.data.Read(p)
|
||
switch {
|
||
case err == nil:
|
||
return n, nil
|
||
case errors.Is(err, io.EOF):
|
||
// Do not return io.EOF as error
|
||
// Instead, just delete the stream
|
||
// On the next call, we will return an error
|
||
f.data = nil
|
||
return n, nil
|
||
default:
|
||
// Should not happen
|
||
panic(err)
|
||
}
|
||
}
|
||
|
||
return 0, errSimulatedStream
|
||
}
|