From 1aced67471c9da6ede5162a2adf155ae12b61408 Mon Sep 17 00:00:00 2001 From: Aaron Lehmann Date: Mon, 20 Jul 2015 12:19:38 -0700 Subject: [PATCH] Improvements to keystore caching * RemoveKey must purge the cache entry * Add mutexes to KeyFileStore and KeyMemoryStore so the cachedKeys map is protected in the case that keystore operations happen from multiple goroutines * Change GetKey to return the alias along with the key. Remove GetKeyAlias. This simplifies the code flows that retrieve the alias (since they usually get the key and alias together). * Fix tests affected by key caching Signed-off-by: Aaron Lehmann --- cryptoservice/crypto_service.go | 4 +- keystoremanager/import_export.go | 14 +---- keystoremanager/import_export_test.go | 8 +-- keystoremanager/keystoremanager.go | 2 +- trustmanager/keyfilestore.go | 89 +++++++++++++++++---------- trustmanager/keyfilestore_test.go | 56 ++++++++--------- 6 files changed, 91 insertions(+), 82 deletions(-) diff --git a/cryptoservice/crypto_service.go b/cryptoservice/crypto_service.go index 061c0541d4..3a61cb95da 100644 --- a/cryptoservice/crypto_service.go +++ b/cryptoservice/crypto_service.go @@ -68,7 +68,7 @@ func (ccs *CryptoService) Create(role string, algorithm data.KeyAlgorithm) (data // GetKey returns a key by ID func (ccs *CryptoService) GetKey(keyID string) data.PublicKey { - key, err := ccs.keyStore.GetKey(keyID) + key, _, err := ccs.keyStore.GetKey(keyID) if err != nil { return nil } @@ -92,7 +92,7 @@ func (ccs *CryptoService) Sign(keyIDs []string, payload []byte) ([]data.Signatur var privKey data.PrivateKey var err error - privKey, err = ccs.keyStore.GetKey(keyName) + privKey, _, err = ccs.keyStore.GetKey(keyName) if err != nil { // Note that GetKey always fails on InitRepo. // InitRepo gets a signer that doesn't have access to diff --git a/keystoremanager/import_export.go b/keystoremanager/import_export.go index 3c42ab8263..bd87244681 100644 --- a/keystoremanager/import_export.go +++ b/keystoremanager/import_export.go @@ -81,12 +81,7 @@ func (km *KeyStoreManager) ImportRootKey(source io.Reader, keyID string) error { func moveKeys(oldKeyStore, newKeyStore *trustmanager.KeyFileStore) error { // List all files but no symlinks for _, f := range oldKeyStore.ListKeys() { - pemBytes, err := oldKeyStore.GetKey(f) - if err != nil { - return err - } - - alias, err := oldKeyStore.GetKeyAlias(f) + pemBytes, alias, err := oldKeyStore.GetKey(f) if err != nil { return err } @@ -259,12 +254,7 @@ func moveKeysByGUN(oldKeyStore, newKeyStore *trustmanager.KeyFileStore, gun stri continue } - privKey, err := oldKeyStore.GetKey(relKeyPath) - if err != nil { - return err - } - - alias, err := oldKeyStore.GetKeyAlias(relKeyPath) + privKey, alias, err := oldKeyStore.GetKey(relKeyPath) if err != nil { return err } diff --git a/keystoremanager/import_export_test.go b/keystoremanager/import_export_test.go index 47247f95c9..4a62995796 100644 --- a/keystoremanager/import_export_test.go +++ b/keystoremanager/import_export_test.go @@ -85,7 +85,7 @@ func TestImportExportZip(t *testing.T) { // because the passwords were chosen by the newPassphraseRetriever. privKeyList := repo.KeyStoreManager.NonRootKeyStore().ListKeys() for _, privKeyName := range privKeyList { - alias, err := repo.KeyStoreManager.NonRootKeyStore().GetKeyAlias(privKeyName) + _, alias, err := repo.KeyStoreManager.NonRootKeyStore().GetKey(privKeyName) assert.NoError(t, err, "privKey %s has no alias", privKeyName) relKeyPath := filepath.Join("private", "tuf_keys", privKeyName+"_"+alias+".key") @@ -156,7 +156,7 @@ func TestImportExportZip(t *testing.T) { // Look for keys in private. The filenames should match the key IDs // in the repo's private key store. for _, privKeyName := range privKeyList { - alias, err := repo.KeyStoreManager.NonRootKeyStore().GetKeyAlias(privKeyName) + _, alias, err := repo.KeyStoreManager.NonRootKeyStore().GetKey(privKeyName) assert.NoError(t, err, "privKey %s has no alias", privKeyName) relKeyPath := filepath.Join("private", "tuf_keys", privKeyName+"_"+alias+".key") @@ -221,7 +221,7 @@ func TestImportExportGUN(t *testing.T) { // because they were formerly unencrypted. privKeyList := repo.KeyStoreManager.NonRootKeyStore().ListKeys() for _, privKeyName := range privKeyList { - alias, err := repo.KeyStoreManager.NonRootKeyStore().GetKeyAlias(privKeyName) + _, alias, err := repo.KeyStoreManager.NonRootKeyStore().GetKey(privKeyName) if err != nil { t.Fatalf("privKey %s has no alias", privKeyName) } @@ -290,7 +290,7 @@ func TestImportExportGUN(t *testing.T) { // Look for keys in private. The filenames should match the key IDs // in the repo's private key store. for _, privKeyName := range privKeyList { - alias, err := repo.KeyStoreManager.NonRootKeyStore().GetKeyAlias(privKeyName) + _, alias, err := repo.KeyStoreManager.NonRootKeyStore().GetKey(privKeyName) if err != nil { t.Fatalf("privKey %s has no alias", privKeyName) } diff --git a/keystoremanager/keystoremanager.go b/keystoremanager/keystoremanager.go index 6f4c1d6b20..ecc4c1bcec 100644 --- a/keystoremanager/keystoremanager.go +++ b/keystoremanager/keystoremanager.go @@ -173,7 +173,7 @@ func (km *KeyStoreManager) GenRootKey(algorithm string) (string, error) { // GetRootCryptoService retrieves a root key and a cryptoservice to use with it // TODO(mccauley): remove this as its no longer needed once we have key caching in the keystores func (km *KeyStoreManager) GetRootCryptoService(rootKeyID string) (*cryptoservice.UnlockedCryptoService, error) { - privKey, err := km.rootKeyStore.GetKey(rootKeyID) + privKey, _, err := km.rootKeyStore.GetKey(rootKeyID) if err != nil { return nil, fmt.Errorf("could not get decrypted root key with keyID: %s, %v", rootKeyID, err) } diff --git a/trustmanager/keyfilestore.go b/trustmanager/keyfilestore.go index 6a50ebf3eb..545345a508 100644 --- a/trustmanager/keyfilestore.go +++ b/trustmanager/keyfilestore.go @@ -3,6 +3,7 @@ package trustmanager import ( "path/filepath" "strings" + "sync" "errors" "fmt" @@ -20,24 +21,36 @@ type KeyStore interface { LimitedFileStore AddKey(name, alias string, privKey data.PrivateKey) error - GetKey(name string) (data.PrivateKey, error) - GetKeyAlias(name string) (string, error) + GetKey(name string) (data.PrivateKey, string, error) ListKeys() []string RemoveKey(name string) error } +type cachedKey struct { + alias string + key data.PrivateKey +} + +// PassphraseRetriever is a callback function that should retrieve a passphrase +// for a given named key. If it should be treated as new passphrase (e.g. with +// confirmation), createNew will be true. Attempts is passed in so that implementers +// decide how many chances to give to a human, for example. +type PassphraseRetriever func(keyId, alias string, createNew bool, attempts int) (passphrase string, giveup bool, err error) + // KeyFileStore persists and manages private keys on disk type KeyFileStore struct { + sync.Mutex SimpleFileStore PassphraseRetriever - cachedKeys map[string]data.PrivateKey + cachedKeys map[string]*cachedKey } // KeyMemoryStore manages private keys in memory type KeyMemoryStore struct { + sync.Mutex MemoryFileStore PassphraseRetriever - cachedKeys map[string]data.PrivateKey + cachedKeys map[string]*cachedKey } // NewKeyFileStore returns a new KeyFileStore creating a private directory to @@ -47,26 +60,27 @@ func NewKeyFileStore(baseDir string, passphraseRetriever passphrase.Retriever) ( if err != nil { return nil, err } - cachedKeys := make(map[string]data.PrivateKey) + cachedKeys := make(map[string]*cachedKey) - return &KeyFileStore{*fileStore, passphraseRetriever, cachedKeys}, nil + return &KeyFileStore{SimpleFileStore: *fileStore, + PassphraseRetriever: passphraseRetriever, + cachedKeys: cachedKeys}, nil } // AddKey stores the contents of a PEM-encoded private key as a PEM block func (s *KeyFileStore) AddKey(name, alias string, privKey data.PrivateKey) error { + s.Lock() + defer s.Unlock() return addKey(s, s.PassphraseRetriever, s.cachedKeys, name, alias, privKey) } // GetKey returns the PrivateKey given a KeyID -func (s *KeyFileStore) GetKey(name string) (data.PrivateKey, error) { +func (s *KeyFileStore) GetKey(name string) (data.PrivateKey, string, error) { + s.Lock() + defer s.Unlock() return getKey(s, s.PassphraseRetriever, s.cachedKeys, name) } -// GetKeyAlias returns the PrivateKey's alias given a KeyID -func (s *KeyFileStore) GetKeyAlias(name string) (string, error) { - return getKeyAlias(s, name) -} - // ListKeys returns a list of unique PublicKeys present on the KeyFileStore. // There might be symlinks associating Certificate IDs to Public Keys, so this // method only returns the IDs that aren't symlinks @@ -76,32 +90,35 @@ func (s *KeyFileStore) ListKeys() []string { // RemoveKey removes the key from the keyfilestore func (s *KeyFileStore) RemoveKey(name string) error { - return removeKey(s, name) + s.Lock() + defer s.Unlock() + return removeKey(s, s.cachedKeys, name) } // NewKeyMemoryStore returns a new KeyMemoryStore which holds keys in memory func NewKeyMemoryStore(passphraseRetriever passphrase.Retriever) *KeyMemoryStore { memStore := NewMemoryFileStore() - cachedKeys := make(map[string]data.PrivateKey) + cachedKeys := make(map[string]*cachedKey) - return &KeyMemoryStore{*memStore, passphraseRetriever, cachedKeys} + return &KeyMemoryStore{MemoryFileStore: *memStore, + PassphraseRetriever: passphraseRetriever, + cachedKeys: cachedKeys} } // AddKey stores the contents of a PEM-encoded private key as a PEM block func (s *KeyMemoryStore) AddKey(name, alias string, privKey data.PrivateKey) error { + s.Lock() + defer s.Unlock() return addKey(s, s.PassphraseRetriever, s.cachedKeys, name, alias, privKey) } // GetKey returns the PrivateKey given a KeyID -func (s *KeyMemoryStore) GetKey(name string) (data.PrivateKey, error) { +func (s *KeyMemoryStore) GetKey(name string) (data.PrivateKey, string, error) { + s.Lock() + defer s.Unlock() return getKey(s, s.PassphraseRetriever, s.cachedKeys, name) } -// GetKeyAlias returns the PrivateKey's alias given a KeyID -func (s *KeyMemoryStore) GetKeyAlias(name string) (string, error) { - return getKeyAlias(s, name) -} - // ListKeys returns a list of unique PublicKeys present on the KeyFileStore. // There might be symlinks associating Certificate IDs to Public Keys, so this // method only returns the IDs that aren't symlinks @@ -111,10 +128,12 @@ func (s *KeyMemoryStore) ListKeys() []string { // RemoveKey removes the key from the keystore func (s *KeyMemoryStore) RemoveKey(name string) error { - return removeKey(s, name) + s.Lock() + defer s.Unlock() + return removeKey(s, s.cachedKeys, name) } -func addKey(s LimitedFileStore, passphraseRetriever PassphraseRetriever, cachedKeys map[string]data.PrivateKey, name, alias string, privKey data.PrivateKey) error { +func addKey(s LimitedFileStore, passphraseRetriever PassphraseRetriever, cachedKeys map[string]*cachedKey, name, alias string, privKey data.PrivateKey) error { pemPrivKey, err := KeyToPEM(privKey) if err != nil { return err @@ -145,7 +164,7 @@ func addKey(s LimitedFileStore, passphraseRetriever PassphraseRetriever, cachedK } } - cachedKeys[name] = privKey + cachedKeys[name] = &cachedKey{alias: alias, key: privKey} return s.Add(name+"_"+alias, pemPrivKey) } @@ -167,19 +186,19 @@ func getKeyAlias(s LimitedFileStore, keyID string) (string, error) { } // GetKey returns the PrivateKey given a KeyID -func getKey(s LimitedFileStore, passphraseRetriever PassphraseRetriever, cachedKeys map[string]data.PrivateKey, name string) (data.PrivateKey, error) { - cachedKey, ok := cachedKeys[name] +func getKey(s LimitedFileStore, passphraseRetriever PassphraseRetriever, cachedKeys map[string]*cachedKey, name string) (data.PrivateKey, string, error) { + cachedKeyEntry, ok := cachedKeys[name] if ok { - return cachedKey, nil + return cachedKeyEntry.key, cachedKeyEntry.alias, nil } keyAlias, err := getKeyAlias(s, name) if err != nil { - return nil, err + return nil, "", err } keyBytes, err := s.Get(name + "_" + keyAlias) if err != nil { - return nil, err + return nil, "", err } // See if the key is encrypted. If its encrypted we'll fail to parse the private key @@ -190,10 +209,10 @@ func getKey(s LimitedFileStore, passphraseRetriever PassphraseRetriever, cachedK passphrase, giveup, err := passphraseRetriever(name, string(keyAlias), false, attempts) // Check if the passphrase retriever got an error or if it is telling us to give up if giveup || err != nil { - return nil, errors.New("obtaining passphrase failed") + return nil, "", errors.New("obtaining passphrase failed") } if attempts > 10 { - return nil, errors.New("maximum number of passphrase attempts exceeded") + return nil, "", errors.New("maximum number of passphrase attempts exceeded") } // Try to convert PEM encoded bytes back to a PrivateKey using the passphrase @@ -204,8 +223,8 @@ func getKey(s LimitedFileStore, passphraseRetriever PassphraseRetriever, cachedK } } } - cachedKeys[name] = privKey - return privKey, nil + cachedKeys[name] = &cachedKey{alias: keyAlias, key: privKey} + return privKey, keyAlias, nil } // ListKeys returns a list of unique PublicKeys present on the KeyFileStore. @@ -223,11 +242,13 @@ func listKeys(s LimitedFileStore) []string { } // RemoveKey removes the key from the keyfilestore -func removeKey(s LimitedFileStore, name string) error { +func removeKey(s LimitedFileStore, cachedKeys map[string]*cachedKey, name string) error { keyAlias, err := getKeyAlias(s, name) if err != nil { return err } + delete(cachedKeys, name) + return s.Remove(name + "_" + keyAlias) } diff --git a/trustmanager/keyfilestore_test.go b/trustmanager/keyfilestore_test.go index 941ef09154..9eae4e96ca 100644 --- a/trustmanager/keyfilestore_test.go +++ b/trustmanager/keyfilestore_test.go @@ -4,12 +4,12 @@ import ( "bytes" "crypto/rand" "errors" + "github.com/docker/notary/Godeps/_workspace/src/github.com/stretchr/testify/assert" "io/ioutil" "os" "path/filepath" "strings" "testing" - "github.com/docker/notary/Godeps/_workspace/src/github.com/stretchr/testify/assert" ) var passphraseRetriever = func(keyID string, alias string, createNew bool, numAttempts int) (string, bool, error) { @@ -121,7 +121,7 @@ EMl3eFOJXjIch/wIesRSN+2dGOsl7neercjMh1i9RvpCwHDx/E0= } // Call the GetKey function - privKey, err := store.GetKey(testName) + privKey, _, err := store.GetKey(testName) if err != nil { t.Fatalf("failed to get file from store: %v", err) } @@ -155,13 +155,7 @@ func TestAddGetKeyMemStore(t *testing.T) { } // Check to see if file exists - retrievedKey, err := store.GetKey(testName) - if err != nil { - t.Fatalf("failed to get key from store: %v", err) - } - - // Check to see if alias exists - retrievedAlias, err := store.GetKeyAlias(testName) + retrievedKey, retrievedAlias, err := store.GetKey(testName) if err != nil { t.Fatalf("failed to get key from store: %v", err) } @@ -216,8 +210,11 @@ func TestGetDecryptedWithTamperedCipherText(t *testing.T) { // Tamper the file fp.WriteAt([]byte("a"), int64(1)) + // Recreate the KeyFileStore to avoid caching + store, err = NewKeyFileStore(tempBaseDir, passphraseRetriever) + // Try to decrypt the file - _, err = store.GetKey(privKey.ID()) + _, _, err = store.GetKey(privKey.ID()) if err == nil { t.Fatalf("expected error while decrypting the content due to invalid cipher text") } @@ -250,15 +247,15 @@ func TestGetDecryptedWithInvalidPassphrase(t *testing.T) { t.Fatalf("failed to create new key filestore: %v", err) } - testGetDecryptedWithInvalidPassphrase(t, fileStore) - - // Test with KeyMemoryStore - memStore := NewKeyMemoryStore(invalidPassphraseRetriever) + newFileStore, err := NewKeyFileStore(tempBaseDir, invalidPassphraseRetriever) if err != nil { - t.Fatalf("failed to create new key memorystore: %v", err) + t.Fatalf("failed to create new key filestore: %v", err) } - testGetDecryptedWithInvalidPassphrase(t, memStore) + testGetDecryptedWithInvalidPassphrase(t, fileStore, newFileStore) + + // Can't test with KeyMemoryStore because we cache the decrypted version of + // the key forever } func TestGetDecryptedWithConsistentlyInvalidPassphrase(t *testing.T) { @@ -283,17 +280,20 @@ func TestGetDecryptedWithConsistentlyInvalidPassphrase(t *testing.T) { t.Fatalf("failed to create new key filestore: %v", err) } - testGetDecryptedWithInvalidPassphrase(t, fileStore) - - // Test with KeyMemoryStore - memStore := NewKeyMemoryStore(consistentlyInvalidPassphraseRetriever) + newFileStore, err := NewKeyFileStore(tempBaseDir, consistentlyInvalidPassphraseRetriever) if err != nil { - t.Fatalf("failed to create new key memorystore: %v", err) + t.Fatalf("failed to create new key filestore: %v", err) } - testGetDecryptedWithInvalidPassphrase(t, memStore) + + testGetDecryptedWithInvalidPassphrase(t, fileStore, newFileStore) + + // Can't test with KeyMemoryStore because we cache the decrypted version of + // the key forever } -func testGetDecryptedWithInvalidPassphrase(t *testing.T, store KeyStore) { +// testGetDecryptedWithInvalidPassphrase takes two keystores so it can add to +// one and get from the other (to work around caching) +func testGetDecryptedWithInvalidPassphrase(t *testing.T, store KeyStore, newStore KeyStore) { testAlias := "root" // Generate a new random RSA Key @@ -309,7 +309,7 @@ func testGetDecryptedWithInvalidPassphrase(t *testing.T, store KeyStore) { } // Try to decrypt the file with an invalid passphrase - _, err = store.GetKey(privKey.ID()) + _, _, err = newStore.GetKey(privKey.ID()) if err == nil { t.Fatalf("expected error while decrypting the content due to invalid passphrase") } @@ -377,7 +377,6 @@ func TestKeysAreCached(t *testing.T) { } defer os.RemoveAll(tempBaseDir) - var countingPassphraseRetriever PassphraseRetriever numTimesCalled := 0 @@ -406,7 +405,7 @@ func TestKeysAreCached(t *testing.T) { assert.Equal(t, 1, numTimesCalled, "numTimesCalled should have been 1") // Call the AddKey function - privKey2, err := store.GetKey(testName) + privKey2, _, err := store.GetKey(testName) if err != nil { t.Fatalf("failed to add file to store: %v", err) } @@ -415,7 +414,6 @@ func TestKeysAreCached(t *testing.T) { assert.Equal(t, privKey.Private(), privKey2.Private(), "cachedPrivKey should be the same as the added privKey") assert.Equal(t, 1, numTimesCalled, "numTimesCalled should be 1 -- no additional call to passphraseRetriever") - // Create a new store store2, err := NewKeyFileStore(tempBaseDir, countingPassphraseRetriever) if err != nil { @@ -423,7 +421,7 @@ func TestKeysAreCached(t *testing.T) { } // Call the AddKey function - privKey3, err := store2.GetKey(testName) + privKey3, _, err := store2.GetKey(testName) if err != nil { t.Fatalf("failed to add file to store: %v", err) } @@ -434,7 +432,7 @@ func TestKeysAreCached(t *testing.T) { // Call the GetKey function a bunch of times for i := 0; i < 10; i++ { - _, err := store2.GetKey(testName) + _, _, err := store2.GetKey(testName) if err != nil { t.Fatalf("failed to add file to store: %v", err) }