diff --git a/security/advancedtls/crl_provider.go b/security/advancedtls/crl_provider.go index 709726472..cc3200639 100644 --- a/security/advancedtls/crl_provider.go +++ b/security/advancedtls/crl_provider.go @@ -20,8 +20,13 @@ package advancedtls import ( "crypto/x509" + "fmt" + "os" + "time" ) +const defaultCRLRefreshDuration = 1 * time.Hour + type CRLProvider interface { // Callers are expected to use the returned value as read-only. CRL(cert *x509.Certificate) (*CRL, error) @@ -41,10 +46,147 @@ func MakeStaticCRLProvider() *StaticCRLProvider { } func (p *StaticCRLProvider) AddCRL(crl *CRL) { - p.crls[crl.CertList.Issuer.ToRDNSequence().String()] = crl + key := crl.CertList.Issuer.ToRDNSequence().String() + p.crls[key] = crl } func (p *StaticCRLProvider) CRL(cert *x509.Certificate) (*CRL, error) { // TODO handle no CRL found - return p.crls[cert.Issuer.ToRDNSequence().String()], nil + key := cert.Issuer.ToRDNSequence().String() + return p.crls[key], nil +} + +type Options struct { + CRLDirectory string + RefreshDuration time.Duration +} + +// NewFileWatcherCRLProvider creates a new FileWatcherCRLProvider. +type FileWatcherCRLProvider struct { + crls map[string]*CRL + opts Options + done chan bool +} + +func NewFileWatcherCRLProvider(o Options) (*FileWatcherCRLProvider, error) { + if err := o.validate(); err != nil { + return nil, err + } + return &FileWatcherCRLProvider{ + crls: make(map[string]*CRL), + opts: o, + done: make(chan bool), + }, nil +} + +func (o Options) validate() error { + // Checks relates to CRLDirectory. + if o.CRLDirectory == "" { + return fmt.Errorf("advancedtls: CRLDirectory needs to be specified") + } + fileInfo, err := os.Stat(o.CRLDirectory) + if err != nil { + if os.IsNotExist(err) { + return fmt.Errorf("advancedtls: CRLDirectory %v does not exist", o.CRLDirectory) + } else { + return err + } + } + if !fileInfo.IsDir() { + return fmt.Errorf("advancedtls: CRLDirectory %v is not a directory", o.CRLDirectory) + } + _, err = os.Open(o.CRLDirectory) + if err != nil { + if os.IsPermission(err) { + return fmt.Errorf("advancedtls: CRLDirectory %v is not readable:", o.CRLDirectory) + } else { + return err + } + } + // Checks related to RefreshDuration. + if o.RefreshDuration <= 0 || o.RefreshDuration < time.Second { + grpclogLogger.Warningf("RefreshDuration must larger then 1 second: provided value %v, default value will be used %v", o.RefreshDuration, defaultCRLRefreshDuration) + } + return nil +} + +// Start starts watching the directory for CRL files and updates the provider accordingly. +func (p *FileWatcherCRLProvider) Start() { + ticker := time.NewTicker(p.opts.RefreshDuration) + defer ticker.Stop() + + // Initial CRL load + p.scanCRLDirectory() + + for { + select { + case <-ticker.C: + p.scanCRLDirectory() + case <-p.done: + return + } + } +} + +// Stop stops the CRL provider and releases resources. +func (p *FileWatcherCRLProvider) Stop() { + close(p.done) +} + +func (p *FileWatcherCRLProvider) scanCRLDirectory() { + dir, err := os.Open(p.opts.CRLDirectory) + if err != nil { + grpclogLogger.Errorf("Can't open CRLDirectory %v", p.opts.CRLDirectory, err) + } + defer dir.Close() + + files, err := dir.ReadDir(0) + if err != nil { + grpclogLogger.Errorf("Can't access files under CRLDirectory %v", p.opts.CRLDirectory, err) + } + + successCounter := 0 + failCounter := 0 + for _, file := range files { + filePath := fmt.Sprintf("%s/%s", p.opts.CRLDirectory, file.Name()) + err := p.addCRL(filePath) + if err != nil { + failCounter++ + grpclogLogger.Warningf("Can't add CRL from file %v under CRLDirectory %v", filePath, p.opts.CRLDirectory, err) + continue + } + successCounter++ + } + grpclogLogger.Infof("Scan of CRLDirectory %v completed, tried %v files, added %v CRLs, %v files failed", len(files), successCounter, failCounter) +} + +func (p *FileWatcherCRLProvider) addCRL(filePath string) error { + crlBytes, err := os.ReadFile(filePath) + if err != nil { + return err + } + crl, err := parseRevocationList(crlBytes) + if err != nil { + return fmt.Errorf("addCRL: can't parse CRL from file %v: %v", filePath, err) + } + var certList *CRL + if certList, err = parseCRLExtensions(crl); err != nil { + return fmt.Errorf("addCRL: unsupported crl %v: %v", filePath, err) + } + rawCRLIssuer, err := extractCRLIssuer(crlBytes) + if err != nil { + return fmt.Errorf("addCRL: can't extract Issuer from CRL from file %v: %v", filePath, err) + } + certList.RawIssuer = rawCRLIssuer + key := certList.CertList.Issuer.ToRDNSequence().String() + p.crls[key] = certList + grpclogLogger.Infof("In-memory CRL storage of FileWatcherCRLProvider for key %v updated", key) + return nil +} + +// CRL retrieves the CRL associated with the given certificate's issuer DN. +func (p *FileWatcherCRLProvider) CRL(cert *x509.Certificate) (*CRL, error) { + // TODO handle no CRL found + key := cert.Issuer.ToRDNSequence().String() + return p.crls[key], nil }