mirror of https://github.com/grpc/grpc-go.git
Refactor async go routine, validate() func, add unit tests
This commit is contained in:
parent
1feaae3b5c
commit
a9a84f1251
|
|
@ -19,6 +19,7 @@
|
|||
package advancedtls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"os"
|
||||
|
|
@ -63,23 +64,26 @@ type Options struct {
|
|||
|
||||
// NewFileWatcherCRLProvider creates a new FileWatcherCRLProvider.
|
||||
type FileWatcherCRLProvider struct {
|
||||
crls map[string]*CRL
|
||||
opts Options
|
||||
done chan bool
|
||||
crls map[string]*CRL
|
||||
opts Options
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewFileWatcherCRLProvider(o Options) (*FileWatcherCRLProvider, error) {
|
||||
func MakeFileWatcherCRLProvider(o Options) (*FileWatcherCRLProvider, error) {
|
||||
if err := o.validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &FileWatcherCRLProvider{
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
provider := &FileWatcherCRLProvider{
|
||||
crls: make(map[string]*CRL),
|
||||
opts: o,
|
||||
done: make(chan bool),
|
||||
}, nil
|
||||
}
|
||||
provider.cancel = cancel
|
||||
go provider.run(ctx)
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (o Options) validate() error {
|
||||
func (o *Options) validate() error {
|
||||
// Checks relates to CRLDirectory.
|
||||
if o.CRLDirectory == "" {
|
||||
return fmt.Errorf("advancedtls: CRLDirectory needs to be specified")
|
||||
|
|
@ -105,32 +109,32 @@ func (o Options) validate() error {
|
|||
}
|
||||
// Checks related to RefreshDuration.
|
||||
if o.RefreshDuration <= 0 || o.RefreshDuration < time.Second {
|
||||
o.RefreshDuration = defaultCRLRefreshDuration
|
||||
grpclogLogger.Warningf("RefreshDuration must larger then 1 second: provided value %v, default value will be used %v", o.RefreshDuration, defaultCRLRefreshDuration)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start starts watching the directory for CRL files and updates the provider accordingly.
|
||||
func (p *FileWatcherCRLProvider) Start() {
|
||||
func (p *FileWatcherCRLProvider) run(ctx context.Context) {
|
||||
ticker := time.NewTicker(p.opts.RefreshDuration)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Initial CRL load
|
||||
p.scanCRLDirectory()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
ticker.Stop()
|
||||
return
|
||||
case <-ticker.C:
|
||||
p.scanCRLDirectory()
|
||||
case <-p.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the CRL provider and releases resources.
|
||||
func (p *FileWatcherCRLProvider) Stop() {
|
||||
close(p.done)
|
||||
func (p *FileWatcherCRLProvider) Close() {
|
||||
p.cancel()
|
||||
}
|
||||
|
||||
func (p *FileWatcherCRLProvider) scanCRLDirectory() {
|
||||
|
|
@ -157,7 +161,7 @@ func (p *FileWatcherCRLProvider) scanCRLDirectory() {
|
|||
}
|
||||
successCounter++
|
||||
}
|
||||
grpclogLogger.Infof("Scan of CRLDirectory %v completed, tried %v files, added %v CRLs, %v files failed", len(files), successCounter, failCounter)
|
||||
grpclogLogger.Infof("Scan of CRLDirectory %v completed, %v files tried, %v CRLs added, %v files failed", len(files), successCounter, failCounter)
|
||||
}
|
||||
|
||||
func (p *FileWatcherCRLProvider) addCRL(filePath string) error {
|
||||
|
|
@ -171,7 +175,7 @@ func (p *FileWatcherCRLProvider) addCRL(filePath string) error {
|
|||
}
|
||||
var certList *CRL
|
||||
if certList, err = parseCRLExtensions(crl); err != nil {
|
||||
return fmt.Errorf("addCRL: unsupported crl %v: %v", filePath, err)
|
||||
return fmt.Errorf("addCRL: unsupported CRL %v: %v", filePath, err)
|
||||
}
|
||||
rawCRLIssuer, err := extractCRLIssuer(crlBytes)
|
||||
if err != nil {
|
||||
|
|
@ -186,7 +190,6 @@ func (p *FileWatcherCRLProvider) addCRL(filePath string) error {
|
|||
|
||||
// CRL retrieves the CRL associated with the given certificate's issuer DN.
|
||||
func (p *FileWatcherCRLProvider) CRL(cert *x509.Certificate) (*CRL, error) {
|
||||
// TODO handle no CRL found
|
||||
key := cert.Issuer.ToRDNSequence().String()
|
||||
return p.crls[key], nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import (
|
|||
"crypto/x509"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/security/advancedtls/testdata"
|
||||
)
|
||||
|
|
@ -71,3 +72,114 @@ func TestStaticCRLProvider(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileWatcherCRLProviderConfig(t *testing.T) {
|
||||
if _, err := MakeFileWatcherCRLProvider(Options{}); err == nil {
|
||||
t.Fatalf("Empty Options should not be allowed")
|
||||
}
|
||||
if _, err := MakeFileWatcherCRLProvider(Options{CRLDirectory: "I_do_not_exist"}); err == nil {
|
||||
t.Fatalf("CRLDirectory must exist")
|
||||
}
|
||||
if defaultProvider, err := MakeFileWatcherCRLProvider(Options{CRLDirectory: testdata.Path("crl/provider")}); err == nil {
|
||||
if defaultProvider.opts.RefreshDuration != defaultCRLRefreshDuration {
|
||||
t.Fatalf("RefreshDuration is not properly updated by validate() func")
|
||||
}
|
||||
defaultProvider.Close()
|
||||
} else {
|
||||
t.Fatal("Unexpected error:", err)
|
||||
}
|
||||
|
||||
regularProvider, err := MakeFileWatcherCRLProvider(Options{
|
||||
CRLDirectory: testdata.Path("crl"),
|
||||
RefreshDuration: 5 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("Unexpected error while creating regular FileWatcherCRLProvider:", err)
|
||||
}
|
||||
|
||||
regularProvider.scanCRLDirectory()
|
||||
tests := []struct {
|
||||
desc string
|
||||
certs []*x509.Certificate
|
||||
expectNoCRL bool
|
||||
}{
|
||||
{
|
||||
desc: "Unrevoked chain",
|
||||
certs: makeChain(t, testdata.Path("crl/unrevoked.pem")),
|
||||
},
|
||||
{
|
||||
desc: "Revoked Intermediate chain",
|
||||
certs: makeChain(t, testdata.Path("crl/revokedInt.pem")),
|
||||
},
|
||||
{
|
||||
desc: "Revoked leaf chain",
|
||||
certs: makeChain(t, testdata.Path("crl/revokedLeaf.pem")),
|
||||
},
|
||||
{
|
||||
desc: "Chain with no CRL for issuer",
|
||||
certs: makeChain(t, testdata.Path("client_cert_1.pem")),
|
||||
expectNoCRL: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
for _, c := range tt.certs {
|
||||
crl, err := regularProvider.CRL(c)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected error fetch from provider: %v", err)
|
||||
}
|
||||
if crl == nil && !tt.expectNoCRL {
|
||||
t.Fatalf("CRL is unexpectedly nil")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
regularProvider.Close()
|
||||
}
|
||||
|
||||
func TestFileWatcherCRLProvider(t *testing.T) {
|
||||
p := MakeStaticCRLProvider()
|
||||
for i := 1; i <= 6; i++ {
|
||||
crl := loadCRL(t, testdata.Path(fmt.Sprintf("crl/%d.crl", i)))
|
||||
p.AddCRL(crl)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
desc string
|
||||
certs []*x509.Certificate
|
||||
expectNoCRL bool
|
||||
}{
|
||||
{
|
||||
desc: "Unrevoked chain",
|
||||
certs: makeChain(t, testdata.Path("crl/unrevoked.pem")),
|
||||
},
|
||||
{
|
||||
desc: "Revoked Intermediate chain",
|
||||
certs: makeChain(t, testdata.Path("crl/revokedInt.pem")),
|
||||
},
|
||||
{
|
||||
desc: "Revoked leaf chain",
|
||||
certs: makeChain(t, testdata.Path("crl/revokedLeaf.pem")),
|
||||
},
|
||||
{
|
||||
desc: "Chain with no CRL for issuer",
|
||||
certs: makeChain(t, testdata.Path("client_cert_1.pem")),
|
||||
expectNoCRL: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
for _, c := range tt.certs {
|
||||
crl, err := p.CRL(c)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected error fetch from provider: %v", err)
|
||||
}
|
||||
if crl == nil && !tt.expectNoCRL {
|
||||
t.Fatalf("CRL is unexpectedly nil")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue