Refactor async go routine, validate() func, add unit tests

This commit is contained in:
Andrey Ermolov 2023-09-29 03:25:01 +00:00
parent 1feaae3b5c
commit a9a84f1251
2 changed files with 133 additions and 18 deletions

View File

@ -19,6 +19,7 @@
package advancedtls
import (
"context"
"crypto/x509"
"fmt"
"os"
@ -63,23 +64,26 @@ type Options struct {
// NewFileWatcherCRLProvider creates a new FileWatcherCRLProvider.
type FileWatcherCRLProvider struct {
crls map[string]*CRL
opts Options
done chan bool
crls map[string]*CRL
opts Options
cancel context.CancelFunc
}
func NewFileWatcherCRLProvider(o Options) (*FileWatcherCRLProvider, error) {
func MakeFileWatcherCRLProvider(o Options) (*FileWatcherCRLProvider, error) {
if err := o.validate(); err != nil {
return nil, err
}
return &FileWatcherCRLProvider{
ctx, cancel := context.WithCancel(context.Background())
provider := &FileWatcherCRLProvider{
crls: make(map[string]*CRL),
opts: o,
done: make(chan bool),
}, nil
}
provider.cancel = cancel
go provider.run(ctx)
return provider, nil
}
func (o Options) validate() error {
func (o *Options) validate() error {
// Checks relates to CRLDirectory.
if o.CRLDirectory == "" {
return fmt.Errorf("advancedtls: CRLDirectory needs to be specified")
@ -105,32 +109,32 @@ func (o Options) validate() error {
}
// Checks related to RefreshDuration.
if o.RefreshDuration <= 0 || o.RefreshDuration < time.Second {
o.RefreshDuration = defaultCRLRefreshDuration
grpclogLogger.Warningf("RefreshDuration must larger then 1 second: provided value %v, default value will be used %v", o.RefreshDuration, defaultCRLRefreshDuration)
}
return nil
}
// Start starts watching the directory for CRL files and updates the provider accordingly.
func (p *FileWatcherCRLProvider) Start() {
func (p *FileWatcherCRLProvider) run(ctx context.Context) {
ticker := time.NewTicker(p.opts.RefreshDuration)
defer ticker.Stop()
// Initial CRL load
p.scanCRLDirectory()
for {
select {
case <-ctx.Done():
ticker.Stop()
return
case <-ticker.C:
p.scanCRLDirectory()
case <-p.done:
return
}
}
}
// Stop stops the CRL provider and releases resources.
func (p *FileWatcherCRLProvider) Stop() {
close(p.done)
func (p *FileWatcherCRLProvider) Close() {
p.cancel()
}
func (p *FileWatcherCRLProvider) scanCRLDirectory() {
@ -157,7 +161,7 @@ func (p *FileWatcherCRLProvider) scanCRLDirectory() {
}
successCounter++
}
grpclogLogger.Infof("Scan of CRLDirectory %v completed, tried %v files, added %v CRLs, %v files failed", len(files), successCounter, failCounter)
grpclogLogger.Infof("Scan of CRLDirectory %v completed, %v files tried, %v CRLs added, %v files failed", len(files), successCounter, failCounter)
}
func (p *FileWatcherCRLProvider) addCRL(filePath string) error {
@ -171,7 +175,7 @@ func (p *FileWatcherCRLProvider) addCRL(filePath string) error {
}
var certList *CRL
if certList, err = parseCRLExtensions(crl); err != nil {
return fmt.Errorf("addCRL: unsupported crl %v: %v", filePath, err)
return fmt.Errorf("addCRL: unsupported CRL %v: %v", filePath, err)
}
rawCRLIssuer, err := extractCRLIssuer(crlBytes)
if err != nil {
@ -186,7 +190,6 @@ func (p *FileWatcherCRLProvider) addCRL(filePath string) error {
// CRL retrieves the CRL associated with the given certificate's issuer DN.
func (p *FileWatcherCRLProvider) CRL(cert *x509.Certificate) (*CRL, error) {
// TODO handle no CRL found
key := cert.Issuer.ToRDNSequence().String()
return p.crls[key], nil
}

View File

@ -22,6 +22,7 @@ import (
"crypto/x509"
"fmt"
"testing"
"time"
"google.golang.org/grpc/security/advancedtls/testdata"
)
@ -71,3 +72,114 @@ func TestStaticCRLProvider(t *testing.T) {
})
}
}
func TestFileWatcherCRLProviderConfig(t *testing.T) {
if _, err := MakeFileWatcherCRLProvider(Options{}); err == nil {
t.Fatalf("Empty Options should not be allowed")
}
if _, err := MakeFileWatcherCRLProvider(Options{CRLDirectory: "I_do_not_exist"}); err == nil {
t.Fatalf("CRLDirectory must exist")
}
if defaultProvider, err := MakeFileWatcherCRLProvider(Options{CRLDirectory: testdata.Path("crl/provider")}); err == nil {
if defaultProvider.opts.RefreshDuration != defaultCRLRefreshDuration {
t.Fatalf("RefreshDuration is not properly updated by validate() func")
}
defaultProvider.Close()
} else {
t.Fatal("Unexpected error:", err)
}
regularProvider, err := MakeFileWatcherCRLProvider(Options{
CRLDirectory: testdata.Path("crl"),
RefreshDuration: 5 * time.Second,
})
if err != nil {
t.Fatal("Unexpected error while creating regular FileWatcherCRLProvider:", err)
}
regularProvider.scanCRLDirectory()
tests := []struct {
desc string
certs []*x509.Certificate
expectNoCRL bool
}{
{
desc: "Unrevoked chain",
certs: makeChain(t, testdata.Path("crl/unrevoked.pem")),
},
{
desc: "Revoked Intermediate chain",
certs: makeChain(t, testdata.Path("crl/revokedInt.pem")),
},
{
desc: "Revoked leaf chain",
certs: makeChain(t, testdata.Path("crl/revokedLeaf.pem")),
},
{
desc: "Chain with no CRL for issuer",
certs: makeChain(t, testdata.Path("client_cert_1.pem")),
expectNoCRL: true,
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
for _, c := range tt.certs {
crl, err := regularProvider.CRL(c)
if err != nil {
t.Fatalf("Expected error fetch from provider: %v", err)
}
if crl == nil && !tt.expectNoCRL {
t.Fatalf("CRL is unexpectedly nil")
}
}
})
}
regularProvider.Close()
}
func TestFileWatcherCRLProvider(t *testing.T) {
p := MakeStaticCRLProvider()
for i := 1; i <= 6; i++ {
crl := loadCRL(t, testdata.Path(fmt.Sprintf("crl/%d.crl", i)))
p.AddCRL(crl)
}
tests := []struct {
desc string
certs []*x509.Certificate
expectNoCRL bool
}{
{
desc: "Unrevoked chain",
certs: makeChain(t, testdata.Path("crl/unrevoked.pem")),
},
{
desc: "Revoked Intermediate chain",
certs: makeChain(t, testdata.Path("crl/revokedInt.pem")),
},
{
desc: "Revoked leaf chain",
certs: makeChain(t, testdata.Path("crl/revokedLeaf.pem")),
},
{
desc: "Chain with no CRL for issuer",
certs: makeChain(t, testdata.Path("client_cert_1.pem")),
expectNoCRL: true,
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
for _, c := range tt.certs {
crl, err := p.CRL(c)
if err != nil {
t.Fatalf("Expected error fetch from provider: %v", err)
}
if crl == nil && !tt.expectNoCRL {
t.Fatalf("CRL is unexpectedly nil")
}
}
})
}
}