Merge pull request #601 from docker/cache-headers

Return cache control headers when returning metadata from server
This commit is contained in:
David Lawrence 2016-03-14 17:56:59 -07:00
commit c74fab9401
26 changed files with 774 additions and 310 deletions

View File

@ -141,7 +141,7 @@ func fullTestServer(t *testing.T) *httptest.Server {
cryptoService := cryptoservice.NewCryptoService( cryptoService := cryptoservice.NewCryptoService(
"", trustmanager.NewKeyMemoryStore(passphraseRetriever)) "", trustmanager.NewKeyMemoryStore(passphraseRetriever))
return httptest.NewServer(server.RootHandler(nil, ctx, cryptoService)) return httptest.NewServer(server.RootHandler(nil, ctx, cryptoService, nil, nil))
} }
// server that returns some particular error code all the time // server that returns some particular error code all the time

163
cmd/notary-server/config.go Normal file
View File

@ -0,0 +1,163 @@
package main
import (
"crypto/tls"
"fmt"
"strconv"
"time"
"github.com/Sirupsen/logrus"
"github.com/docker/distribution/health"
_ "github.com/docker/distribution/registry/auth/htpasswd"
_ "github.com/docker/distribution/registry/auth/token"
"github.com/docker/go-connections/tlsconfig"
"github.com/docker/notary/server/storage"
"github.com/docker/notary/signer/client"
"github.com/docker/notary/tuf/data"
"github.com/docker/notary/tuf/signed"
"github.com/docker/notary/utils"
_ "github.com/go-sql-driver/mysql"
"github.com/spf13/viper"
)
// get the address for the HTTP server, and parses the optional TLS
// configuration for the server - if no TLS configuration is specified,
// TLS is not enabled.
func getAddrAndTLSConfig(configuration *viper.Viper) (string, *tls.Config, error) {
httpAddr := configuration.GetString("server.http_addr")
if httpAddr == "" {
return "", nil, fmt.Errorf("http listen address required for server")
}
tlsConfig, err := utils.ParseServerTLS(configuration, false)
if err != nil {
return "", nil, fmt.Errorf(err.Error())
}
return httpAddr, tlsConfig, nil
}
// sets up TLS for the GRPC connection to notary-signer
func grpcTLS(configuration *viper.Viper) (*tls.Config, error) {
rootCA := utils.GetPathRelativeToConfig(configuration, "trust_service.tls_ca_file")
clientCert := utils.GetPathRelativeToConfig(configuration, "trust_service.tls_client_cert")
clientKey := utils.GetPathRelativeToConfig(configuration, "trust_service.tls_client_key")
if clientCert == "" && clientKey != "" || clientCert != "" && clientKey == "" {
return nil, fmt.Errorf("either pass both client key and cert, or neither")
}
tlsConfig, err := tlsconfig.Client(tlsconfig.Options{
CAFile: rootCA,
CertFile: clientCert,
KeyFile: clientKey,
})
if err != nil {
return nil, fmt.Errorf(
"Unable to configure TLS to the trust service: %s", err.Error())
}
return tlsConfig, nil
}
// parses the configuration and returns a backing store for the TUF files
func getStore(configuration *viper.Viper, allowedBackends []string) (
storage.MetaStore, error) {
storeConfig, err := utils.ParseStorage(configuration, allowedBackends)
if err != nil {
return nil, err
}
logrus.Infof("Using %s backend", storeConfig.Backend)
if storeConfig.Backend == utils.MemoryBackend {
return storage.NewMemStorage(), nil
}
store, err := storage.NewSQLStorage(storeConfig.Backend, storeConfig.Source)
if err != nil {
return nil, fmt.Errorf("Error starting DB driver: %s", err.Error())
}
health.RegisterPeriodicFunc(
"DB operational", store.CheckHealth, time.Second*60)
return store, nil
}
type signerFactory func(hostname, port string, tlsConfig *tls.Config) *client.NotarySigner
type healthRegister func(name string, checkFunc func() error, duration time.Duration)
// parses the configuration and determines which trust service and key algorithm
// to return
func getTrustService(configuration *viper.Viper, sFactory signerFactory,
hRegister healthRegister) (signed.CryptoService, string, error) {
switch configuration.GetString("trust_service.type") {
case "local":
logrus.Info("Using local signing service, which requires ED25519. " +
"Ignoring all other trust_service parameters, including keyAlgorithm")
return signed.NewEd25519(), data.ED25519Key, nil
case "remote":
default:
return nil, "", fmt.Errorf(
"must specify either a \"local\" or \"remote\" type for trust_service")
}
keyAlgo := configuration.GetString("trust_service.key_algorithm")
if keyAlgo != data.ED25519Key && keyAlgo != data.ECDSAKey && keyAlgo != data.RSAKey {
return nil, "", fmt.Errorf("invalid key algorithm configured: %s", keyAlgo)
}
clientTLS, err := grpcTLS(configuration)
if err != nil {
return nil, "", err
}
logrus.Info("Using remote signing service")
notarySigner := sFactory(
configuration.GetString("trust_service.hostname"),
configuration.GetString("trust_service.port"),
clientTLS,
)
minute := 1 * time.Minute
hRegister(
"Trust operational",
// If the trust service fails, the server is degraded but not
// exactly unhealthy, so always return healthy and just log an
// error.
func() error {
err := notarySigner.CheckHealth(minute)
if err != nil {
logrus.Error("Trust not fully operational: ", err.Error())
}
return nil
},
minute)
return notarySigner, keyAlgo, nil
}
// Gets the cache configuration for GET-ting current and checksummed metadata
// This is mainly the max-age (an integer in seconds, just like in the
// Cache-Control header) for consistent (content-addressable) downloads and
// current (latest version) downloads. The max-age must be between 0 and 31536000
// (one year in seconds, which is the recommended maximum time data is cached),
// else parsing will return an error. A max-age of 0 will disable caching for
// that type of download (consistent or current).
func getCacheConfig(configuration *viper.Viper) (utils.CacheControlConfig, utils.CacheControlConfig, error) {
var cccs []utils.CacheControlConfig
types := []string{"current_metadata", "metadata_by_checksum"}
for _, optionName := range types {
m := configuration.GetString(fmt.Sprintf("caching.max_age.%s", optionName))
if m == "" {
continue
}
seconds, err := strconv.Atoi(m)
if err != nil || seconds < 0 || seconds > maxMaxAge {
return nil, nil, fmt.Errorf(
"must specify a cache-control max-age between 0 and %v", maxMaxAge)
}
cccs = append(cccs, utils.NewCacheControlConfig(seconds, optionName == "current_metadata"))
}
return cccs[0], cccs[1], nil
}

View File

@ -1,27 +1,18 @@
package main package main
import ( import (
"crypto/tls"
_ "expvar" _ "expvar"
"flag" "flag"
"fmt" "fmt"
"net/http" "net/http"
_ "net/http/pprof" _ "net/http/pprof"
"os" "os"
"time"
"github.com/Sirupsen/logrus" "github.com/Sirupsen/logrus"
"github.com/docker/distribution/health" "github.com/docker/distribution/health"
_ "github.com/docker/distribution/registry/auth/htpasswd"
_ "github.com/docker/distribution/registry/auth/token"
"github.com/docker/notary/server/storage"
"github.com/docker/notary/signer/client" "github.com/docker/notary/signer/client"
"github.com/docker/notary/tuf/data"
"github.com/docker/notary/tuf/signed"
_ "github.com/go-sql-driver/mysql"
"golang.org/x/net/context" "golang.org/x/net/context"
"github.com/docker/go-connections/tlsconfig"
"github.com/docker/notary/server" "github.com/docker/notary/server"
"github.com/docker/notary/utils" "github.com/docker/notary/utils"
"github.com/docker/notary/version" "github.com/docker/notary/version"
@ -32,6 +23,10 @@ import (
const ( const (
jsonLogFormat = "json" jsonLogFormat = "json"
DebugAddress = "localhost:8080" DebugAddress = "localhost:8080"
// This is the generally recommended maximum age for Cache-Control headers
// (one year, in seconds, since one year is forever in terms of internet
// content)
maxMaxAge = 60 * 60 * 24 * 365
) )
var ( var (
@ -55,121 +50,6 @@ func init() {
} }
} }
// get the address for the HTTP server, and parses the optional TLS
// configuration for the server - if no TLS configuration is specified,
// TLS is not enabled.
func getAddrAndTLSConfig(configuration *viper.Viper) (string, *tls.Config, error) {
httpAddr := configuration.GetString("server.http_addr")
if httpAddr == "" {
return "", nil, fmt.Errorf("http listen address required for server")
}
tlsConfig, err := utils.ParseServerTLS(configuration, false)
if err != nil {
return "", nil, fmt.Errorf(err.Error())
}
return httpAddr, tlsConfig, nil
}
// sets up TLS for the GRPC connection to notary-signer
func grpcTLS(configuration *viper.Viper) (*tls.Config, error) {
rootCA := utils.GetPathRelativeToConfig(configuration, "trust_service.tls_ca_file")
clientCert := utils.GetPathRelativeToConfig(configuration, "trust_service.tls_client_cert")
clientKey := utils.GetPathRelativeToConfig(configuration, "trust_service.tls_client_key")
if clientCert == "" && clientKey != "" || clientCert != "" && clientKey == "" {
return nil, fmt.Errorf("either pass both client key and cert, or neither")
}
tlsConfig, err := tlsconfig.Client(tlsconfig.Options{
CAFile: rootCA,
CertFile: clientCert,
KeyFile: clientKey,
})
if err != nil {
return nil, fmt.Errorf(
"Unable to configure TLS to the trust service: %s", err.Error())
}
return tlsConfig, nil
}
// parses the configuration and returns a backing store for the TUF files
func getStore(configuration *viper.Viper, allowedBackends []string) (
storage.MetaStore, error) {
storeConfig, err := utils.ParseStorage(configuration, allowedBackends)
if err != nil {
return nil, err
}
logrus.Infof("Using %s backend", storeConfig.Backend)
if storeConfig.Backend == utils.MemoryBackend {
return storage.NewMemStorage(), nil
}
store, err := storage.NewSQLStorage(storeConfig.Backend, storeConfig.Source)
if err != nil {
return nil, fmt.Errorf("Error starting DB driver: %s", err.Error())
}
health.RegisterPeriodicFunc(
"DB operational", store.CheckHealth, time.Second*60)
return store, nil
}
type signerFactory func(hostname, port string, tlsConfig *tls.Config) *client.NotarySigner
type healthRegister func(name string, checkFunc func() error, duration time.Duration)
// parses the configuration and determines which trust service and key algorithm
// to return
func getTrustService(configuration *viper.Viper, sFactory signerFactory,
hRegister healthRegister) (signed.CryptoService, string, error) {
switch configuration.GetString("trust_service.type") {
case "local":
logrus.Info("Using local signing service, which requires ED25519. " +
"Ignoring all other trust_service parameters, including keyAlgorithm")
return signed.NewEd25519(), data.ED25519Key, nil
case "remote":
default:
return nil, "", fmt.Errorf(
"must specify either a \"local\" or \"remote\" type for trust_service")
}
keyAlgo := configuration.GetString("trust_service.key_algorithm")
if keyAlgo != data.ED25519Key && keyAlgo != data.ECDSAKey && keyAlgo != data.RSAKey {
return nil, "", fmt.Errorf("invalid key algorithm configured: %s", keyAlgo)
}
clientTLS, err := grpcTLS(configuration)
if err != nil {
return nil, "", err
}
logrus.Info("Using remote signing service")
notarySigner := sFactory(
configuration.GetString("trust_service.hostname"),
configuration.GetString("trust_service.port"),
clientTLS,
)
minute := 1 * time.Minute
hRegister(
"Trust operational",
// If the trust service fails, the server is degraded but not
// exactly unheatlthy, so always return healthy and just log an
// error.
func() error {
err := notarySigner.CheckHealth(minute)
if err != nil {
logrus.Error("Trust not fully operational: ", err.Error())
}
return nil
},
minute)
return notarySigner, keyAlgo, nil
}
func main() { func main() {
flag.Usage = usage flag.Usage = usage
flag.Parse() flag.Parse()
@ -215,6 +95,11 @@ func main() {
} }
ctx = context.WithValue(ctx, "metaStore", store) ctx = context.WithValue(ctx, "metaStore", store)
currentCache, consistentCache, err := getCacheConfig(mainViper)
if err != nil {
logrus.Fatal(err.Error())
}
httpAddr, tlsConfig, err := getAddrAndTLSConfig(mainViper) httpAddr, tlsConfig, err := getAddrAndTLSConfig(mainViper)
if err != nil { if err != nil {
logrus.Fatal(err.Error()) logrus.Fatal(err.Error())
@ -223,11 +108,15 @@ func main() {
logrus.Info("Starting Server") logrus.Info("Starting Server")
err = server.Run( err = server.Run(
ctx, ctx,
httpAddr, server.Config{
tlsConfig, Addr: httpAddr,
trust, TLSConfig: tlsConfig,
mainViper.GetString("auth.type"), Trust: trust,
mainViper.Get("auth.options"), AuthMethod: mainViper.GetString("auth.type"),
AuthOpts: mainViper.Get("auth.options"),
CurrentCacheControlConfig: currentCache,
ConsistentCacheControlConfig: consistentCache,
},
) )
logrus.Error(err.Error()) logrus.Error(err.Error())

View File

@ -328,3 +328,22 @@ func TestGetMemoryStore(t *testing.T) {
_, ok := store.(*storage.MemStorage) _, ok := store.(*storage.MemStorage)
assert.True(t, ok) assert.True(t, ok)
} }
func TestGetCacheConfig(t *testing.T) {
valid := `{"caching": {"max_age": {"current_metadata": 0, "metadata_by_checksum": 31536000}}}`
invalids := []string{
`{"caching": {"max_age": {"current_metadata": 0, "metadata_by_checksum": 31539000}}}`,
`{"caching": {"max_age": {"current_metadata": -1, "metadata_by_checksum": 300}}}`,
`{"caching": {"max_age": {"current_metadata": "hello", "metadata_by_checksum": 300}}}`,
}
current, consistent, err := getCacheConfig(configure(valid))
assert.NoError(t, err)
assert.IsType(t, utils.NoCacheControl{}, current)
assert.IsType(t, utils.PublicCacheControl{}, consistent)
for _, invalid := range invalids {
_, _, err := getCacheConfig(configure(invalid))
assert.Error(t, err)
}
}

View File

@ -71,7 +71,7 @@ func setupServerHandler(metaStore storage.MetaStore) http.Handler {
cryptoService := cryptoservice.NewCryptoService( cryptoService := cryptoservice.NewCryptoService(
"", trustmanager.NewKeyMemoryStore(passphrase.ConstantRetriever("pass"))) "", trustmanager.NewKeyMemoryStore(passphrase.ConstantRetriever("pass")))
return server.RootHandler(nil, ctx, cryptoService) return server.RootHandler(nil, ctx, cryptoService, nil, nil)
} }
// makes a testing notary-server // makes a testing notary-server

View File

@ -5,6 +5,7 @@ import (
"bufio" "bufio"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
@ -20,7 +21,6 @@ import (
"github.com/docker/notary/tuf/data" "github.com/docker/notary/tuf/data"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/viper" "github.com/spf13/viper"
"io/ioutil"
) )
var cmdKeyTemplate = usageTemplate{ var cmdKeyTemplate = usageTemplate{

View File

@ -10,6 +10,7 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"testing" "testing"
"time"
"github.com/docker/go-connections/tlsconfig" "github.com/docker/go-connections/tlsconfig"
"github.com/docker/notary/passphrase" "github.com/docker/notary/passphrase"
@ -195,7 +196,7 @@ func TestBareCommandPrintsUsageAndNoError(t *testing.T) {
cmd := NewNotaryCommand() cmd := NewNotaryCommand()
cmd.SetOutput(b) cmd.SetOutput(b)
cmd.SetArgs([]string{"-c", filepath.Join(tempdir, "idonotexist.json"), bareCommand}) cmd.SetArgs([]string{"-c", filepath.Join(tempdir, "idonotexist.json"), "-d", tempdir, bareCommand})
require.NoError(t, cmd.Execute(), "Expected no error from a help request") require.NoError(t, cmd.Execute(), "Expected no error from a help request")
// usage is printed // usage is printed
require.Contains(t, b.String(), "Usage:", "expected usage when running `notary %s`", bareCommand) require.Contains(t, b.String(), "Usage:", "expected usage when running `notary %s`", bareCommand)
@ -209,14 +210,14 @@ type recordingMetaStore struct {
// GetCurrent gets the metadata from the underlying MetaStore, but also records // GetCurrent gets the metadata from the underlying MetaStore, but also records
// that the metadata was requested // that the metadata was requested
func (r *recordingMetaStore) GetCurrent(gun, role string) (data []byte, err error) { func (r *recordingMetaStore) GetCurrent(gun, role string) (*time.Time, []byte, error) {
r.gotten = append(r.gotten, fmt.Sprintf("%s.%s", gun, role)) r.gotten = append(r.gotten, fmt.Sprintf("%s.%s", gun, role))
return r.MemStorage.GetCurrent(gun, role) return r.MemStorage.GetCurrent(gun, role)
} }
// GetChecksum gets the metadata from the underlying MetaStore, but also records // GetChecksum gets the metadata from the underlying MetaStore, but also records
// that the metadata was requested // that the metadata was requested
func (r *recordingMetaStore) GetChecksum(gun, role, checksum string) (data []byte, err error) { func (r *recordingMetaStore) GetChecksum(gun, role, checksum string) (*time.Time, []byte, error) {
r.gotten = append(r.gotten, fmt.Sprintf("%s.%s", gun, role)) r.gotten = append(r.gotten, fmt.Sprintf("%s.%s", gun, role))
return r.MemStorage.GetChecksum(gun, role, checksum) return r.MemStorage.GetChecksum(gun, role, checksum)
} }
@ -255,7 +256,7 @@ func TestConfigFileTLSCannotBeRelativeToCWD(t *testing.T) {
// set a config file, so it doesn't check ~/.notary/config.json by default, // set a config file, so it doesn't check ~/.notary/config.json by default,
// and execute a random command so that the flags are parsed // and execute a random command so that the flags are parsed
cmd := NewNotaryCommand() cmd := NewNotaryCommand()
cmd.SetArgs([]string{"-c", configFile, "list", "repo"}) cmd.SetArgs([]string{"-c", configFile, "-d", tempDir, "list", "repo"})
cmd.SetOutput(new(bytes.Buffer)) // eat the output cmd.SetOutput(new(bytes.Buffer)) // eat the output
err = cmd.Execute() err = cmd.Execute()
assert.Error(t, err, "expected a failure due to TLS") assert.Error(t, err, "expected a failure due to TLS")
@ -309,7 +310,7 @@ func TestConfigFileTLSCanBeRelativeToConfigOrAbsolute(t *testing.T) {
// set a config file, so it doesn't check ~/.notary/config.json by default, // set a config file, so it doesn't check ~/.notary/config.json by default,
// and execute a random command so that the flags are parsed // and execute a random command so that the flags are parsed
cmd := NewNotaryCommand() cmd := NewNotaryCommand()
cmd.SetArgs([]string{"-c", configFile.Name(), "list", "repo"}) cmd.SetArgs([]string{"-c", configFile.Name(), "-d", tempDir, "list", "repo"})
cmd.SetOutput(new(bytes.Buffer)) // eat the output cmd.SetOutput(new(bytes.Buffer)) // eat the output
err = cmd.Execute() err = cmd.Execute()
assert.Error(t, err, "there was no repository, so list should have failed") assert.Error(t, err, "there was no repository, so list should have failed")
@ -356,7 +357,7 @@ func TestConfigFileOverridenByCmdLineFlags(t *testing.T) {
cmd := NewNotaryCommand() cmd := NewNotaryCommand()
cmd.SetArgs([]string{ cmd.SetArgs([]string{
"-c", configFile, "list", "repo", "-c", configFile, "-d", tempDir, "list", "repo",
"--tlscacert", "../../fixtures/root-ca.crt", "--tlscacert", "../../fixtures/root-ca.crt",
"--tlscert", filepath.Clean(filepath.Join(cwd, "../../fixtures/notary-server.crt")), "--tlscert", filepath.Clean(filepath.Join(cwd, "../../fixtures/notary-server.crt")),
"--tlskey", "../../fixtures/notary-server.key"}) "--tlskey", "../../fixtures/notary-server.key"})

View File

@ -7,6 +7,7 @@ import (
"net/http" "net/http"
"strings" "strings"
"github.com/Sirupsen/logrus"
ctxu "github.com/docker/distribution/context" ctxu "github.com/docker/distribution/context"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"golang.org/x/net/context" "golang.org/x/net/context"
@ -18,6 +19,7 @@ import (
"github.com/docker/notary/tuf/data" "github.com/docker/notary/tuf/data"
"github.com/docker/notary/tuf/signed" "github.com/docker/notary/tuf/signed"
"github.com/docker/notary/tuf/validation" "github.com/docker/notary/tuf/validation"
"github.com/docker/notary/utils"
) )
// MainHandler is the default handler for the server // MainHandler is the default handler for the server
@ -117,12 +119,28 @@ func getHandler(ctx context.Context, w http.ResponseWriter, r *http.Request, var
checksum := vars["checksum"] checksum := vars["checksum"]
tufRole := vars["tufRole"] tufRole := vars["tufRole"]
s := ctx.Value("metaStore") s := ctx.Value("metaStore")
store, ok := s.(storage.MetaStore) store, ok := s.(storage.MetaStore)
if !ok { if !ok {
return errors.ErrNoStorage.WithDetail(nil) return errors.ErrNoStorage.WithDetail(nil)
} }
return getRole(ctx, w, store, gun, tufRole, checksum) lastModified, output, err := getRole(ctx, store, gun, tufRole, checksum)
if err != nil {
return err
}
if lastModified != nil {
// This shouldn't always be true, but in case it is nil, and the last modified headers
// are not set, the cache control handler should set the last modified date to the beginning
// of time.
utils.SetLastModifiedHeader(w.Header(), *lastModified)
} else {
logrus.Warnf("Got bytes out for %s's %s (checksum: %s), but missing lastModified date",
gun, tufRole, checksum)
}
w.Write(output)
return nil
} }
// DeleteHandler deletes all data for a GUN. A 200 responses indicates success. // DeleteHandler deletes all data for a GUN. A 200 responses indicates success.

View File

@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"time"
"golang.org/x/net/context" "golang.org/x/net/context"
@ -354,8 +355,8 @@ type failStore struct {
storage.MemStorage storage.MemStorage
} }
func (s *failStore) GetCurrent(_, _ string) ([]byte, error) { func (s *failStore) GetCurrent(_, _ string) (*time.Time, []byte, error) {
return nil, fmt.Errorf("oh no! storage has failed") return nil, nil, fmt.Errorf("oh no! storage has failed")
} }
// a non-validation failure, such as the storage failing, will not be propagated // a non-validation failure, such as the storage failing, will not be propagated

View File

@ -1,7 +1,7 @@
package handlers package handlers
import ( import (
"io" "time"
"golang.org/x/net/context" "golang.org/x/net/context"
@ -13,8 +13,9 @@ import (
"github.com/docker/notary/tuf/signed" "github.com/docker/notary/tuf/signed"
) )
func getRole(ctx context.Context, w io.Writer, store storage.MetaStore, gun, role, checksum string) error { func getRole(ctx context.Context, store storage.MetaStore, gun, role, checksum string) (*time.Time, []byte, error) {
var ( var (
lastModified *time.Time
out []byte out []byte
err error err error
) )
@ -23,25 +24,24 @@ func getRole(ctx context.Context, w io.Writer, store storage.MetaStore, gun, rol
// handled specially // handled specially
switch role { switch role {
case data.CanonicalTimestampRole, data.CanonicalSnapshotRole: case data.CanonicalTimestampRole, data.CanonicalSnapshotRole:
return getMaybeServerSigned(ctx, w, store, gun, role) return getMaybeServerSigned(ctx, store, gun, role)
} }
out, err = store.GetCurrent(gun, role) lastModified, out, err = store.GetCurrent(gun, role)
} else { } else {
out, err = store.GetChecksum(gun, role, checksum) lastModified, out, err = store.GetChecksum(gun, role, checksum)
} }
if err != nil { if err != nil {
if _, ok := err.(storage.ErrNotFound); ok { if _, ok := err.(storage.ErrNotFound); ok {
return errors.ErrMetadataNotFound.WithDetail(err) return nil, nil, errors.ErrMetadataNotFound.WithDetail(err)
} }
return errors.ErrUnknown.WithDetail(err) return nil, nil, errors.ErrUnknown.WithDetail(err)
} }
if out == nil { if out == nil {
return errors.ErrMetadataNotFound.WithDetail(nil) return nil, nil, errors.ErrMetadataNotFound.WithDetail(nil)
} }
w.Write(out)
return nil return lastModified, out, nil
} }
// getMaybeServerSigned writes the current snapshot or timestamp (based on the // getMaybeServerSigned writes the current snapshot or timestamp (based on the
@ -49,32 +49,32 @@ func getRole(ctx context.Context, w io.Writer, store storage.MetaStore, gun, rol
// the timestamp and snapshot, based on the keys held by the server, a new one // the timestamp and snapshot, based on the keys held by the server, a new one
// might be generated and signed due to expiry of the previous one or updates // might be generated and signed due to expiry of the previous one or updates
// to other roles. // to other roles.
func getMaybeServerSigned(ctx context.Context, w io.Writer, store storage.MetaStore, gun, role string) error { func getMaybeServerSigned(ctx context.Context, store storage.MetaStore, gun, role string) (*time.Time, []byte, error) {
cryptoServiceVal := ctx.Value("cryptoService") cryptoServiceVal := ctx.Value("cryptoService")
cryptoService, ok := cryptoServiceVal.(signed.CryptoService) cryptoService, ok := cryptoServiceVal.(signed.CryptoService)
if !ok { if !ok {
return errors.ErrNoCryptoService.WithDetail(nil) return nil, nil, errors.ErrNoCryptoService.WithDetail(nil)
} }
var ( var (
lastModified *time.Time
out []byte out []byte
err error err error
) )
switch role { switch role {
case data.CanonicalSnapshotRole: case data.CanonicalSnapshotRole:
out, err = snapshot.GetOrCreateSnapshot(gun, store, cryptoService) lastModified, out, err = snapshot.GetOrCreateSnapshot(gun, store, cryptoService)
case data.CanonicalTimestampRole: case data.CanonicalTimestampRole:
out, err = timestamp.GetOrCreateTimestamp(gun, store, cryptoService) lastModified, out, err = timestamp.GetOrCreateTimestamp(gun, store, cryptoService)
} }
if err != nil { if err != nil {
switch err.(type) { switch err.(type) {
case *storage.ErrNoKey, storage.ErrNotFound: case *storage.ErrNoKey, storage.ErrNotFound:
return errors.ErrMetadataNotFound.WithDetail(err) return nil, nil, errors.ErrMetadataNotFound.WithDetail(err)
default: default:
return errors.ErrUnknown.WithDetail(err) return nil, nil, errors.ErrUnknown.WithDetail(err)
} }
} }
w.Write(out) return lastModified, out, nil
return nil
} }

View File

@ -14,10 +14,7 @@ import (
) )
func TestGetMaybeServerSignedNoCrypto(t *testing.T) { func TestGetMaybeServerSignedNoCrypto(t *testing.T) {
err := getMaybeServerSigned( _, _, err := getMaybeServerSigned(context.Background(), nil, "", "")
context.Background(),
nil, nil, "", "",
)
require.Error(t, err) require.Error(t, err)
require.IsType(t, errcode.Error{}, err) require.IsType(t, errcode.Error{}, err)
@ -33,9 +30,8 @@ func TestGetMaybeServerSignedNoKey(t *testing.T) {
ctx = context.WithValue(ctx, "cryptoService", crypto) ctx = context.WithValue(ctx, "cryptoService", crypto)
ctx = context.WithValue(ctx, "keyAlgorithm", data.ED25519Key) ctx = context.WithValue(ctx, "keyAlgorithm", data.ED25519Key)
err := getMaybeServerSigned( _, _, err := getMaybeServerSigned(
ctx, ctx,
nil,
store, store,
"gun", "gun",
data.CanonicalTimestampRole, data.CanonicalTimestampRole,

View File

@ -41,7 +41,7 @@ func validateUpdate(cs signed.CryptoService, gun string, updates []storage.MetaU
} }
var root *data.SignedRoot var root *data.SignedRoot
oldRootJSON, err := store.GetCurrent(gun, rootRole) _, oldRootJSON, err := store.GetCurrent(gun, rootRole)
if _, ok := err.(storage.ErrNotFound); err != nil && !ok { if _, ok := err.(storage.ErrNotFound); err != nil && !ok {
// problem with storage. No expectation we can // problem with storage. No expectation we can
// write if we can't read so bail. // write if we can't read so bail.
@ -92,7 +92,7 @@ func validateUpdate(cs signed.CryptoService, gun string, updates []storage.MetaU
// At this point, root and targets must have been loaded into the repo // At this point, root and targets must have been loaded into the repo
if _, ok := roles[snapshotRole]; ok { if _, ok := roles[snapshotRole]; ok {
var oldSnap *data.SignedSnapshot var oldSnap *data.SignedSnapshot
oldSnapJSON, err := store.GetCurrent(gun, snapshotRole) _, oldSnapJSON, err := store.GetCurrent(gun, snapshotRole)
if _, ok := err.(storage.ErrNotFound); err != nil && !ok { if _, ok := err.(storage.ErrNotFound); err != nil && !ok {
// problem with storage. No expectation we can // problem with storage. No expectation we can
// write if we can't read so bail. // write if we can't read so bail.
@ -180,7 +180,7 @@ func loadAndValidateTargets(gun string, repo *tuf.Repo, roles map[string]storage
} }
func loadTargetsFromStore(gun, role string, repo *tuf.Repo, store storage.MetaStore) error { func loadTargetsFromStore(gun, role string, repo *tuf.Repo, store storage.MetaStore) error {
tgtJSON, err := store.GetCurrent(gun, role) _, tgtJSON, err := store.GetCurrent(gun, role)
if err != nil { if err != nil {
return err return err
} }
@ -217,7 +217,7 @@ func generateSnapshot(gun string, repo *tuf.Repo, store storage.MetaStore) (*sto
Msg: "no snapshot was included in update and server does not hold current snapshot key for repository"} Msg: "no snapshot was included in update and server does not hold current snapshot key for repository"}
} }
currentJSON, err := store.GetCurrent(gun, data.CanonicalSnapshotRole) _, currentJSON, err := store.GetCurrent(gun, data.CanonicalSnapshotRole)
if err != nil { if err != nil {
if _, ok := err.(storage.ErrNotFound); !ok { if _, ok := err.(storage.ErrNotFound); !ok {
return nil, validation.ErrValidation{Msg: err.Error()} return nil, validation.ErrValidation{Msg: err.Error()}

View File

@ -24,7 +24,7 @@ func TestValidationErrorFormat(t *testing.T) {
context.Background(), "metaStore", storage.NewMemStorage()) context.Background(), "metaStore", storage.NewMemStorage())
ctx = context.WithValue(ctx, "keyAlgorithm", data.ED25519Key) ctx = context.WithValue(ctx, "keyAlgorithm", data.ED25519Key)
handler := RootHandler(nil, ctx, signed.NewEd25519()) handler := RootHandler(nil, ctx, signed.NewEd25519(), nil, nil)
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()

View File

@ -31,12 +31,22 @@ func prometheusOpts(operation string) prometheus.SummaryOpts {
} }
} }
// Config tells Run how to configure a server
type Config struct {
Addr string
TLSConfig *tls.Config
Trust signed.CryptoService
AuthMethod string
AuthOpts interface{}
ConsistentCacheControlConfig utils.CacheControlConfig
CurrentCacheControlConfig utils.CacheControlConfig
}
// Run sets up and starts a TLS server that can be cancelled using the // Run sets up and starts a TLS server that can be cancelled using the
// given configuration. The context it is passed is the context it should // given configuration. The context it is passed is the context it should
// use directly for the TLS server, and generate children off for requests // use directly for the TLS server, and generate children off for requests
func Run(ctx context.Context, addr string, tlsConfig *tls.Config, trust signed.CryptoService, authMethod string, authOpts interface{}) error { func Run(ctx context.Context, conf Config) error {
tcpAddr, err := net.ResolveTCPAddr("tcp", conf.Addr)
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil { if err != nil {
return err return err
} }
@ -46,29 +56,29 @@ func Run(ctx context.Context, addr string, tlsConfig *tls.Config, trust signed.C
return err return err
} }
if tlsConfig != nil { if conf.TLSConfig != nil {
logrus.Info("Enabling TLS") logrus.Info("Enabling TLS")
lsnr = tls.NewListener(lsnr, tlsConfig) lsnr = tls.NewListener(lsnr, conf.TLSConfig)
} }
var ac auth.AccessController var ac auth.AccessController
if authMethod == "token" { if conf.AuthMethod == "token" {
authOptions, ok := authOpts.(map[string]interface{}) authOptions, ok := conf.AuthOpts.(map[string]interface{})
if !ok { if !ok {
return fmt.Errorf("auth.options must be a map[string]interface{}") return fmt.Errorf("auth.options must be a map[string]interface{}")
} }
ac, err = auth.GetAccessController(authMethod, authOptions) ac, err = auth.GetAccessController(conf.AuthMethod, authOptions)
if err != nil { if err != nil {
return err return err
} }
} }
svr := http.Server{ svr := http.Server{
Addr: addr, Addr: conf.Addr,
Handler: RootHandler(ac, ctx, trust), Handler: RootHandler(ac, ctx, conf.Trust, conf.ConsistentCacheControlConfig, conf.CurrentCacheControlConfig),
} }
logrus.Info("Starting on ", addr) logrus.Info("Starting on ", conf.Addr)
err = svr.Serve(lsnr) err = svr.Serve(lsnr)
@ -77,7 +87,9 @@ func Run(ctx context.Context, addr string, tlsConfig *tls.Config, trust signed.C
// RootHandler returns the handler that routes all the paths from / for the // RootHandler returns the handler that routes all the paths from / for the
// server. // server.
func RootHandler(ac auth.AccessController, ctx context.Context, trust signed.CryptoService) http.Handler { func RootHandler(ac auth.AccessController, ctx context.Context, trust signed.CryptoService,
consistent, current utils.CacheControlConfig) http.Handler {
hand := utils.RootHandlerFactory(ac, ctx, trust) hand := utils.RootHandlerFactory(ac, ctx, trust)
r := mux.NewRouter() r := mux.NewRouter()
@ -89,11 +101,11 @@ func RootHandler(ac auth.AccessController, ctx context.Context, trust signed.Cry
r.Methods("GET").Path("/v2/{imageName:.*}/_trust/tuf/{tufRole:root|targets(?:/[^/\\s]+)*|snapshot|timestamp}.{checksum:[a-fA-F0-9]{64}|[a-fA-F0-9]{96}|[a-fA-F0-9]{128}}.json").Handler( r.Methods("GET").Path("/v2/{imageName:.*}/_trust/tuf/{tufRole:root|targets(?:/[^/\\s]+)*|snapshot|timestamp}.{checksum:[a-fA-F0-9]{64}|[a-fA-F0-9]{96}|[a-fA-F0-9]{128}}.json").Handler(
prometheus.InstrumentHandlerWithOpts( prometheus.InstrumentHandlerWithOpts(
prometheusOpts("GetRoleByHash"), prometheusOpts("GetRoleByHash"),
hand(handlers.GetHandler, "pull"))) utils.WrapWithCacheHandler(consistent, hand(handlers.GetHandler, "pull"))))
r.Methods("GET").Path("/v2/{imageName:.*}/_trust/tuf/{tufRole:root|targets(?:/[^/\\s]+)*|snapshot|timestamp}.json").Handler( r.Methods("GET").Path("/v2/{imageName:.*}/_trust/tuf/{tufRole:root|targets(?:/[^/\\s]+)*|snapshot|timestamp}.json").Handler(
prometheus.InstrumentHandlerWithOpts( prometheus.InstrumentHandlerWithOpts(
prometheusOpts("GetRole"), prometheusOpts("GetRole"),
hand(handlers.GetHandler, "pull"))) utils.WrapWithCacheHandler(current, hand(handlers.GetHandler, "pull"))))
r.Methods("GET").Path( r.Methods("GET").Path(
"/v2/{imageName:.*}/_trust/tuf/{tufRole:snapshot|timestamp}.key").Handler( "/v2/{imageName:.*}/_trust/tuf/{tufRole:snapshot|timestamp}.key").Handler(
prometheus.InstrumentHandlerWithOpts( prometheus.InstrumentHandlerWithOpts(

View File

@ -1,6 +1,7 @@
package server package server
import ( import (
"bytes"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
@ -16,6 +17,8 @@ import (
"github.com/docker/notary/server/storage" "github.com/docker/notary/server/storage"
"github.com/docker/notary/tuf/data" "github.com/docker/notary/tuf/data"
"github.com/docker/notary/tuf/signed" "github.com/docker/notary/tuf/signed"
"github.com/docker/notary/tuf/testutils"
"github.com/docker/notary/utils"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/net/context" "golang.org/x/net/context"
) )
@ -23,11 +26,10 @@ import (
func TestRunBadAddr(t *testing.T) { func TestRunBadAddr(t *testing.T) {
err := Run( err := Run(
context.Background(), context.Background(),
"testAddr", Config{
nil, Addr: "testAddr",
signed.NewEd25519(), Trust: signed.NewEd25519(),
"", },
nil,
) )
assert.Error(t, err, "Passed bad addr, Run should have failed") assert.Error(t, err, "Passed bad addr, Run should have failed")
} }
@ -37,11 +39,10 @@ func TestRunReservedPort(t *testing.T) {
err := Run( err := Run(
ctx, ctx,
"localhost:80", Config{
nil, Addr: "localhost:80",
signed.NewEd25519(), Trust: signed.NewEd25519(),
"", },
nil,
) )
assert.Error(t, err) assert.Error(t, err)
@ -55,7 +56,8 @@ func TestRunReservedPort(t *testing.T) {
} }
func TestMetricsEndpoint(t *testing.T) { func TestMetricsEndpoint(t *testing.T) {
handler := RootHandler(nil, context.Background(), signed.NewEd25519()) handler := RootHandler(nil, context.Background(), signed.NewEd25519(),
nil, nil)
ts := httptest.NewServer(handler) ts := httptest.NewServer(handler)
defer ts.Close() defer ts.Close()
@ -70,7 +72,7 @@ func TestGetKeysEndpoint(t *testing.T) {
context.Background(), "metaStore", storage.NewMemStorage()) context.Background(), "metaStore", storage.NewMemStorage())
ctx = context.WithValue(ctx, "keyAlgorithm", data.ED25519Key) ctx = context.WithValue(ctx, "keyAlgorithm", data.ED25519Key)
handler := RootHandler(nil, ctx, signed.NewEd25519()) handler := RootHandler(nil, ctx, signed.NewEd25519(), nil, nil)
ts := httptest.NewServer(handler) ts := httptest.NewServer(handler)
defer ts.Close() defer ts.Close()
@ -90,7 +92,7 @@ func TestGetKeysEndpoint(t *testing.T) {
} }
} }
// This just checks the URL routing is working correctly. // This just checks the URL routing is working correctly and cache headers are set correctly.
// More detailed tests for this path including negative // More detailed tests for this path including negative
// tests are located in /server/handlers/ // tests are located in /server/handlers/
func TestGetRoleByHash(t *testing.T) { func TestGetRoleByHash(t *testing.T) {
@ -99,48 +101,46 @@ func TestGetRoleByHash(t *testing.T) {
ts := data.SignedTimestamp{ ts := data.SignedTimestamp{
Signatures: make([]data.Signature, 0), Signatures: make([]data.Signature, 0),
Signed: data.Timestamp{ Signed: data.Timestamp{
Type: data.TUFTypes["timestamp"], Type: data.TUFTypes[data.CanonicalTimestampRole],
Version: 1, Version: 1,
Expires: data.DefaultExpires("timestamp"), Expires: data.DefaultExpires(data.CanonicalTimestampRole),
}, },
} }
j, err := json.Marshal(&ts) j, err := json.Marshal(&ts)
assert.NoError(t, err) assert.NoError(t, err)
update := storage.MetaUpdate{ store.UpdateCurrent("gun", storage.MetaUpdate{
Role: data.CanonicalTimestampRole, Role: data.CanonicalTimestampRole,
Version: 1, Version: 1,
Data: j, Data: j,
} })
checksumBytes := sha256.Sum256(j) checksumBytes := sha256.Sum256(j)
checksum := hex.EncodeToString(checksumBytes[:]) checksum := hex.EncodeToString(checksumBytes[:])
store.UpdateCurrent("gun", update)
// create and add a newer timestamp. We're going to try and request // create and add a newer timestamp. We're going to try and request
// the older version we created above. // the older version we created above.
ts = data.SignedTimestamp{ ts = data.SignedTimestamp{
Signatures: make([]data.Signature, 0), Signatures: make([]data.Signature, 0),
Signed: data.Timestamp{ Signed: data.Timestamp{
Type: data.TUFTypes["timestamp"], Type: data.TUFTypes[data.CanonicalTimestampRole],
Version: 2, Version: 2,
Expires: data.DefaultExpires("timestamp"), Expires: data.DefaultExpires(data.CanonicalTimestampRole),
}, },
} }
newJ, err := json.Marshal(&ts) newTS, err := json.Marshal(&ts)
assert.NoError(t, err) assert.NoError(t, err)
update = storage.MetaUpdate{ store.UpdateCurrent("gun", storage.MetaUpdate{
Role: data.CanonicalTimestampRole, Role: data.CanonicalTimestampRole,
Version: 2, Version: 1,
Data: newJ, Data: newTS,
} })
store.UpdateCurrent("gun", update)
ctx := context.WithValue( ctx := context.WithValue(
context.Background(), "metaStore", store) context.Background(), "metaStore", store)
ctx = context.WithValue(ctx, "keyAlgorithm", data.ED25519Key) ctx = context.WithValue(ctx, "keyAlgorithm", data.ED25519Key)
handler := RootHandler(nil, ctx, signed.NewEd25519()) ccc := utils.NewCacheControlConfig(10, false)
handler := RootHandler(nil, ctx, signed.NewEd25519(), ccc, ccc)
serv := httptest.NewServer(handler) serv := httptest.NewServer(handler)
defer serv.Close() defer serv.Close()
@ -152,10 +152,59 @@ func TestGetRoleByHash(t *testing.T) {
)) ))
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode) assert.Equal(t, http.StatusOK, res.StatusCode)
body, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
defer res.Body.Close()
// if content is equal, checksums are guaranteed to be equal // if content is equal, checksums are guaranteed to be equal
assert.EqualValues(t, j, body) verifyGetResponse(t, res, j)
}
// This just checks the URL routing is working correctly and cache headers are set correctly.
// More detailed tests for this path including negative
// tests are located in /server/handlers/
func TestGetCurrentRole(t *testing.T) {
store := storage.NewMemStorage()
metadata, _, err := testutils.NewRepoMetadata("gun")
assert.NoError(t, err)
// need both the snapshot and the timestamp, because when getting the current
// timestamp the server checks to see if it's out of date (there's a new snapshot)
// and if so, generates a new one
store.UpdateCurrent("gun", storage.MetaUpdate{
Role: data.CanonicalSnapshotRole,
Version: 1,
Data: metadata[data.CanonicalSnapshotRole],
})
store.UpdateCurrent("gun", storage.MetaUpdate{
Role: data.CanonicalTimestampRole,
Version: 1,
Data: metadata[data.CanonicalTimestampRole],
})
ctx := context.WithValue(
context.Background(), "metaStore", store)
ctx = context.WithValue(ctx, "keyAlgorithm", data.ED25519Key)
ccc := utils.NewCacheControlConfig(10, false)
handler := RootHandler(nil, ctx, signed.NewEd25519(), ccc, ccc)
serv := httptest.NewServer(handler)
defer serv.Close()
res, err := http.Get(fmt.Sprintf(
"%s/v2/gun/_trust/tuf/%s.json",
serv.URL,
data.CanonicalTimestampRole,
))
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
verifyGetResponse(t, res, metadata[data.CanonicalTimestampRole])
}
// Verifies that the body is as expected and that there are cache control headers
func verifyGetResponse(t *testing.T, r *http.Response, expectedBytes []byte) {
body, err := ioutil.ReadAll(r.Body)
assert.NoError(t, err)
assert.True(t, bytes.Equal(expectedBytes, body))
assert.NotEqual(t, "", r.Header.Get("Cache-Control"))
assert.NotEqual(t, "", r.Header.Get("Last-Modified"))
assert.Equal(t, "", r.Header.Get("Pragma"))
} }

View File

@ -2,6 +2,7 @@ package snapshot
import ( import (
"encoding/json" "encoding/json"
"time"
"github.com/Sirupsen/logrus" "github.com/Sirupsen/logrus"
@ -45,10 +46,12 @@ func GetOrCreateSnapshotKey(gun string, store storage.KeyStore, crypto signed.Cr
// GetOrCreateSnapshot either returns the exisiting latest snapshot, or uses // GetOrCreateSnapshot either returns the exisiting latest snapshot, or uses
// whatever the most recent snapshot is to create the next one, only updating // whatever the most recent snapshot is to create the next one, only updating
// the expiry time and version. // the expiry time and version.
func GetOrCreateSnapshot(gun string, store storage.MetaStore, cryptoService signed.CryptoService) ([]byte, error) { func GetOrCreateSnapshot(gun string, store storage.MetaStore, cryptoService signed.CryptoService) (
d, err := store.GetCurrent(gun, "snapshot") *time.Time, []byte, error) {
lastModified, d, err := store.GetCurrent(gun, data.CanonicalSnapshotRole)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
sn := &data.SignedSnapshot{} sn := &data.SignedSnapshot{}
@ -56,29 +59,30 @@ func GetOrCreateSnapshot(gun string, store storage.MetaStore, cryptoService sign
err := json.Unmarshal(d, sn) err := json.Unmarshal(d, sn)
if err != nil { if err != nil {
logrus.Error("Failed to unmarshal existing snapshot") logrus.Error("Failed to unmarshal existing snapshot")
return nil, err return nil, nil, err
} }
if !snapshotExpired(sn) { if !snapshotExpired(sn) {
return d, nil return lastModified, d, nil
} }
} }
sgnd, version, err := createSnapshot(gun, sn, store, cryptoService) sgnd, version, err := createSnapshot(gun, sn, store, cryptoService)
if err != nil { if err != nil {
logrus.Error("Failed to create a new snapshot") logrus.Error("Failed to create a new snapshot")
return nil, err return nil, nil, err
} }
out, err := json.Marshal(sgnd) out, err := json.Marshal(sgnd)
if err != nil { if err != nil {
logrus.Error("Failed to marshal new snapshot") logrus.Error("Failed to marshal new snapshot")
return nil, err return nil, nil, err
} }
err = store.UpdateCurrent(gun, storage.MetaUpdate{Role: "snapshot", Version: version, Data: out}) err = store.UpdateCurrent(gun, storage.MetaUpdate{Role: "snapshot", Version: version, Data: out})
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
return out, nil c := time.Now()
return &c, out, nil
} }
// snapshotExpired simply checks if the snapshot is past its expiry time // snapshotExpired simply checks if the snapshot is past its expiry time

View File

@ -118,7 +118,7 @@ func TestGetSnapshotNotExists(t *testing.T) {
store := storage.NewMemStorage() store := storage.NewMemStorage()
crypto := signed.NewEd25519() crypto := signed.NewEd25519()
_, err := GetOrCreateSnapshot("gun", store, crypto) _, _, err := GetOrCreateSnapshot("gun", store, crypto)
assert.Error(t, err) assert.Error(t, err)
} }
@ -144,18 +144,23 @@ func TestGetSnapshotCurrValid(t *testing.T) {
// test when db is missing the role data // test when db is missing the role data
store.UpdateCurrent("gun", storage.MetaUpdate{Role: "snapshot", Version: 0, Data: snapJSON}) store.UpdateCurrent("gun", storage.MetaUpdate{Role: "snapshot", Version: 0, Data: snapJSON})
_, err = GetOrCreateSnapshot("gun", store, crypto) c1, result, err := GetOrCreateSnapshot("gun", store, crypto)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, bytes.Equal(snapJSON, result))
// test when db has the role data // test when db has the role data
store.UpdateCurrent("gun", storage.MetaUpdate{Role: "root", Version: 0, Data: newData}) store.UpdateCurrent("gun", storage.MetaUpdate{Role: "root", Version: 0, Data: newData})
_, err = GetOrCreateSnapshot("gun", store, crypto) c2, result, err := GetOrCreateSnapshot("gun", store, crypto)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, bytes.Equal(snapJSON, result))
assert.True(t, c1.Equal(*c2))
// test when db role data is expired // test when db role data is corrupt
store.UpdateCurrent("gun", storage.MetaUpdate{Role: "root", Version: 1, Data: []byte{3}}) store.UpdateCurrent("gun", storage.MetaUpdate{Role: "root", Version: 1, Data: []byte{3}})
_, err = GetOrCreateSnapshot("gun", store, crypto) c2, result, err = GetOrCreateSnapshot("gun", store, crypto)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, bytes.Equal(snapJSON, result))
assert.True(t, c1.Equal(*c2))
} }
func TestGetSnapshotCurrExpired(t *testing.T) { func TestGetSnapshotCurrExpired(t *testing.T) {
@ -168,8 +173,10 @@ func TestGetSnapshotCurrExpired(t *testing.T) {
snapJSON, _ := json.Marshal(snapshot) snapJSON, _ := json.Marshal(snapshot)
store.UpdateCurrent("gun", storage.MetaUpdate{Role: "snapshot", Version: 0, Data: snapJSON}) store.UpdateCurrent("gun", storage.MetaUpdate{Role: "snapshot", Version: 0, Data: snapJSON})
_, err = GetOrCreateSnapshot("gun", store, crypto) c1, newJSON, err := GetOrCreateSnapshot("gun", store, crypto)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, bytes.Equal(snapJSON, newJSON))
assert.True(t, c1.After(time.Now().Add(-1*time.Minute)))
} }
func TestGetSnapshotCurrCorrupt(t *testing.T) { func TestGetSnapshotCurrCorrupt(t *testing.T) {
@ -182,7 +189,7 @@ func TestGetSnapshotCurrCorrupt(t *testing.T) {
snapJSON, _ := json.Marshal(snapshot) snapJSON, _ := json.Marshal(snapshot)
store.UpdateCurrent("gun", storage.MetaUpdate{Role: "snapshot", Version: 0, Data: snapJSON[1:]}) store.UpdateCurrent("gun", storage.MetaUpdate{Role: "snapshot", Version: 0, Data: snapJSON[1:]})
_, err = GetOrCreateSnapshot("gun", store, crypto) _, _, err = GetOrCreateSnapshot("gun", store, crypto)
assert.Error(t, err) assert.Error(t, err)
} }

View File

@ -4,6 +4,7 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"time"
"github.com/Sirupsen/logrus" "github.com/Sirupsen/logrus"
"github.com/go-sql-driver/mysql" "github.com/go-sql-driver/mysql"
@ -117,32 +118,39 @@ func (db *SQLStorage) UpdateMany(gun string, updates []MetaUpdate) error {
} }
// GetCurrent gets a specific TUF record // GetCurrent gets a specific TUF record
func (db *SQLStorage) GetCurrent(gun, tufRole string) ([]byte, error) { func (db *SQLStorage) GetCurrent(gun, tufRole string) (*time.Time, []byte, error) {
var row TUFFile var row TUFFile
q := db.Select("data").Where(&TUFFile{Gun: gun, Role: tufRole}).Order("version desc").Limit(1).First(&row) q := db.Select("updated_at, data").Where(
return returnRead(q, row) &TUFFile{Gun: gun, Role: tufRole}).Order("version desc").Limit(1).First(&row)
if err := isReadErr(q, row); err != nil {
return nil, nil, err
}
return &(row.UpdatedAt), row.Data, nil
} }
// GetChecksum gets a specific TUF record by its hex checksum // GetChecksum gets a specific TUF record by its hex checksum
func (db *SQLStorage) GetChecksum(gun, tufRole, checksum string) ([]byte, error) { func (db *SQLStorage) GetChecksum(gun, tufRole, checksum string) (*time.Time, []byte, error) {
var row TUFFile var row TUFFile
q := db.Select("data").Where( q := db.Select("created_at, data").Where(
&TUFFile{ &TUFFile{
Gun: gun, Gun: gun,
Role: tufRole, Role: tufRole,
Sha256: checksum, Sha256: checksum,
}, },
).First(&row) ).First(&row)
return returnRead(q, row) if err := isReadErr(q, row); err != nil {
return nil, nil, err
}
return &(row.CreatedAt), row.Data, nil
} }
func returnRead(q *gorm.DB, row TUFFile) ([]byte, error) { func isReadErr(q *gorm.DB, row TUFFile) error {
if q.RecordNotFound() { if q.RecordNotFound() {
return nil, ErrNotFound{} return ErrNotFound{}
} else if q.Error != nil { } else if q.Error != nil {
return nil, q.Error return q.Error
} }
return row.Data, nil return nil
} }
// Delete deletes all the records for a specific GUN // Delete deletes all the records for a specific GUN

View File

@ -4,9 +4,11 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"testing" "testing"
"time"
"github.com/docker/notary/tuf/data" "github.com/docker/notary/tuf/data"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
@ -231,7 +233,7 @@ func TestSQLGetCurrent(t *testing.T) {
gormDB, dbStore := SetUpSQLite(t, tempBaseDir) gormDB, dbStore := SetUpSQLite(t, tempBaseDir)
defer os.RemoveAll(tempBaseDir) defer os.RemoveAll(tempBaseDir)
byt, err := dbStore.GetCurrent("testGUN", "root") _, byt, err := dbStore.GetCurrent("testGUN", "root")
require.Nil(t, byt) require.Nil(t, byt)
require.Error(t, err, "There should be an error Getting an empty table") require.Error(t, err, "There should be an error Getting an empty table")
require.IsType(t, ErrNotFound{}, err, "Should get a not found error") require.IsType(t, ErrNotFound{}, err, "Should get a not found error")
@ -240,9 +242,13 @@ func TestSQLGetCurrent(t *testing.T) {
query := gormDB.Create(&tuf) query := gormDB.Create(&tuf)
require.NoError(t, query.Error, "Creating a row in an empty DB failed.") require.NoError(t, query.Error, "Creating a row in an empty DB failed.")
byt, err = dbStore.GetCurrent("testGUN", "root") cDate, byt, err := dbStore.GetCurrent("testGUN", "root")
require.NoError(t, err, "There should not be any errors getting.") require.NoError(t, err, "There should not be any errors getting.")
require.Equal(t, []byte("1"), byt, "Returned data was incorrect") require.Equal(t, []byte("1"), byt, "Returned data was incorrect")
// the update date was sometime wthin the last minute
fmt.Println(cDate)
require.True(t, cDate.After(time.Now().Add(-1*time.Minute)))
require.True(t, cDate.Before(time.Now().Add(5*time.Second)))
dbStore.DB.Close() dbStore.DB.Close()
} }
@ -487,9 +493,12 @@ func TestDBGetChecksum(t *testing.T) {
store.UpdateCurrent("gun", update) store.UpdateCurrent("gun", update)
data, err := store.GetChecksum("gun", data.CanonicalTimestampRole, checksum) cDate, data, err := store.GetChecksum("gun", data.CanonicalTimestampRole, checksum)
require.NoError(t, err) require.NoError(t, err)
require.EqualValues(t, j, data) require.EqualValues(t, j, data)
// the creation date was sometime wthin the last minute
require.True(t, cDate.After(time.Now().Add(-1*time.Minute)))
require.True(t, cDate.Before(time.Now().Add(5*time.Second)))
} }
func TestDBGetChecksumNotFound(t *testing.T) { func TestDBGetChecksumNotFound(t *testing.T) {
@ -497,7 +506,7 @@ func TestDBGetChecksumNotFound(t *testing.T) {
_, store := SetUpSQLite(t, tempBaseDir) _, store := SetUpSQLite(t, tempBaseDir)
defer os.RemoveAll(tempBaseDir) defer os.RemoveAll(tempBaseDir)
_, err = store.GetChecksum("gun", data.CanonicalTimestampRole, "12345") _, _, err = store.GetChecksum("gun", data.CanonicalTimestampRole, "12345")
require.Error(t, err) require.Error(t, err)
require.IsType(t, ErrNotFound{}, err) require.IsType(t, ErrNotFound{}, err)
} }

View File

@ -1,5 +1,7 @@
package storage package storage
import "time"
// KeyStore provides a minimal interface for managing key persistence // KeyStore provides a minimal interface for managing key persistence
type KeyStore interface { type KeyStore interface {
// GetKey returns the algorithm and public key for the given GUN and role. // GetKey returns the algorithm and public key for the given GUN and role.
@ -24,15 +26,15 @@ type MetaStore interface {
// none of the metadata is added, and an error is be returned. // none of the metadata is added, and an error is be returned.
UpdateMany(gun string, updates []MetaUpdate) error UpdateMany(gun string, updates []MetaUpdate) error
// GetCurrent returns the data part of the metadata for the latest version // GetCurrent returns the modification date and data part of the metadata for
// of the given GUN and role. If there is no data for the given GUN and // the latest version of the given GUN and role. If there is no data for
// role, an error is returned. // the given GUN and role, an error is returned.
GetCurrent(gun, tufRole string) (data []byte, err error) GetCurrent(gun, tufRole string) (created *time.Time, data []byte, err error)
// GetChecksum return the given tuf role file for the GUN with the // GetChecksum returns the given TUF role file and creation date for the
// provided checksum. If the given (gun, role, checksum) are not // GUN with the provided checksum. If the given (gun, role, checksum) are
// found, it returns storage.ErrNotFound // not found, it returns storage.ErrNotFound
GetChecksum(gun, tufRole, checksum string) (data []byte, err error) GetChecksum(gun, tufRole, checksum string) (created *time.Time, data []byte, err error)
// Delete removes all metadata for a given GUN. It does not return an // Delete removes all metadata for a given GUN. It does not return an
// error if no metadata exists for the given GUN. // error if no metadata exists for the given GUN.

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"strings" "strings"
"sync" "sync"
"time"
) )
type key struct { type key struct {
@ -16,6 +17,7 @@ type key struct {
type ver struct { type ver struct {
version int version int
data []byte data []byte
createupdate time.Time
} }
// MemStorage is really just designed for dev and testing. It is very // MemStorage is really just designed for dev and testing. It is very
@ -24,7 +26,7 @@ type MemStorage struct {
lock sync.Mutex lock sync.Mutex
tufMeta map[string][]*ver tufMeta map[string][]*ver
keys map[string]map[string]*key keys map[string]map[string]*key
checksums map[string]map[string][]byte checksums map[string]map[string]ver
} }
// NewMemStorage instantiates a memStorage instance // NewMemStorage instantiates a memStorage instance
@ -32,7 +34,7 @@ func NewMemStorage() *MemStorage {
return &MemStorage{ return &MemStorage{
tufMeta: make(map[string][]*ver), tufMeta: make(map[string][]*ver),
keys: make(map[string]map[string]*key), keys: make(map[string]map[string]*key),
checksums: make(map[string]map[string][]byte), checksums: make(map[string]map[string]ver),
} }
} }
@ -48,15 +50,16 @@ func (st *MemStorage) UpdateCurrent(gun string, update MetaUpdate) error {
} }
} }
} }
st.tufMeta[id] = append(st.tufMeta[id], &ver{version: update.Version, data: update.Data}) version := ver{version: update.Version, data: update.Data, createupdate: time.Now()}
st.tufMeta[id] = append(st.tufMeta[id], &version)
checksumBytes := sha256.Sum256(update.Data) checksumBytes := sha256.Sum256(update.Data)
checksum := hex.EncodeToString(checksumBytes[:]) checksum := hex.EncodeToString(checksumBytes[:])
_, ok := st.checksums[gun] _, ok := st.checksums[gun]
if !ok { if !ok {
st.checksums[gun] = make(map[string][]byte) st.checksums[gun] = make(map[string]ver)
} }
st.checksums[gun][checksum] = update.Data st.checksums[gun][checksum] = version
return nil return nil
} }
@ -68,27 +71,27 @@ func (st *MemStorage) UpdateMany(gun string, updates []MetaUpdate) error {
return nil return nil
} }
// GetCurrent returns the metadata for a given role, under a GUN // GetCurrent returns the createupdate date metadata for a given role, under a GUN.
func (st *MemStorage) GetCurrent(gun, role string) (data []byte, err error) { func (st *MemStorage) GetCurrent(gun, role string) (*time.Time, []byte, error) {
id := entryKey(gun, role) id := entryKey(gun, role)
st.lock.Lock() st.lock.Lock()
defer st.lock.Unlock() defer st.lock.Unlock()
space, ok := st.tufMeta[id] space, ok := st.tufMeta[id]
if !ok || len(space) == 0 { if !ok || len(space) == 0 {
return nil, ErrNotFound{} return nil, nil, ErrNotFound{}
} }
return space[len(space)-1].data, nil return &(space[len(space)-1].createupdate), space[len(space)-1].data, nil
} }
// GetChecksum returns the metadata for a given role, under a GUN // GetChecksum returns the createupdate date and metadata for a given role, under a GUN.
func (st *MemStorage) GetChecksum(gun, role, checksum string) (data []byte, err error) { func (st *MemStorage) GetChecksum(gun, role, checksum string) (*time.Time, []byte, error) {
st.lock.Lock() st.lock.Lock()
defer st.lock.Unlock() defer st.lock.Unlock()
data, ok := st.checksums[gun][checksum] space, ok := st.checksums[gun][checksum]
if !ok || len(data) == 0 { if !ok || len(space.data) == 0 {
return nil, ErrNotFound{} return nil, nil, ErrNotFound{}
} }
return data, nil return &(space.createupdate), space.data, nil
} }
// Delete deletes all the metadata for a given GUN // Delete deletes all the metadata for a given GUN

View File

@ -22,11 +22,11 @@ func TestUpdateCurrent(t *testing.T) {
func TestGetCurrent(t *testing.T) { func TestGetCurrent(t *testing.T) {
s := NewMemStorage() s := NewMemStorage()
_, err := s.GetCurrent("gun", "role") _, _, err := s.GetCurrent("gun", "role")
assert.IsType(t, ErrNotFound{}, err, "Expected error to be ErrNotFound") assert.IsType(t, ErrNotFound{}, err, "Expected error to be ErrNotFound")
s.UpdateCurrent("gun", MetaUpdate{"role", 1, []byte("test")}) s.UpdateCurrent("gun", MetaUpdate{"role", 1, []byte("test")})
d, err := s.GetCurrent("gun", "role") _, d, err := s.GetCurrent("gun", "role")
assert.Nil(t, err, "Expected error to be nil") assert.Nil(t, err, "Expected error to be nil")
assert.Equal(t, []byte("test"), d, "Data was incorrect") assert.Equal(t, []byte("test"), d, "Data was incorrect")
} }
@ -97,7 +97,7 @@ func TestSetKeySameRoleGun(t *testing.T) {
func TestGetChecksumNotFound(t *testing.T) { func TestGetChecksumNotFound(t *testing.T) {
s := NewMemStorage() s := NewMemStorage()
_, err := s.GetChecksum("gun", "root", "12345") _, _, err := s.GetChecksum("gun", "root", "12345")
assert.Error(t, err) assert.Error(t, err)
assert.IsType(t, ErrNotFound{}, err) assert.IsType(t, ErrNotFound{}, err)
} }

View File

@ -1,6 +1,8 @@
package timestamp package timestamp
import ( import (
"time"
"github.com/docker/go/canonical/json" "github.com/docker/go/canonical/json"
"github.com/docker/notary/tuf/data" "github.com/docker/notary/tuf/data"
"github.com/docker/notary/tuf/signed" "github.com/docker/notary/tuf/signed"
@ -47,16 +49,18 @@ func GetOrCreateTimestampKey(gun string, store storage.MetaStore, crypto signed.
// GetOrCreateTimestamp returns the current timestamp for the gun. This may mean // GetOrCreateTimestamp returns the current timestamp for the gun. This may mean
// a new timestamp is generated either because none exists, or because the current // a new timestamp is generated either because none exists, or because the current
// one has expired. Once generated, the timestamp is saved in the store. // one has expired. Once generated, the timestamp is saved in the store.
func GetOrCreateTimestamp(gun string, store storage.MetaStore, cryptoService signed.CryptoService) ([]byte, error) { func GetOrCreateTimestamp(gun string, store storage.MetaStore, cryptoService signed.CryptoService) (
snapshot, err := snapshot.GetOrCreateSnapshot(gun, store, cryptoService) *time.Time, []byte, error) {
_, snapshot, err := snapshot.GetOrCreateSnapshot(gun, store, cryptoService)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
d, err := store.GetCurrent(gun, "timestamp") lastModified, d, err := store.GetCurrent(gun, data.CanonicalTimestampRole)
if err != nil { if err != nil {
if _, ok := err.(storage.ErrNotFound); !ok { if _, ok := err.(storage.ErrNotFound); !ok {
logrus.Error("error retrieving timestamp: ", err.Error()) logrus.Error("error retrieving timestamp: ", err.Error())
return nil, err return nil, nil, err
} }
logrus.Debug("No timestamp found, will proceed to create first timestamp") logrus.Debug("No timestamp found, will proceed to create first timestamp")
} }
@ -65,27 +69,28 @@ func GetOrCreateTimestamp(gun string, store storage.MetaStore, cryptoService sig
err := json.Unmarshal(d, ts) err := json.Unmarshal(d, ts)
if err != nil { if err != nil {
logrus.Error("Failed to unmarshal existing timestamp") logrus.Error("Failed to unmarshal existing timestamp")
return nil, err return nil, nil, err
} }
if !timestampExpired(ts) && !snapshotExpired(ts, snapshot) { if !timestampExpired(ts) && !snapshotExpired(ts, snapshot) {
return d, nil return lastModified, d, nil
} }
} }
sgnd, version, err := CreateTimestamp(gun, ts, snapshot, store, cryptoService) sgnd, version, err := CreateTimestamp(gun, ts, snapshot, store, cryptoService)
if err != nil { if err != nil {
logrus.Error("Failed to create a new timestamp") logrus.Error("Failed to create a new timestamp")
return nil, err return nil, nil, err
} }
out, err := json.Marshal(sgnd) out, err := json.Marshal(sgnd)
if err != nil { if err != nil {
logrus.Error("Failed to marshal new timestamp") logrus.Error("Failed to marshal new timestamp")
return nil, err return nil, nil, err
} }
err = store.UpdateCurrent(gun, storage.MetaUpdate{Role: "timestamp", Version: version, Data: out}) err = store.UpdateCurrent(gun, storage.MetaUpdate{Role: "timestamp", Version: version, Data: out})
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
return out, nil c := time.Now()
return &c, out, nil
} }
// timestampExpired compares the current time to the expiry time of the timestamp // timestampExpired compares the current time to the expiry time of the timestamp

View File

@ -64,7 +64,7 @@ func TestGetTimestamp(t *testing.T) {
_, err := GetOrCreateTimestampKey("gun", store, crypto, data.ED25519Key) _, err := GetOrCreateTimestampKey("gun", store, crypto, data.ED25519Key)
assert.Nil(t, err, "GetKey errored") assert.Nil(t, err, "GetKey errored")
_, err = GetOrCreateTimestamp("gun", store, crypto) _, _, err = GetOrCreateTimestamp("gun", store, crypto)
assert.Nil(t, err, "GetTimestamp errored") assert.Nil(t, err, "GetTimestamp errored")
} }
@ -85,7 +85,7 @@ func TestGetTimestampNewSnapshot(t *testing.T) {
_, err := GetOrCreateTimestampKey("gun", store, crypto, data.ED25519Key) _, err := GetOrCreateTimestampKey("gun", store, crypto, data.ED25519Key)
assert.Nil(t, err, "GetKey errored") assert.Nil(t, err, "GetKey errored")
ts1, err := GetOrCreateTimestamp("gun", store, crypto) c1, ts1, err := GetOrCreateTimestamp("gun", store, crypto)
assert.Nil(t, err, "GetTimestamp errored") assert.Nil(t, err, "GetTimestamp errored")
snapshot = &data.SignedSnapshot{ snapshot = &data.SignedSnapshot{
@ -98,8 +98,8 @@ func TestGetTimestampNewSnapshot(t *testing.T) {
store.UpdateCurrent("gun", storage.MetaUpdate{Role: "snapshot", Version: 1, Data: snapJSON}) store.UpdateCurrent("gun", storage.MetaUpdate{Role: "snapshot", Version: 1, Data: snapJSON})
ts2, err := GetOrCreateTimestamp("gun", store, crypto) c2, ts2, err := GetOrCreateTimestamp("gun", store, crypto)
assert.NoError(t, err, "GetTimestamp errored") assert.NoError(t, err, "GetTimestamp errored")
assert.NotEqual(t, ts1, ts2, "Timestamp was not regenerated when snapshot changed") assert.NotEqual(t, ts1, ts2, "Timestamp was not regenerated when snapshot changed")
assert.True(t, c1.Before(*c2), "Timestamp modification time incorrect")
} }

View File

@ -1,7 +1,9 @@
package utils package utils
import ( import (
"fmt"
"net/http" "net/http"
"time"
"github.com/Sirupsen/logrus" "github.com/Sirupsen/logrus"
ctxu "github.com/docker/distribution/context" ctxu "github.com/docker/distribution/context"
@ -94,3 +96,102 @@ func buildAccessRecords(repo string, actions ...string) []auth.Access {
} }
return requiredAccess return requiredAccess
} }
// CacheControlConfig is an interface for something that knows how to set cache
// control headers
type CacheControlConfig interface {
// SetHeaders will actually set the cache control headers on a Headers object
SetHeaders(headers http.Header)
}
// NewCacheControlConfig returns CacheControlConfig interface for either setting
// cache control or disabling cache control entirely
func NewCacheControlConfig(maxAgeInSeconds int, mustRevalidate bool) CacheControlConfig {
if maxAgeInSeconds > 0 {
return PublicCacheControl{MustReValidate: mustRevalidate, MaxAgeInSeconds: maxAgeInSeconds}
}
return NoCacheControl{}
}
// PublicCacheControl is a set of options that we will set to enable cache control
type PublicCacheControl struct {
MustReValidate bool
MaxAgeInSeconds int
}
// SetHeaders sets the public headers with an optional must-revalidate header
func (p PublicCacheControl) SetHeaders(headers http.Header) {
cacheControlValue := fmt.Sprintf("public, max-age=%v, s-maxage=%v",
p.MaxAgeInSeconds, p.MaxAgeInSeconds)
if p.MustReValidate {
cacheControlValue = fmt.Sprintf("%s, must-revalidate", cacheControlValue)
}
headers.Set("Cache-Control", cacheControlValue)
// delete the Pragma directive, because the only valid value in HTTP is
// "no-cache"
headers.Del("Pragma")
if headers.Get("Last-Modified") == "" {
SetLastModifiedHeader(headers, time.Time{})
}
}
// NoCacheControl is an object which represents a directive to cache nothing
type NoCacheControl struct{}
// SetHeaders sets the public headers cache-control headers and pragma to no-cache
func (n NoCacheControl) SetHeaders(headers http.Header) {
headers.Set("Cache-Control", "max-age=0, no-cache, no-store")
headers.Set("Pragma", "no-cache")
}
// cacheControlResponseWriter wraps an existing response writer, and if Write is
// called, will try to set the cache control headers if it can
type cacheControlResponseWriter struct {
http.ResponseWriter
config CacheControlConfig
statusCode int
}
// WriteHeader stores the header before writing it, so we can tell if it's been set
// to a non-200 status code
func (c *cacheControlResponseWriter) WriteHeader(statusCode int) {
c.statusCode = statusCode
c.ResponseWriter.WriteHeader(statusCode)
}
// Write will set the cache headers if they haven't already been set and if the status
// code has either not been set or set to 200
func (c *cacheControlResponseWriter) Write(data []byte) (int, error) {
if c.statusCode == http.StatusOK || c.statusCode == 0 {
headers := c.ResponseWriter.Header()
if headers.Get("Cache-Control") == "" {
c.config.SetHeaders(headers)
}
}
return c.ResponseWriter.Write(data)
}
type cacheControlHandler struct {
http.Handler
config CacheControlConfig
}
func (c cacheControlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.Handler.ServeHTTP(&cacheControlResponseWriter{ResponseWriter: w, config: c.config}, r)
}
// WrapWithCacheHandler wraps another handler in one that can add cache control headers
// given a 200 response
func WrapWithCacheHandler(ccc CacheControlConfig, handler http.Handler) http.Handler {
if ccc != nil {
return cacheControlHandler{Handler: handler, config: ccc}
}
return handler
}
// SetLastModifiedHeader takes a time and uses it to set the LastModified header using
// the right date format
func SetLastModifiedHeader(headers http.Header, lmt time.Time) {
headers.Set("Last-Modified", lmt.Format(time.RFC1123))
}

View File

@ -1,14 +1,18 @@
package utils package utils
import ( import (
"bytes"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"strings" "strings"
"testing" "testing"
"time"
"github.com/docker/distribution/registry/api/errcode" "github.com/docker/distribution/registry/api/errcode"
"github.com/docker/notary/tuf/signed" "github.com/docker/notary/tuf/signed"
"github.com/stretchr/testify/assert"
"golang.org/x/net/context" "golang.org/x/net/context"
) )
@ -39,22 +43,6 @@ func TestRootHandlerFactory(t *testing.T) {
} }
} }
//func TestRootHandlerUnauthorized(t *testing.T) {
// hand := RootHandlerFactory(nil, context.Background(), &signed.Ed25519{})
// handler := hand(MockContextHandler)
//
// ts := httptest.NewServer(handler)
// defer ts.Close()
//
// res, err := http.Get(ts.URL)
// if err != nil {
// t.Fatal(err)
// }
// if res.StatusCode != http.StatusUnauthorized {
// t.Fatalf("Expected 401, received %d", res.StatusCode)
// }
//}
func TestRootHandlerError(t *testing.T) { func TestRootHandlerError(t *testing.T) {
hand := RootHandlerFactory(nil, context.Background(), &signed.Ed25519{}) hand := RootHandlerFactory(nil, context.Background(), &signed.Ed25519{})
handler := hand(MockBetterErrorHandler) handler := hand(MockBetterErrorHandler)
@ -75,3 +63,192 @@ func TestRootHandlerError(t *testing.T) {
t.Fatalf("Error Body Incorrect: `%s`", content) t.Fatalf("Error Body Incorrect: `%s`", content)
} }
} }
// If no CacheControlConfig is passed, wrapping the handler just returns the handler
func TestWrapWithCacheHeaderNilCacheControlConfig(t *testing.T) {
mux := http.NewServeMux()
wrapped := WrapWithCacheHandler(nil, mux)
assert.Equal(t, mux, wrapped)
}
// If the wrapped handler returns a non-200, no matter which CacheControlConfig is
// used, the Cache-Control header not set.
func TestWrapWithCacheHeaderNon200Response(t *testing.T) {
mux := http.NewServeMux()
configs := []CacheControlConfig{NewCacheControlConfig(10, true), NewCacheControlConfig(0, true)}
for _, conf := range configs {
req := &http.Request{URL: &url.URL{Path: "/"}, Body: ioutil.NopCloser(bytes.NewBuffer(nil))}
wrapped := WrapWithCacheHandler(conf, mux)
assert.NotEqual(t, mux, wrapped)
rw := httptest.NewRecorder()
wrapped.ServeHTTP(rw, req)
assert.Equal(t, "", rw.HeaderMap.Get("Cache-Control"))
assert.Equal(t, "", rw.HeaderMap.Get("Last-Modified"))
assert.Equal(t, "", rw.HeaderMap.Get("Pragma"))
}
}
// If the wrapped handler writes no cache headers whatsoever, and a PublicCacheControl
// is used, the Cache-Control header is set with the given maxAge and re-validate value.
// The Last-Modified header is also set to the beginning of (computer) time. If a
// Pragma header is written is deleted
func TestWrapWithCacheHeaderPublicCacheControlNoCacheHeaders(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hello!"))
})
mux.HandleFunc("/a", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Pragma", "no-cache")
w.Write([]byte("hello!"))
})
for _, path := range []string{"/", "/a"} {
req := &http.Request{URL: &url.URL{Path: path}, Body: ioutil.NopCloser(bytes.NewBuffer(nil))}
// must-revalidate is set if revalidate is set to true, and not if revalidate is set to false
for _, revalidate := range []bool{true, false} {
wrapped := WrapWithCacheHandler(NewCacheControlConfig(10, revalidate), mux)
assert.NotEqual(t, mux, wrapped)
rw := httptest.NewRecorder()
wrapped.ServeHTTP(rw, req)
cacheControl := "public, max-age=10, s-maxage=10"
if revalidate {
cacheControl = cacheControl + ", must-revalidate"
}
assert.Equal(t, cacheControl, rw.HeaderMap.Get("Cache-Control"))
lastModified, err := time.Parse(time.RFC1123, rw.HeaderMap.Get("Last-Modified"))
assert.NoError(t, err)
assert.True(t, lastModified.Equal(time.Time{}))
assert.Equal(t, "", rw.HeaderMap.Get("Pragma"))
}
}
}
// If the wrapped handler writes a last modified header, and a PublicCacheControl
// is used, the Cache-Control header is set with the given maxAge and re-validate value.
// The Last-Modified header is not replaced. The Pragma header is deleted though.
func TestWrapWithCacheHeaderPublicCacheControlLastModifiedHeader(t *testing.T) {
now := time.Now()
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
SetLastModifiedHeader(w.Header(), now)
w.Header().Set("Pragma", "no-cache")
w.Write([]byte("hello!"))
})
req := &http.Request{URL: &url.URL{Path: "/"}, Body: ioutil.NopCloser(bytes.NewBuffer(nil))}
wrapped := WrapWithCacheHandler(NewCacheControlConfig(10, true), mux)
assert.NotEqual(t, mux, wrapped)
rw := httptest.NewRecorder()
wrapped.ServeHTTP(rw, req)
assert.Equal(t, "public, max-age=10, s-maxage=10, must-revalidate", rw.HeaderMap.Get("Cache-Control"))
lastModified, err := time.Parse(time.RFC1123, rw.HeaderMap.Get("Last-Modified"))
assert.NoError(t, err)
// RFC1123 does not include nanoseconds
nowToNearestSecond := now.Add(time.Duration(-1 * now.Nanosecond()))
assert.True(t, lastModified.Equal(nowToNearestSecond))
assert.Equal(t, "", rw.HeaderMap.Get("Pragma"))
}
// If the wrapped handler writes a Cache-Control header, even if the last modified
// header is not written, then the Cache-Control header is not written, nor is a
// Last-Modified header written. The Pragma header is not deleted.
func TestWrapWithCacheHeaderPublicCacheControlCacheControlHeader(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Cache-Control", "some invalid cache control value")
w.Header().Set("Pragma", "invalid value")
w.Write([]byte("hello!"))
})
req := &http.Request{URL: &url.URL{Path: "/"}, Body: ioutil.NopCloser(bytes.NewBuffer(nil))}
wrapped := WrapWithCacheHandler(NewCacheControlConfig(10, true), mux)
assert.NotEqual(t, mux, wrapped)
rw := httptest.NewRecorder()
wrapped.ServeHTTP(rw, req)
assert.Equal(t, "some invalid cache control value", rw.HeaderMap.Get("Cache-Control"))
assert.Equal(t, "", rw.HeaderMap.Get("Last-Modified"))
assert.Equal(t, "invalid value", rw.HeaderMap.Get("Pragma"))
}
// If the wrapped handler writes no cache headers whatsoever, and NoCacheControl
// is used, the Cache-Control and Pragma headers are set with no-cache.
func TestWrapWithCacheHeaderNoCacheControlNoCacheHeaders(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Pragma", "invalid value")
w.Write([]byte("hello!"))
})
req := &http.Request{URL: &url.URL{Path: "/"}, Body: ioutil.NopCloser(bytes.NewBuffer(nil))}
wrapped := WrapWithCacheHandler(NewCacheControlConfig(0, false), mux)
assert.NotEqual(t, mux, wrapped)
rw := httptest.NewRecorder()
wrapped.ServeHTTP(rw, req)
assert.Equal(t, "max-age=0, no-cache, no-store", rw.HeaderMap.Get("Cache-Control"))
assert.Equal(t, "", rw.HeaderMap.Get("Last-Modified"))
assert.Equal(t, "no-cache", rw.HeaderMap.Get("Pragma"))
}
// If the wrapped handler writes a last modified header, and NoCacheControl
// is used, the Cache-Control and Pragma headers are set with no-cache without
// messing with the Last-Modified header.
func TestWrapWithCacheHeaderNoCacheControlLastModifiedHeader(t *testing.T) {
now := time.Now()
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
SetLastModifiedHeader(w.Header(), now)
w.Write([]byte("hello!"))
})
req := &http.Request{URL: &url.URL{Path: "/"}, Body: ioutil.NopCloser(bytes.NewBuffer(nil))}
wrapped := WrapWithCacheHandler(NewCacheControlConfig(0, true), mux)
assert.NotEqual(t, mux, wrapped)
rw := httptest.NewRecorder()
wrapped.ServeHTTP(rw, req)
assert.Equal(t, "max-age=0, no-cache, no-store", rw.HeaderMap.Get("Cache-Control"))
assert.Equal(t, "no-cache", rw.HeaderMap.Get("Pragma"))
lastModified, err := time.Parse(time.RFC1123, rw.HeaderMap.Get("Last-Modified"))
assert.NoError(t, err)
// RFC1123 does not include nanoseconds
nowToNearestSecond := now.Add(time.Duration(-1 * now.Nanosecond()))
assert.True(t, lastModified.Equal(nowToNearestSecond))
}
// If the wrapped handler writes a Cache-Control header, even if the last modified
// header is not written, then the Cache-Control header is not written, nor is a
// Pragma added. The Last-Modified header is untouched.
func TestWrapWithCacheHeaderNoCacheControlCacheControlHeader(t *testing.T) {
now := time.Now()
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Cache-Control", "some invalid cache control value")
SetLastModifiedHeader(w.Header(), now)
w.Write([]byte("hello!"))
})
req := &http.Request{URL: &url.URL{Path: "/"}, Body: ioutil.NopCloser(bytes.NewBuffer(nil))}
wrapped := WrapWithCacheHandler(NewCacheControlConfig(0, true), mux)
assert.NotEqual(t, mux, wrapped)
rw := httptest.NewRecorder()
wrapped.ServeHTTP(rw, req)
assert.Equal(t, "some invalid cache control value", rw.HeaderMap.Get("Cache-Control"))
assert.Equal(t, "", rw.HeaderMap.Get("Pragma"))
lastModified, err := time.Parse(time.RFC1123, rw.HeaderMap.Get("Last-Modified"))
assert.NoError(t, err)
// RFC1123 does not include nanoseconds
nowToNearestSecond := now.Add(time.Duration(-1 * now.Nanosecond()))
assert.True(t, lastModified.Equal(nowToNearestSecond))
}