mirror of https://github.com/dapr/kit.git
575 lines
14 KiB
Go
575 lines
14 KiB
Go
/*
|
|
Copyright 2024 The Dapr Authors
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package file
|
|
|
|
import (
|
|
"context"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/spiffe/go-spiffe/v2/spiffeid"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/dapr/kit/crypto/test"
|
|
"github.com/dapr/kit/logger"
|
|
)
|
|
|
|
func TestFile_Run(t *testing.T) {
|
|
t.Run("if Run multiple times, expect error", func(t *testing.T) {
|
|
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
|
ta := From(Options{
|
|
Log: logger.NewLogger("test"),
|
|
CAPath: tmp,
|
|
})
|
|
f, ok := ta.(*file)
|
|
require.True(t, ok)
|
|
f.initFileWatchInterval = time.Millisecond
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
errCh := make(chan error)
|
|
go func() {
|
|
errCh <- f.Run(ctx)
|
|
}()
|
|
go func() {
|
|
errCh <- f.Run(ctx)
|
|
}()
|
|
|
|
select {
|
|
case err := <-errCh:
|
|
require.Error(t, err)
|
|
case <-time.After(time.Second):
|
|
assert.Fail(t, "Expected error")
|
|
}
|
|
|
|
select {
|
|
case <-f.closeCh:
|
|
assert.Fail(t, "closeCh should not be closed")
|
|
default:
|
|
}
|
|
|
|
cancel()
|
|
|
|
select {
|
|
case err := <-errCh:
|
|
require.ErrorIs(t, err, context.Canceled)
|
|
case <-time.After(time.Second):
|
|
assert.Fail(t, "First Run should have returned and returned no error ")
|
|
}
|
|
})
|
|
|
|
t.Run("if file is not found and context cancelled, should return ctx.Err", func(t *testing.T) {
|
|
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
|
|
|
ta := From(Options{
|
|
Log: logger.NewLogger("test"),
|
|
CAPath: tmp,
|
|
})
|
|
f, ok := ta.(*file)
|
|
require.True(t, ok)
|
|
f.initFileWatchInterval = time.Millisecond
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
errCh := make(chan error)
|
|
go func() {
|
|
errCh <- f.Run(ctx)
|
|
}()
|
|
|
|
cancel()
|
|
|
|
select {
|
|
case err := <-errCh:
|
|
require.ErrorIs(t, err, context.Canceled)
|
|
case <-time.After(time.Second):
|
|
assert.Fail(t, "First Run should have returned and returned no error ")
|
|
}
|
|
})
|
|
|
|
t.Run("if file found but is empty, should return error", func(t *testing.T) {
|
|
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
|
require.NoError(t, os.WriteFile(tmp, nil, 0o600))
|
|
|
|
ta := From(Options{
|
|
Log: logger.NewLogger("test"),
|
|
CAPath: tmp,
|
|
})
|
|
f, ok := ta.(*file)
|
|
require.True(t, ok)
|
|
f.initFileWatchInterval = time.Millisecond
|
|
|
|
errCh := make(chan error)
|
|
go func() {
|
|
errCh <- f.Run(context.Background())
|
|
}()
|
|
|
|
select {
|
|
case err := <-errCh:
|
|
require.Error(t, err)
|
|
case <-time.After(time.Second):
|
|
assert.Fail(t, "expected error")
|
|
}
|
|
})
|
|
|
|
t.Run("if file found but is only garbage data, expect error", func(t *testing.T) {
|
|
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
|
require.NoError(t, os.WriteFile(tmp, []byte("garbage data"), 0o600))
|
|
|
|
ta := From(Options{
|
|
Log: logger.NewLogger("test"),
|
|
CAPath: tmp,
|
|
})
|
|
f, ok := ta.(*file)
|
|
require.True(t, ok)
|
|
f.initFileWatchInterval = time.Millisecond
|
|
|
|
errCh := make(chan error)
|
|
go func() {
|
|
errCh <- f.Run(context.Background())
|
|
}()
|
|
|
|
select {
|
|
case err := <-errCh:
|
|
require.Error(t, err)
|
|
case <-time.After(time.Second):
|
|
assert.Fail(t, "expected error")
|
|
}
|
|
})
|
|
|
|
t.Run("if file found but is only garbage data in root, expect error", func(t *testing.T) {
|
|
pki := test.GenPKI(t, test.PKIOptions{})
|
|
root := pki.RootCertPEM[10:]
|
|
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
|
require.NoError(t, os.WriteFile(tmp, root, 0o600))
|
|
|
|
ta := From(Options{
|
|
Log: logger.NewLogger("test"),
|
|
CAPath: tmp,
|
|
})
|
|
f, ok := ta.(*file)
|
|
require.True(t, ok)
|
|
f.initFileWatchInterval = time.Millisecond
|
|
|
|
errCh := make(chan error)
|
|
go func() {
|
|
errCh <- f.Run(context.Background())
|
|
}()
|
|
|
|
select {
|
|
case err := <-errCh:
|
|
require.Error(t, err)
|
|
case <-time.After(time.Second):
|
|
assert.Fail(t, "expected error")
|
|
}
|
|
})
|
|
|
|
t.Run("single root should be correctly parsed from file", func(t *testing.T) {
|
|
pki := test.GenPKI(t, test.PKIOptions{})
|
|
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
|
require.NoError(t, os.WriteFile(tmp, pki.RootCertPEM, 0o600))
|
|
|
|
ta := From(Options{
|
|
Log: logger.NewLogger("test"),
|
|
CAPath: tmp,
|
|
})
|
|
f, ok := ta.(*file)
|
|
require.True(t, ok)
|
|
f.initFileWatchInterval = time.Millisecond
|
|
|
|
errCh := make(chan error)
|
|
go func() {
|
|
errCh <- f.Run(context.Background())
|
|
}()
|
|
|
|
select {
|
|
case <-f.readyCh:
|
|
case <-time.After(time.Second):
|
|
assert.Fail(t, "expected to be ready in time")
|
|
}
|
|
|
|
b, err := f.CurrentTrustAnchors(context.Background())
|
|
require.NoError(t, err)
|
|
assert.Equal(t, pki.RootCertPEM, b)
|
|
})
|
|
|
|
t.Run("garbage data outside of root should be ignored", func(t *testing.T) {
|
|
pki := test.GenPKI(t, test.PKIOptions{})
|
|
//nolint:gocritic
|
|
root := append(pki.RootCertPEM, []byte("garbage data")...)
|
|
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
|
require.NoError(t, os.WriteFile(tmp, root, 0o600))
|
|
|
|
ta := From(Options{
|
|
Log: logger.NewLogger("test"),
|
|
CAPath: tmp,
|
|
})
|
|
f, ok := ta.(*file)
|
|
require.True(t, ok)
|
|
f.initFileWatchInterval = time.Millisecond
|
|
|
|
errCh := make(chan error)
|
|
go func() {
|
|
errCh <- f.Run(context.Background())
|
|
}()
|
|
|
|
select {
|
|
case <-f.readyCh:
|
|
case <-time.After(time.Second):
|
|
assert.Fail(t, "expected to be ready in time")
|
|
}
|
|
|
|
b, err := f.CurrentTrustAnchors(context.Background())
|
|
require.NoError(t, err)
|
|
assert.Equal(t, root, b)
|
|
})
|
|
|
|
t.Run("multiple roots should be parsed", func(t *testing.T) {
|
|
pki1, pki2 := test.GenPKI(t, test.PKIOptions{}), test.GenPKI(t, test.PKIOptions{})
|
|
//nolint:gocritic
|
|
roots := append(pki1.RootCertPEM, pki2.RootCertPEM...)
|
|
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
|
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
|
|
|
|
ta := From(Options{
|
|
Log: logger.NewLogger("test"),
|
|
CAPath: tmp,
|
|
})
|
|
f, ok := ta.(*file)
|
|
require.True(t, ok)
|
|
f.initFileWatchInterval = time.Millisecond
|
|
|
|
errCh := make(chan error)
|
|
go func() {
|
|
errCh <- f.Run(context.Background())
|
|
}()
|
|
|
|
select {
|
|
case <-f.readyCh:
|
|
case <-time.After(time.Second):
|
|
assert.Fail(t, "expected to be ready in time")
|
|
}
|
|
|
|
b, err := f.CurrentTrustAnchors(context.Background())
|
|
require.NoError(t, err)
|
|
assert.Equal(t, roots, b)
|
|
})
|
|
|
|
t.Run("writing a bad root PEM file should make Run return error", func(t *testing.T) {
|
|
pki1, pki2 := test.GenPKI(t, test.PKIOptions{}), test.GenPKI(t, test.PKIOptions{})
|
|
//nolint:gocritic
|
|
roots := append(pki1.RootCertPEM, pki2.RootCertPEM...)
|
|
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
|
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
|
|
|
|
ta := From(Options{
|
|
Log: logger.NewLogger("test"),
|
|
CAPath: tmp,
|
|
})
|
|
f, ok := ta.(*file)
|
|
require.True(t, ok)
|
|
f.initFileWatchInterval = time.Millisecond
|
|
f.fsWatcherInterval = time.Millisecond
|
|
|
|
errCh := make(chan error)
|
|
go func() {
|
|
errCh <- f.Run(context.Background())
|
|
}()
|
|
|
|
select {
|
|
case <-f.readyCh:
|
|
case <-time.After(time.Second):
|
|
assert.Fail(t, "expected to be ready in time")
|
|
}
|
|
|
|
require.NoError(t, os.WriteFile(tmp, []byte("garbage data"), 0o600))
|
|
|
|
select {
|
|
case err := <-errCh:
|
|
require.Error(t, err)
|
|
case <-time.After(time.Second):
|
|
assert.Fail(t, "expected error to be returned from Run")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestFile_GetX509BundleForTrustDomain(t *testing.T) {
|
|
t.Run("Should return full PEM regardless given trust domain", func(t *testing.T) {
|
|
pki := test.GenPKI(t, test.PKIOptions{})
|
|
//nolint:gocritic
|
|
root := append(pki.RootCertPEM, []byte("garbage data")...)
|
|
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
|
require.NoError(t, os.WriteFile(tmp, root, 0o600))
|
|
ta := From(Options{
|
|
Log: logger.NewLogger("test"),
|
|
CAPath: tmp,
|
|
})
|
|
f, ok := ta.(*file)
|
|
require.True(t, ok)
|
|
|
|
errCh := make(chan error)
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
go func() {
|
|
errCh <- ta.Run(ctx)
|
|
}()
|
|
t.Cleanup(func() {
|
|
cancel()
|
|
select {
|
|
case err := <-errCh:
|
|
require.NoError(t, err)
|
|
case <-time.After(time.Second):
|
|
assert.Fail(t, "expected Run to return")
|
|
}
|
|
})
|
|
|
|
trustDomain1, err := spiffeid.TrustDomainFromString("example.com")
|
|
require.NoError(t, err)
|
|
bundle, err := f.GetX509BundleForTrustDomain(trustDomain1)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, f.x509Bundle, bundle)
|
|
b1, err := bundle.Marshal()
|
|
require.NoError(t, err)
|
|
assert.Equal(t, pki.RootCertPEM, b1)
|
|
|
|
trustDomain2, err := spiffeid.TrustDomainFromString("another-example.org")
|
|
require.NoError(t, err)
|
|
bundle, err = f.GetX509BundleForTrustDomain(trustDomain2)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, f.x509Bundle, bundle)
|
|
b2, err := bundle.Marshal()
|
|
require.NoError(t, err)
|
|
assert.Equal(t, pki.RootCertPEM, b2)
|
|
})
|
|
}
|
|
|
|
func TestFile_Watch(t *testing.T) {
|
|
t.Run("should return when Run context has been cancelled", func(t *testing.T) {
|
|
pki := test.GenPKI(t, test.PKIOptions{})
|
|
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
|
require.NoError(t, os.WriteFile(tmp, pki.RootCertPEM, 0o600))
|
|
|
|
ta := From(Options{
|
|
Log: logger.NewLogger("test"),
|
|
CAPath: tmp,
|
|
})
|
|
f, ok := ta.(*file)
|
|
require.True(t, ok)
|
|
f.initFileWatchInterval = time.Millisecond
|
|
|
|
errCh := make(chan error)
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
go func() {
|
|
errCh <- f.Run(ctx)
|
|
}()
|
|
|
|
watchDone := make(chan struct{})
|
|
go func() {
|
|
ta.Watch(context.Background(), make(chan []byte))
|
|
close(watchDone)
|
|
}()
|
|
|
|
cancel()
|
|
|
|
select {
|
|
case err := <-errCh:
|
|
require.NoError(t, err)
|
|
case <-time.After(time.Second):
|
|
assert.Fail(t, "expected error to be returned from Run")
|
|
}
|
|
|
|
select {
|
|
case <-watchDone:
|
|
case <-time.After(time.Second):
|
|
assert.Fail(t, "expected Watch to have returned")
|
|
}
|
|
})
|
|
|
|
t.Run("should return when given context has been cancelled", func(t *testing.T) {
|
|
pki := test.GenPKI(t, test.PKIOptions{})
|
|
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
|
require.NoError(t, os.WriteFile(tmp, pki.RootCertPEM, 0o600))
|
|
|
|
ta := From(Options{
|
|
Log: logger.NewLogger("test"),
|
|
CAPath: tmp,
|
|
})
|
|
f, ok := ta.(*file)
|
|
require.True(t, ok)
|
|
f.initFileWatchInterval = time.Millisecond
|
|
|
|
errCh := make(chan error)
|
|
ctx1, cancel1 := context.WithCancel(context.Background())
|
|
go func() {
|
|
errCh <- f.Run(ctx1)
|
|
}()
|
|
|
|
watchDone := make(chan struct{})
|
|
ctx2, cancel2 := context.WithCancel(context.Background())
|
|
go func() {
|
|
ta.Watch(ctx2, make(chan []byte))
|
|
close(watchDone)
|
|
}()
|
|
|
|
cancel2()
|
|
|
|
select {
|
|
case <-watchDone:
|
|
case <-time.After(time.Second):
|
|
assert.Fail(t, "expected Watch to have returned")
|
|
}
|
|
|
|
cancel1()
|
|
|
|
select {
|
|
case err := <-errCh:
|
|
require.NoError(t, err)
|
|
case <-time.After(time.Second):
|
|
assert.Fail(t, "expected error to be returned from Run")
|
|
}
|
|
})
|
|
|
|
t.Run("should update Watch subscribers when root PEM has been changed", func(t *testing.T) {
|
|
pki1 := test.GenPKI(t, test.PKIOptions{})
|
|
pki2 := test.GenPKI(t, test.PKIOptions{})
|
|
pki3 := test.GenPKI(t, test.PKIOptions{})
|
|
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
|
require.NoError(t, os.WriteFile(tmp, pki1.RootCertPEM, 0o600))
|
|
|
|
ta := From(Options{
|
|
Log: logger.NewLogger("test"),
|
|
CAPath: tmp,
|
|
})
|
|
f, ok := ta.(*file)
|
|
require.True(t, ok)
|
|
f.initFileWatchInterval = time.Millisecond
|
|
f.fsWatcherInterval = time.Millisecond
|
|
|
|
errCh := make(chan error)
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
go func() {
|
|
errCh <- f.Run(ctx)
|
|
}()
|
|
|
|
select {
|
|
case <-f.readyCh:
|
|
case <-time.After(time.Second):
|
|
assert.Fail(t, "expected to be ready in time")
|
|
}
|
|
|
|
watchDone1, watchDone2 := make(chan struct{}), make(chan struct{})
|
|
tCh1, tCh2 := make(chan []byte), make(chan []byte)
|
|
go func() {
|
|
ta.Watch(context.Background(), tCh1)
|
|
close(watchDone1)
|
|
}()
|
|
go func() {
|
|
ta.Watch(context.Background(), tCh2)
|
|
close(watchDone2)
|
|
}()
|
|
|
|
//nolint:gocritic
|
|
roots := append(pki1.RootCertPEM, pki2.RootCertPEM...)
|
|
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
|
|
|
|
for _, ch := range []chan []byte{tCh1, tCh2} {
|
|
select {
|
|
case b := <-ch:
|
|
assert.Equal(t, string(roots), string(b))
|
|
case <-time.After(time.Second):
|
|
assert.Fail(t, "failed to get subscribed file watch in time")
|
|
}
|
|
}
|
|
|
|
//nolint:gocritic
|
|
roots = append(pki1.RootCertPEM, append(pki2.RootCertPEM, pki3.RootCertPEM...)...)
|
|
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
|
|
|
|
for _, ch := range []chan []byte{tCh1, tCh2} {
|
|
select {
|
|
case b := <-ch:
|
|
assert.Equal(t, string(roots), string(b))
|
|
case <-time.After(time.Second):
|
|
assert.Fail(t, "failed to get subscribed file watch in time")
|
|
}
|
|
}
|
|
|
|
cancel()
|
|
|
|
for _, ch := range []chan struct{}{watchDone1, watchDone2} {
|
|
select {
|
|
case <-ch:
|
|
case <-time.After(time.Second):
|
|
assert.Fail(t, "expected Watch to have returned")
|
|
}
|
|
}
|
|
|
|
select {
|
|
case err := <-errCh:
|
|
require.NoError(t, err)
|
|
case <-time.After(time.Second):
|
|
assert.Fail(t, "expected error to be returned from Run")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestFile_CurrentTrustAnchors(t *testing.T) {
|
|
t.Run("returns trust anchors as they change", func(t *testing.T) {
|
|
pki1, pki2, pki3 := test.GenPKI(t, test.PKIOptions{}), test.GenPKI(t, test.PKIOptions{}), test.GenPKI(t, test.PKIOptions{})
|
|
tmp := filepath.Join(t.TempDir(), "ca.crt")
|
|
require.NoError(t, os.WriteFile(tmp, pki1.RootCertPEM, 0o600))
|
|
|
|
ta := From(Options{
|
|
Log: logger.NewLogger("test"),
|
|
CAPath: tmp,
|
|
})
|
|
f, ok := ta.(*file)
|
|
require.True(t, ok)
|
|
f.initFileWatchInterval = time.Millisecond
|
|
f.fsWatcherInterval = time.Millisecond
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
errCh := make(chan error)
|
|
go func() {
|
|
errCh <- f.Run(ctx)
|
|
}()
|
|
|
|
//nolint:gocritic
|
|
roots := append(pki1.RootCertPEM, pki2.RootCertPEM...)
|
|
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
|
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
|
pem, err := ta.CurrentTrustAnchors(context.Background())
|
|
require.NoError(t, err)
|
|
assert.Equal(c, roots, pem)
|
|
}, time.Second, time.Millisecond)
|
|
|
|
//nolint:gocritic
|
|
roots = append(pki1.RootCertPEM, append(pki2.RootCertPEM, pki3.RootCertPEM...)...)
|
|
require.NoError(t, os.WriteFile(tmp, roots, 0o600))
|
|
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
|
pem, err := ta.CurrentTrustAnchors(context.Background())
|
|
require.NoError(t, err)
|
|
assert.Equal(c, roots, pem)
|
|
}, time.Second, time.Millisecond)
|
|
|
|
cancel()
|
|
|
|
select {
|
|
case err := <-errCh:
|
|
require.NoError(t, err)
|
|
case <-time.After(time.Second):
|
|
assert.Fail(t, "expected error to be returned from Run")
|
|
}
|
|
})
|
|
}
|