mirror of https://github.com/dapr/kit.git
426 lines
14 KiB
Go
426 lines
14 KiB
Go
/*
|
|
Copyright 2023 The Dapr Authors
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package v1
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"sync"
|
|
)
|
|
|
|
const (
|
|
// SchemeName is the name of the encryption scheme.
|
|
SchemeName = "dapr.io/enc/v1"
|
|
|
|
// Size of each segment in the encrypted message.
|
|
// Each segment is exactly 64KB in length, except the last one which could be shorter.
|
|
SegmentSize = 64 << 10
|
|
|
|
// Overhead of each segment in bytes.
|
|
// This is equivalent to the size of the authentication tag for AES-GCM and ChaCha20-Poly1305.
|
|
SegmentOverhead = 16
|
|
|
|
// Length of the nonce prefix.
|
|
NoncePrefixLength = 7
|
|
)
|
|
|
|
var (
|
|
// Error returned when trying to decrypt a document whose manifest does not contain a key name, and the caller did not provide an explicit key name.
|
|
ErrDecryptionKeyMissing = errors.New("document's manifest does not contain a key name, and no key name was provided")
|
|
|
|
// Error returned when the signature of the document could not be validated.
|
|
ErrDecryptionSignature = errors.New("failed to validate the document's signature")
|
|
|
|
// Error returned when the deryption fails.
|
|
// Most commonly this happens when a segment has been tampered with.
|
|
ErrDecryptionFailed = errors.New("failed to decrypt segment")
|
|
)
|
|
|
|
type (
|
|
// Signature of the method that wraps keys.
|
|
// This does not accept a context, which needs to be provided by the caller of the Encrypt method inside the lambda.
|
|
WrapKeyFn = func(plaintextKey []byte, algorithm string, keyName string, nonce []byte) (wrappedKey []byte, tag []byte, err error)
|
|
|
|
// Signature of the method that unwraps keys.
|
|
// This does not accept a context, which needs to be provided by the caller of the Decrypt method inside the lambda.
|
|
UnwrapKeyFn = func(wrappedKey []byte, algorithm string, keyName string, nonce []byte, tag []byte) (plaintextKey []byte, err error)
|
|
)
|
|
|
|
// EncryptOptions contains the options passed to the Encrypt method
|
|
type EncryptOptions struct {
|
|
// Function that is invoked to wrap the key
|
|
WrapKeyFn WrapKeyFn
|
|
// Algorithm used to wrap the file key
|
|
// This must be one of the supported KeyAlgorithm constants, and must be usable by the kind of key provided
|
|
Algorithm KeyAlgorithm
|
|
// Name of the key to use
|
|
KeyName string
|
|
// Name of the key to include as decryption key
|
|
// If empty, uses KeyName
|
|
DecryptionKeyName string
|
|
// If true, does not include the key name in the manifest
|
|
OmitKeyName bool
|
|
// Cipher used to encrypt the data
|
|
// If nil, defaults to AES-GCM
|
|
Cipher *Cipher
|
|
}
|
|
|
|
// DecryptOptions contains the options passed to the Decrypt method
|
|
type DecryptOptions struct {
|
|
// Function that is invoked to unwrap the key
|
|
UnwrapKeyFn UnwrapKeyFn
|
|
// If set, uses this value as key name rather than the one included in the manifest
|
|
KeyName string
|
|
}
|
|
|
|
// BufPool is a sync.Pool that returns buffers of SegmentSize+SegmentOverhead, plus one extra byte
|
|
var BufPool = sync.Pool{
|
|
New: func() any {
|
|
const bufSize = SegmentSize + SegmentOverhead + 1
|
|
// Return a pointer here
|
|
// See https://github.com/dominikh/go-tools/issues/1336 for explanation
|
|
b := make([]byte, bufSize)
|
|
return &b
|
|
},
|
|
}
|
|
|
|
// Encrypt a document using the `dapr.io/enc/v1` scheme.
|
|
// The plaintext is read from the `in` stream and written to the returned stream.
|
|
func Encrypt(in io.Reader, opts EncryptOptions) (io.Reader, error) {
|
|
// Validate the request options
|
|
if in == nil {
|
|
return nil, errors.New("in stream is nil")
|
|
}
|
|
if opts.WrapKeyFn == nil {
|
|
return nil, errors.New("option WrapKeyFn is required")
|
|
}
|
|
if opts.KeyName == "" {
|
|
return nil, errors.New("option KeyName is required")
|
|
}
|
|
if opts.Algorithm == "" {
|
|
return nil, errors.New("option Algorithm is required")
|
|
}
|
|
keyWrapAlgorithm, err := opts.Algorithm.Validate()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("option Algorithm is not valid: %w", err)
|
|
}
|
|
cipher := CipherAESGCM
|
|
if opts.Cipher != nil {
|
|
cipher, err = opts.Cipher.Validate()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("option Cipher is not valid: %w", err)
|
|
}
|
|
}
|
|
|
|
// Start by generating a random file key
|
|
fk, err := newFileKey(cipher)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Wrap the file key
|
|
// Note: we're skipping the nonce and ignoring the tag parameter at the moment because none of the supported ciphers use them
|
|
wrappedFileKey, _, err := opts.WrapKeyFn(fk.GetFileKey(), string(keyWrapAlgorithm), opts.KeyName, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to wrap the file key: %w", err)
|
|
}
|
|
|
|
// Create the manifest and sign it
|
|
keyName := opts.DecryptionKeyName
|
|
if opts.OmitKeyName {
|
|
keyName = ""
|
|
} else if keyName == "" {
|
|
keyName = opts.KeyName
|
|
}
|
|
manifest, err := json.Marshal(&Manifest{
|
|
KeyName: keyName,
|
|
KeyWrappingAlgorithm: keyWrapAlgorithm,
|
|
WFK: wrappedFileKey,
|
|
Cipher: cipher,
|
|
NoncePrefix: fk.GetNoncePrefix(),
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to encode JSON manifest: %w", err)
|
|
}
|
|
header, err := fk.SignHeader(manifest)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to sign header: %w", err)
|
|
}
|
|
|
|
// Start a background goroutine to perform the encryption, and return the stream to the caller
|
|
// From now on, errors are returned as errors on the stream
|
|
outR, outW := io.Pipe()
|
|
go func() {
|
|
// Write the header
|
|
if !writeOrClosePipe(outW, header) {
|
|
return
|
|
}
|
|
|
|
// Proceed with processing all segments
|
|
processSegments(in, outW, fk.EncryptSegment, SegmentSize)
|
|
}()
|
|
|
|
return outR, nil
|
|
}
|
|
|
|
// Decrypt a document using the `dapr.io/enc/v1` scheme
|
|
// The ciphertext is read from the `in` stream and written to the returned stream
|
|
func Decrypt(in io.Reader, opts DecryptOptions) (io.Reader, error) {
|
|
// Validate the request options
|
|
if in == nil {
|
|
return nil, errors.New("in stream is nil")
|
|
}
|
|
if opts.UnwrapKeyFn == nil {
|
|
return nil, errors.New("option UnwrapKeyFn is required")
|
|
}
|
|
|
|
// Read the header
|
|
manifest, mac, err := readHeader(&in)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid header: %w", err)
|
|
}
|
|
|
|
// Parse the manifest to get the key name and validate it
|
|
var manifestObj Manifest
|
|
err = json.Unmarshal(manifest, &manifestObj)
|
|
if err != nil || manifestObj.Validate() != nil {
|
|
// Do not return the exact error to avoid disclosing too much information
|
|
return nil, errors.New("invalid header: invalid manifest")
|
|
}
|
|
|
|
// Get the name of the key, and check if we need to override it
|
|
keyName := opts.KeyName
|
|
if keyName == "" {
|
|
keyName = manifestObj.KeyName
|
|
if keyName == "" {
|
|
return nil, ErrDecryptionKeyMissing
|
|
}
|
|
}
|
|
|
|
// Unwrap the file key
|
|
// Note: we're skipping the nonce and tag parameters at the moment because none of the supported ciphers use them
|
|
fileKeyBytes, _ := opts.UnwrapKeyFn(manifestObj.WFK, string(manifestObj.KeyWrappingAlgorithm), keyName, nil, nil)
|
|
if len(fileKeyBytes) != 32 {
|
|
// This is where things get a bit tricky.
|
|
// If the UnwrapKeyFn returned an error, we want to ignore that for now, and instead continue validating the MAC using an empty fileKey (which will fail).
|
|
// This is because otherwise we may be making it easier to disclose certain information such as whether a key exists or not in the vault via timing attacks.
|
|
// What we're doing here doesn't remove timing attacks entirely, starting from the fact that we're putting an `if` block. Also, the underlying components may respond faster if the key isn't available… but at least we can try not making the situation worse!
|
|
// Also, this takes some time as we're allocating memory, but in the case of err==nil the operation there takes some time too.
|
|
fileKeyBytes = make([]byte, 32)
|
|
}
|
|
|
|
// Import the file key
|
|
fk, err := importFileKey(fileKeyBytes, manifestObj.NoncePrefix, manifestObj.Cipher)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Now validate the MAC of the header
|
|
err = fk.VerifyHeaderSignature(manifest, mac)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Start a background goroutine to perform the encryption, and return the stream to the caller
|
|
// From now on, errors are returned as errors on the stream
|
|
outR, outW := io.Pipe()
|
|
go processSegments(in, outW, fk.DecryptSegment, SegmentSize+SegmentOverhead)
|
|
|
|
return outR, nil
|
|
}
|
|
|
|
// Reads all segment from the input stream, either plaintext or ciphertext, and process them (encrypt or decrypt them)
|
|
func processSegments(in io.Reader, out *io.PipeWriter, processFn processSegmentFn, segmentSize int) {
|
|
// Get a buffer from the pool
|
|
buf := BufPool.Get().(*[]byte)
|
|
defer func() {
|
|
BufPool.Put(buf)
|
|
}()
|
|
|
|
// Read from the input stream till the end, one segment at a time
|
|
var (
|
|
err error
|
|
segment uint32
|
|
done bool
|
|
hasCarryover bool
|
|
carryover byte
|
|
n, nn int
|
|
)
|
|
for !done {
|
|
n = 0
|
|
|
|
// Add the carryover byte if we have one
|
|
if hasCarryover {
|
|
(*buf)[0] = carryover
|
|
n = 1
|
|
hasCarryover = false
|
|
}
|
|
|
|
// Read a segment from the buffer till we have a full segment + 1 byte or an error (could be EOF).
|
|
// We are reading an extra byte because we need to understand if we've reached the end of the file.
|
|
// Otherwise, if the input stream's data were exactly multiples of segmentSize, we wouldn't have a way to know.
|
|
// Note that the underlying buffer may be larger, so we may not fill it up ever, and that's ok (i.e. if segmentSize == SegmentSize, we are reading an extra 1 byte rather than 17)
|
|
for n < (segmentSize+1) && err == nil {
|
|
nn, err = in.Read((*buf)[n:(segmentSize + 1)])
|
|
n += nn
|
|
}
|
|
|
|
// Ignore EOF errors, which mean that the input stream is done
|
|
// We will still need to continue processing whatever data we have
|
|
if err != nil && !errors.Is(err, io.EOF) {
|
|
// In case of any other error, close the out stream with the error
|
|
_ = out.CloseWithError(err)
|
|
return
|
|
}
|
|
|
|
// If we read an extra byte, set that as carryover
|
|
// Otherwise, this means that we have the last segment
|
|
if n > segmentSize {
|
|
carryover = (*buf)[n-1]
|
|
hasCarryover = true
|
|
n--
|
|
} else {
|
|
done = true
|
|
}
|
|
|
|
// It's ok if we got less than a full segment, as long as this was the last segment (i.e. the stream is done)
|
|
// Realistically, this should never happen, because in this case we would have had an error returned by in.Read.
|
|
if n < segmentSize && !done {
|
|
_ = out.CloseWithError(io.ErrUnexpectedEOF)
|
|
return
|
|
}
|
|
|
|
// A completely empty segment is ok only if this is the first segment (i.e. the input was empty)
|
|
// Note that here, we've already checked and made sure that the input stream is done
|
|
if n == 0 {
|
|
if segment != 0 {
|
|
// Realistically, it should be impossible for us to get to this point as well, as there would have been a carryover from the previous iteration.
|
|
_ = out.CloseWithError(io.ErrUnexpectedEOF)
|
|
return
|
|
}
|
|
break
|
|
}
|
|
|
|
// We can now process the segment
|
|
err = processFn(out, (*buf)[:n], segment, done)
|
|
if err != nil {
|
|
_ = out.CloseWithError(fmt.Errorf("error processing segment %d: %w", segment, err))
|
|
return
|
|
}
|
|
|
|
// Proceed to the next segment if not done
|
|
if !done && segment == 1<<32-1 {
|
|
// We're about to overflow
|
|
_ = out.CloseWithError(errors.New("input stream is too large"))
|
|
return
|
|
}
|
|
segment++
|
|
}
|
|
|
|
// Close the out stream as done
|
|
_ = out.Close()
|
|
}
|
|
|
|
func readHeader(in *io.Reader) (manifest []byte, mac []byte, err error) {
|
|
// Get a buffer from the pool
|
|
buf := BufPool.Get().(*[]byte)
|
|
defer func() {
|
|
BufPool.Put(buf)
|
|
}()
|
|
|
|
// Read the first segment to get the header
|
|
// We know that the header (including the MAC) aren't larger than a single segment
|
|
// Keep reading from the buffer until we get at least 3 newline characters (or an error)
|
|
var (
|
|
n, nn, i, ul int
|
|
newlines int
|
|
lastNewline int
|
|
line []byte
|
|
)
|
|
for newlines < 3 && err == nil {
|
|
// Even though the maximum size for the header is 1 segment (64KB + 16 bytes), read 512 bytes at a time at most, since most headers are much smaller than that
|
|
ul = n + 512
|
|
if ul > SegmentSize {
|
|
ul = SegmentSize
|
|
}
|
|
if n == ul {
|
|
break
|
|
}
|
|
nn, err = (*in).Read((*buf)[n:SegmentSize])
|
|
if nn <= 0 {
|
|
continue
|
|
}
|
|
|
|
for i = n; i < (n+nn) && newlines < 3; i++ {
|
|
if (*buf)[i] != '\n' {
|
|
continue
|
|
}
|
|
|
|
if i <= lastNewline {
|
|
return nil, nil, errors.New("invalid format")
|
|
}
|
|
line = (*buf)[lastNewline:i]
|
|
switch newlines {
|
|
case 0:
|
|
// First line must be the scheme name
|
|
if string(line) != SchemeName {
|
|
return nil, nil, errors.New("unsupported scheme")
|
|
}
|
|
case 1:
|
|
// Second line is the manifest
|
|
manifest = line
|
|
case 2:
|
|
// Third line is the MAC
|
|
mac = line
|
|
}
|
|
newlines++
|
|
lastNewline = i + 1
|
|
}
|
|
n += nn
|
|
}
|
|
|
|
// Ensure we have a manifest and MAC
|
|
if newlines < 1 {
|
|
return nil, nil, errors.New("scheme name not found")
|
|
}
|
|
if len(manifest) == 0 {
|
|
return nil, nil, errors.New("manifest not found")
|
|
}
|
|
if len(mac) == 0 {
|
|
return nil, nil, errors.New("message authentication code not found")
|
|
}
|
|
|
|
// Whatever data we read extra, add it back to the beginning of the stream
|
|
if n > lastNewline {
|
|
// We need to copy the data because the buffer will be given back
|
|
extraBytes := make([]byte, n-lastNewline)
|
|
copy(extraBytes, (*buf)[(lastNewline):n])
|
|
*in = io.MultiReader(bytes.NewReader(extraBytes), *in)
|
|
}
|
|
|
|
return manifest, mac, nil
|
|
}
|
|
|
|
func writeOrClosePipe(w *io.PipeWriter, b []byte) bool {
|
|
_, err := w.Write(b)
|
|
if err != nil {
|
|
_ = w.CloseWithError(fmt.Errorf("failed to write to the stream: %w", err))
|
|
return false
|
|
}
|
|
return true
|
|
}
|