From 1421f4725851cc7c80a5953c0b4c200741931201 Mon Sep 17 00:00:00 2001 From: Nathan McCauley Date: Mon, 20 Jul 2015 01:40:55 -0700 Subject: [PATCH] keystore caching Signed-off-by: Nathan McCauley --- trustmanager/keyfilestore.go | 30 ++++++++---- trustmanager/keyfilestore_test.go | 77 +++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 10 deletions(-) diff --git a/trustmanager/keyfilestore.go b/trustmanager/keyfilestore.go index ae941088bd..6a50ebf3eb 100644 --- a/trustmanager/keyfilestore.go +++ b/trustmanager/keyfilestore.go @@ -29,13 +29,15 @@ type KeyStore interface { // KeyFileStore persists and manages private keys on disk type KeyFileStore struct { SimpleFileStore - PassphraseRetriever passphrase.Retriever + PassphraseRetriever + cachedKeys map[string]data.PrivateKey } // KeyMemoryStore manages private keys in memory type KeyMemoryStore struct { MemoryFileStore - PassphraseRetriever passphrase.Retriever + PassphraseRetriever + cachedKeys map[string]data.PrivateKey } // NewKeyFileStore returns a new KeyFileStore creating a private directory to @@ -45,18 +47,19 @@ func NewKeyFileStore(baseDir string, passphraseRetriever passphrase.Retriever) ( if err != nil { return nil, err } + cachedKeys := make(map[string]data.PrivateKey) - return &KeyFileStore{*fileStore, passphraseRetriever}, nil + return &KeyFileStore{*fileStore, passphraseRetriever, cachedKeys}, nil } // AddKey stores the contents of a PEM-encoded private key as a PEM block func (s *KeyFileStore) AddKey(name, alias string, privKey data.PrivateKey) error { - return addKey(s, s.PassphraseRetriever, name, alias, privKey) + return addKey(s, s.PassphraseRetriever, s.cachedKeys, name, alias, privKey) } // GetKey returns the PrivateKey given a KeyID func (s *KeyFileStore) GetKey(name string) (data.PrivateKey, error) { - return getKey(s, s.PassphraseRetriever, name) + return getKey(s, s.PassphraseRetriever, s.cachedKeys, name) } // GetKeyAlias returns the PrivateKey's alias given a KeyID @@ -79,18 +82,19 @@ func (s *KeyFileStore) RemoveKey(name string) error { // NewKeyMemoryStore returns a new KeyMemoryStore which holds keys in memory func NewKeyMemoryStore(passphraseRetriever passphrase.Retriever) *KeyMemoryStore { memStore := NewMemoryFileStore() + cachedKeys := make(map[string]data.PrivateKey) - return &KeyMemoryStore{*memStore, passphraseRetriever} + return &KeyMemoryStore{*memStore, passphraseRetriever, cachedKeys} } // AddKey stores the contents of a PEM-encoded private key as a PEM block func (s *KeyMemoryStore) AddKey(name, alias string, privKey data.PrivateKey) error { - return addKey(s, s.PassphraseRetriever, name, alias, privKey) + return addKey(s, s.PassphraseRetriever, s.cachedKeys, name, alias, privKey) } // GetKey returns the PrivateKey given a KeyID func (s *KeyMemoryStore) GetKey(name string) (data.PrivateKey, error) { - return getKey(s, s.PassphraseRetriever, name) + return getKey(s, s.PassphraseRetriever, s.cachedKeys, name) } // GetKeyAlias returns the PrivateKey's alias given a KeyID @@ -110,7 +114,7 @@ func (s *KeyMemoryStore) RemoveKey(name string) error { return removeKey(s, name) } -func addKey(s LimitedFileStore, passphraseRetriever passphrase.Retriever, name, alias string, privKey data.PrivateKey) error { +func addKey(s LimitedFileStore, passphraseRetriever PassphraseRetriever, cachedKeys map[string]data.PrivateKey, name, alias string, privKey data.PrivateKey) error { pemPrivKey, err := KeyToPEM(privKey) if err != nil { return err @@ -141,6 +145,7 @@ func addKey(s LimitedFileStore, passphraseRetriever passphrase.Retriever, name, } } + cachedKeys[name] = privKey return s.Add(name+"_"+alias, pemPrivKey) } @@ -162,7 +167,11 @@ func getKeyAlias(s LimitedFileStore, keyID string) (string, error) { } // GetKey returns the PrivateKey given a KeyID -func getKey(s LimitedFileStore, passphraseRetriever passphrase.Retriever, name string) (data.PrivateKey, error) { +func getKey(s LimitedFileStore, passphraseRetriever PassphraseRetriever, cachedKeys map[string]data.PrivateKey, name string) (data.PrivateKey, error) { + cachedKey, ok := cachedKeys[name] + if ok { + return cachedKey, nil + } keyAlias, err := getKeyAlias(s, name) if err != nil { return nil, err @@ -195,6 +204,7 @@ func getKey(s LimitedFileStore, passphraseRetriever passphrase.Retriever, name s } } } + cachedKeys[name] = privKey return privKey, nil } diff --git a/trustmanager/keyfilestore_test.go b/trustmanager/keyfilestore_test.go index a04f026183..941ef09154 100644 --- a/trustmanager/keyfilestore_test.go +++ b/trustmanager/keyfilestore_test.go @@ -9,6 +9,7 @@ import ( "path/filepath" "strings" "testing" + "github.com/docker/notary/Godeps/_workspace/src/github.com/stretchr/testify/assert" ) var passphraseRetriever = func(keyID string, alias string, createNew bool, numAttempts int) (string, bool, error) { @@ -364,3 +365,79 @@ func TestRemoveKey(t *testing.T) { t.Fatalf("file should not exist %s", expectedFilePath) } } + +func TestKeysAreCached(t *testing.T) { + testName := "docker.com/notary/root" + testAlias := "alias" + + // Temporary directory where test files will be created + tempBaseDir, err := ioutil.TempDir("", "notary-test-") + if err != nil { + t.Fatalf("failed to create a temporary directory: %v", err) + } + defer os.RemoveAll(tempBaseDir) + + + var countingPassphraseRetriever PassphraseRetriever + + numTimesCalled := 0 + countingPassphraseRetriever = func(keyId, alias string, createNew bool, attempts int) (passphrase string, giveup bool, err error) { + numTimesCalled++ + return "password", false, nil + } + + // Create our store + store, err := NewKeyFileStore(tempBaseDir, countingPassphraseRetriever) + if err != nil { + t.Fatalf("failed to create new key filestore: %v", err) + } + + privKey, err := GenerateRSAKey(rand.Reader, 512) + if err != nil { + t.Fatalf("could not generate private key: %v", err) + } + + // Call the AddKey function + err = store.AddKey(testName, testAlias, privKey) + if err != nil { + t.Fatalf("failed to add file to store: %v", err) + } + + assert.Equal(t, 1, numTimesCalled, "numTimesCalled should have been 1") + + // Call the AddKey function + privKey2, err := store.GetKey(testName) + if err != nil { + t.Fatalf("failed to add file to store: %v", err) + } + + assert.Equal(t, privKey.Public(), privKey2.Public(), "cachedPrivKey should be the same as the added privKey") + assert.Equal(t, privKey.Private(), privKey2.Private(), "cachedPrivKey should be the same as the added privKey") + assert.Equal(t, 1, numTimesCalled, "numTimesCalled should be 1 -- no additional call to passphraseRetriever") + + + // Create a new store + store2, err := NewKeyFileStore(tempBaseDir, countingPassphraseRetriever) + if err != nil { + t.Fatalf("failed to create new key filestore: %v", err) + } + + // Call the AddKey function + privKey3, err := store2.GetKey(testName) + if err != nil { + t.Fatalf("failed to add file to store: %v", err) + } + + assert.Equal(t, privKey2.Private(), privKey3.Private(), "privkey from store1 should be the same as privkey from store2") + assert.Equal(t, privKey2.Public(), privKey3.Public(), "privkey from store1 should be the same as privkey from store2") + assert.Equal(t, 2, numTimesCalled, "numTimesCalled should be 2 -- one additional call to passphraseRetriever") + + // Call the GetKey function a bunch of times + for i := 0; i < 10; i++ { + _, err := store2.GetKey(testName) + if err != nil { + t.Fatalf("failed to add file to store: %v", err) + } + } + assert.Equal(t, 2, numTimesCalled, "numTimesCalled should be 2 -- no additional call to passphraseRetriever") +}