diff --git a/cmd/notary-server/main.go b/cmd/notary-server/main.go index 46635356b8..ecc54678df 100644 --- a/cmd/notary-server/main.go +++ b/cmd/notary-server/main.go @@ -8,6 +8,7 @@ import ( "net/http" _ "net/http/pprof" "os" + "strconv" "time" "github.com/Sirupsen/logrus" @@ -23,6 +24,7 @@ import ( "github.com/docker/go-connections/tlsconfig" "github.com/docker/notary/server" + "github.com/docker/notary/server/handlers" "github.com/docker/notary/utils" "github.com/docker/notary/version" "github.com/spf13/viper" @@ -32,6 +34,7 @@ import ( const ( jsonLogFormat = "json" DebugAddress = "localhost:8080" + maxMaxAge = 31536000 ) var ( @@ -170,6 +173,29 @@ func getTrustService(configuration *viper.Viper, sFactory signerFactory, return notarySigner, keyAlgo, nil } +func getCacheConfig(configuration *viper.Viper) (*handlers.CacheControlConfig, error) { + cacheConfig := handlers.NewCacheControlConfig() + for _, option := range []string{"current_metadata", "metadata_by_checksum"} { + m := configuration.GetString(fmt.Sprintf("caching.max_age.%s", option)) + if m == "" { + continue + } + seconds, err := strconv.Atoi(m) + if err != nil || seconds < 0 || seconds > maxMaxAge { + return nil, fmt.Errorf( + "must specify a cache-control max-age between 0 and %v", maxMaxAge) + } + + switch option { + case "current_metadata": + cacheConfig.SetCurrentCacheMaxAge(seconds) + default: + cacheConfig.SetConsistentCacheMaxAge(seconds) + } + } + return cacheConfig, nil +} + func main() { flag.Usage = usage flag.Parse() @@ -215,6 +241,12 @@ func main() { } ctx = context.WithValue(ctx, "metaStore", store) + cacheConfig, err := getCacheConfig(mainViper) + if err != nil { + logrus.Fatal(err.Error()) + } + ctx = context.WithValue(ctx, "cacheConfig", cacheConfig) + httpAddr, tlsConfig, err := getAddrAndTLSConfig(mainViper) if err != nil { logrus.Fatal(err.Error()) diff --git a/cmd/notary-server/main_test.go b/cmd/notary-server/main_test.go index ee8450ad76..be23f35ab7 100644 --- a/cmd/notary-server/main_test.go +++ b/cmd/notary-server/main_test.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "fmt" "io/ioutil" + "net/http" "os" "reflect" "strings" @@ -328,3 +329,25 @@ func TestGetMemoryStore(t *testing.T) { _, ok := store.(*storage.MemStorage) assert.True(t, ok) } + +func TestGetCacheConfig(t *testing.T) { + valid := `{"caching": {"max_age": {"current_metadata": 0, "metadata_by_checksum": 31536000}}}` + invalids := []string{ + `{"caching": {"max_age": {"current_metadata": 0, "metadata_by_checksum": 31539000}}}`, + `{"caching": {"max_age": {"current_metadata": -1, "metadata_by_checksum": 300}}}`, + `{"caching": {"max_age": {"current_metadata": "hello", "metadata_by_checksum": 300}}}`, + } + + cacheConfig, err := getCacheConfig(configure(valid)) + assert.NoError(t, err) + h := http.Header{} + cacheConfig.UpdateCurrentHeaders(h, time.Now()) + assert.True(t, strings.Contains(h.Get("Cache-Control"), "max-age=0")) + cacheConfig.UpdateConsistentHeaders(h, time.Now()) + assert.True(t, strings.Contains(h.Get("Cache-Control"), "max-age=31536000")) + + for _, invalid := range invalids { + _, err := getCacheConfig(configure(invalid)) + assert.Error(t, err) + } +}