From 8768c12901f9688752b9b5732652151660fb2c4b Mon Sep 17 00:00:00 2001 From: Ying Li Date: Thu, 25 Feb 2016 16:55:19 -0800 Subject: [PATCH 1/7] Return the creation date for GetChecksum and GetCurrent from the server database store. Signed-off-by: Ying Li --- cmd/notary/main_test.go | 5 +++-- server/handlers/default_test.go | 5 +++-- server/handlers/roles.go | 4 ++-- server/handlers/validation.go | 8 +++---- server/snapshot/snapshot.go | 2 +- server/storage/database.go | 18 +++++++++------- server/storage/database_test.go | 15 +++++++++---- server/storage/interface.go | 18 +++++++++------- server/storage/memory.go | 37 ++++++++++++++++++--------------- server/storage/memory_test.go | 6 +++--- server/timestamp/timestamp.go | 2 +- 11 files changed, 68 insertions(+), 52 deletions(-) diff --git a/cmd/notary/main_test.go b/cmd/notary/main_test.go index 296d093d63..978e1d464f 100644 --- a/cmd/notary/main_test.go +++ b/cmd/notary/main_test.go @@ -10,6 +10,7 @@ import ( "path/filepath" "strings" "testing" + "time" "github.com/docker/go-connections/tlsconfig" "github.com/docker/notary/passphrase" @@ -209,14 +210,14 @@ type recordingMetaStore struct { // GetCurrent gets the metadata from the underlying MetaStore, but also records // that the metadata was requested -func (r *recordingMetaStore) GetCurrent(gun, role string) (data []byte, err error) { +func (r *recordingMetaStore) GetCurrent(gun, role string) (*time.Time, []byte, error) { r.gotten = append(r.gotten, fmt.Sprintf("%s.%s", gun, role)) return r.MemStorage.GetCurrent(gun, role) } // GetChecksum gets the metadata from the underlying MetaStore, but also records // that the metadata was requested -func (r *recordingMetaStore) GetChecksum(gun, role, checksum string) (data []byte, err error) { +func (r *recordingMetaStore) GetChecksum(gun, role, checksum string) (*time.Time, []byte, error) { r.gotten = append(r.gotten, fmt.Sprintf("%s.%s", gun, role)) return r.MemStorage.GetChecksum(gun, role, checksum) } diff --git a/server/handlers/default_test.go b/server/handlers/default_test.go index abb6d61a7f..1f729028ba 100644 --- a/server/handlers/default_test.go +++ b/server/handlers/default_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "golang.org/x/net/context" @@ -354,8 +355,8 @@ type failStore struct { storage.MemStorage } -func (s *failStore) GetCurrent(_, _ string) ([]byte, error) { - return nil, fmt.Errorf("oh no! storage has failed") +func (s *failStore) GetCurrent(_, _ string) (*time.Time, []byte, error) { + return nil, nil, fmt.Errorf("oh no! storage has failed") } // a non-validation failure, such as the storage failing, will not be propagated diff --git a/server/handlers/roles.go b/server/handlers/roles.go index 15a1b3fbbe..dac0efdbfc 100644 --- a/server/handlers/roles.go +++ b/server/handlers/roles.go @@ -25,9 +25,9 @@ func getRole(ctx context.Context, w io.Writer, store storage.MetaStore, gun, rol case data.CanonicalTimestampRole, data.CanonicalSnapshotRole: return getMaybeServerSigned(ctx, w, store, gun, role) } - out, err = store.GetCurrent(gun, role) + _, out, err = store.GetCurrent(gun, role) } else { - out, err = store.GetChecksum(gun, role, checksum) + _, out, err = store.GetChecksum(gun, role, checksum) } if err != nil { diff --git a/server/handlers/validation.go b/server/handlers/validation.go index ed0b2e4235..7a8d426e55 100644 --- a/server/handlers/validation.go +++ b/server/handlers/validation.go @@ -41,7 +41,7 @@ func validateUpdate(cs signed.CryptoService, gun string, updates []storage.MetaU } var root *data.SignedRoot - oldRootJSON, err := store.GetCurrent(gun, rootRole) + _, oldRootJSON, err := store.GetCurrent(gun, rootRole) if _, ok := err.(storage.ErrNotFound); err != nil && !ok { // problem with storage. No expectation we can // write if we can't read so bail. @@ -92,7 +92,7 @@ func validateUpdate(cs signed.CryptoService, gun string, updates []storage.MetaU // At this point, root and targets must have been loaded into the repo if _, ok := roles[snapshotRole]; ok { var oldSnap *data.SignedSnapshot - oldSnapJSON, err := store.GetCurrent(gun, snapshotRole) + _, oldSnapJSON, err := store.GetCurrent(gun, snapshotRole) if _, ok := err.(storage.ErrNotFound); err != nil && !ok { // problem with storage. No expectation we can // write if we can't read so bail. @@ -180,7 +180,7 @@ func loadAndValidateTargets(gun string, repo *tuf.Repo, roles map[string]storage } func loadTargetsFromStore(gun, role string, repo *tuf.Repo, store storage.MetaStore) error { - tgtJSON, err := store.GetCurrent(gun, role) + _, tgtJSON, err := store.GetCurrent(gun, role) if err != nil { return err } @@ -217,7 +217,7 @@ func generateSnapshot(gun string, repo *tuf.Repo, store storage.MetaStore) (*sto Msg: "no snapshot was included in update and server does not hold current snapshot key for repository"} } - currentJSON, err := store.GetCurrent(gun, data.CanonicalSnapshotRole) + _, currentJSON, err := store.GetCurrent(gun, data.CanonicalSnapshotRole) if err != nil { if _, ok := err.(storage.ErrNotFound); !ok { return nil, validation.ErrValidation{Msg: err.Error()} diff --git a/server/snapshot/snapshot.go b/server/snapshot/snapshot.go index 5b71eb68a9..9492e2fd0d 100644 --- a/server/snapshot/snapshot.go +++ b/server/snapshot/snapshot.go @@ -46,7 +46,7 @@ func GetOrCreateSnapshotKey(gun string, store storage.KeyStore, crypto signed.Cr // whatever the most recent snapshot is to create the next one, only updating // the expiry time and version. func GetOrCreateSnapshot(gun string, store storage.MetaStore, cryptoService signed.CryptoService) ([]byte, error) { - d, err := store.GetCurrent(gun, "snapshot") + _, d, err := store.GetCurrent(gun, data.CanonicalSnapshotRole) if err != nil { return nil, err } diff --git a/server/storage/database.go b/server/storage/database.go index 5456d4a724..bd62301760 100644 --- a/server/storage/database.go +++ b/server/storage/database.go @@ -4,6 +4,7 @@ import ( "crypto/sha256" "encoding/hex" "fmt" + "time" "github.com/Sirupsen/logrus" "github.com/go-sql-driver/mysql" @@ -117,16 +118,17 @@ func (db *SQLStorage) UpdateMany(gun string, updates []MetaUpdate) error { } // GetCurrent gets a specific TUF record -func (db *SQLStorage) GetCurrent(gun, tufRole string) ([]byte, error) { +func (db *SQLStorage) GetCurrent(gun, tufRole string) (*time.Time, []byte, error) { var row TUFFile - q := db.Select("data").Where(&TUFFile{Gun: gun, Role: tufRole}).Order("version desc").Limit(1).First(&row) + q := db.Select("created_at, data").Where( + &TUFFile{Gun: gun, Role: tufRole}).Order("version desc").Limit(1).First(&row) return returnRead(q, row) } // GetChecksum gets a specific TUF record by its hex checksum -func (db *SQLStorage) GetChecksum(gun, tufRole, checksum string) ([]byte, error) { +func (db *SQLStorage) GetChecksum(gun, tufRole, checksum string) (*time.Time, []byte, error) { var row TUFFile - q := db.Select("data").Where( + q := db.Select("created_at, data").Where( &TUFFile{ Gun: gun, Role: tufRole, @@ -136,13 +138,13 @@ func (db *SQLStorage) GetChecksum(gun, tufRole, checksum string) ([]byte, error) return returnRead(q, row) } -func returnRead(q *gorm.DB, row TUFFile) ([]byte, error) { +func returnRead(q *gorm.DB, row TUFFile) (*time.Time, []byte, error) { if q.RecordNotFound() { - return nil, ErrNotFound{} + return nil, nil, ErrNotFound{} } else if q.Error != nil { - return nil, q.Error + return nil, nil, q.Error } - return row.Data, nil + return &(row.CreatedAt), row.Data, nil } // Delete deletes all the records for a specific GUN diff --git a/server/storage/database_test.go b/server/storage/database_test.go index 37bd2bdc4d..6991950438 100644 --- a/server/storage/database_test.go +++ b/server/storage/database_test.go @@ -7,6 +7,7 @@ import ( "io/ioutil" "os" "testing" + "time" "github.com/docker/notary/tuf/data" "github.com/jinzhu/gorm" @@ -231,7 +232,7 @@ func TestSQLGetCurrent(t *testing.T) { gormDB, dbStore := SetUpSQLite(t, tempBaseDir) defer os.RemoveAll(tempBaseDir) - byt, err := dbStore.GetCurrent("testGUN", "root") + _, byt, err := dbStore.GetCurrent("testGUN", "root") require.Nil(t, byt) require.Error(t, err, "There should be an error Getting an empty table") require.IsType(t, ErrNotFound{}, err, "Should get a not found error") @@ -240,9 +241,12 @@ func TestSQLGetCurrent(t *testing.T) { query := gormDB.Create(&tuf) require.NoError(t, query.Error, "Creating a row in an empty DB failed.") - byt, err = dbStore.GetCurrent("testGUN", "root") + cDate, byt, err := dbStore.GetCurrent("testGUN", "root") require.NoError(t, err, "There should not be any errors getting.") require.Equal(t, []byte("1"), byt, "Returned data was incorrect") + // the creation date was sometime wthin the last minute + require.True(t, cDate.After(time.Now().Add(-1*time.Minute))) + require.True(t, cDate.Before(time.Now().Add(5*time.Second))) dbStore.DB.Close() } @@ -487,9 +491,12 @@ func TestDBGetChecksum(t *testing.T) { store.UpdateCurrent("gun", update) - data, err := store.GetChecksum("gun", data.CanonicalTimestampRole, checksum) + cDate, data, err := store.GetChecksum("gun", data.CanonicalTimestampRole, checksum) require.NoError(t, err) require.EqualValues(t, j, data) + // the creation date was sometime wthin the last minute + require.True(t, cDate.After(time.Now().Add(-1*time.Minute))) + require.True(t, cDate.Before(time.Now().Add(5*time.Second))) } func TestDBGetChecksumNotFound(t *testing.T) { @@ -497,7 +504,7 @@ func TestDBGetChecksumNotFound(t *testing.T) { _, store := SetUpSQLite(t, tempBaseDir) defer os.RemoveAll(tempBaseDir) - _, err = store.GetChecksum("gun", data.CanonicalTimestampRole, "12345") + _, _, err = store.GetChecksum("gun", data.CanonicalTimestampRole, "12345") require.Error(t, err) require.IsType(t, ErrNotFound{}, err) } diff --git a/server/storage/interface.go b/server/storage/interface.go index edc8ace153..f5d7691de2 100644 --- a/server/storage/interface.go +++ b/server/storage/interface.go @@ -1,5 +1,7 @@ package storage +import "time" + // KeyStore provides a minimal interface for managing key persistence type KeyStore interface { // GetKey returns the algorithm and public key for the given GUN and role. @@ -24,15 +26,15 @@ type MetaStore interface { // none of the metadata is added, and an error is be returned. UpdateMany(gun string, updates []MetaUpdate) error - // GetCurrent returns the data part of the metadata for the latest version - // of the given GUN and role. If there is no data for the given GUN and - // role, an error is returned. - GetCurrent(gun, tufRole string) (data []byte, err error) + // GetCurrent returns the creation date and data part of the metadata for + // the latest version of the given GUN and role. If there is no data for + // the given GUN and role, an error is returned. + GetCurrent(gun, tufRole string) (created *time.Time, data []byte, err error) - // GetChecksum return the given tuf role file for the GUN with the - // provided checksum. If the given (gun, role, checksum) are not - // found, it returns storage.ErrNotFound - GetChecksum(gun, tufRole, checksum string) (data []byte, err error) + // GetChecksum returns the given TUF role file and creation date for the + // GUN with the provided checksum. If the given (gun, role, checksum) are + // not found, it returns storage.ErrNotFound + GetChecksum(gun, tufRole, checksum string) (created *time.Time, data []byte, err error) // Delete removes all metadata for a given GUN. It does not return an // error if no metadata exists for the given GUN. diff --git a/server/storage/memory.go b/server/storage/memory.go index b46bd02302..373902b4c7 100644 --- a/server/storage/memory.go +++ b/server/storage/memory.go @@ -6,6 +6,7 @@ import ( "fmt" "strings" "sync" + "time" ) type key struct { @@ -14,8 +15,9 @@ type key struct { } type ver struct { - version int - data []byte + version int + data []byte + creation time.Time } // MemStorage is really just designed for dev and testing. It is very @@ -24,7 +26,7 @@ type MemStorage struct { lock sync.Mutex tufMeta map[string][]*ver keys map[string]map[string]*key - checksums map[string]map[string][]byte + checksums map[string]map[string]ver } // NewMemStorage instantiates a memStorage instance @@ -32,7 +34,7 @@ func NewMemStorage() *MemStorage { return &MemStorage{ tufMeta: make(map[string][]*ver), keys: make(map[string]map[string]*key), - checksums: make(map[string]map[string][]byte), + checksums: make(map[string]map[string]ver), } } @@ -48,15 +50,16 @@ func (st *MemStorage) UpdateCurrent(gun string, update MetaUpdate) error { } } } - st.tufMeta[id] = append(st.tufMeta[id], &ver{version: update.Version, data: update.Data}) + version := ver{version: update.Version, data: update.Data, creation: time.Now()} + st.tufMeta[id] = append(st.tufMeta[id], &version) checksumBytes := sha256.Sum256(update.Data) checksum := hex.EncodeToString(checksumBytes[:]) _, ok := st.checksums[gun] if !ok { - st.checksums[gun] = make(map[string][]byte) + st.checksums[gun] = make(map[string]ver) } - st.checksums[gun][checksum] = update.Data + st.checksums[gun][checksum] = version return nil } @@ -68,27 +71,27 @@ func (st *MemStorage) UpdateMany(gun string, updates []MetaUpdate) error { return nil } -// GetCurrent returns the metadata for a given role, under a GUN -func (st *MemStorage) GetCurrent(gun, role string) (data []byte, err error) { +// GetCurrent returns the creation date metadata for a given role, under a GUN. +func (st *MemStorage) GetCurrent(gun, role string) (*time.Time, []byte, error) { id := entryKey(gun, role) st.lock.Lock() defer st.lock.Unlock() space, ok := st.tufMeta[id] if !ok || len(space) == 0 { - return nil, ErrNotFound{} + return nil, nil, ErrNotFound{} } - return space[len(space)-1].data, nil + return &(space[len(space)-1].creation), space[len(space)-1].data, nil } -// GetChecksum returns the metadata for a given role, under a GUN -func (st *MemStorage) GetChecksum(gun, role, checksum string) (data []byte, err error) { +// GetChecksum returns the creation date and metadata for a given role, under a GUN. +func (st *MemStorage) GetChecksum(gun, role, checksum string) (*time.Time, []byte, error) { st.lock.Lock() defer st.lock.Unlock() - data, ok := st.checksums[gun][checksum] - if !ok || len(data) == 0 { - return nil, ErrNotFound{} + space, ok := st.checksums[gun][checksum] + if !ok || len(space.data) == 0 { + return nil, nil, ErrNotFound{} } - return data, nil + return &(space.creation), space.data, nil } // Delete deletes all the metadata for a given GUN diff --git a/server/storage/memory_test.go b/server/storage/memory_test.go index a91f527e56..9e740a7077 100644 --- a/server/storage/memory_test.go +++ b/server/storage/memory_test.go @@ -22,11 +22,11 @@ func TestUpdateCurrent(t *testing.T) { func TestGetCurrent(t *testing.T) { s := NewMemStorage() - _, err := s.GetCurrent("gun", "role") + _, _, err := s.GetCurrent("gun", "role") assert.IsType(t, ErrNotFound{}, err, "Expected error to be ErrNotFound") s.UpdateCurrent("gun", MetaUpdate{"role", 1, []byte("test")}) - d, err := s.GetCurrent("gun", "role") + _, d, err := s.GetCurrent("gun", "role") assert.Nil(t, err, "Expected error to be nil") assert.Equal(t, []byte("test"), d, "Data was incorrect") } @@ -97,7 +97,7 @@ func TestSetKeySameRoleGun(t *testing.T) { func TestGetChecksumNotFound(t *testing.T) { s := NewMemStorage() - _, err := s.GetChecksum("gun", "root", "12345") + _, _, err := s.GetChecksum("gun", "root", "12345") assert.Error(t, err) assert.IsType(t, ErrNotFound{}, err) } diff --git a/server/timestamp/timestamp.go b/server/timestamp/timestamp.go index 2cf3628f5c..645c27eea4 100644 --- a/server/timestamp/timestamp.go +++ b/server/timestamp/timestamp.go @@ -52,7 +52,7 @@ func GetOrCreateTimestamp(gun string, store storage.MetaStore, cryptoService sig if err != nil { return nil, err } - d, err := store.GetCurrent(gun, "timestamp") + _, d, err := store.GetCurrent(gun, "timestamp") if err != nil { if _, ok := err.(storage.ErrNotFound); !ok { logrus.Error("error retrieving timestamp: ", err.Error()) From 802673fc9dbbce3c938773c43ffd51130ea838e3 Mon Sep 17 00:00:00 2001 From: Ying Li Date: Fri, 26 Feb 2016 10:51:55 -0800 Subject: [PATCH 2/7] Add cache control headers to Getting metadata Signed-off-by: Ying Li --- server/handlers/default.go | 38 +++++++++++++++++- server/handlers/default_test.go | 63 +++++++++++++++++++++++++----- server/handlers/roles.go | 44 ++++++++++----------- server/handlers/roles_test.go | 8 +--- server/snapshot/snapshot.go | 22 ++++++----- server/snapshot/snapshot_test.go | 21 ++++++---- server/timestamp/timestamp.go | 27 +++++++------ server/timestamp/timestamp_test.go | 8 ++-- 8 files changed, 162 insertions(+), 69 deletions(-) diff --git a/server/handlers/default.go b/server/handlers/default.go index 96a601412c..1d43af467b 100644 --- a/server/handlers/default.go +++ b/server/handlers/default.go @@ -2,11 +2,16 @@ package handlers import ( "bytes" + "crypto/sha256" + "encoding/hex" "encoding/json" + "fmt" "io" "net/http" "strings" + "time" + "github.com/Sirupsen/logrus" ctxu "github.com/docker/distribution/context" "github.com/gorilla/mux" "golang.org/x/net/context" @@ -20,6 +25,14 @@ import ( "github.com/docker/notary/tuf/validation" ) +const ( + // ConsistentCacheMaxAge is the Cache-Control header's max age for consistent downloads + ConsistentCacheMaxAge int = 30 * 24 * 60 * 60 // 30 days + + // NonConsistentCacheMaxAge is the Cache-Control header's max age for current (non-consistent) downloads + NonConsistentCacheMaxAge int = 5 * 60 // five minutes +) + // MainHandler is the default handler for the server func MainHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) error { // For now it only supports `GET` @@ -122,7 +135,30 @@ func getHandler(ctx context.Context, w http.ResponseWriter, r *http.Request, var return errors.ErrNoStorage.WithDetail(nil) } - return getRole(ctx, w, store, gun, tufRole, checksum) + creation, output, err := getRole(ctx, store, gun, tufRole, checksum) + if err != nil { + return err + } + if creation == nil { + // This shouldn't ever happen, but if it does, it just messes up the cache headers, so + // proceed anyway + logrus.Warnf("Got bytes out for %s's %s (checksum: %s), but missing creation date", + gun, tufRole, checksum) + creation = &time.Time{} // set the last modification date to the beginning of time + } + + maxAge := ConsistentCacheMaxAge + if checksum == "" { // maxAge should be much shorter for metadata without checksums + maxAge = NonConsistentCacheMaxAge + shasum := sha256.Sum256(output) + checksum = hex.EncodeToString(shasum[:]) + } + + w.Header().Set("Cache-Control", fmt.Sprintf("public, max-age=%v", maxAge)) + w.Header().Set("Last-Modified", creation.Format(time.RFC1123)) + w.Header().Set("ETag", checksum) + w.Write(output) + return nil } // DeleteHandler deletes all data for a GUN. A 200 responses indicates success. diff --git a/server/handlers/default_test.go b/server/handlers/default_test.go index 1f729028ba..5b794e9a1f 100644 --- a/server/handlers/default_test.go +++ b/server/handlers/default_test.go @@ -2,6 +2,8 @@ package handlers import ( "bytes" + "crypto/sha256" + "encoding/hex" "encoding/json" "fmt" "io/ioutil" @@ -172,6 +174,28 @@ func TestGetKeyHandlerCreatesOnce(t *testing.T) { } } +// Verifies that the body is as expected, the ETag is as expected, +func verifyGetResponse(t *testing.T, rw *httptest.ResponseRecorder, expectedBytes []byte, + getWithChecksum bool, checksumHex string) { + + body, err := ioutil.ReadAll(rw.Body) + assert.NoError(t, err) + assert.True(t, bytes.Equal(expectedBytes, body)) + + lastModified, err := time.Parse(time.RFC1123, rw.HeaderMap.Get("Last-Modified")) + assert.NoError(t, err) + assert.True(t, lastModified.After(time.Now().Add(-5*time.Minute))) + + cacheControl := rw.HeaderMap.Get("Cache-Control") + maxAge := NonConsistentCacheMaxAge + if getWithChecksum { + maxAge = ConsistentCacheMaxAge + } + assert.Equal(t, fmt.Sprintf("public, max-age=%v", maxAge), cacheControl) + + assert.Equal(t, rw.HeaderMap.Get("ETag"), checksumHex) +} + func TestGetHandlerRoot(t *testing.T) { metaStore := storage.NewMemStorage() repo, _, err := testutils.EmptyRepo("gun") @@ -194,10 +218,17 @@ func TestGetHandlerRoot(t *testing.T) { "tufRole": "root", } - rw := httptest.NewRecorder() + checksumBytes := sha256.Sum256(rootJSON) + checksumHex := hex.EncodeToString(checksumBytes[:]) - err = getHandler(ctx, rw, req, vars) - assert.NoError(t, err) + rw := httptest.NewRecorder() + assert.NoError(t, getHandler(ctx, rw, req, vars)) + verifyGetResponse(t, rw, rootJSON, false, checksumHex) + + vars["checksum"] = checksumHex + rw = httptest.NewRecorder() + assert.NoError(t, getHandler(ctx, rw, req, vars)) + verifyGetResponse(t, rw, rootJSON, true, checksumHex) } func TestGetHandlerTimestamp(t *testing.T) { @@ -228,10 +259,17 @@ func TestGetHandlerTimestamp(t *testing.T) { "tufRole": "timestamp", } - rw := httptest.NewRecorder() + checksumBytes := sha256.Sum256(tsJSON) + checksumHex := hex.EncodeToString(checksumBytes[:]) - err = getHandler(ctx, rw, req, vars) - assert.NoError(t, err) + rw := httptest.NewRecorder() + assert.NoError(t, getHandler(ctx, rw, req, vars)) + verifyGetResponse(t, rw, tsJSON, false, checksumHex) + + vars["checksum"] = checksumHex + rw = httptest.NewRecorder() + assert.NoError(t, getHandler(ctx, rw, req, vars)) + verifyGetResponse(t, rw, tsJSON, true, checksumHex) } func TestGetHandlerSnapshot(t *testing.T) { @@ -256,10 +294,17 @@ func TestGetHandlerSnapshot(t *testing.T) { "tufRole": "snapshot", } - rw := httptest.NewRecorder() + checksumBytes := sha256.Sum256(snJSON) + checksumHex := hex.EncodeToString(checksumBytes[:]) - err = getHandler(ctx, rw, req, vars) - assert.NoError(t, err) + rw := httptest.NewRecorder() + assert.NoError(t, getHandler(ctx, rw, req, vars)) + verifyGetResponse(t, rw, snJSON, false, checksumHex) + + vars["checksum"] = checksumHex + rw = httptest.NewRecorder() + assert.NoError(t, getHandler(ctx, rw, req, vars)) + verifyGetResponse(t, rw, snJSON, true, checksumHex) } func TestGetHandler404(t *testing.T) { diff --git a/server/handlers/roles.go b/server/handlers/roles.go index dac0efdbfc..415cf2bb85 100644 --- a/server/handlers/roles.go +++ b/server/handlers/roles.go @@ -1,7 +1,7 @@ package handlers import ( - "io" + "time" "golang.org/x/net/context" @@ -13,35 +13,35 @@ import ( "github.com/docker/notary/tuf/signed" ) -func getRole(ctx context.Context, w io.Writer, store storage.MetaStore, gun, role, checksum string) error { +func getRole(ctx context.Context, store storage.MetaStore, gun, role, checksum string) (*time.Time, []byte, error) { var ( - out []byte - err error + creation *time.Time + out []byte + err error ) if checksum == "" { // the timestamp and snapshot might be server signed so are // handled specially switch role { case data.CanonicalTimestampRole, data.CanonicalSnapshotRole: - return getMaybeServerSigned(ctx, w, store, gun, role) + return getMaybeServerSigned(ctx, store, gun, role) } - _, out, err = store.GetCurrent(gun, role) + creation, out, err = store.GetCurrent(gun, role) } else { - _, out, err = store.GetChecksum(gun, role, checksum) + creation, out, err = store.GetChecksum(gun, role, checksum) } if err != nil { if _, ok := err.(storage.ErrNotFound); ok { - return errors.ErrMetadataNotFound.WithDetail(err) + return nil, nil, errors.ErrMetadataNotFound.WithDetail(err) } - return errors.ErrUnknown.WithDetail(err) + return nil, nil, errors.ErrUnknown.WithDetail(err) } if out == nil { - return errors.ErrMetadataNotFound.WithDetail(nil) + return nil, nil, errors.ErrMetadataNotFound.WithDetail(nil) } - w.Write(out) - return nil + return creation, out, nil } // getMaybeServerSigned writes the current snapshot or timestamp (based on the @@ -49,32 +49,32 @@ func getRole(ctx context.Context, w io.Writer, store storage.MetaStore, gun, rol // the timestamp and snapshot, based on the keys held by the server, a new one // might be generated and signed due to expiry of the previous one or updates // to other roles. -func getMaybeServerSigned(ctx context.Context, w io.Writer, store storage.MetaStore, gun, role string) error { +func getMaybeServerSigned(ctx context.Context, store storage.MetaStore, gun, role string) (*time.Time, []byte, error) { cryptoServiceVal := ctx.Value("cryptoService") cryptoService, ok := cryptoServiceVal.(signed.CryptoService) if !ok { - return errors.ErrNoCryptoService.WithDetail(nil) + return nil, nil, errors.ErrNoCryptoService.WithDetail(nil) } var ( - out []byte - err error + creation *time.Time + out []byte + err error ) switch role { case data.CanonicalSnapshotRole: - out, err = snapshot.GetOrCreateSnapshot(gun, store, cryptoService) + creation, out, err = snapshot.GetOrCreateSnapshot(gun, store, cryptoService) case data.CanonicalTimestampRole: - out, err = timestamp.GetOrCreateTimestamp(gun, store, cryptoService) + creation, out, err = timestamp.GetOrCreateTimestamp(gun, store, cryptoService) } if err != nil { switch err.(type) { case *storage.ErrNoKey, storage.ErrNotFound: - return errors.ErrMetadataNotFound.WithDetail(err) + return nil, nil, errors.ErrMetadataNotFound.WithDetail(err) default: - return errors.ErrUnknown.WithDetail(err) + return nil, nil, errors.ErrUnknown.WithDetail(err) } } - w.Write(out) - return nil + return creation, out, nil } diff --git a/server/handlers/roles_test.go b/server/handlers/roles_test.go index ce6df0d4c5..3c12d292d7 100644 --- a/server/handlers/roles_test.go +++ b/server/handlers/roles_test.go @@ -14,10 +14,7 @@ import ( ) func TestGetMaybeServerSignedNoCrypto(t *testing.T) { - err := getMaybeServerSigned( - context.Background(), - nil, nil, "", "", - ) + _, _, err := getMaybeServerSigned(context.Background(), nil, "", "") require.Error(t, err) require.IsType(t, errcode.Error{}, err) @@ -33,9 +30,8 @@ func TestGetMaybeServerSignedNoKey(t *testing.T) { ctx = context.WithValue(ctx, "cryptoService", crypto) ctx = context.WithValue(ctx, "keyAlgorithm", data.ED25519Key) - err := getMaybeServerSigned( + _, _, err := getMaybeServerSigned( ctx, - nil, store, "gun", data.CanonicalTimestampRole, diff --git a/server/snapshot/snapshot.go b/server/snapshot/snapshot.go index 9492e2fd0d..be31c6fc8c 100644 --- a/server/snapshot/snapshot.go +++ b/server/snapshot/snapshot.go @@ -2,6 +2,7 @@ package snapshot import ( "encoding/json" + "time" "github.com/Sirupsen/logrus" @@ -45,10 +46,12 @@ func GetOrCreateSnapshotKey(gun string, store storage.KeyStore, crypto signed.Cr // GetOrCreateSnapshot either returns the exisiting latest snapshot, or uses // whatever the most recent snapshot is to create the next one, only updating // the expiry time and version. -func GetOrCreateSnapshot(gun string, store storage.MetaStore, cryptoService signed.CryptoService) ([]byte, error) { - _, d, err := store.GetCurrent(gun, data.CanonicalSnapshotRole) +func GetOrCreateSnapshot(gun string, store storage.MetaStore, cryptoService signed.CryptoService) ( + *time.Time, []byte, error) { + + creation, d, err := store.GetCurrent(gun, data.CanonicalSnapshotRole) if err != nil { - return nil, err + return nil, nil, err } sn := &data.SignedSnapshot{} @@ -56,29 +59,30 @@ func GetOrCreateSnapshot(gun string, store storage.MetaStore, cryptoService sign err := json.Unmarshal(d, sn) if err != nil { logrus.Error("Failed to unmarshal existing snapshot") - return nil, err + return nil, nil, err } if !snapshotExpired(sn) { - return d, nil + return creation, d, nil } } sgnd, version, err := createSnapshot(gun, sn, store, cryptoService) if err != nil { logrus.Error("Failed to create a new snapshot") - return nil, err + return nil, nil, err } out, err := json.Marshal(sgnd) if err != nil { logrus.Error("Failed to marshal new snapshot") - return nil, err + return nil, nil, err } err = store.UpdateCurrent(gun, storage.MetaUpdate{Role: "snapshot", Version: version, Data: out}) if err != nil { - return nil, err + return nil, nil, err } - return out, nil + c := time.Now() + return &c, out, nil } // snapshotExpired simply checks if the snapshot is past its expiry time diff --git a/server/snapshot/snapshot_test.go b/server/snapshot/snapshot_test.go index b2226b5c8f..6be143ae32 100644 --- a/server/snapshot/snapshot_test.go +++ b/server/snapshot/snapshot_test.go @@ -118,7 +118,7 @@ func TestGetSnapshotNotExists(t *testing.T) { store := storage.NewMemStorage() crypto := signed.NewEd25519() - _, err := GetOrCreateSnapshot("gun", store, crypto) + _, _, err := GetOrCreateSnapshot("gun", store, crypto) assert.Error(t, err) } @@ -144,18 +144,23 @@ func TestGetSnapshotCurrValid(t *testing.T) { // test when db is missing the role data store.UpdateCurrent("gun", storage.MetaUpdate{Role: "snapshot", Version: 0, Data: snapJSON}) - _, err = GetOrCreateSnapshot("gun", store, crypto) + c1, result, err := GetOrCreateSnapshot("gun", store, crypto) assert.NoError(t, err) + assert.True(t, bytes.Equal(snapJSON, result)) // test when db has the role data store.UpdateCurrent("gun", storage.MetaUpdate{Role: "root", Version: 0, Data: newData}) - _, err = GetOrCreateSnapshot("gun", store, crypto) + c2, result, err := GetOrCreateSnapshot("gun", store, crypto) assert.NoError(t, err) + assert.True(t, bytes.Equal(snapJSON, result)) + assert.True(t, c1.Equal(*c2)) - // test when db role data is expired + // test when db role data is corrupt store.UpdateCurrent("gun", storage.MetaUpdate{Role: "root", Version: 1, Data: []byte{3}}) - _, err = GetOrCreateSnapshot("gun", store, crypto) + c2, result, err = GetOrCreateSnapshot("gun", store, crypto) assert.NoError(t, err) + assert.True(t, bytes.Equal(snapJSON, result)) + assert.True(t, c1.Equal(*c2)) } func TestGetSnapshotCurrExpired(t *testing.T) { @@ -168,8 +173,10 @@ func TestGetSnapshotCurrExpired(t *testing.T) { snapJSON, _ := json.Marshal(snapshot) store.UpdateCurrent("gun", storage.MetaUpdate{Role: "snapshot", Version: 0, Data: snapJSON}) - _, err = GetOrCreateSnapshot("gun", store, crypto) + c1, newJSON, err := GetOrCreateSnapshot("gun", store, crypto) assert.NoError(t, err) + assert.False(t, bytes.Equal(snapJSON, newJSON)) + assert.True(t, c1.After(time.Now().Add(-1*time.Minute))) } func TestGetSnapshotCurrCorrupt(t *testing.T) { @@ -182,7 +189,7 @@ func TestGetSnapshotCurrCorrupt(t *testing.T) { snapJSON, _ := json.Marshal(snapshot) store.UpdateCurrent("gun", storage.MetaUpdate{Role: "snapshot", Version: 0, Data: snapJSON[1:]}) - _, err = GetOrCreateSnapshot("gun", store, crypto) + _, _, err = GetOrCreateSnapshot("gun", store, crypto) assert.Error(t, err) } diff --git a/server/timestamp/timestamp.go b/server/timestamp/timestamp.go index 645c27eea4..5b1555deea 100644 --- a/server/timestamp/timestamp.go +++ b/server/timestamp/timestamp.go @@ -1,6 +1,8 @@ package timestamp import ( + "time" + "github.com/docker/go/canonical/json" "github.com/docker/notary/tuf/data" "github.com/docker/notary/tuf/signed" @@ -47,16 +49,18 @@ func GetOrCreateTimestampKey(gun string, store storage.MetaStore, crypto signed. // GetOrCreateTimestamp returns the current timestamp for the gun. This may mean // a new timestamp is generated either because none exists, or because the current // one has expired. Once generated, the timestamp is saved in the store. -func GetOrCreateTimestamp(gun string, store storage.MetaStore, cryptoService signed.CryptoService) ([]byte, error) { - snapshot, err := snapshot.GetOrCreateSnapshot(gun, store, cryptoService) +func GetOrCreateTimestamp(gun string, store storage.MetaStore, cryptoService signed.CryptoService) ( + *time.Time, []byte, error) { + + _, snapshot, err := snapshot.GetOrCreateSnapshot(gun, store, cryptoService) if err != nil { - return nil, err + return nil, nil, err } - _, d, err := store.GetCurrent(gun, "timestamp") + creation, d, err := store.GetCurrent(gun, data.CanonicalTimestampRole) if err != nil { if _, ok := err.(storage.ErrNotFound); !ok { logrus.Error("error retrieving timestamp: ", err.Error()) - return nil, err + return nil, nil, err } logrus.Debug("No timestamp found, will proceed to create first timestamp") } @@ -65,27 +69,28 @@ func GetOrCreateTimestamp(gun string, store storage.MetaStore, cryptoService sig err := json.Unmarshal(d, ts) if err != nil { logrus.Error("Failed to unmarshal existing timestamp") - return nil, err + return nil, nil, err } if !timestampExpired(ts) && !snapshotExpired(ts, snapshot) { - return d, nil + return creation, d, nil } } sgnd, version, err := CreateTimestamp(gun, ts, snapshot, store, cryptoService) if err != nil { logrus.Error("Failed to create a new timestamp") - return nil, err + return nil, nil, err } out, err := json.Marshal(sgnd) if err != nil { logrus.Error("Failed to marshal new timestamp") - return nil, err + return nil, nil, err } err = store.UpdateCurrent(gun, storage.MetaUpdate{Role: "timestamp", Version: version, Data: out}) if err != nil { - return nil, err + return nil, nil, err } - return out, nil + c := time.Now() + return &c, out, nil } // timestampExpired compares the current time to the expiry time of the timestamp diff --git a/server/timestamp/timestamp_test.go b/server/timestamp/timestamp_test.go index fe67456e10..2ec28b59b6 100644 --- a/server/timestamp/timestamp_test.go +++ b/server/timestamp/timestamp_test.go @@ -64,7 +64,7 @@ func TestGetTimestamp(t *testing.T) { _, err := GetOrCreateTimestampKey("gun", store, crypto, data.ED25519Key) assert.Nil(t, err, "GetKey errored") - _, err = GetOrCreateTimestamp("gun", store, crypto) + _, _, err = GetOrCreateTimestamp("gun", store, crypto) assert.Nil(t, err, "GetTimestamp errored") } @@ -85,7 +85,7 @@ func TestGetTimestampNewSnapshot(t *testing.T) { _, err := GetOrCreateTimestampKey("gun", store, crypto, data.ED25519Key) assert.Nil(t, err, "GetKey errored") - ts1, err := GetOrCreateTimestamp("gun", store, crypto) + c1, ts1, err := GetOrCreateTimestamp("gun", store, crypto) assert.Nil(t, err, "GetTimestamp errored") snapshot = &data.SignedSnapshot{ @@ -98,8 +98,8 @@ func TestGetTimestampNewSnapshot(t *testing.T) { store.UpdateCurrent("gun", storage.MetaUpdate{Role: "snapshot", Version: 1, Data: snapJSON}) - ts2, err := GetOrCreateTimestamp("gun", store, crypto) + c2, ts2, err := GetOrCreateTimestamp("gun", store, crypto) assert.NoError(t, err, "GetTimestamp errored") - assert.NotEqual(t, ts1, ts2, "Timestamp was not regenerated when snapshot changed") + assert.True(t, c1.Before(*c2), "Timestamp creation time incorrect") } From 9b022a9cdaf60bc248edf44318089efc0157dcaf Mon Sep 17 00:00:00 2001 From: Ying Li Date: Fri, 26 Feb 2016 16:45:12 -0800 Subject: [PATCH 3/7] Modify server handler to set cache headers based upon a cache configuration object Signed-off-by: Ying Li --- server/handlers/default.go | 95 +++++++++++++++++++++++++++------ server/handlers/default_test.go | 72 ++++++++++++++++++------- 2 files changed, 133 insertions(+), 34 deletions(-) diff --git a/server/handlers/default.go b/server/handlers/default.go index 1d43af467b..025dcb6bb0 100644 --- a/server/handlers/default.go +++ b/server/handlers/default.go @@ -25,13 +25,68 @@ import ( "github.com/docker/notary/tuf/validation" ) -const ( - // ConsistentCacheMaxAge is the Cache-Control header's max age for consistent downloads - ConsistentCacheMaxAge int = 30 * 24 * 60 * 60 // 30 days +// NewCacheControlConfig creates a new configuration for Cache-Control headers, +// which by default, sets cache max-age values for consistent (by checksum) +// downloads 30 days and non-consistent (current) downloads to 5 minutes +func NewCacheControlConfig() *CacheControlConfig { + return &CacheControlConfig{ + headerVals: map[string]int{ + "consistent": 30 * 24 * 60 * 60, // 30 days + "current": 5 * 60, // 5 minutes + }, + } +} - // NonConsistentCacheMaxAge is the Cache-Control header's max age for current (non-consistent) downloads - NonConsistentCacheMaxAge int = 5 * 60 // five minutes -) +// CacheControlConfig is the configuration for the max cache age for +// cache control headers. +type CacheControlConfig struct { + headerVals map[string]int +} + +// SetConsistentCacheMaxAge sets the Cache-Control header value for consistent +// downloads +func (c *CacheControlConfig) SetConsistentCacheMaxAge(value int) { + c.headerVals["consistent"] = value +} + +// SetCurrentCacheMaxAge sets the Cache-Control header value for current +// (non-consistent) downloads +func (c *CacheControlConfig) SetCurrentCacheMaxAge(value int) { + c.headerVals["current"] = value +} + +// UpdateConsistentHeaders updates the given Headers object with the Cache-Control +// headers for consistent downloads +func (c *CacheControlConfig) UpdateConsistentHeaders(headers http.Header, lastModified time.Time) { + c.updateHeaders(headers, lastModified, true) +} + +// UpdateCurrentHeaders updates the given Headers object with th eCache-Control +// headers for current (non-consistent) downloads +func (c *CacheControlConfig) UpdateCurrentHeaders(headers http.Header, lastModified time.Time) { + c.updateHeaders(headers, lastModified, false) +} + +func (c *CacheControlConfig) updateHeaders(headers http.Header, lastModified time.Time, consistent bool) { + var seconds int + var cacheHeader string + + if consistent { + seconds = c.headerVals["consistent"] + cacheHeader = fmt.Sprintf("public, max-age=%v, s-maxage=%v, must-revalidate", seconds, seconds) + } else { + seconds = c.headerVals["current"] + cacheHeader = fmt.Sprintf("public, max-age=%v, s-maxage=%v", seconds, seconds) + } + + if seconds > 0 { + headers.Set("Cache-Control", cacheHeader) + headers.Set("Last-Modified", lastModified.Format(time.RFC1123)) + } else { + headers.Set("Cache-Control", "max-age=0, no-cache, no-store") + headers.Set("Pragma", "no-cache") + } +} // MainHandler is the default handler for the server func MainHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) error { @@ -130,32 +185,42 @@ func getHandler(ctx context.Context, w http.ResponseWriter, r *http.Request, var checksum := vars["checksum"] tufRole := vars["tufRole"] s := ctx.Value("metaStore") + c := ctx.Value("cacheConfig") + store, ok := s.(storage.MetaStore) if !ok { return errors.ErrNoStorage.WithDetail(nil) } - creation, output, err := getRole(ctx, store, gun, tufRole, checksum) + // If cache control headers were not provided, just use the default values + cacheConfig, ok := c.(*CacheControlConfig) + if !ok { + cacheConfig = NewCacheControlConfig() + } + + lastModified, output, err := getRole(ctx, store, gun, tufRole, checksum) if err != nil { return err } - if creation == nil { + if lastModified == nil { // This shouldn't ever happen, but if it does, it just messes up the cache headers, so // proceed anyway - logrus.Warnf("Got bytes out for %s's %s (checksum: %s), but missing creation date", + logrus.Warnf("Got bytes out for %s's %s (checksum: %s), but missing lastModified date", gun, tufRole, checksum) - creation = &time.Time{} // set the last modification date to the beginning of time + lastModified = &time.Time{} // set the last modification date to the beginning of time } - maxAge := ConsistentCacheMaxAge - if checksum == "" { // maxAge should be much shorter for metadata without checksums - maxAge = NonConsistentCacheMaxAge + switch checksum { + case "": + cacheConfig.UpdateCurrentHeaders(w.Header(), *lastModified) + shasum := sha256.Sum256(output) checksum = hex.EncodeToString(shasum[:]) + + default: + cacheConfig.UpdateConsistentHeaders(w.Header(), *lastModified) } - w.Header().Set("Cache-Control", fmt.Sprintf("public, max-age=%v", maxAge)) - w.Header().Set("Last-Modified", creation.Format(time.RFC1123)) w.Header().Set("ETag", checksum) w.Write(output) return nil diff --git a/server/handlers/default_test.go b/server/handlers/default_test.go index 5b794e9a1f..e470195efd 100644 --- a/server/handlers/default_test.go +++ b/server/handlers/default_test.go @@ -174,29 +174,51 @@ func TestGetKeyHandlerCreatesOnce(t *testing.T) { } } -// Verifies that the body is as expected, the ETag is as expected, +type expectedCacheSetting int + +const ( + checksumCaching expectedCacheSetting = iota + currentCaching + noCaching +) + +// Verifies that the body is as expected, the ETag is as expected, and cache control headers +// are as expected func verifyGetResponse(t *testing.T, rw *httptest.ResponseRecorder, expectedBytes []byte, - getWithChecksum bool, checksumHex string) { + checksumHex string, cacheType expectedCacheSetting, cacheConfig *CacheControlConfig) { body, err := ioutil.ReadAll(rw.Body) assert.NoError(t, err) assert.True(t, bytes.Equal(expectedBytes, body)) - lastModified, err := time.Parse(time.RFC1123, rw.HeaderMap.Get("Last-Modified")) - assert.NoError(t, err) - assert.True(t, lastModified.After(time.Now().Add(-5*time.Minute))) + assert.Equal(t, rw.HeaderMap.Get("ETag"), checksumHex) cacheControl := rw.HeaderMap.Get("Cache-Control") - maxAge := NonConsistentCacheMaxAge - if getWithChecksum { - maxAge = ConsistentCacheMaxAge + switch cacheType { + case checksumCaching: + maxAge := cacheConfig.headerVals["consistent"] + assert.Equal(t, fmt.Sprintf("public, max-age=%v, s-maxage=%v, must-revalidate", maxAge, maxAge), cacheControl) + case currentCaching: + maxAge := cacheConfig.headerVals["current"] + assert.Equal(t, fmt.Sprintf("public, max-age=%v, s-maxage=%v", maxAge, maxAge), cacheControl) + default: + assert.Equal(t, "max-age=0, no-cache, no-store", cacheControl) } - assert.Equal(t, fmt.Sprintf("public, max-age=%v", maxAge), cacheControl) - assert.Equal(t, rw.HeaderMap.Get("ETag"), checksumHex) + switch cacheType { + case checksumCaching, currentCaching: + lastModified, err := time.Parse(time.RFC1123, rw.HeaderMap.Get("Last-Modified")) + assert.NoError(t, err) + assert.True(t, lastModified.After(time.Now().Add(-5*time.Minute))) + + assert.Equal(t, "", rw.HeaderMap.Get("Pragma")) + default: + assert.Equal(t, "", rw.HeaderMap.Get("Last-Modified")) + assert.Equal(t, "no-cache", rw.HeaderMap.Get("Pragma")) + } } -func TestGetHandlerRoot(t *testing.T) { +func TestGetHandlerRootAndNoCacheConfigProvided(t *testing.T) { metaStore := storage.NewMemStorage() repo, _, err := testutils.EmptyRepo("gun") assert.NoError(t, err) @@ -221,17 +243,19 @@ func TestGetHandlerRoot(t *testing.T) { checksumBytes := sha256.Sum256(rootJSON) checksumHex := hex.EncodeToString(checksumBytes[:]) + cacheConfig := NewCacheControlConfig() + rw := httptest.NewRecorder() assert.NoError(t, getHandler(ctx, rw, req, vars)) - verifyGetResponse(t, rw, rootJSON, false, checksumHex) + verifyGetResponse(t, rw, rootJSON, checksumHex, currentCaching, cacheConfig) vars["checksum"] = checksumHex rw = httptest.NewRecorder() assert.NoError(t, getHandler(ctx, rw, req, vars)) - verifyGetResponse(t, rw, rootJSON, true, checksumHex) + verifyGetResponse(t, rw, rootJSON, checksumHex, checksumCaching, cacheConfig) } -func TestGetHandlerTimestamp(t *testing.T) { +func TestGetHandlerTimestampWithCacheValues(t *testing.T) { metaStore := storage.NewMemStorage() repo, crypto, err := testutils.EmptyRepo("gun") assert.NoError(t, err) @@ -259,20 +283,25 @@ func TestGetHandlerTimestamp(t *testing.T) { "tufRole": "timestamp", } + cacheConfig := NewCacheControlConfig() + cacheConfig.SetConsistentCacheMaxAge(365 * 24 * 60 * 60) + cacheConfig.SetCurrentCacheMaxAge(1) + ctx = context.WithValue(ctx, "cacheConfig", cacheConfig) + checksumBytes := sha256.Sum256(tsJSON) checksumHex := hex.EncodeToString(checksumBytes[:]) rw := httptest.NewRecorder() assert.NoError(t, getHandler(ctx, rw, req, vars)) - verifyGetResponse(t, rw, tsJSON, false, checksumHex) + verifyGetResponse(t, rw, tsJSON, checksumHex, currentCaching, cacheConfig) vars["checksum"] = checksumHex rw = httptest.NewRecorder() assert.NoError(t, getHandler(ctx, rw, req, vars)) - verifyGetResponse(t, rw, tsJSON, true, checksumHex) + verifyGetResponse(t, rw, tsJSON, checksumHex, checksumCaching, cacheConfig) } -func TestGetHandlerSnapshot(t *testing.T) { +func TestGetHandlerSnapshotWithNoCaching(t *testing.T) { metaStore := storage.NewMemStorage() repo, crypto, err := testutils.EmptyRepo("gun") assert.NoError(t, err) @@ -294,17 +323,22 @@ func TestGetHandlerSnapshot(t *testing.T) { "tufRole": "snapshot", } + cacheConfig := NewCacheControlConfig() + cacheConfig.SetConsistentCacheMaxAge(0) + cacheConfig.SetCurrentCacheMaxAge(-1) + ctx = context.WithValue(ctx, "cacheConfig", cacheConfig) + checksumBytes := sha256.Sum256(snJSON) checksumHex := hex.EncodeToString(checksumBytes[:]) rw := httptest.NewRecorder() assert.NoError(t, getHandler(ctx, rw, req, vars)) - verifyGetResponse(t, rw, snJSON, false, checksumHex) + verifyGetResponse(t, rw, snJSON, checksumHex, noCaching, cacheConfig) vars["checksum"] = checksumHex rw = httptest.NewRecorder() assert.NoError(t, getHandler(ctx, rw, req, vars)) - verifyGetResponse(t, rw, snJSON, true, checksumHex) + verifyGetResponse(t, rw, snJSON, checksumHex, noCaching, cacheConfig) } func TestGetHandler404(t *testing.T) { From 329b47d2531179322623490e8ed51d103deb00a0 Mon Sep 17 00:00:00 2001 From: Ying Li Date: Fri, 26 Feb 2016 17:11:19 -0800 Subject: [PATCH 4/7] Parse for cache control options in the server config file Signed-off-by: Ying Li --- cmd/notary-server/main.go | 32 ++++++++++++++++++++++++++++++++ cmd/notary-server/main_test.go | 23 +++++++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/cmd/notary-server/main.go b/cmd/notary-server/main.go index 46635356b8..ecc54678df 100644 --- a/cmd/notary-server/main.go +++ b/cmd/notary-server/main.go @@ -8,6 +8,7 @@ import ( "net/http" _ "net/http/pprof" "os" + "strconv" "time" "github.com/Sirupsen/logrus" @@ -23,6 +24,7 @@ import ( "github.com/docker/go-connections/tlsconfig" "github.com/docker/notary/server" + "github.com/docker/notary/server/handlers" "github.com/docker/notary/utils" "github.com/docker/notary/version" "github.com/spf13/viper" @@ -32,6 +34,7 @@ import ( const ( jsonLogFormat = "json" DebugAddress = "localhost:8080" + maxMaxAge = 31536000 ) var ( @@ -170,6 +173,29 @@ func getTrustService(configuration *viper.Viper, sFactory signerFactory, return notarySigner, keyAlgo, nil } +func getCacheConfig(configuration *viper.Viper) (*handlers.CacheControlConfig, error) { + cacheConfig := handlers.NewCacheControlConfig() + for _, option := range []string{"current_metadata", "metadata_by_checksum"} { + m := configuration.GetString(fmt.Sprintf("caching.max_age.%s", option)) + if m == "" { + continue + } + seconds, err := strconv.Atoi(m) + if err != nil || seconds < 0 || seconds > maxMaxAge { + return nil, fmt.Errorf( + "must specify a cache-control max-age between 0 and %v", maxMaxAge) + } + + switch option { + case "current_metadata": + cacheConfig.SetCurrentCacheMaxAge(seconds) + default: + cacheConfig.SetConsistentCacheMaxAge(seconds) + } + } + return cacheConfig, nil +} + func main() { flag.Usage = usage flag.Parse() @@ -215,6 +241,12 @@ func main() { } ctx = context.WithValue(ctx, "metaStore", store) + cacheConfig, err := getCacheConfig(mainViper) + if err != nil { + logrus.Fatal(err.Error()) + } + ctx = context.WithValue(ctx, "cacheConfig", cacheConfig) + httpAddr, tlsConfig, err := getAddrAndTLSConfig(mainViper) if err != nil { logrus.Fatal(err.Error()) diff --git a/cmd/notary-server/main_test.go b/cmd/notary-server/main_test.go index ee8450ad76..be23f35ab7 100644 --- a/cmd/notary-server/main_test.go +++ b/cmd/notary-server/main_test.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "fmt" "io/ioutil" + "net/http" "os" "reflect" "strings" @@ -328,3 +329,25 @@ func TestGetMemoryStore(t *testing.T) { _, ok := store.(*storage.MemStorage) assert.True(t, ok) } + +func TestGetCacheConfig(t *testing.T) { + valid := `{"caching": {"max_age": {"current_metadata": 0, "metadata_by_checksum": 31536000}}}` + invalids := []string{ + `{"caching": {"max_age": {"current_metadata": 0, "metadata_by_checksum": 31539000}}}`, + `{"caching": {"max_age": {"current_metadata": -1, "metadata_by_checksum": 300}}}`, + `{"caching": {"max_age": {"current_metadata": "hello", "metadata_by_checksum": 300}}}`, + } + + cacheConfig, err := getCacheConfig(configure(valid)) + assert.NoError(t, err) + h := http.Header{} + cacheConfig.UpdateCurrentHeaders(h, time.Now()) + assert.True(t, strings.Contains(h.Get("Cache-Control"), "max-age=0")) + cacheConfig.UpdateConsistentHeaders(h, time.Now()) + assert.True(t, strings.Contains(h.Get("Cache-Control"), "max-age=31536000")) + + for _, invalid := range invalids { + _, err := getCacheConfig(configure(invalid)) + assert.Error(t, err) + } +} From e1397f4b035e344a1b06a16a6392193c3f413bb9 Mon Sep 17 00:00:00 2001 From: Ying Li Date: Fri, 26 Feb 2016 17:23:30 -0800 Subject: [PATCH 5/7] Use updated-at for last modification date for getting current metadata Signed-off-by: Ying Li --- cmd/notary-server/main.go | 12 +++++++++++- server/handlers/default.go | 9 ++++++--- server/handlers/roles.go | 24 ++++++++++++------------ server/snapshot/snapshot.go | 4 ++-- server/storage/database.go | 26 ++++++++++++++++---------- server/storage/database_test.go | 4 +++- server/storage/interface.go | 2 +- server/storage/memory.go | 16 ++++++++-------- server/timestamp/timestamp.go | 4 ++-- server/timestamp/timestamp_test.go | 2 +- 10 files changed, 62 insertions(+), 41 deletions(-) diff --git a/cmd/notary-server/main.go b/cmd/notary-server/main.go index ecc54678df..af75aefd54 100644 --- a/cmd/notary-server/main.go +++ b/cmd/notary-server/main.go @@ -34,7 +34,10 @@ import ( const ( jsonLogFormat = "json" DebugAddress = "localhost:8080" - maxMaxAge = 31536000 + // This is the generally recommended maximum age for Cache-Control headers + // (one year, in seconds, since one year is forever in terms of internet + // content) + maxMaxAge = 31536000 ) var ( @@ -173,6 +176,13 @@ func getTrustService(configuration *viper.Viper, sFactory signerFactory, return notarySigner, keyAlgo, nil } +// Gets the cache configuration for GET-ting metadata. This is the max-age +// (an integer in seconds, just like in the Cache-Control header) for consistent +// (content-addressable) downloads and current (latest version) downloads. +// The max-age must be between 0 and 31536000 (one year in seconds, which is +// the recommended maximum time data is cached), else parsing will return an +// error. A max-age of 0 will disable caching for that type of download +// (consistent or current). func getCacheConfig(configuration *viper.Viper) (*handlers.CacheControlConfig, error) { cacheConfig := handlers.NewCacheControlConfig() for _, option := range []string{"current_metadata", "metadata_by_checksum"} { diff --git a/server/handlers/default.go b/server/handlers/default.go index 025dcb6bb0..50526d49c0 100644 --- a/server/handlers/default.go +++ b/server/handlers/default.go @@ -26,8 +26,11 @@ import ( ) // NewCacheControlConfig creates a new configuration for Cache-Control headers, -// which by default, sets cache max-age values for consistent (by checksum) -// downloads 30 days and non-consistent (current) downloads to 5 minutes +// which by default, sets cache max-age values for consistent +// (content-addressable, by checksum) downloads 30 days and non-consistent +// (current/latest version) downloads to 5 minutes. +// If a max-age of <=0 is supplied, then caching will be disabled for that type +// of download (this may be desirable for the current downloads, for example). func NewCacheControlConfig() *CacheControlConfig { return &CacheControlConfig{ headerVals: map[string]int{ @@ -61,7 +64,7 @@ func (c *CacheControlConfig) UpdateConsistentHeaders(headers http.Header, lastMo c.updateHeaders(headers, lastModified, true) } -// UpdateCurrentHeaders updates the given Headers object with th eCache-Control +// UpdateCurrentHeaders updates the given Headers object with the Cache-Control // headers for current (non-consistent) downloads func (c *CacheControlConfig) UpdateCurrentHeaders(headers http.Header, lastModified time.Time) { c.updateHeaders(headers, lastModified, false) diff --git a/server/handlers/roles.go b/server/handlers/roles.go index 415cf2bb85..be2b788c1d 100644 --- a/server/handlers/roles.go +++ b/server/handlers/roles.go @@ -15,9 +15,9 @@ import ( func getRole(ctx context.Context, store storage.MetaStore, gun, role, checksum string) (*time.Time, []byte, error) { var ( - creation *time.Time - out []byte - err error + lastModified *time.Time + out []byte + err error ) if checksum == "" { // the timestamp and snapshot might be server signed so are @@ -26,9 +26,9 @@ func getRole(ctx context.Context, store storage.MetaStore, gun, role, checksum s case data.CanonicalTimestampRole, data.CanonicalSnapshotRole: return getMaybeServerSigned(ctx, store, gun, role) } - creation, out, err = store.GetCurrent(gun, role) + lastModified, out, err = store.GetCurrent(gun, role) } else { - creation, out, err = store.GetChecksum(gun, role, checksum) + lastModified, out, err = store.GetChecksum(gun, role, checksum) } if err != nil { @@ -41,7 +41,7 @@ func getRole(ctx context.Context, store storage.MetaStore, gun, role, checksum s return nil, nil, errors.ErrMetadataNotFound.WithDetail(nil) } - return creation, out, nil + return lastModified, out, nil } // getMaybeServerSigned writes the current snapshot or timestamp (based on the @@ -57,15 +57,15 @@ func getMaybeServerSigned(ctx context.Context, store storage.MetaStore, gun, rol } var ( - creation *time.Time - out []byte - err error + lastModified *time.Time + out []byte + err error ) switch role { case data.CanonicalSnapshotRole: - creation, out, err = snapshot.GetOrCreateSnapshot(gun, store, cryptoService) + lastModified, out, err = snapshot.GetOrCreateSnapshot(gun, store, cryptoService) case data.CanonicalTimestampRole: - creation, out, err = timestamp.GetOrCreateTimestamp(gun, store, cryptoService) + lastModified, out, err = timestamp.GetOrCreateTimestamp(gun, store, cryptoService) } if err != nil { switch err.(type) { @@ -76,5 +76,5 @@ func getMaybeServerSigned(ctx context.Context, store storage.MetaStore, gun, rol } } - return creation, out, nil + return lastModified, out, nil } diff --git a/server/snapshot/snapshot.go b/server/snapshot/snapshot.go index be31c6fc8c..2d1602d02e 100644 --- a/server/snapshot/snapshot.go +++ b/server/snapshot/snapshot.go @@ -49,7 +49,7 @@ func GetOrCreateSnapshotKey(gun string, store storage.KeyStore, crypto signed.Cr func GetOrCreateSnapshot(gun string, store storage.MetaStore, cryptoService signed.CryptoService) ( *time.Time, []byte, error) { - creation, d, err := store.GetCurrent(gun, data.CanonicalSnapshotRole) + lastModified, d, err := store.GetCurrent(gun, data.CanonicalSnapshotRole) if err != nil { return nil, nil, err } @@ -63,7 +63,7 @@ func GetOrCreateSnapshot(gun string, store storage.MetaStore, cryptoService sign } if !snapshotExpired(sn) { - return creation, d, nil + return lastModified, d, nil } } diff --git a/server/storage/database.go b/server/storage/database.go index bd62301760..7c4f2ae10e 100644 --- a/server/storage/database.go +++ b/server/storage/database.go @@ -120,9 +120,12 @@ func (db *SQLStorage) UpdateMany(gun string, updates []MetaUpdate) error { // GetCurrent gets a specific TUF record func (db *SQLStorage) GetCurrent(gun, tufRole string) (*time.Time, []byte, error) { var row TUFFile - q := db.Select("created_at, data").Where( + q := db.Select("updated_at, data").Where( &TUFFile{Gun: gun, Role: tufRole}).Order("version desc").Limit(1).First(&row) - return returnRead(q, row) + if err := isReadErr(q, row); err != nil { + return nil, nil, err + } + return &(row.UpdatedAt), row.Data, nil } // GetChecksum gets a specific TUF record by its hex checksum @@ -135,18 +138,21 @@ func (db *SQLStorage) GetChecksum(gun, tufRole, checksum string) (*time.Time, [] Sha256: checksum, }, ).First(&row) - return returnRead(q, row) -} - -func returnRead(q *gorm.DB, row TUFFile) (*time.Time, []byte, error) { - if q.RecordNotFound() { - return nil, nil, ErrNotFound{} - } else if q.Error != nil { - return nil, nil, q.Error + if err := isReadErr(q, row); err != nil { + return nil, nil, err } return &(row.CreatedAt), row.Data, nil } +func isReadErr(q *gorm.DB, row TUFFile) error { + if q.RecordNotFound() { + return ErrNotFound{} + } else if q.Error != nil { + return q.Error + } + return nil +} + // Delete deletes all the records for a specific GUN func (db *SQLStorage) Delete(gun string) error { return db.Where(&TUFFile{Gun: gun}).Delete(TUFFile{}).Error diff --git a/server/storage/database_test.go b/server/storage/database_test.go index 6991950438..c1957daaa2 100644 --- a/server/storage/database_test.go +++ b/server/storage/database_test.go @@ -4,6 +4,7 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" + "fmt" "io/ioutil" "os" "testing" @@ -244,7 +245,8 @@ func TestSQLGetCurrent(t *testing.T) { cDate, byt, err := dbStore.GetCurrent("testGUN", "root") require.NoError(t, err, "There should not be any errors getting.") require.Equal(t, []byte("1"), byt, "Returned data was incorrect") - // the creation date was sometime wthin the last minute + // the update date was sometime wthin the last minute + fmt.Println(cDate) require.True(t, cDate.After(time.Now().Add(-1*time.Minute))) require.True(t, cDate.Before(time.Now().Add(5*time.Second))) diff --git a/server/storage/interface.go b/server/storage/interface.go index f5d7691de2..965e263814 100644 --- a/server/storage/interface.go +++ b/server/storage/interface.go @@ -26,7 +26,7 @@ type MetaStore interface { // none of the metadata is added, and an error is be returned. UpdateMany(gun string, updates []MetaUpdate) error - // GetCurrent returns the creation date and data part of the metadata for + // GetCurrent returns the modification date and data part of the metadata for // the latest version of the given GUN and role. If there is no data for // the given GUN and role, an error is returned. GetCurrent(gun, tufRole string) (created *time.Time, data []byte, err error) diff --git a/server/storage/memory.go b/server/storage/memory.go index 373902b4c7..3103e6043d 100644 --- a/server/storage/memory.go +++ b/server/storage/memory.go @@ -15,9 +15,9 @@ type key struct { } type ver struct { - version int - data []byte - creation time.Time + version int + data []byte + createupdate time.Time } // MemStorage is really just designed for dev and testing. It is very @@ -50,7 +50,7 @@ func (st *MemStorage) UpdateCurrent(gun string, update MetaUpdate) error { } } } - version := ver{version: update.Version, data: update.Data, creation: time.Now()} + version := ver{version: update.Version, data: update.Data, createupdate: time.Now()} st.tufMeta[id] = append(st.tufMeta[id], &version) checksumBytes := sha256.Sum256(update.Data) checksum := hex.EncodeToString(checksumBytes[:]) @@ -71,7 +71,7 @@ func (st *MemStorage) UpdateMany(gun string, updates []MetaUpdate) error { return nil } -// GetCurrent returns the creation date metadata for a given role, under a GUN. +// GetCurrent returns the createupdate date metadata for a given role, under a GUN. func (st *MemStorage) GetCurrent(gun, role string) (*time.Time, []byte, error) { id := entryKey(gun, role) st.lock.Lock() @@ -80,10 +80,10 @@ func (st *MemStorage) GetCurrent(gun, role string) (*time.Time, []byte, error) { if !ok || len(space) == 0 { return nil, nil, ErrNotFound{} } - return &(space[len(space)-1].creation), space[len(space)-1].data, nil + return &(space[len(space)-1].createupdate), space[len(space)-1].data, nil } -// GetChecksum returns the creation date and metadata for a given role, under a GUN. +// GetChecksum returns the createupdate date and metadata for a given role, under a GUN. func (st *MemStorage) GetChecksum(gun, role, checksum string) (*time.Time, []byte, error) { st.lock.Lock() defer st.lock.Unlock() @@ -91,7 +91,7 @@ func (st *MemStorage) GetChecksum(gun, role, checksum string) (*time.Time, []byt if !ok || len(space.data) == 0 { return nil, nil, ErrNotFound{} } - return &(space.creation), space.data, nil + return &(space.createupdate), space.data, nil } // Delete deletes all the metadata for a given GUN diff --git a/server/timestamp/timestamp.go b/server/timestamp/timestamp.go index 5b1555deea..ab99102630 100644 --- a/server/timestamp/timestamp.go +++ b/server/timestamp/timestamp.go @@ -56,7 +56,7 @@ func GetOrCreateTimestamp(gun string, store storage.MetaStore, cryptoService sig if err != nil { return nil, nil, err } - creation, d, err := store.GetCurrent(gun, data.CanonicalTimestampRole) + lastModified, d, err := store.GetCurrent(gun, data.CanonicalTimestampRole) if err != nil { if _, ok := err.(storage.ErrNotFound); !ok { logrus.Error("error retrieving timestamp: ", err.Error()) @@ -72,7 +72,7 @@ func GetOrCreateTimestamp(gun string, store storage.MetaStore, cryptoService sig return nil, nil, err } if !timestampExpired(ts) && !snapshotExpired(ts, snapshot) { - return creation, d, nil + return lastModified, d, nil } } sgnd, version, err := CreateTimestamp(gun, ts, snapshot, store, cryptoService) diff --git a/server/timestamp/timestamp_test.go b/server/timestamp/timestamp_test.go index 2ec28b59b6..0ca272016c 100644 --- a/server/timestamp/timestamp_test.go +++ b/server/timestamp/timestamp_test.go @@ -101,5 +101,5 @@ func TestGetTimestampNewSnapshot(t *testing.T) { c2, ts2, err := GetOrCreateTimestamp("gun", store, crypto) assert.NoError(t, err, "GetTimestamp errored") assert.NotEqual(t, ts1, ts2, "Timestamp was not regenerated when snapshot changed") - assert.True(t, c1.Before(*c2), "Timestamp creation time incorrect") + assert.True(t, c1.Before(*c2), "Timestamp modification time incorrect") } From 84f5ed28d2f1eaeb8488c3c02e199a099311fbab Mon Sep 17 00:00:00 2001 From: Ying Li Date: Tue, 1 Mar 2016 18:17:04 -0500 Subject: [PATCH 6/7] Move the configuration parsing for notary-server to its own file Signed-off-by: Ying Li --- cmd/notary-server/config.go | 165 ++++++++++++++++++++++++++++++++++++ cmd/notary-server/main.go | 158 +--------------------------------- 2 files changed, 166 insertions(+), 157 deletions(-) create mode 100644 cmd/notary-server/config.go diff --git a/cmd/notary-server/config.go b/cmd/notary-server/config.go new file mode 100644 index 0000000000..74ffab2e79 --- /dev/null +++ b/cmd/notary-server/config.go @@ -0,0 +1,165 @@ +package main + +import ( + "crypto/tls" + "fmt" + "strconv" + "time" + + "github.com/Sirupsen/logrus" + "github.com/docker/distribution/health" + _ "github.com/docker/distribution/registry/auth/htpasswd" + _ "github.com/docker/distribution/registry/auth/token" + "github.com/docker/go-connections/tlsconfig" + "github.com/docker/notary/server/handlers" + "github.com/docker/notary/server/storage" + "github.com/docker/notary/signer/client" + "github.com/docker/notary/tuf/data" + "github.com/docker/notary/tuf/signed" + "github.com/docker/notary/utils" + _ "github.com/go-sql-driver/mysql" + "github.com/spf13/viper" +) + +// get the address for the HTTP server, and parses the optional TLS +// configuration for the server - if no TLS configuration is specified, +// TLS is not enabled. +func getAddrAndTLSConfig(configuration *viper.Viper) (string, *tls.Config, error) { + httpAddr := configuration.GetString("server.http_addr") + if httpAddr == "" { + return "", nil, fmt.Errorf("http listen address required for server") + } + + tlsConfig, err := utils.ParseServerTLS(configuration, false) + if err != nil { + return "", nil, fmt.Errorf(err.Error()) + } + return httpAddr, tlsConfig, nil +} + +// sets up TLS for the GRPC connection to notary-signer +func grpcTLS(configuration *viper.Viper) (*tls.Config, error) { + rootCA := utils.GetPathRelativeToConfig(configuration, "trust_service.tls_ca_file") + clientCert := utils.GetPathRelativeToConfig(configuration, "trust_service.tls_client_cert") + clientKey := utils.GetPathRelativeToConfig(configuration, "trust_service.tls_client_key") + + if clientCert == "" && clientKey != "" || clientCert != "" && clientKey == "" { + return nil, fmt.Errorf("either pass both client key and cert, or neither") + } + + tlsConfig, err := tlsconfig.Client(tlsconfig.Options{ + CAFile: rootCA, + CertFile: clientCert, + KeyFile: clientKey, + }) + if err != nil { + return nil, fmt.Errorf( + "Unable to configure TLS to the trust service: %s", err.Error()) + } + return tlsConfig, nil +} + +// parses the configuration and returns a backing store for the TUF files +func getStore(configuration *viper.Viper, allowedBackends []string) ( + storage.MetaStore, error) { + + storeConfig, err := utils.ParseStorage(configuration, allowedBackends) + if err != nil { + return nil, err + } + logrus.Infof("Using %s backend", storeConfig.Backend) + + if storeConfig.Backend == utils.MemoryBackend { + return storage.NewMemStorage(), nil + } + + store, err := storage.NewSQLStorage(storeConfig.Backend, storeConfig.Source) + if err != nil { + return nil, fmt.Errorf("Error starting DB driver: %s", err.Error()) + } + health.RegisterPeriodicFunc( + "DB operational", store.CheckHealth, time.Second*60) + return store, nil +} + +type signerFactory func(hostname, port string, tlsConfig *tls.Config) *client.NotarySigner +type healthRegister func(name string, checkFunc func() error, duration time.Duration) + +// parses the configuration and determines which trust service and key algorithm +// to return +func getTrustService(configuration *viper.Viper, sFactory signerFactory, + hRegister healthRegister) (signed.CryptoService, string, error) { + + switch configuration.GetString("trust_service.type") { + case "local": + logrus.Info("Using local signing service, which requires ED25519. " + + "Ignoring all other trust_service parameters, including keyAlgorithm") + return signed.NewEd25519(), data.ED25519Key, nil + case "remote": + default: + return nil, "", fmt.Errorf( + "must specify either a \"local\" or \"remote\" type for trust_service") + } + + keyAlgo := configuration.GetString("trust_service.key_algorithm") + if keyAlgo != data.ED25519Key && keyAlgo != data.ECDSAKey && keyAlgo != data.RSAKey { + return nil, "", fmt.Errorf("invalid key algorithm configured: %s", keyAlgo) + } + + clientTLS, err := grpcTLS(configuration) + if err != nil { + return nil, "", err + } + + logrus.Info("Using remote signing service") + + notarySigner := sFactory( + configuration.GetString("trust_service.hostname"), + configuration.GetString("trust_service.port"), + clientTLS, + ) + + minute := 1 * time.Minute + hRegister( + "Trust operational", + // If the trust service fails, the server is degraded but not + // exactly unheatlthy, so always return healthy and just log an + // error. + func() error { + err := notarySigner.CheckHealth(minute) + if err != nil { + logrus.Error("Trust not fully operational: ", err.Error()) + } + return nil + }, + minute) + return notarySigner, keyAlgo, nil +} + +// Gets the cache configuration for GET-ting metadata. This is the max-age +// (an integer in seconds, just like in the Cache-Control header) for consistent +// (content-addressable) downloads and current (latest version) downloads. +// The max-age must be between 0 and 31536000 (one year in seconds, which is +// the recommended maximum time data is cached), else parsing will return an +// error. A max-age of 0 will disable caching for that type of download +// (consistent or current). +func getCacheConfig(configuration *viper.Viper) (*handlers.CacheControlConfig, error) { + cacheConfig := handlers.NewCacheControlConfig() + for option, setMaxAge := range map[string]func(int){ + "current_metadata": cacheConfig.SetCurrentCacheMaxAge, + "metadata_by_checksum": cacheConfig.SetConsistentCacheMaxAge, + } { + m := configuration.GetString(fmt.Sprintf("caching.max_age.%s", option)) + if m == "" { + continue + } + seconds, err := strconv.Atoi(m) + if err != nil || seconds < 0 || seconds > maxMaxAge { + return nil, fmt.Errorf( + "must specify a cache-control max-age between 0 and %v", maxMaxAge) + } + + setMaxAge(seconds) + } + return cacheConfig, nil +} diff --git a/cmd/notary-server/main.go b/cmd/notary-server/main.go index af75aefd54..2c6aa23112 100644 --- a/cmd/notary-server/main.go +++ b/cmd/notary-server/main.go @@ -1,30 +1,19 @@ package main import ( - "crypto/tls" _ "expvar" "flag" "fmt" "net/http" _ "net/http/pprof" "os" - "strconv" - "time" "github.com/Sirupsen/logrus" "github.com/docker/distribution/health" - _ "github.com/docker/distribution/registry/auth/htpasswd" - _ "github.com/docker/distribution/registry/auth/token" - "github.com/docker/notary/server/storage" "github.com/docker/notary/signer/client" - "github.com/docker/notary/tuf/data" - "github.com/docker/notary/tuf/signed" - _ "github.com/go-sql-driver/mysql" "golang.org/x/net/context" - "github.com/docker/go-connections/tlsconfig" "github.com/docker/notary/server" - "github.com/docker/notary/server/handlers" "github.com/docker/notary/utils" "github.com/docker/notary/version" "github.com/spf13/viper" @@ -37,7 +26,7 @@ const ( // This is the generally recommended maximum age for Cache-Control headers // (one year, in seconds, since one year is forever in terms of internet // content) - maxMaxAge = 31536000 + maxMaxAge = 60 * 60 * 24 * 365 ) var ( @@ -61,151 +50,6 @@ func init() { } } -// get the address for the HTTP server, and parses the optional TLS -// configuration for the server - if no TLS configuration is specified, -// TLS is not enabled. -func getAddrAndTLSConfig(configuration *viper.Viper) (string, *tls.Config, error) { - httpAddr := configuration.GetString("server.http_addr") - if httpAddr == "" { - return "", nil, fmt.Errorf("http listen address required for server") - } - - tlsConfig, err := utils.ParseServerTLS(configuration, false) - if err != nil { - return "", nil, fmt.Errorf(err.Error()) - } - return httpAddr, tlsConfig, nil -} - -// sets up TLS for the GRPC connection to notary-signer -func grpcTLS(configuration *viper.Viper) (*tls.Config, error) { - rootCA := utils.GetPathRelativeToConfig(configuration, "trust_service.tls_ca_file") - clientCert := utils.GetPathRelativeToConfig(configuration, "trust_service.tls_client_cert") - clientKey := utils.GetPathRelativeToConfig(configuration, "trust_service.tls_client_key") - - if clientCert == "" && clientKey != "" || clientCert != "" && clientKey == "" { - return nil, fmt.Errorf("either pass both client key and cert, or neither") - } - - tlsConfig, err := tlsconfig.Client(tlsconfig.Options{ - CAFile: rootCA, - CertFile: clientCert, - KeyFile: clientKey, - }) - if err != nil { - return nil, fmt.Errorf( - "Unable to configure TLS to the trust service: %s", err.Error()) - } - return tlsConfig, nil -} - -// parses the configuration and returns a backing store for the TUF files -func getStore(configuration *viper.Viper, allowedBackends []string) ( - storage.MetaStore, error) { - - storeConfig, err := utils.ParseStorage(configuration, allowedBackends) - if err != nil { - return nil, err - } - logrus.Infof("Using %s backend", storeConfig.Backend) - - if storeConfig.Backend == utils.MemoryBackend { - return storage.NewMemStorage(), nil - } - - store, err := storage.NewSQLStorage(storeConfig.Backend, storeConfig.Source) - if err != nil { - return nil, fmt.Errorf("Error starting DB driver: %s", err.Error()) - } - health.RegisterPeriodicFunc( - "DB operational", store.CheckHealth, time.Second*60) - return store, nil -} - -type signerFactory func(hostname, port string, tlsConfig *tls.Config) *client.NotarySigner -type healthRegister func(name string, checkFunc func() error, duration time.Duration) - -// parses the configuration and determines which trust service and key algorithm -// to return -func getTrustService(configuration *viper.Viper, sFactory signerFactory, - hRegister healthRegister) (signed.CryptoService, string, error) { - - switch configuration.GetString("trust_service.type") { - case "local": - logrus.Info("Using local signing service, which requires ED25519. " + - "Ignoring all other trust_service parameters, including keyAlgorithm") - return signed.NewEd25519(), data.ED25519Key, nil - case "remote": - default: - return nil, "", fmt.Errorf( - "must specify either a \"local\" or \"remote\" type for trust_service") - } - - keyAlgo := configuration.GetString("trust_service.key_algorithm") - if keyAlgo != data.ED25519Key && keyAlgo != data.ECDSAKey && keyAlgo != data.RSAKey { - return nil, "", fmt.Errorf("invalid key algorithm configured: %s", keyAlgo) - } - - clientTLS, err := grpcTLS(configuration) - if err != nil { - return nil, "", err - } - - logrus.Info("Using remote signing service") - - notarySigner := sFactory( - configuration.GetString("trust_service.hostname"), - configuration.GetString("trust_service.port"), - clientTLS, - ) - - minute := 1 * time.Minute - hRegister( - "Trust operational", - // If the trust service fails, the server is degraded but not - // exactly unheatlthy, so always return healthy and just log an - // error. - func() error { - err := notarySigner.CheckHealth(minute) - if err != nil { - logrus.Error("Trust not fully operational: ", err.Error()) - } - return nil - }, - minute) - return notarySigner, keyAlgo, nil -} - -// Gets the cache configuration for GET-ting metadata. This is the max-age -// (an integer in seconds, just like in the Cache-Control header) for consistent -// (content-addressable) downloads and current (latest version) downloads. -// The max-age must be between 0 and 31536000 (one year in seconds, which is -// the recommended maximum time data is cached), else parsing will return an -// error. A max-age of 0 will disable caching for that type of download -// (consistent or current). -func getCacheConfig(configuration *viper.Viper) (*handlers.CacheControlConfig, error) { - cacheConfig := handlers.NewCacheControlConfig() - for _, option := range []string{"current_metadata", "metadata_by_checksum"} { - m := configuration.GetString(fmt.Sprintf("caching.max_age.%s", option)) - if m == "" { - continue - } - seconds, err := strconv.Atoi(m) - if err != nil || seconds < 0 || seconds > maxMaxAge { - return nil, fmt.Errorf( - "must specify a cache-control max-age between 0 and %v", maxMaxAge) - } - - switch option { - case "current_metadata": - cacheConfig.SetCurrentCacheMaxAge(seconds) - default: - cacheConfig.SetConsistentCacheMaxAge(seconds) - } - } - return cacheConfig, nil -} - func main() { flag.Usage = usage flag.Parse() From e25746dac3b68a2e341f4d5011b7a7496442a127 Mon Sep 17 00:00:00 2001 From: Ying Li Date: Fri, 11 Mar 2016 16:23:19 -0800 Subject: [PATCH 7/7] Use a CacheControlHandler that wraps other handlers instead Signed-off-by: Ying Li --- client/client_test.go | 2 +- cmd/notary-server/config.go | 36 +++--- cmd/notary-server/main.go | 17 +-- cmd/notary-server/main_test.go | 12 +- cmd/notary/integration_test.go | 2 +- cmd/notary/keys.go | 2 +- cmd/notary/main_test.go | 8 +- server/handlers/default.go | 100 ++------------- server/handlers/default_test.go | 97 ++------------- server/integration_test.go | 2 +- server/server.go | 40 +++--- server/server_test.go | 115 +++++++++++++----- utils/http.go | 101 +++++++++++++++ utils/http_test.go | 209 +++++++++++++++++++++++++++++--- 14 files changed, 457 insertions(+), 286 deletions(-) diff --git a/client/client_test.go b/client/client_test.go index 0bffadb2aa..39de2cd6f6 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -141,7 +141,7 @@ func fullTestServer(t *testing.T) *httptest.Server { cryptoService := cryptoservice.NewCryptoService( "", trustmanager.NewKeyMemoryStore(passphraseRetriever)) - return httptest.NewServer(server.RootHandler(nil, ctx, cryptoService)) + return httptest.NewServer(server.RootHandler(nil, ctx, cryptoService, nil, nil)) } // server that returns some particular error code all the time diff --git a/cmd/notary-server/config.go b/cmd/notary-server/config.go index 74ffab2e79..06f00ddae1 100644 --- a/cmd/notary-server/config.go +++ b/cmd/notary-server/config.go @@ -11,7 +11,6 @@ import ( _ "github.com/docker/distribution/registry/auth/htpasswd" _ "github.com/docker/distribution/registry/auth/token" "github.com/docker/go-connections/tlsconfig" - "github.com/docker/notary/server/handlers" "github.com/docker/notary/server/storage" "github.com/docker/notary/signer/client" "github.com/docker/notary/tuf/data" @@ -123,7 +122,7 @@ func getTrustService(configuration *viper.Viper, sFactory signerFactory, hRegister( "Trust operational", // If the trust service fails, the server is degraded but not - // exactly unheatlthy, so always return healthy and just log an + // exactly unhealthy, so always return healthy and just log an // error. func() error { err := notarySigner.CheckHealth(minute) @@ -136,30 +135,29 @@ func getTrustService(configuration *viper.Viper, sFactory signerFactory, return notarySigner, keyAlgo, nil } -// Gets the cache configuration for GET-ting metadata. This is the max-age -// (an integer in seconds, just like in the Cache-Control header) for consistent -// (content-addressable) downloads and current (latest version) downloads. -// The max-age must be between 0 and 31536000 (one year in seconds, which is -// the recommended maximum time data is cached), else parsing will return an -// error. A max-age of 0 will disable caching for that type of download -// (consistent or current). -func getCacheConfig(configuration *viper.Viper) (*handlers.CacheControlConfig, error) { - cacheConfig := handlers.NewCacheControlConfig() - for option, setMaxAge := range map[string]func(int){ - "current_metadata": cacheConfig.SetCurrentCacheMaxAge, - "metadata_by_checksum": cacheConfig.SetConsistentCacheMaxAge, - } { - m := configuration.GetString(fmt.Sprintf("caching.max_age.%s", option)) +// Gets the cache configuration for GET-ting current and checksummed metadata +// This is mainly the max-age (an integer in seconds, just like in the +// Cache-Control header) for consistent (content-addressable) downloads and +// current (latest version) downloads. The max-age must be between 0 and 31536000 +// (one year in seconds, which is the recommended maximum time data is cached), +// else parsing will return an error. A max-age of 0 will disable caching for +// that type of download (consistent or current). +func getCacheConfig(configuration *viper.Viper) (utils.CacheControlConfig, utils.CacheControlConfig, error) { + var cccs []utils.CacheControlConfig + types := []string{"current_metadata", "metadata_by_checksum"} + + for _, optionName := range types { + m := configuration.GetString(fmt.Sprintf("caching.max_age.%s", optionName)) if m == "" { continue } seconds, err := strconv.Atoi(m) if err != nil || seconds < 0 || seconds > maxMaxAge { - return nil, fmt.Errorf( + return nil, nil, fmt.Errorf( "must specify a cache-control max-age between 0 and %v", maxMaxAge) } - setMaxAge(seconds) + cccs = append(cccs, utils.NewCacheControlConfig(seconds, optionName == "current_metadata")) } - return cacheConfig, nil + return cccs[0], cccs[1], nil } diff --git a/cmd/notary-server/main.go b/cmd/notary-server/main.go index 2c6aa23112..206200569f 100644 --- a/cmd/notary-server/main.go +++ b/cmd/notary-server/main.go @@ -95,11 +95,10 @@ func main() { } ctx = context.WithValue(ctx, "metaStore", store) - cacheConfig, err := getCacheConfig(mainViper) + currentCache, consistentCache, err := getCacheConfig(mainViper) if err != nil { logrus.Fatal(err.Error()) } - ctx = context.WithValue(ctx, "cacheConfig", cacheConfig) httpAddr, tlsConfig, err := getAddrAndTLSConfig(mainViper) if err != nil { @@ -109,11 +108,15 @@ func main() { logrus.Info("Starting Server") err = server.Run( ctx, - httpAddr, - tlsConfig, - trust, - mainViper.GetString("auth.type"), - mainViper.Get("auth.options"), + server.Config{ + Addr: httpAddr, + TLSConfig: tlsConfig, + Trust: trust, + AuthMethod: mainViper.GetString("auth.type"), + AuthOpts: mainViper.Get("auth.options"), + CurrentCacheControlConfig: currentCache, + ConsistentCacheControlConfig: consistentCache, + }, ) logrus.Error(err.Error()) diff --git a/cmd/notary-server/main_test.go b/cmd/notary-server/main_test.go index be23f35ab7..7192b19721 100644 --- a/cmd/notary-server/main_test.go +++ b/cmd/notary-server/main_test.go @@ -5,7 +5,6 @@ import ( "crypto/tls" "fmt" "io/ioutil" - "net/http" "os" "reflect" "strings" @@ -338,16 +337,13 @@ func TestGetCacheConfig(t *testing.T) { `{"caching": {"max_age": {"current_metadata": "hello", "metadata_by_checksum": 300}}}`, } - cacheConfig, err := getCacheConfig(configure(valid)) + current, consistent, err := getCacheConfig(configure(valid)) assert.NoError(t, err) - h := http.Header{} - cacheConfig.UpdateCurrentHeaders(h, time.Now()) - assert.True(t, strings.Contains(h.Get("Cache-Control"), "max-age=0")) - cacheConfig.UpdateConsistentHeaders(h, time.Now()) - assert.True(t, strings.Contains(h.Get("Cache-Control"), "max-age=31536000")) + assert.IsType(t, utils.NoCacheControl{}, current) + assert.IsType(t, utils.PublicCacheControl{}, consistent) for _, invalid := range invalids { - _, err := getCacheConfig(configure(invalid)) + _, _, err := getCacheConfig(configure(invalid)) assert.Error(t, err) } } diff --git a/cmd/notary/integration_test.go b/cmd/notary/integration_test.go index b9ed1ffc7f..10c4bdff90 100644 --- a/cmd/notary/integration_test.go +++ b/cmd/notary/integration_test.go @@ -71,7 +71,7 @@ func setupServerHandler(metaStore storage.MetaStore) http.Handler { cryptoService := cryptoservice.NewCryptoService( "", trustmanager.NewKeyMemoryStore(passphrase.ConstantRetriever("pass"))) - return server.RootHandler(nil, ctx, cryptoService) + return server.RootHandler(nil, ctx, cryptoService, nil, nil) } // makes a testing notary-server diff --git a/cmd/notary/keys.go b/cmd/notary/keys.go index 0e991fb297..e461a8ba9d 100644 --- a/cmd/notary/keys.go +++ b/cmd/notary/keys.go @@ -5,6 +5,7 @@ import ( "bufio" "fmt" "io" + "io/ioutil" "net/http" "os" "path/filepath" @@ -20,7 +21,6 @@ import ( "github.com/docker/notary/tuf/data" "github.com/spf13/cobra" "github.com/spf13/viper" - "io/ioutil" ) var cmdKeyTemplate = usageTemplate{ diff --git a/cmd/notary/main_test.go b/cmd/notary/main_test.go index 978e1d464f..85c350ec79 100644 --- a/cmd/notary/main_test.go +++ b/cmd/notary/main_test.go @@ -196,7 +196,7 @@ func TestBareCommandPrintsUsageAndNoError(t *testing.T) { cmd := NewNotaryCommand() cmd.SetOutput(b) - cmd.SetArgs([]string{"-c", filepath.Join(tempdir, "idonotexist.json"), bareCommand}) + cmd.SetArgs([]string{"-c", filepath.Join(tempdir, "idonotexist.json"), "-d", tempdir, bareCommand}) require.NoError(t, cmd.Execute(), "Expected no error from a help request") // usage is printed require.Contains(t, b.String(), "Usage:", "expected usage when running `notary %s`", bareCommand) @@ -256,7 +256,7 @@ func TestConfigFileTLSCannotBeRelativeToCWD(t *testing.T) { // set a config file, so it doesn't check ~/.notary/config.json by default, // and execute a random command so that the flags are parsed cmd := NewNotaryCommand() - cmd.SetArgs([]string{"-c", configFile, "list", "repo"}) + cmd.SetArgs([]string{"-c", configFile, "-d", tempDir, "list", "repo"}) cmd.SetOutput(new(bytes.Buffer)) // eat the output err = cmd.Execute() assert.Error(t, err, "expected a failure due to TLS") @@ -310,7 +310,7 @@ func TestConfigFileTLSCanBeRelativeToConfigOrAbsolute(t *testing.T) { // set a config file, so it doesn't check ~/.notary/config.json by default, // and execute a random command so that the flags are parsed cmd := NewNotaryCommand() - cmd.SetArgs([]string{"-c", configFile.Name(), "list", "repo"}) + cmd.SetArgs([]string{"-c", configFile.Name(), "-d", tempDir, "list", "repo"}) cmd.SetOutput(new(bytes.Buffer)) // eat the output err = cmd.Execute() assert.Error(t, err, "there was no repository, so list should have failed") @@ -357,7 +357,7 @@ func TestConfigFileOverridenByCmdLineFlags(t *testing.T) { cmd := NewNotaryCommand() cmd.SetArgs([]string{ - "-c", configFile, "list", "repo", + "-c", configFile, "-d", tempDir, "list", "repo", "--tlscacert", "../../fixtures/root-ca.crt", "--tlscert", filepath.Clean(filepath.Join(cwd, "../../fixtures/notary-server.crt")), "--tlskey", "../../fixtures/notary-server.key"}) diff --git a/server/handlers/default.go b/server/handlers/default.go index 50526d49c0..769f7fe6a2 100644 --- a/server/handlers/default.go +++ b/server/handlers/default.go @@ -2,14 +2,10 @@ package handlers import ( "bytes" - "crypto/sha256" - "encoding/hex" "encoding/json" - "fmt" "io" "net/http" "strings" - "time" "github.com/Sirupsen/logrus" ctxu "github.com/docker/distribution/context" @@ -23,74 +19,9 @@ import ( "github.com/docker/notary/tuf/data" "github.com/docker/notary/tuf/signed" "github.com/docker/notary/tuf/validation" + "github.com/docker/notary/utils" ) -// NewCacheControlConfig creates a new configuration for Cache-Control headers, -// which by default, sets cache max-age values for consistent -// (content-addressable, by checksum) downloads 30 days and non-consistent -// (current/latest version) downloads to 5 minutes. -// If a max-age of <=0 is supplied, then caching will be disabled for that type -// of download (this may be desirable for the current downloads, for example). -func NewCacheControlConfig() *CacheControlConfig { - return &CacheControlConfig{ - headerVals: map[string]int{ - "consistent": 30 * 24 * 60 * 60, // 30 days - "current": 5 * 60, // 5 minutes - }, - } -} - -// CacheControlConfig is the configuration for the max cache age for -// cache control headers. -type CacheControlConfig struct { - headerVals map[string]int -} - -// SetConsistentCacheMaxAge sets the Cache-Control header value for consistent -// downloads -func (c *CacheControlConfig) SetConsistentCacheMaxAge(value int) { - c.headerVals["consistent"] = value -} - -// SetCurrentCacheMaxAge sets the Cache-Control header value for current -// (non-consistent) downloads -func (c *CacheControlConfig) SetCurrentCacheMaxAge(value int) { - c.headerVals["current"] = value -} - -// UpdateConsistentHeaders updates the given Headers object with the Cache-Control -// headers for consistent downloads -func (c *CacheControlConfig) UpdateConsistentHeaders(headers http.Header, lastModified time.Time) { - c.updateHeaders(headers, lastModified, true) -} - -// UpdateCurrentHeaders updates the given Headers object with the Cache-Control -// headers for current (non-consistent) downloads -func (c *CacheControlConfig) UpdateCurrentHeaders(headers http.Header, lastModified time.Time) { - c.updateHeaders(headers, lastModified, false) -} - -func (c *CacheControlConfig) updateHeaders(headers http.Header, lastModified time.Time, consistent bool) { - var seconds int - var cacheHeader string - - if consistent { - seconds = c.headerVals["consistent"] - cacheHeader = fmt.Sprintf("public, max-age=%v, s-maxage=%v, must-revalidate", seconds, seconds) - } else { - seconds = c.headerVals["current"] - cacheHeader = fmt.Sprintf("public, max-age=%v, s-maxage=%v", seconds, seconds) - } - - if seconds > 0 { - headers.Set("Cache-Control", cacheHeader) - headers.Set("Last-Modified", lastModified.Format(time.RFC1123)) - } else { - headers.Set("Cache-Control", "max-age=0, no-cache, no-store") - headers.Set("Pragma", "no-cache") - } -} - // MainHandler is the default handler for the server func MainHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) error { // For now it only supports `GET` @@ -188,43 +119,26 @@ func getHandler(ctx context.Context, w http.ResponseWriter, r *http.Request, var checksum := vars["checksum"] tufRole := vars["tufRole"] s := ctx.Value("metaStore") - c := ctx.Value("cacheConfig") store, ok := s.(storage.MetaStore) if !ok { return errors.ErrNoStorage.WithDetail(nil) } - // If cache control headers were not provided, just use the default values - cacheConfig, ok := c.(*CacheControlConfig) - if !ok { - cacheConfig = NewCacheControlConfig() - } - lastModified, output, err := getRole(ctx, store, gun, tufRole, checksum) if err != nil { return err } - if lastModified == nil { - // This shouldn't ever happen, but if it does, it just messes up the cache headers, so - // proceed anyway + if lastModified != nil { + // This shouldn't always be true, but in case it is nil, and the last modified headers + // are not set, the cache control handler should set the last modified date to the beginning + // of time. + utils.SetLastModifiedHeader(w.Header(), *lastModified) + } else { logrus.Warnf("Got bytes out for %s's %s (checksum: %s), but missing lastModified date", gun, tufRole, checksum) - lastModified = &time.Time{} // set the last modification date to the beginning of time } - switch checksum { - case "": - cacheConfig.UpdateCurrentHeaders(w.Header(), *lastModified) - - shasum := sha256.Sum256(output) - checksum = hex.EncodeToString(shasum[:]) - - default: - cacheConfig.UpdateConsistentHeaders(w.Header(), *lastModified) - } - - w.Header().Set("ETag", checksum) w.Write(output) return nil } diff --git a/server/handlers/default_test.go b/server/handlers/default_test.go index e470195efd..1f729028ba 100644 --- a/server/handlers/default_test.go +++ b/server/handlers/default_test.go @@ -2,8 +2,6 @@ package handlers import ( "bytes" - "crypto/sha256" - "encoding/hex" "encoding/json" "fmt" "io/ioutil" @@ -174,51 +172,7 @@ func TestGetKeyHandlerCreatesOnce(t *testing.T) { } } -type expectedCacheSetting int - -const ( - checksumCaching expectedCacheSetting = iota - currentCaching - noCaching -) - -// Verifies that the body is as expected, the ETag is as expected, and cache control headers -// are as expected -func verifyGetResponse(t *testing.T, rw *httptest.ResponseRecorder, expectedBytes []byte, - checksumHex string, cacheType expectedCacheSetting, cacheConfig *CacheControlConfig) { - - body, err := ioutil.ReadAll(rw.Body) - assert.NoError(t, err) - assert.True(t, bytes.Equal(expectedBytes, body)) - - assert.Equal(t, rw.HeaderMap.Get("ETag"), checksumHex) - - cacheControl := rw.HeaderMap.Get("Cache-Control") - switch cacheType { - case checksumCaching: - maxAge := cacheConfig.headerVals["consistent"] - assert.Equal(t, fmt.Sprintf("public, max-age=%v, s-maxage=%v, must-revalidate", maxAge, maxAge), cacheControl) - case currentCaching: - maxAge := cacheConfig.headerVals["current"] - assert.Equal(t, fmt.Sprintf("public, max-age=%v, s-maxage=%v", maxAge, maxAge), cacheControl) - default: - assert.Equal(t, "max-age=0, no-cache, no-store", cacheControl) - } - - switch cacheType { - case checksumCaching, currentCaching: - lastModified, err := time.Parse(time.RFC1123, rw.HeaderMap.Get("Last-Modified")) - assert.NoError(t, err) - assert.True(t, lastModified.After(time.Now().Add(-5*time.Minute))) - - assert.Equal(t, "", rw.HeaderMap.Get("Pragma")) - default: - assert.Equal(t, "", rw.HeaderMap.Get("Last-Modified")) - assert.Equal(t, "no-cache", rw.HeaderMap.Get("Pragma")) - } -} - -func TestGetHandlerRootAndNoCacheConfigProvided(t *testing.T) { +func TestGetHandlerRoot(t *testing.T) { metaStore := storage.NewMemStorage() repo, _, err := testutils.EmptyRepo("gun") assert.NoError(t, err) @@ -240,22 +194,13 @@ func TestGetHandlerRootAndNoCacheConfigProvided(t *testing.T) { "tufRole": "root", } - checksumBytes := sha256.Sum256(rootJSON) - checksumHex := hex.EncodeToString(checksumBytes[:]) - - cacheConfig := NewCacheControlConfig() - rw := httptest.NewRecorder() - assert.NoError(t, getHandler(ctx, rw, req, vars)) - verifyGetResponse(t, rw, rootJSON, checksumHex, currentCaching, cacheConfig) - vars["checksum"] = checksumHex - rw = httptest.NewRecorder() - assert.NoError(t, getHandler(ctx, rw, req, vars)) - verifyGetResponse(t, rw, rootJSON, checksumHex, checksumCaching, cacheConfig) + err = getHandler(ctx, rw, req, vars) + assert.NoError(t, err) } -func TestGetHandlerTimestampWithCacheValues(t *testing.T) { +func TestGetHandlerTimestamp(t *testing.T) { metaStore := storage.NewMemStorage() repo, crypto, err := testutils.EmptyRepo("gun") assert.NoError(t, err) @@ -283,25 +228,13 @@ func TestGetHandlerTimestampWithCacheValues(t *testing.T) { "tufRole": "timestamp", } - cacheConfig := NewCacheControlConfig() - cacheConfig.SetConsistentCacheMaxAge(365 * 24 * 60 * 60) - cacheConfig.SetCurrentCacheMaxAge(1) - ctx = context.WithValue(ctx, "cacheConfig", cacheConfig) - - checksumBytes := sha256.Sum256(tsJSON) - checksumHex := hex.EncodeToString(checksumBytes[:]) - rw := httptest.NewRecorder() - assert.NoError(t, getHandler(ctx, rw, req, vars)) - verifyGetResponse(t, rw, tsJSON, checksumHex, currentCaching, cacheConfig) - vars["checksum"] = checksumHex - rw = httptest.NewRecorder() - assert.NoError(t, getHandler(ctx, rw, req, vars)) - verifyGetResponse(t, rw, tsJSON, checksumHex, checksumCaching, cacheConfig) + err = getHandler(ctx, rw, req, vars) + assert.NoError(t, err) } -func TestGetHandlerSnapshotWithNoCaching(t *testing.T) { +func TestGetHandlerSnapshot(t *testing.T) { metaStore := storage.NewMemStorage() repo, crypto, err := testutils.EmptyRepo("gun") assert.NoError(t, err) @@ -323,22 +256,10 @@ func TestGetHandlerSnapshotWithNoCaching(t *testing.T) { "tufRole": "snapshot", } - cacheConfig := NewCacheControlConfig() - cacheConfig.SetConsistentCacheMaxAge(0) - cacheConfig.SetCurrentCacheMaxAge(-1) - ctx = context.WithValue(ctx, "cacheConfig", cacheConfig) - - checksumBytes := sha256.Sum256(snJSON) - checksumHex := hex.EncodeToString(checksumBytes[:]) - rw := httptest.NewRecorder() - assert.NoError(t, getHandler(ctx, rw, req, vars)) - verifyGetResponse(t, rw, snJSON, checksumHex, noCaching, cacheConfig) - vars["checksum"] = checksumHex - rw = httptest.NewRecorder() - assert.NoError(t, getHandler(ctx, rw, req, vars)) - verifyGetResponse(t, rw, snJSON, checksumHex, noCaching, cacheConfig) + err = getHandler(ctx, rw, req, vars) + assert.NoError(t, err) } func TestGetHandler404(t *testing.T) { diff --git a/server/integration_test.go b/server/integration_test.go index e150ca020e..1b3d851aa2 100644 --- a/server/integration_test.go +++ b/server/integration_test.go @@ -24,7 +24,7 @@ func TestValidationErrorFormat(t *testing.T) { context.Background(), "metaStore", storage.NewMemStorage()) ctx = context.WithValue(ctx, "keyAlgorithm", data.ED25519Key) - handler := RootHandler(nil, ctx, signed.NewEd25519()) + handler := RootHandler(nil, ctx, signed.NewEd25519(), nil, nil) server := httptest.NewServer(handler) defer server.Close() diff --git a/server/server.go b/server/server.go index 2f7ccb580c..05ac52195e 100644 --- a/server/server.go +++ b/server/server.go @@ -31,12 +31,22 @@ func prometheusOpts(operation string) prometheus.SummaryOpts { } } +// Config tells Run how to configure a server +type Config struct { + Addr string + TLSConfig *tls.Config + Trust signed.CryptoService + AuthMethod string + AuthOpts interface{} + ConsistentCacheControlConfig utils.CacheControlConfig + CurrentCacheControlConfig utils.CacheControlConfig +} + // Run sets up and starts a TLS server that can be cancelled using the // given configuration. The context it is passed is the context it should // use directly for the TLS server, and generate children off for requests -func Run(ctx context.Context, addr string, tlsConfig *tls.Config, trust signed.CryptoService, authMethod string, authOpts interface{}) error { - - tcpAddr, err := net.ResolveTCPAddr("tcp", addr) +func Run(ctx context.Context, conf Config) error { + tcpAddr, err := net.ResolveTCPAddr("tcp", conf.Addr) if err != nil { return err } @@ -46,29 +56,29 @@ func Run(ctx context.Context, addr string, tlsConfig *tls.Config, trust signed.C return err } - if tlsConfig != nil { + if conf.TLSConfig != nil { logrus.Info("Enabling TLS") - lsnr = tls.NewListener(lsnr, tlsConfig) + lsnr = tls.NewListener(lsnr, conf.TLSConfig) } var ac auth.AccessController - if authMethod == "token" { - authOptions, ok := authOpts.(map[string]interface{}) + if conf.AuthMethod == "token" { + authOptions, ok := conf.AuthOpts.(map[string]interface{}) if !ok { return fmt.Errorf("auth.options must be a map[string]interface{}") } - ac, err = auth.GetAccessController(authMethod, authOptions) + ac, err = auth.GetAccessController(conf.AuthMethod, authOptions) if err != nil { return err } } svr := http.Server{ - Addr: addr, - Handler: RootHandler(ac, ctx, trust), + Addr: conf.Addr, + Handler: RootHandler(ac, ctx, conf.Trust, conf.ConsistentCacheControlConfig, conf.CurrentCacheControlConfig), } - logrus.Info("Starting on ", addr) + logrus.Info("Starting on ", conf.Addr) err = svr.Serve(lsnr) @@ -77,7 +87,9 @@ func Run(ctx context.Context, addr string, tlsConfig *tls.Config, trust signed.C // RootHandler returns the handler that routes all the paths from / for the // server. -func RootHandler(ac auth.AccessController, ctx context.Context, trust signed.CryptoService) http.Handler { +func RootHandler(ac auth.AccessController, ctx context.Context, trust signed.CryptoService, + consistent, current utils.CacheControlConfig) http.Handler { + hand := utils.RootHandlerFactory(ac, ctx, trust) r := mux.NewRouter() @@ -89,11 +101,11 @@ func RootHandler(ac auth.AccessController, ctx context.Context, trust signed.Cry r.Methods("GET").Path("/v2/{imageName:.*}/_trust/tuf/{tufRole:root|targets(?:/[^/\\s]+)*|snapshot|timestamp}.{checksum:[a-fA-F0-9]{64}|[a-fA-F0-9]{96}|[a-fA-F0-9]{128}}.json").Handler( prometheus.InstrumentHandlerWithOpts( prometheusOpts("GetRoleByHash"), - hand(handlers.GetHandler, "pull"))) + utils.WrapWithCacheHandler(consistent, hand(handlers.GetHandler, "pull")))) r.Methods("GET").Path("/v2/{imageName:.*}/_trust/tuf/{tufRole:root|targets(?:/[^/\\s]+)*|snapshot|timestamp}.json").Handler( prometheus.InstrumentHandlerWithOpts( prometheusOpts("GetRole"), - hand(handlers.GetHandler, "pull"))) + utils.WrapWithCacheHandler(current, hand(handlers.GetHandler, "pull")))) r.Methods("GET").Path( "/v2/{imageName:.*}/_trust/tuf/{tufRole:snapshot|timestamp}.key").Handler( prometheus.InstrumentHandlerWithOpts( diff --git a/server/server_test.go b/server/server_test.go index 81d9724f62..a1cf45057c 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1,6 +1,7 @@ package server import ( + "bytes" "crypto/sha256" "encoding/hex" "encoding/json" @@ -16,6 +17,8 @@ import ( "github.com/docker/notary/server/storage" "github.com/docker/notary/tuf/data" "github.com/docker/notary/tuf/signed" + "github.com/docker/notary/tuf/testutils" + "github.com/docker/notary/utils" "github.com/stretchr/testify/assert" "golang.org/x/net/context" ) @@ -23,11 +26,10 @@ import ( func TestRunBadAddr(t *testing.T) { err := Run( context.Background(), - "testAddr", - nil, - signed.NewEd25519(), - "", - nil, + Config{ + Addr: "testAddr", + Trust: signed.NewEd25519(), + }, ) assert.Error(t, err, "Passed bad addr, Run should have failed") } @@ -37,11 +39,10 @@ func TestRunReservedPort(t *testing.T) { err := Run( ctx, - "localhost:80", - nil, - signed.NewEd25519(), - "", - nil, + Config{ + Addr: "localhost:80", + Trust: signed.NewEd25519(), + }, ) assert.Error(t, err) @@ -55,7 +56,8 @@ func TestRunReservedPort(t *testing.T) { } func TestMetricsEndpoint(t *testing.T) { - handler := RootHandler(nil, context.Background(), signed.NewEd25519()) + handler := RootHandler(nil, context.Background(), signed.NewEd25519(), + nil, nil) ts := httptest.NewServer(handler) defer ts.Close() @@ -70,7 +72,7 @@ func TestGetKeysEndpoint(t *testing.T) { context.Background(), "metaStore", storage.NewMemStorage()) ctx = context.WithValue(ctx, "keyAlgorithm", data.ED25519Key) - handler := RootHandler(nil, ctx, signed.NewEd25519()) + handler := RootHandler(nil, ctx, signed.NewEd25519(), nil, nil) ts := httptest.NewServer(handler) defer ts.Close() @@ -90,7 +92,7 @@ func TestGetKeysEndpoint(t *testing.T) { } } -// This just checks the URL routing is working correctly. +// This just checks the URL routing is working correctly and cache headers are set correctly. // More detailed tests for this path including negative // tests are located in /server/handlers/ func TestGetRoleByHash(t *testing.T) { @@ -99,48 +101,46 @@ func TestGetRoleByHash(t *testing.T) { ts := data.SignedTimestamp{ Signatures: make([]data.Signature, 0), Signed: data.Timestamp{ - Type: data.TUFTypes["timestamp"], + Type: data.TUFTypes[data.CanonicalTimestampRole], Version: 1, - Expires: data.DefaultExpires("timestamp"), + Expires: data.DefaultExpires(data.CanonicalTimestampRole), }, } j, err := json.Marshal(&ts) assert.NoError(t, err) - update := storage.MetaUpdate{ + store.UpdateCurrent("gun", storage.MetaUpdate{ Role: data.CanonicalTimestampRole, Version: 1, Data: j, - } + }) checksumBytes := sha256.Sum256(j) checksum := hex.EncodeToString(checksumBytes[:]) - store.UpdateCurrent("gun", update) - // create and add a newer timestamp. We're going to try and request // the older version we created above. ts = data.SignedTimestamp{ Signatures: make([]data.Signature, 0), Signed: data.Timestamp{ - Type: data.TUFTypes["timestamp"], + Type: data.TUFTypes[data.CanonicalTimestampRole], Version: 2, - Expires: data.DefaultExpires("timestamp"), + Expires: data.DefaultExpires(data.CanonicalTimestampRole), }, } - newJ, err := json.Marshal(&ts) + newTS, err := json.Marshal(&ts) assert.NoError(t, err) - update = storage.MetaUpdate{ + store.UpdateCurrent("gun", storage.MetaUpdate{ Role: data.CanonicalTimestampRole, - Version: 2, - Data: newJ, - } - store.UpdateCurrent("gun", update) + Version: 1, + Data: newTS, + }) ctx := context.WithValue( context.Background(), "metaStore", store) ctx = context.WithValue(ctx, "keyAlgorithm", data.ED25519Key) - handler := RootHandler(nil, ctx, signed.NewEd25519()) + ccc := utils.NewCacheControlConfig(10, false) + handler := RootHandler(nil, ctx, signed.NewEd25519(), ccc, ccc) serv := httptest.NewServer(handler) defer serv.Close() @@ -152,10 +152,59 @@ func TestGetRoleByHash(t *testing.T) { )) assert.NoError(t, err) assert.Equal(t, http.StatusOK, res.StatusCode) - - body, err := ioutil.ReadAll(res.Body) - assert.NoError(t, err) - defer res.Body.Close() // if content is equal, checksums are guaranteed to be equal - assert.EqualValues(t, j, body) + verifyGetResponse(t, res, j) +} + +// This just checks the URL routing is working correctly and cache headers are set correctly. +// More detailed tests for this path including negative +// tests are located in /server/handlers/ +func TestGetCurrentRole(t *testing.T) { + store := storage.NewMemStorage() + metadata, _, err := testutils.NewRepoMetadata("gun") + assert.NoError(t, err) + + // need both the snapshot and the timestamp, because when getting the current + // timestamp the server checks to see if it's out of date (there's a new snapshot) + // and if so, generates a new one + store.UpdateCurrent("gun", storage.MetaUpdate{ + Role: data.CanonicalSnapshotRole, + Version: 1, + Data: metadata[data.CanonicalSnapshotRole], + }) + store.UpdateCurrent("gun", storage.MetaUpdate{ + Role: data.CanonicalTimestampRole, + Version: 1, + Data: metadata[data.CanonicalTimestampRole], + }) + + ctx := context.WithValue( + context.Background(), "metaStore", store) + + ctx = context.WithValue(ctx, "keyAlgorithm", data.ED25519Key) + + ccc := utils.NewCacheControlConfig(10, false) + handler := RootHandler(nil, ctx, signed.NewEd25519(), ccc, ccc) + serv := httptest.NewServer(handler) + defer serv.Close() + + res, err := http.Get(fmt.Sprintf( + "%s/v2/gun/_trust/tuf/%s.json", + serv.URL, + data.CanonicalTimestampRole, + )) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, res.StatusCode) + verifyGetResponse(t, res, metadata[data.CanonicalTimestampRole]) +} + +// Verifies that the body is as expected and that there are cache control headers +func verifyGetResponse(t *testing.T, r *http.Response, expectedBytes []byte) { + body, err := ioutil.ReadAll(r.Body) + assert.NoError(t, err) + assert.True(t, bytes.Equal(expectedBytes, body)) + + assert.NotEqual(t, "", r.Header.Get("Cache-Control")) + assert.NotEqual(t, "", r.Header.Get("Last-Modified")) + assert.Equal(t, "", r.Header.Get("Pragma")) } diff --git a/utils/http.go b/utils/http.go index 5bf191fd46..d9a1357972 100644 --- a/utils/http.go +++ b/utils/http.go @@ -1,7 +1,9 @@ package utils import ( + "fmt" "net/http" + "time" "github.com/Sirupsen/logrus" ctxu "github.com/docker/distribution/context" @@ -94,3 +96,102 @@ func buildAccessRecords(repo string, actions ...string) []auth.Access { } return requiredAccess } + +// CacheControlConfig is an interface for something that knows how to set cache +// control headers +type CacheControlConfig interface { + // SetHeaders will actually set the cache control headers on a Headers object + SetHeaders(headers http.Header) +} + +// NewCacheControlConfig returns CacheControlConfig interface for either setting +// cache control or disabling cache control entirely +func NewCacheControlConfig(maxAgeInSeconds int, mustRevalidate bool) CacheControlConfig { + if maxAgeInSeconds > 0 { + return PublicCacheControl{MustReValidate: mustRevalidate, MaxAgeInSeconds: maxAgeInSeconds} + } + return NoCacheControl{} +} + +// PublicCacheControl is a set of options that we will set to enable cache control +type PublicCacheControl struct { + MustReValidate bool + MaxAgeInSeconds int +} + +// SetHeaders sets the public headers with an optional must-revalidate header +func (p PublicCacheControl) SetHeaders(headers http.Header) { + cacheControlValue := fmt.Sprintf("public, max-age=%v, s-maxage=%v", + p.MaxAgeInSeconds, p.MaxAgeInSeconds) + + if p.MustReValidate { + cacheControlValue = fmt.Sprintf("%s, must-revalidate", cacheControlValue) + } + headers.Set("Cache-Control", cacheControlValue) + // delete the Pragma directive, because the only valid value in HTTP is + // "no-cache" + headers.Del("Pragma") + if headers.Get("Last-Modified") == "" { + SetLastModifiedHeader(headers, time.Time{}) + } +} + +// NoCacheControl is an object which represents a directive to cache nothing +type NoCacheControl struct{} + +// SetHeaders sets the public headers cache-control headers and pragma to no-cache +func (n NoCacheControl) SetHeaders(headers http.Header) { + headers.Set("Cache-Control", "max-age=0, no-cache, no-store") + headers.Set("Pragma", "no-cache") +} + +// cacheControlResponseWriter wraps an existing response writer, and if Write is +// called, will try to set the cache control headers if it can +type cacheControlResponseWriter struct { + http.ResponseWriter + config CacheControlConfig + statusCode int +} + +// WriteHeader stores the header before writing it, so we can tell if it's been set +// to a non-200 status code +func (c *cacheControlResponseWriter) WriteHeader(statusCode int) { + c.statusCode = statusCode + c.ResponseWriter.WriteHeader(statusCode) +} + +// Write will set the cache headers if they haven't already been set and if the status +// code has either not been set or set to 200 +func (c *cacheControlResponseWriter) Write(data []byte) (int, error) { + if c.statusCode == http.StatusOK || c.statusCode == 0 { + headers := c.ResponseWriter.Header() + if headers.Get("Cache-Control") == "" { + c.config.SetHeaders(headers) + } + } + return c.ResponseWriter.Write(data) +} + +type cacheControlHandler struct { + http.Handler + config CacheControlConfig +} + +func (c cacheControlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + c.Handler.ServeHTTP(&cacheControlResponseWriter{ResponseWriter: w, config: c.config}, r) +} + +// WrapWithCacheHandler wraps another handler in one that can add cache control headers +// given a 200 response +func WrapWithCacheHandler(ccc CacheControlConfig, handler http.Handler) http.Handler { + if ccc != nil { + return cacheControlHandler{Handler: handler, config: ccc} + } + return handler +} + +// SetLastModifiedHeader takes a time and uses it to set the LastModified header using +// the right date format +func SetLastModifiedHeader(headers http.Header, lmt time.Time) { + headers.Set("Last-Modified", lmt.Format(time.RFC1123)) +} diff --git a/utils/http_test.go b/utils/http_test.go index 62d991cbde..8a20e4c85b 100644 --- a/utils/http_test.go +++ b/utils/http_test.go @@ -1,14 +1,18 @@ package utils import ( + "bytes" "io/ioutil" "net/http" "net/http/httptest" + "net/url" "strings" "testing" + "time" "github.com/docker/distribution/registry/api/errcode" "github.com/docker/notary/tuf/signed" + "github.com/stretchr/testify/assert" "golang.org/x/net/context" ) @@ -39,22 +43,6 @@ func TestRootHandlerFactory(t *testing.T) { } } -//func TestRootHandlerUnauthorized(t *testing.T) { -// hand := RootHandlerFactory(nil, context.Background(), &signed.Ed25519{}) -// handler := hand(MockContextHandler) -// -// ts := httptest.NewServer(handler) -// defer ts.Close() -// -// res, err := http.Get(ts.URL) -// if err != nil { -// t.Fatal(err) -// } -// if res.StatusCode != http.StatusUnauthorized { -// t.Fatalf("Expected 401, received %d", res.StatusCode) -// } -//} - func TestRootHandlerError(t *testing.T) { hand := RootHandlerFactory(nil, context.Background(), &signed.Ed25519{}) handler := hand(MockBetterErrorHandler) @@ -75,3 +63,192 @@ func TestRootHandlerError(t *testing.T) { t.Fatalf("Error Body Incorrect: `%s`", content) } } + +// If no CacheControlConfig is passed, wrapping the handler just returns the handler +func TestWrapWithCacheHeaderNilCacheControlConfig(t *testing.T) { + mux := http.NewServeMux() + wrapped := WrapWithCacheHandler(nil, mux) + assert.Equal(t, mux, wrapped) +} + +// If the wrapped handler returns a non-200, no matter which CacheControlConfig is +// used, the Cache-Control header not set. +func TestWrapWithCacheHeaderNon200Response(t *testing.T) { + mux := http.NewServeMux() + configs := []CacheControlConfig{NewCacheControlConfig(10, true), NewCacheControlConfig(0, true)} + + for _, conf := range configs { + req := &http.Request{URL: &url.URL{Path: "/"}, Body: ioutil.NopCloser(bytes.NewBuffer(nil))} + + wrapped := WrapWithCacheHandler(conf, mux) + assert.NotEqual(t, mux, wrapped) + rw := httptest.NewRecorder() + wrapped.ServeHTTP(rw, req) + + assert.Equal(t, "", rw.HeaderMap.Get("Cache-Control")) + assert.Equal(t, "", rw.HeaderMap.Get("Last-Modified")) + assert.Equal(t, "", rw.HeaderMap.Get("Pragma")) + } +} + +// If the wrapped handler writes no cache headers whatsoever, and a PublicCacheControl +// is used, the Cache-Control header is set with the given maxAge and re-validate value. +// The Last-Modified header is also set to the beginning of (computer) time. If a +// Pragma header is written is deleted +func TestWrapWithCacheHeaderPublicCacheControlNoCacheHeaders(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hello!")) + }) + mux.HandleFunc("/a", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Pragma", "no-cache") + w.Write([]byte("hello!")) + }) + + for _, path := range []string{"/", "/a"} { + req := &http.Request{URL: &url.URL{Path: path}, Body: ioutil.NopCloser(bytes.NewBuffer(nil))} + + // must-revalidate is set if revalidate is set to true, and not if revalidate is set to false + for _, revalidate := range []bool{true, false} { + wrapped := WrapWithCacheHandler(NewCacheControlConfig(10, revalidate), mux) + assert.NotEqual(t, mux, wrapped) + rw := httptest.NewRecorder() + wrapped.ServeHTTP(rw, req) + + cacheControl := "public, max-age=10, s-maxage=10" + if revalidate { + cacheControl = cacheControl + ", must-revalidate" + } + assert.Equal(t, cacheControl, rw.HeaderMap.Get("Cache-Control")) + + lastModified, err := time.Parse(time.RFC1123, rw.HeaderMap.Get("Last-Modified")) + assert.NoError(t, err) + assert.True(t, lastModified.Equal(time.Time{})) + assert.Equal(t, "", rw.HeaderMap.Get("Pragma")) + } + } +} + +// If the wrapped handler writes a last modified header, and a PublicCacheControl +// is used, the Cache-Control header is set with the given maxAge and re-validate value. +// The Last-Modified header is not replaced. The Pragma header is deleted though. +func TestWrapWithCacheHeaderPublicCacheControlLastModifiedHeader(t *testing.T) { + now := time.Now() + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + SetLastModifiedHeader(w.Header(), now) + w.Header().Set("Pragma", "no-cache") + w.Write([]byte("hello!")) + }) + req := &http.Request{URL: &url.URL{Path: "/"}, Body: ioutil.NopCloser(bytes.NewBuffer(nil))} + + wrapped := WrapWithCacheHandler(NewCacheControlConfig(10, true), mux) + assert.NotEqual(t, mux, wrapped) + rw := httptest.NewRecorder() + wrapped.ServeHTTP(rw, req) + + assert.Equal(t, "public, max-age=10, s-maxage=10, must-revalidate", rw.HeaderMap.Get("Cache-Control")) + lastModified, err := time.Parse(time.RFC1123, rw.HeaderMap.Get("Last-Modified")) + assert.NoError(t, err) + // RFC1123 does not include nanoseconds + nowToNearestSecond := now.Add(time.Duration(-1 * now.Nanosecond())) + assert.True(t, lastModified.Equal(nowToNearestSecond)) + assert.Equal(t, "", rw.HeaderMap.Get("Pragma")) +} + +// If the wrapped handler writes a Cache-Control header, even if the last modified +// header is not written, then the Cache-Control header is not written, nor is a +// Last-Modified header written. The Pragma header is not deleted. +func TestWrapWithCacheHeaderPublicCacheControlCacheControlHeader(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "some invalid cache control value") + w.Header().Set("Pragma", "invalid value") + w.Write([]byte("hello!")) + }) + req := &http.Request{URL: &url.URL{Path: "/"}, Body: ioutil.NopCloser(bytes.NewBuffer(nil))} + + wrapped := WrapWithCacheHandler(NewCacheControlConfig(10, true), mux) + assert.NotEqual(t, mux, wrapped) + rw := httptest.NewRecorder() + wrapped.ServeHTTP(rw, req) + + assert.Equal(t, "some invalid cache control value", rw.HeaderMap.Get("Cache-Control")) + assert.Equal(t, "", rw.HeaderMap.Get("Last-Modified")) + assert.Equal(t, "invalid value", rw.HeaderMap.Get("Pragma")) +} + +// If the wrapped handler writes no cache headers whatsoever, and NoCacheControl +// is used, the Cache-Control and Pragma headers are set with no-cache. +func TestWrapWithCacheHeaderNoCacheControlNoCacheHeaders(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Pragma", "invalid value") + w.Write([]byte("hello!")) + }) + req := &http.Request{URL: &url.URL{Path: "/"}, Body: ioutil.NopCloser(bytes.NewBuffer(nil))} + + wrapped := WrapWithCacheHandler(NewCacheControlConfig(0, false), mux) + assert.NotEqual(t, mux, wrapped) + rw := httptest.NewRecorder() + wrapped.ServeHTTP(rw, req) + + assert.Equal(t, "max-age=0, no-cache, no-store", rw.HeaderMap.Get("Cache-Control")) + assert.Equal(t, "", rw.HeaderMap.Get("Last-Modified")) + assert.Equal(t, "no-cache", rw.HeaderMap.Get("Pragma")) +} + +// If the wrapped handler writes a last modified header, and NoCacheControl +// is used, the Cache-Control and Pragma headers are set with no-cache without +// messing with the Last-Modified header. +func TestWrapWithCacheHeaderNoCacheControlLastModifiedHeader(t *testing.T) { + now := time.Now() + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + SetLastModifiedHeader(w.Header(), now) + w.Write([]byte("hello!")) + }) + req := &http.Request{URL: &url.URL{Path: "/"}, Body: ioutil.NopCloser(bytes.NewBuffer(nil))} + + wrapped := WrapWithCacheHandler(NewCacheControlConfig(0, true), mux) + assert.NotEqual(t, mux, wrapped) + rw := httptest.NewRecorder() + wrapped.ServeHTTP(rw, req) + + assert.Equal(t, "max-age=0, no-cache, no-store", rw.HeaderMap.Get("Cache-Control")) + assert.Equal(t, "no-cache", rw.HeaderMap.Get("Pragma")) + + lastModified, err := time.Parse(time.RFC1123, rw.HeaderMap.Get("Last-Modified")) + assert.NoError(t, err) + // RFC1123 does not include nanoseconds + nowToNearestSecond := now.Add(time.Duration(-1 * now.Nanosecond())) + assert.True(t, lastModified.Equal(nowToNearestSecond)) +} + +// If the wrapped handler writes a Cache-Control header, even if the last modified +// header is not written, then the Cache-Control header is not written, nor is a +// Pragma added. The Last-Modified header is untouched. +func TestWrapWithCacheHeaderNoCacheControlCacheControlHeader(t *testing.T) { + now := time.Now() + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "some invalid cache control value") + SetLastModifiedHeader(w.Header(), now) + w.Write([]byte("hello!")) + }) + req := &http.Request{URL: &url.URL{Path: "/"}, Body: ioutil.NopCloser(bytes.NewBuffer(nil))} + + wrapped := WrapWithCacheHandler(NewCacheControlConfig(0, true), mux) + assert.NotEqual(t, mux, wrapped) + rw := httptest.NewRecorder() + wrapped.ServeHTTP(rw, req) + + assert.Equal(t, "some invalid cache control value", rw.HeaderMap.Get("Cache-Control")) + assert.Equal(t, "", rw.HeaderMap.Get("Pragma")) + + lastModified, err := time.Parse(time.RFC1123, rw.HeaderMap.Get("Last-Modified")) + assert.NoError(t, err) + // RFC1123 does not include nanoseconds + nowToNearestSecond := now.Add(time.Duration(-1 * now.Nanosecond())) + assert.True(t, lastModified.Equal(nowToNearestSecond)) +}