mirror of https://github.com/grpc/grpc-go.git
Refactor async go routine, validate() func, add unit tests
This commit is contained in:
parent
1feaae3b5c
commit
a9a84f1251
|
|
@ -19,6 +19,7 @@
|
||||||
package advancedtls
|
package advancedtls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
@ -65,21 +66,24 @@ type Options struct {
|
||||||
type FileWatcherCRLProvider struct {
|
type FileWatcherCRLProvider struct {
|
||||||
crls map[string]*CRL
|
crls map[string]*CRL
|
||||||
opts Options
|
opts Options
|
||||||
done chan bool
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFileWatcherCRLProvider(o Options) (*FileWatcherCRLProvider, error) {
|
func MakeFileWatcherCRLProvider(o Options) (*FileWatcherCRLProvider, error) {
|
||||||
if err := o.validate(); err != nil {
|
if err := o.validate(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &FileWatcherCRLProvider{
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
provider := &FileWatcherCRLProvider{
|
||||||
crls: make(map[string]*CRL),
|
crls: make(map[string]*CRL),
|
||||||
opts: o,
|
opts: o,
|
||||||
done: make(chan bool),
|
}
|
||||||
}, nil
|
provider.cancel = cancel
|
||||||
|
go provider.run(ctx)
|
||||||
|
return provider, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o Options) validate() error {
|
func (o *Options) validate() error {
|
||||||
// Checks relates to CRLDirectory.
|
// Checks relates to CRLDirectory.
|
||||||
if o.CRLDirectory == "" {
|
if o.CRLDirectory == "" {
|
||||||
return fmt.Errorf("advancedtls: CRLDirectory needs to be specified")
|
return fmt.Errorf("advancedtls: CRLDirectory needs to be specified")
|
||||||
|
|
@ -105,32 +109,32 @@ func (o Options) validate() error {
|
||||||
}
|
}
|
||||||
// Checks related to RefreshDuration.
|
// Checks related to RefreshDuration.
|
||||||
if o.RefreshDuration <= 0 || o.RefreshDuration < time.Second {
|
if o.RefreshDuration <= 0 || o.RefreshDuration < time.Second {
|
||||||
|
o.RefreshDuration = defaultCRLRefreshDuration
|
||||||
grpclogLogger.Warningf("RefreshDuration must larger then 1 second: provided value %v, default value will be used %v", o.RefreshDuration, defaultCRLRefreshDuration)
|
grpclogLogger.Warningf("RefreshDuration must larger then 1 second: provided value %v, default value will be used %v", o.RefreshDuration, defaultCRLRefreshDuration)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start starts watching the directory for CRL files and updates the provider accordingly.
|
// Start starts watching the directory for CRL files and updates the provider accordingly.
|
||||||
func (p *FileWatcherCRLProvider) Start() {
|
func (p *FileWatcherCRLProvider) run(ctx context.Context) {
|
||||||
ticker := time.NewTicker(p.opts.RefreshDuration)
|
ticker := time.NewTicker(p.opts.RefreshDuration)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
// Initial CRL load
|
|
||||||
p.scanCRLDirectory()
|
p.scanCRLDirectory()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
ticker.Stop()
|
||||||
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
p.scanCRLDirectory()
|
p.scanCRLDirectory()
|
||||||
case <-p.done:
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop stops the CRL provider and releases resources.
|
// Stop stops the CRL provider and releases resources.
|
||||||
func (p *FileWatcherCRLProvider) Stop() {
|
func (p *FileWatcherCRLProvider) Close() {
|
||||||
close(p.done)
|
p.cancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *FileWatcherCRLProvider) scanCRLDirectory() {
|
func (p *FileWatcherCRLProvider) scanCRLDirectory() {
|
||||||
|
|
@ -157,7 +161,7 @@ func (p *FileWatcherCRLProvider) scanCRLDirectory() {
|
||||||
}
|
}
|
||||||
successCounter++
|
successCounter++
|
||||||
}
|
}
|
||||||
grpclogLogger.Infof("Scan of CRLDirectory %v completed, tried %v files, added %v CRLs, %v files failed", len(files), successCounter, failCounter)
|
grpclogLogger.Infof("Scan of CRLDirectory %v completed, %v files tried, %v CRLs added, %v files failed", len(files), successCounter, failCounter)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *FileWatcherCRLProvider) addCRL(filePath string) error {
|
func (p *FileWatcherCRLProvider) addCRL(filePath string) error {
|
||||||
|
|
@ -171,7 +175,7 @@ func (p *FileWatcherCRLProvider) addCRL(filePath string) error {
|
||||||
}
|
}
|
||||||
var certList *CRL
|
var certList *CRL
|
||||||
if certList, err = parseCRLExtensions(crl); err != nil {
|
if certList, err = parseCRLExtensions(crl); err != nil {
|
||||||
return fmt.Errorf("addCRL: unsupported crl %v: %v", filePath, err)
|
return fmt.Errorf("addCRL: unsupported CRL %v: %v", filePath, err)
|
||||||
}
|
}
|
||||||
rawCRLIssuer, err := extractCRLIssuer(crlBytes)
|
rawCRLIssuer, err := extractCRLIssuer(crlBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -186,7 +190,6 @@ func (p *FileWatcherCRLProvider) addCRL(filePath string) error {
|
||||||
|
|
||||||
// CRL retrieves the CRL associated with the given certificate's issuer DN.
|
// CRL retrieves the CRL associated with the given certificate's issuer DN.
|
||||||
func (p *FileWatcherCRLProvider) CRL(cert *x509.Certificate) (*CRL, error) {
|
func (p *FileWatcherCRLProvider) CRL(cert *x509.Certificate) (*CRL, error) {
|
||||||
// TODO handle no CRL found
|
|
||||||
key := cert.Issuer.ToRDNSequence().String()
|
key := cert.Issuer.ToRDNSequence().String()
|
||||||
return p.crls[key], nil
|
return p.crls[key], nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"google.golang.org/grpc/security/advancedtls/testdata"
|
"google.golang.org/grpc/security/advancedtls/testdata"
|
||||||
)
|
)
|
||||||
|
|
@ -71,3 +72,114 @@ func TestStaticCRLProvider(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFileWatcherCRLProviderConfig(t *testing.T) {
|
||||||
|
if _, err := MakeFileWatcherCRLProvider(Options{}); err == nil {
|
||||||
|
t.Fatalf("Empty Options should not be allowed")
|
||||||
|
}
|
||||||
|
if _, err := MakeFileWatcherCRLProvider(Options{CRLDirectory: "I_do_not_exist"}); err == nil {
|
||||||
|
t.Fatalf("CRLDirectory must exist")
|
||||||
|
}
|
||||||
|
if defaultProvider, err := MakeFileWatcherCRLProvider(Options{CRLDirectory: testdata.Path("crl/provider")}); err == nil {
|
||||||
|
if defaultProvider.opts.RefreshDuration != defaultCRLRefreshDuration {
|
||||||
|
t.Fatalf("RefreshDuration is not properly updated by validate() func")
|
||||||
|
}
|
||||||
|
defaultProvider.Close()
|
||||||
|
} else {
|
||||||
|
t.Fatal("Unexpected error:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
regularProvider, err := MakeFileWatcherCRLProvider(Options{
|
||||||
|
CRLDirectory: testdata.Path("crl"),
|
||||||
|
RefreshDuration: 5 * time.Second,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Unexpected error while creating regular FileWatcherCRLProvider:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
regularProvider.scanCRLDirectory()
|
||||||
|
tests := []struct {
|
||||||
|
desc string
|
||||||
|
certs []*x509.Certificate
|
||||||
|
expectNoCRL bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "Unrevoked chain",
|
||||||
|
certs: makeChain(t, testdata.Path("crl/unrevoked.pem")),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Revoked Intermediate chain",
|
||||||
|
certs: makeChain(t, testdata.Path("crl/revokedInt.pem")),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Revoked leaf chain",
|
||||||
|
certs: makeChain(t, testdata.Path("crl/revokedLeaf.pem")),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Chain with no CRL for issuer",
|
||||||
|
certs: makeChain(t, testdata.Path("client_cert_1.pem")),
|
||||||
|
expectNoCRL: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
for _, c := range tt.certs {
|
||||||
|
crl, err := regularProvider.CRL(c)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Expected error fetch from provider: %v", err)
|
||||||
|
}
|
||||||
|
if crl == nil && !tt.expectNoCRL {
|
||||||
|
t.Fatalf("CRL is unexpectedly nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
regularProvider.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileWatcherCRLProvider(t *testing.T) {
|
||||||
|
p := MakeStaticCRLProvider()
|
||||||
|
for i := 1; i <= 6; i++ {
|
||||||
|
crl := loadCRL(t, testdata.Path(fmt.Sprintf("crl/%d.crl", i)))
|
||||||
|
p.AddCRL(crl)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
desc string
|
||||||
|
certs []*x509.Certificate
|
||||||
|
expectNoCRL bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "Unrevoked chain",
|
||||||
|
certs: makeChain(t, testdata.Path("crl/unrevoked.pem")),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Revoked Intermediate chain",
|
||||||
|
certs: makeChain(t, testdata.Path("crl/revokedInt.pem")),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Revoked leaf chain",
|
||||||
|
certs: makeChain(t, testdata.Path("crl/revokedLeaf.pem")),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Chain with no CRL for issuer",
|
||||||
|
certs: makeChain(t, testdata.Path("client_cert_1.pem")),
|
||||||
|
expectNoCRL: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
for _, c := range tt.certs {
|
||||||
|
crl, err := p.CRL(c)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Expected error fetch from provider: %v", err)
|
||||||
|
}
|
||||||
|
if crl == nil && !tt.expectNoCRL {
|
||||||
|
t.Fatalf("CRL is unexpectedly nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue