Merge pull request #20253 from aaronlehmann/smarter-tls-fallback

Smarter push/pull fallback from TLS to plaintext
This commit is contained in:
Tibor Vass 2016-02-19 13:16:53 -05:00
commit aa1fdf42da
16 changed files with 177 additions and 54 deletions

View File

@ -31,6 +31,10 @@ type fallbackError struct {
// supports the v2 protocol. This is used to limit fallbacks to the v1 // supports the v2 protocol. This is used to limit fallbacks to the v1
// protocol. // protocol.
confirmedV2 bool confirmedV2 bool
// transportOK is set to true if we managed to speak HTTP with the
// registry. This confirms that we're using appropriate TLS settings
// (or lack of TLS).
transportOK bool
} }
// Error renders the FallbackError as a string. // Error renders the FallbackError as a string.

View File

@ -109,12 +109,25 @@ func Pull(ctx context.Context, ref reference.Named, imagePullConfig *ImagePullCo
// confirm that it was talking to a v2 registry. This will // confirm that it was talking to a v2 registry. This will
// prevent fallback to the v1 protocol. // prevent fallback to the v1 protocol.
confirmedV2 bool confirmedV2 bool
// confirmedTLSRegistries is a map indicating which registries
// are known to be using TLS. There should never be a plaintext
// retry for any of these.
confirmedTLSRegistries = make(map[string]struct{})
) )
for _, endpoint := range endpoints { for _, endpoint := range endpoints {
if confirmedV2 && endpoint.Version == registry.APIVersion1 { if confirmedV2 && endpoint.Version == registry.APIVersion1 {
logrus.Debugf("Skipping v1 endpoint %s because v2 registry was detected", endpoint.URL) logrus.Debugf("Skipping v1 endpoint %s because v2 registry was detected", endpoint.URL)
continue continue
} }
if endpoint.URL.Scheme != "https" {
if _, confirmedTLS := confirmedTLSRegistries[endpoint.URL.Host]; confirmedTLS {
logrus.Debugf("Skipping non-TLS endpoint %s for host/port that appears to use TLS", endpoint.URL)
continue
}
}
logrus.Debugf("Trying to pull %s from %s %s", repoInfo.Name(), endpoint.URL, endpoint.Version) logrus.Debugf("Trying to pull %s from %s %s", repoInfo.Name(), endpoint.URL, endpoint.Version)
puller, err := newPuller(endpoint, repoInfo, imagePullConfig) puller, err := newPuller(endpoint, repoInfo, imagePullConfig)
@ -132,6 +145,9 @@ func Pull(ctx context.Context, ref reference.Named, imagePullConfig *ImagePullCo
if fallbackErr, ok := err.(fallbackError); ok { if fallbackErr, ok := err.(fallbackError); ok {
fallback = true fallback = true
confirmedV2 = confirmedV2 || fallbackErr.confirmedV2 confirmedV2 = confirmedV2 || fallbackErr.confirmedV2
if fallbackErr.transportOK && endpoint.URL.Scheme == "https" {
confirmedTLSRegistries[endpoint.URL.Host] = struct{}{}
}
err = fallbackErr.err err = fallbackErr.err
} }
} }

View File

@ -62,7 +62,7 @@ func (p *v2Puller) Pull(ctx context.Context, ref reference.Named) (err error) {
p.repo, p.confirmedV2, err = NewV2Repository(ctx, p.repoInfo, p.endpoint, p.config.MetaHeaders, p.config.AuthConfig, "pull") p.repo, p.confirmedV2, err = NewV2Repository(ctx, p.repoInfo, p.endpoint, p.config.MetaHeaders, p.config.AuthConfig, "pull")
if err != nil { if err != nil {
logrus.Warnf("Error getting v2 registry: %v", err) logrus.Warnf("Error getting v2 registry: %v", err)
return fallbackError{err: err, confirmedV2: p.confirmedV2} return err
} }
if err = p.pullV2Repository(ctx, ref); err != nil { if err = p.pullV2Repository(ctx, ref); err != nil {
@ -71,7 +71,11 @@ func (p *v2Puller) Pull(ctx context.Context, ref reference.Named) (err error) {
} }
if continueOnError(err) { if continueOnError(err) {
logrus.Errorf("Error trying v2 registry: %v", err) logrus.Errorf("Error trying v2 registry: %v", err)
return fallbackError{err: err, confirmedV2: p.confirmedV2} return fallbackError{
err: err,
confirmedV2: p.confirmedV2,
transportOK: true,
}
} }
} }
return err return err
@ -716,12 +720,20 @@ func allowV1Fallback(err error) error {
case errcode.Errors: case errcode.Errors:
if len(v) != 0 { if len(v) != 0 {
if v0, ok := v[0].(errcode.Error); ok && shouldV2Fallback(v0) { if v0, ok := v[0].(errcode.Error); ok && shouldV2Fallback(v0) {
return fallbackError{err: err, confirmedV2: false} return fallbackError{
err: err,
confirmedV2: false,
transportOK: true,
}
} }
} }
case errcode.Error: case errcode.Error:
if shouldV2Fallback(v) { if shouldV2Fallback(v) {
return fallbackError{err: err, confirmedV2: false} return fallbackError{
err: err,
confirmedV2: false,
transportOK: true,
}
} }
case *url.Error: case *url.Error:
if v.Err == auth.ErrNoBasicAuthCredentials { if v.Err == auth.ErrNoBasicAuthCredentials {

View File

@ -119,6 +119,11 @@ func Push(ctx context.Context, ref reference.Named, imagePushConfig *ImagePushCo
// confirm that it was talking to a v2 registry. This will // confirm that it was talking to a v2 registry. This will
// prevent fallback to the v1 protocol. // prevent fallback to the v1 protocol.
confirmedV2 bool confirmedV2 bool
// confirmedTLSRegistries is a map indicating which registries
// are known to be using TLS. There should never be a plaintext
// retry for any of these.
confirmedTLSRegistries = make(map[string]struct{})
) )
for _, endpoint := range endpoints { for _, endpoint := range endpoints {
@ -127,6 +132,13 @@ func Push(ctx context.Context, ref reference.Named, imagePushConfig *ImagePushCo
continue continue
} }
if endpoint.URL.Scheme != "https" {
if _, confirmedTLS := confirmedTLSRegistries[endpoint.URL.Host]; confirmedTLS {
logrus.Debugf("Skipping non-TLS endpoint %s for host/port that appears to use TLS", endpoint.URL)
continue
}
}
logrus.Debugf("Trying to push %s to %s %s", repoInfo.FullName(), endpoint.URL, endpoint.Version) logrus.Debugf("Trying to push %s to %s %s", repoInfo.FullName(), endpoint.URL, endpoint.Version)
pusher, err := NewPusher(ref, endpoint, repoInfo, imagePushConfig) pusher, err := NewPusher(ref, endpoint, repoInfo, imagePushConfig)
@ -142,6 +154,9 @@ func Push(ctx context.Context, ref reference.Named, imagePushConfig *ImagePushCo
default: default:
if fallbackErr, ok := err.(fallbackError); ok { if fallbackErr, ok := err.(fallbackError); ok {
confirmedV2 = confirmedV2 || fallbackErr.confirmedV2 confirmedV2 = confirmedV2 || fallbackErr.confirmedV2
if fallbackErr.transportOK && endpoint.URL.Scheme == "https" {
confirmedTLSRegistries[endpoint.URL.Host] = struct{}{}
}
err = fallbackErr.err err = fallbackErr.err
lastErr = err lastErr = err
logrus.Errorf("Attempting next endpoint for push after error: %v", err) logrus.Errorf("Attempting next endpoint for push after error: %v", err)

View File

@ -64,12 +64,16 @@ func (p *v2Pusher) Push(ctx context.Context) (err error) {
p.repo, p.pushState.confirmedV2, err = NewV2Repository(ctx, p.repoInfo, p.endpoint, p.config.MetaHeaders, p.config.AuthConfig, "push", "pull") p.repo, p.pushState.confirmedV2, err = NewV2Repository(ctx, p.repoInfo, p.endpoint, p.config.MetaHeaders, p.config.AuthConfig, "push", "pull")
if err != nil { if err != nil {
logrus.Debugf("Error getting v2 registry: %v", err) logrus.Debugf("Error getting v2 registry: %v", err)
return fallbackError{err: err, confirmedV2: p.pushState.confirmedV2} return err
} }
if err = p.pushV2Repository(ctx); err != nil { if err = p.pushV2Repository(ctx); err != nil {
if continueOnError(err) { if continueOnError(err) {
return fallbackError{err: err, confirmedV2: p.pushState.confirmedV2} return fallbackError{
err: err,
confirmedV2: p.pushState.confirmedV2,
transportOK: true,
}
} }
} }
return err return err

View File

@ -57,17 +57,21 @@ func NewV2Repository(ctx context.Context, repoInfo *registry.RepositoryInfo, end
Transport: authTransport, Transport: authTransport,
Timeout: 15 * time.Second, Timeout: 15 * time.Second,
} }
endpointStr := strings.TrimRight(endpoint.URL, "/") + "/v2/" endpointStr := strings.TrimRight(endpoint.URL.String(), "/") + "/v2/"
req, err := http.NewRequest("GET", endpointStr, nil) req, err := http.NewRequest("GET", endpointStr, nil)
if err != nil { if err != nil {
return nil, false, err return nil, false, fallbackError{err: err}
} }
resp, err := pingClient.Do(req) resp, err := pingClient.Do(req)
if err != nil { if err != nil {
return nil, false, err return nil, false, fallbackError{err: err}
} }
defer resp.Body.Close() defer resp.Body.Close()
// We got a HTTP request through, so we're using the right TLS settings.
// From this point forward, set transportOK to true in any fallbackError
// we return.
v2Version := auth.APIVersion{ v2Version := auth.APIVersion{
Type: "registry", Type: "registry",
Version: "2.0", Version: "2.0",
@ -87,7 +91,11 @@ func NewV2Repository(ctx context.Context, repoInfo *registry.RepositoryInfo, end
challengeManager := auth.NewSimpleChallengeManager() challengeManager := auth.NewSimpleChallengeManager()
if err := challengeManager.AddResponse(resp); err != nil { if err := challengeManager.AddResponse(resp); err != nil {
return nil, foundVersion, err return nil, foundVersion, fallbackError{
err: err,
confirmedV2: foundVersion,
transportOK: true,
}
} }
if authConfig.RegistryToken != "" { if authConfig.RegistryToken != "" {
@ -103,11 +111,22 @@ func NewV2Repository(ctx context.Context, repoInfo *registry.RepositoryInfo, end
repoNameRef, err := distreference.ParseNamed(repoName) repoNameRef, err := distreference.ParseNamed(repoName)
if err != nil { if err != nil {
return nil, foundVersion, err return nil, foundVersion, fallbackError{
err: err,
confirmedV2: foundVersion,
transportOK: true,
}
} }
repo, err = client.NewRepository(ctx, repoNameRef, endpoint.URL, tr) repo, err = client.NewRepository(ctx, repoNameRef, endpoint.URL.String(), tr)
return repo, foundVersion, err if err != nil {
err = fallbackError{
err: err,
confirmedV2: foundVersion,
transportOK: true,
}
}
return
} }
type existingTokenHandler struct { type existingTokenHandler struct {

View File

@ -3,6 +3,7 @@ package distribution
import ( import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"os" "os"
"strings" "strings"
"testing" "testing"
@ -43,9 +44,14 @@ func testTokenPassThru(t *testing.T, ts *httptest.Server) {
} }
defer os.RemoveAll(tmp) defer os.RemoveAll(tmp)
uri, err := url.Parse(ts.URL)
if err != nil {
t.Fatalf("could not parse url from test server: %v", err)
}
endpoint := registry.APIEndpoint{ endpoint := registry.APIEndpoint{
Mirror: false, Mirror: false,
URL: ts.URL, URL: uri,
Version: 2, Version: 2,
Official: false, Official: false,
TrimHostname: false, TrimHostname: false,

View File

@ -19,7 +19,7 @@ type Options struct {
InsecureRegistries opts.ListOpts InsecureRegistries opts.ListOpts
} }
const ( var (
// DefaultNamespace is the default namespace // DefaultNamespace is the default namespace
DefaultNamespace = "docker.io" DefaultNamespace = "docker.io"
// DefaultRegistryVersionHeader is the name of the default HTTP header // DefaultRegistryVersionHeader is the name of the default HTTP header
@ -27,7 +27,7 @@ const (
DefaultRegistryVersionHeader = "Docker-Distribution-Api-Version" DefaultRegistryVersionHeader = "Docker-Distribution-Api-Version"
// IndexServer is the v1 registry server used for user auth + account creation // IndexServer is the v1 registry server used for user auth + account creation
IndexServer = DefaultV1Registry + "/v1/" IndexServer = DefaultV1Registry.String() + "/v1/"
// IndexName is the name of the index // IndexName is the name of the index
IndexName = "docker.io" IndexName = "docker.io"

View File

@ -2,12 +2,22 @@
package registry package registry
const ( import (
"net/url"
)
var (
// DefaultV1Registry is the URI of the default v1 registry // DefaultV1Registry is the URI of the default v1 registry
DefaultV1Registry = "https://index.docker.io" DefaultV1Registry = &url.URL{
Scheme: "https",
Host: "index.docker.io",
}
// DefaultV2Registry is the URI of the default v2 registry // DefaultV2Registry is the URI of the default v2 registry
DefaultV2Registry = "https://registry-1.docker.io" DefaultV2Registry = &url.URL{
Scheme: "https",
Host: "registry-1.docker.io",
}
) )
var ( var (

View File

@ -1,21 +1,28 @@
package registry package registry
import ( import (
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
) )
const ( var (
// DefaultV1Registry is the URI of the default v1 registry // DefaultV1Registry is the URI of the default v1 registry
DefaultV1Registry = "https://registry-win-tp3.docker.io" DefaultV1Registry = &url.URL{
Scheme: "https",
Host: "registry-win-tp3.docker.io",
}
// DefaultV2Registry is the URI of the default (official) v2 registry. // DefaultV2Registry is the URI of the default (official) v2 registry.
// This is the windows-specific endpoint. // This is the windows-specific endpoint.
// //
// Currently it is a TEMPORARY link that allows Microsoft to continue // Currently it is a TEMPORARY link that allows Microsoft to continue
// development of Docker Engine for Windows. // development of Docker Engine for Windows.
DefaultV2Registry = "https://registry-win-tp3.docker.io" DefaultV2Registry = &url.URL{
Scheme: "https",
Host: "registry-win-tp3.docker.io",
}
) )
// CertsDir is the directory where certificates are stored // CertsDir is the directory where certificates are stored

View File

@ -50,10 +50,12 @@ func NewEndpoint(index *registrytypes.IndexInfo, userAgent string, metaHeaders h
if err != nil { if err != nil {
return nil, err return nil, err
} }
endpoint, err := newEndpoint(GetAuthConfigKey(index), tlsConfig, userAgent, metaHeaders)
endpoint, err := newEndpointFromStr(GetAuthConfigKey(index), tlsConfig, userAgent, metaHeaders)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if v != APIVersionUnknown { if v != APIVersionUnknown {
endpoint.Version = v endpoint.Version = v
} }
@ -91,24 +93,14 @@ func validateEndpoint(endpoint *Endpoint) error {
return nil return nil
} }
func newEndpoint(address string, tlsConfig *tls.Config, userAgent string, metaHeaders http.Header) (*Endpoint, error) { func newEndpoint(address url.URL, tlsConfig *tls.Config, userAgent string, metaHeaders http.Header) (*Endpoint, error) {
var ( endpoint := &Endpoint{
endpoint = new(Endpoint) IsSecure: (tlsConfig == nil || !tlsConfig.InsecureSkipVerify),
trimmedAddress string URL: new(url.URL),
err error Version: APIVersionUnknown,
)
if !strings.HasPrefix(address, "http") {
address = "https://" + address
} }
endpoint.IsSecure = (tlsConfig == nil || !tlsConfig.InsecureSkipVerify) *endpoint.URL = address
trimmedAddress, endpoint.Version = scanForAPIVersion(address)
if endpoint.URL, err = url.Parse(trimmedAddress); err != nil {
return nil, err
}
// TODO(tiborvass): make sure a ConnectTimeout transport is used // TODO(tiborvass): make sure a ConnectTimeout transport is used
tr := NewTransport(tlsConfig) tr := NewTransport(tlsConfig)
@ -116,6 +108,27 @@ func newEndpoint(address string, tlsConfig *tls.Config, userAgent string, metaHe
return endpoint, nil return endpoint, nil
} }
func newEndpointFromStr(address string, tlsConfig *tls.Config, userAgent string, metaHeaders http.Header) (*Endpoint, error) {
if !strings.HasPrefix(address, "http://") && !strings.HasPrefix(address, "https://") {
address = "https://" + address
}
trimmedAddress, detectedVersion := scanForAPIVersion(address)
uri, err := url.Parse(trimmedAddress)
if err != nil {
return nil, err
}
endpoint, err := newEndpoint(*uri, tlsConfig, userAgent, metaHeaders)
if err != nil {
return nil, err
}
endpoint.Version = detectedVersion
return endpoint, nil
}
// Endpoint stores basic information about a registry endpoint. // Endpoint stores basic information about a registry endpoint.
type Endpoint struct { type Endpoint struct {
client *http.Client client *http.Client

View File

@ -19,7 +19,7 @@ func TestEndpointParse(t *testing.T) {
{"0.0.0.0:5000", "https://0.0.0.0:5000/v0/"}, {"0.0.0.0:5000", "https://0.0.0.0:5000/v0/"},
} }
for _, td := range testData { for _, td := range testData {
e, err := newEndpoint(td.str, nil, "", nil) e, err := newEndpointFromStr(td.str, nil, "", nil)
if err != nil { if err != nil {
t.Errorf("%q: %s", td.str, err) t.Errorf("%q: %s", td.str, err)
} }

View File

@ -673,7 +673,7 @@ func TestNewIndexInfo(t *testing.T) {
func TestMirrorEndpointLookup(t *testing.T) { func TestMirrorEndpointLookup(t *testing.T) {
containsMirror := func(endpoints []APIEndpoint) bool { containsMirror := func(endpoints []APIEndpoint) bool {
for _, pe := range endpoints { for _, pe := range endpoints {
if pe.URL == "my.mirror" { if pe.URL.Host == "my.mirror" {
return true return true
} }
} }

View File

@ -121,7 +121,7 @@ func (s *Service) ResolveIndex(name string) (*registrytypes.IndexInfo, error) {
// APIEndpoint represents a remote API endpoint // APIEndpoint represents a remote API endpoint
type APIEndpoint struct { type APIEndpoint struct {
Mirror bool Mirror bool
URL string URL *url.URL
Version APIVersion Version APIVersion
Official bool Official bool
TrimHostname bool TrimHostname bool
@ -130,7 +130,7 @@ type APIEndpoint struct {
// ToV1Endpoint returns a V1 API endpoint based on the APIEndpoint // ToV1Endpoint returns a V1 API endpoint based on the APIEndpoint
func (e APIEndpoint) ToV1Endpoint(userAgent string, metaHeaders http.Header) (*Endpoint, error) { func (e APIEndpoint) ToV1Endpoint(userAgent string, metaHeaders http.Header) (*Endpoint, error) {
return newEndpoint(e.URL, e.TLSConfig, userAgent, metaHeaders) return newEndpoint(*e.URL, e.TLSConfig, userAgent, metaHeaders)
} }
// TLSConfig constructs a client TLS configuration based on server defaults // TLSConfig constructs a client TLS configuration based on server defaults
@ -138,11 +138,7 @@ func (s *Service) TLSConfig(hostname string) (*tls.Config, error) {
return newTLSConfig(hostname, isSecureIndex(s.Config, hostname)) return newTLSConfig(hostname, isSecureIndex(s.Config, hostname))
} }
func (s *Service) tlsConfigForMirror(mirror string) (*tls.Config, error) { func (s *Service) tlsConfigForMirror(mirrorURL *url.URL) (*tls.Config, error) {
mirrorURL, err := url.Parse(mirror)
if err != nil {
return nil, err
}
return s.TLSConfig(mirrorURL.Host) return s.TLSConfig(mirrorURL.Host)
} }

View File

@ -2,6 +2,7 @@ package registry
import ( import (
"fmt" "fmt"
"net/url"
"strings" "strings"
"github.com/docker/docker/reference" "github.com/docker/docker/reference"
@ -36,7 +37,10 @@ func (s *Service) lookupV1Endpoints(repoName reference.Named) (endpoints []APIEn
endpoints = []APIEndpoint{ endpoints = []APIEndpoint{
{ {
URL: "https://" + hostname, URL: &url.URL{
Scheme: "https",
Host: hostname,
},
Version: APIVersion1, Version: APIVersion1,
TrimHostname: true, TrimHostname: true,
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
@ -45,7 +49,10 @@ func (s *Service) lookupV1Endpoints(repoName reference.Named) (endpoints []APIEn
if tlsConfig.InsecureSkipVerify { if tlsConfig.InsecureSkipVerify {
endpoints = append(endpoints, APIEndpoint{ // or this endpoints = append(endpoints, APIEndpoint{ // or this
URL: "http://" + hostname, URL: &url.URL{
Scheme: "http",
Host: hostname,
},
Version: APIVersion1, Version: APIVersion1,
TrimHostname: true, TrimHostname: true,
// used to check if supposed to be secure via InsecureSkipVerify // used to check if supposed to be secure via InsecureSkipVerify

View File

@ -2,6 +2,7 @@ package registry
import ( import (
"fmt" "fmt"
"net/url"
"strings" "strings"
"github.com/docker/docker/reference" "github.com/docker/docker/reference"
@ -15,12 +16,19 @@ func (s *Service) lookupV2Endpoints(repoName reference.Named) (endpoints []APIEn
if strings.HasPrefix(nameString, DefaultNamespace+"/") { if strings.HasPrefix(nameString, DefaultNamespace+"/") {
// v2 mirrors // v2 mirrors
for _, mirror := range s.Config.Mirrors { for _, mirror := range s.Config.Mirrors {
mirrorTLSConfig, err := s.tlsConfigForMirror(mirror) if !strings.HasPrefix(mirror, "http://") && !strings.HasPrefix(mirror, "https://") {
mirror = "https://" + mirror
}
mirrorURL, err := url.Parse(mirror)
if err != nil {
return nil, err
}
mirrorTLSConfig, err := s.tlsConfigForMirror(mirrorURL)
if err != nil { if err != nil {
return nil, err return nil, err
} }
endpoints = append(endpoints, APIEndpoint{ endpoints = append(endpoints, APIEndpoint{
URL: mirror, URL: mirrorURL,
// guess mirrors are v2 // guess mirrors are v2
Version: APIVersion2, Version: APIVersion2,
Mirror: true, Mirror: true,
@ -53,7 +61,10 @@ func (s *Service) lookupV2Endpoints(repoName reference.Named) (endpoints []APIEn
endpoints = []APIEndpoint{ endpoints = []APIEndpoint{
{ {
URL: "https://" + hostname, URL: &url.URL{
Scheme: "https",
Host: hostname,
},
Version: APIVersion2, Version: APIVersion2,
TrimHostname: true, TrimHostname: true,
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
@ -62,7 +73,10 @@ func (s *Service) lookupV2Endpoints(repoName reference.Named) (endpoints []APIEn
if tlsConfig.InsecureSkipVerify { if tlsConfig.InsecureSkipVerify {
endpoints = append(endpoints, APIEndpoint{ endpoints = append(endpoints, APIEndpoint{
URL: "http://" + hostname, URL: &url.URL{
Scheme: "http",
Host: hostname,
},
Version: APIVersion2, Version: APIVersion2,
TrimHostname: true, TrimHostname: true,
// used to check if supposed to be secure via InsecureSkipVerify // used to check if supposed to be secure via InsecureSkipVerify