diff --git a/graph/pull.go b/graph/pull.go index 5fd837b42..0c33e313a 100644 --- a/graph/pull.go +++ b/graph/pull.go @@ -17,6 +17,7 @@ import ( "github.com/docker/docker/pkg/progressreader" "github.com/docker/docker/pkg/streamformatter" "github.com/docker/docker/pkg/stringid" + "github.com/docker/docker/pkg/transport" "github.com/docker/docker/registry" "github.com/docker/docker/utils" ) @@ -55,12 +56,19 @@ func (s *TagStore) Pull(image string, tag string, imagePullConfig *ImagePullConf defer s.poolRemove("pull", utils.ImageReference(repoInfo.LocalName, tag)) logrus.Debugf("pulling image from host %q with remote name %q", repoInfo.Index.Name, repoInfo.RemoteName) - endpoint, err := repoInfo.GetEndpoint() + + endpoint, err := repoInfo.GetEndpoint(imagePullConfig.MetaHeaders) if err != nil { return err } - - r, err := registry.NewSession(imagePullConfig.AuthConfig, registry.HTTPRequestFactory(imagePullConfig.MetaHeaders), endpoint, true) + // TODO(tiborvass): reuse client from endpoint? + // Adds Docker-specific headers as well as user-specified headers (metaHeaders) + tr := transport.NewTransport( + registry.NewTransport(registry.ReceiveTimeout, endpoint.IsSecure), + registry.DockerHeaders(imagePullConfig.MetaHeaders)..., + ) + client := registry.HTTPClient(tr) + r, err := registry.NewSession(client, imagePullConfig.AuthConfig, endpoint) if err != nil { return err } @@ -109,7 +117,7 @@ func (s *TagStore) pullRepository(r *registry.Session, out io.Writer, repoInfo * } logrus.Debugf("Retrieving the tag list") - tagsList, err := r.GetRemoteTags(repoData.Endpoints, repoInfo.RemoteName, repoData.Tokens) + tagsList, err := r.GetRemoteTags(repoData.Endpoints, repoInfo.RemoteName) if err != nil { logrus.Errorf("unable to get remote tags: %s", err) return err @@ -240,7 +248,7 @@ func (s *TagStore) pullRepository(r *registry.Session, out io.Writer, repoInfo * } func (s *TagStore) pullImage(r *registry.Session, out io.Writer, imgID, endpoint string, token []string, sf *streamformatter.StreamFormatter) (bool, error) { - history, err := r.GetRemoteHistory(imgID, endpoint, token) + history, err := r.GetRemoteHistory(imgID, endpoint) if err != nil { return false, err } @@ -269,7 +277,7 @@ func (s *TagStore) pullImage(r *registry.Session, out io.Writer, imgID, endpoint ) retries := 5 for j := 1; j <= retries; j++ { - imgJSON, imgSize, err = r.GetRemoteImageJSON(id, endpoint, token) + imgJSON, imgSize, err = r.GetRemoteImageJSON(id, endpoint) if err != nil && j == retries { out.Write(sf.FormatProgress(stringid.TruncateID(id), "Error pulling dependent layers", nil)) return layersDownloaded, err @@ -297,7 +305,7 @@ func (s *TagStore) pullImage(r *registry.Session, out io.Writer, imgID, endpoint status = fmt.Sprintf("Pulling fs layer [retries: %d]", j) } out.Write(sf.FormatProgress(stringid.TruncateID(id), status, nil)) - layer, err := r.GetRemoteImageLayer(img.ID, endpoint, token, int64(imgSize)) + layer, err := r.GetRemoteImageLayer(img.ID, endpoint, int64(imgSize)) if uerr, ok := err.(*url.Error); ok { err = uerr.Err } diff --git a/graph/push.go b/graph/push.go index 1ae1fc1d0..817ef707f 100644 --- a/graph/push.go +++ b/graph/push.go @@ -18,6 +18,7 @@ import ( "github.com/docker/docker/pkg/progressreader" "github.com/docker/docker/pkg/streamformatter" "github.com/docker/docker/pkg/stringid" + "github.com/docker/docker/pkg/transport" "github.com/docker/docker/registry" "github.com/docker/docker/runconfig" "github.com/docker/docker/utils" @@ -141,7 +142,7 @@ func lookupImageOnEndpoint(wg *sync.WaitGroup, r *registry.Session, out io.Write images chan imagePushData, imagesToPush chan string) { defer wg.Done() for image := range images { - if err := r.LookupRemoteImage(image.id, image.endpoint, image.tokens); err != nil { + if err := r.LookupRemoteImage(image.id, image.endpoint); err != nil { logrus.Errorf("Error in LookupRemoteImage: %s", err) imagesToPush <- image.id continue @@ -199,7 +200,7 @@ func (s *TagStore) pushImageToEndpoint(endpoint string, out io.Writer, remoteNam } for _, tag := range tags[id] { out.Write(sf.FormatStatus("", "Pushing tag for rev [%s] on {%s}", stringid.TruncateID(id), endpoint+"repositories/"+remoteName+"/tags/"+tag)) - if err := r.PushRegistryTag(remoteName, id, tag, endpoint, repo.Tokens); err != nil { + if err := r.PushRegistryTag(remoteName, id, tag, endpoint); err != nil { return err } } @@ -258,7 +259,7 @@ func (s *TagStore) pushImage(r *registry.Session, out io.Writer, imgID, ep strin } // Send the json - if err := r.PushImageJSONRegistry(imgData, jsonRaw, ep, token); err != nil { + if err := r.PushImageJSONRegistry(imgData, jsonRaw, ep); err != nil { if err == registry.ErrAlreadyExists { out.Write(sf.FormatProgress(stringid.TruncateID(imgData.ID), "Image already pushed, skipping", nil)) return "", nil @@ -284,14 +285,14 @@ func (s *TagStore) pushImage(r *registry.Session, out io.Writer, imgID, ep strin NewLines: false, ID: stringid.TruncateID(imgData.ID), Action: "Pushing", - }), ep, token, jsonRaw) + }), ep, jsonRaw) if err != nil { return "", err } imgData.Checksum = checksum imgData.ChecksumPayload = checksumPayload // Send the checksum - if err := r.PushImageChecksumRegistry(imgData, ep, token); err != nil { + if err := r.PushImageChecksumRegistry(imgData, ep); err != nil { return "", err } @@ -509,12 +510,18 @@ func (s *TagStore) Push(localName string, imagePushConfig *ImagePushConfig) erro } defer s.poolRemove("push", repoInfo.LocalName) - endpoint, err := repoInfo.GetEndpoint() + endpoint, err := repoInfo.GetEndpoint(imagePushConfig.MetaHeaders) if err != nil { return err } - - r, err := registry.NewSession(imagePushConfig.AuthConfig, registry.HTTPRequestFactory(imagePushConfig.MetaHeaders), endpoint, false) + // TODO(tiborvass): reuse client from endpoint? + // Adds Docker-specific headers as well as user-specified headers (metaHeaders) + tr := transport.NewTransport( + registry.NewTransport(registry.NoTimeout, endpoint.IsSecure), + registry.DockerHeaders(imagePushConfig.MetaHeaders)..., + ) + client := registry.HTTPClient(tr) + r, err := registry.NewSession(client, imagePushConfig.AuthConfig, endpoint) if err != nil { return err } diff --git a/pkg/requestdecorator/README.md b/pkg/requestdecorator/README.md deleted file mode 100644 index 76f8ca798..000000000 --- a/pkg/requestdecorator/README.md +++ /dev/null @@ -1,2 +0,0 @@ -This package provides helper functions for decorating a request with user agent -versions, auth, meta headers. diff --git a/pkg/requestdecorator/requestdecorator.go b/pkg/requestdecorator/requestdecorator.go deleted file mode 100644 index c236e3fe3..000000000 --- a/pkg/requestdecorator/requestdecorator.go +++ /dev/null @@ -1,172 +0,0 @@ -// Package requestdecorator provides helper functions to decorate a request with -// user agent versions, auth, meta headers. -package requestdecorator - -import ( - "errors" - "io" - "net/http" - "strings" - - "github.com/Sirupsen/logrus" -) - -var ( - ErrNilRequest = errors.New("request cannot be nil") -) - -// UAVersionInfo is used to model UserAgent versions. -type UAVersionInfo struct { - Name string - Version string -} - -func NewUAVersionInfo(name, version string) UAVersionInfo { - return UAVersionInfo{ - Name: name, - Version: version, - } -} - -func (vi *UAVersionInfo) isValid() bool { - const stopChars = " \t\r\n/" - name := vi.Name - vers := vi.Version - if len(name) == 0 || strings.ContainsAny(name, stopChars) { - return false - } - if len(vers) == 0 || strings.ContainsAny(vers, stopChars) { - return false - } - return true -} - -// Convert versions to a string and append the string to the string base. -// -// Each UAVersionInfo will be converted to a string in the format of -// "product/version", where the "product" is get from the name field, while -// version is get from the version field. Several pieces of verson information -// will be concatinated and separated by space. -func appendVersions(base string, versions ...UAVersionInfo) string { - if len(versions) == 0 { - return base - } - - verstrs := make([]string, 0, 1+len(versions)) - if len(base) > 0 { - verstrs = append(verstrs, base) - } - - for _, v := range versions { - if !v.isValid() { - continue - } - verstrs = append(verstrs, v.Name+"/"+v.Version) - } - return strings.Join(verstrs, " ") -} - -// Decorator is used to change an instance of -// http.Request. It could be used to add more header fields, -// change body, etc. -type Decorator interface { - // ChangeRequest() changes the request accordingly. - // The changed request will be returned or err will be non-nil - // if an error occur. - ChangeRequest(req *http.Request) (newReq *http.Request, err error) -} - -// UserAgentDecorator appends the product/version to the user agent field -// of a request. -type UserAgentDecorator struct { - Versions []UAVersionInfo -} - -func (h *UserAgentDecorator) ChangeRequest(req *http.Request) (*http.Request, error) { - if req == nil { - return req, ErrNilRequest - } - - userAgent := appendVersions(req.UserAgent(), h.Versions...) - if len(userAgent) > 0 { - req.Header.Set("User-Agent", userAgent) - } - return req, nil -} - -type MetaHeadersDecorator struct { - Headers map[string][]string -} - -func (h *MetaHeadersDecorator) ChangeRequest(req *http.Request) (*http.Request, error) { - if h.Headers == nil { - return req, ErrNilRequest - } - for k, v := range h.Headers { - req.Header[k] = v - } - return req, nil -} - -type AuthDecorator struct { - login string - password string -} - -func NewAuthDecorator(login, password string) Decorator { - return &AuthDecorator{ - login: login, - password: password, - } -} - -func (self *AuthDecorator) ChangeRequest(req *http.Request) (*http.Request, error) { - if req == nil { - return req, ErrNilRequest - } - req.SetBasicAuth(self.login, self.password) - return req, nil -} - -// RequestFactory creates an HTTP request -// and applies a list of decorators on the request. -type RequestFactory struct { - decorators []Decorator -} - -func NewRequestFactory(d ...Decorator) *RequestFactory { - return &RequestFactory{ - decorators: d, - } -} - -func (f *RequestFactory) AddDecorator(d ...Decorator) { - f.decorators = append(f.decorators, d...) -} - -func (f *RequestFactory) GetDecorators() []Decorator { - return f.decorators -} - -// NewRequest() creates a new *http.Request, -// applies all decorators in the Factory on the request, -// then applies decorators provided by d on the request. -func (h *RequestFactory) NewRequest(method, urlStr string, body io.Reader, d ...Decorator) (*http.Request, error) { - req, err := http.NewRequest(method, urlStr, body) - if err != nil { - return nil, err - } - - // By default, a nil factory should work. - if h == nil { - return req, nil - } - for _, dec := range h.decorators { - req, _ = dec.ChangeRequest(req) - } - for _, dec := range d { - req, _ = dec.ChangeRequest(req) - } - logrus.Debugf("%v -- HEADERS: %v", req.URL, req.Header) - return req, err -} diff --git a/pkg/requestdecorator/requestdecorator_test.go b/pkg/requestdecorator/requestdecorator_test.go deleted file mode 100644 index ed6113546..000000000 --- a/pkg/requestdecorator/requestdecorator_test.go +++ /dev/null @@ -1,222 +0,0 @@ -package requestdecorator - -import ( - "net/http" - "strings" - "testing" -) - -func TestUAVersionInfo(t *testing.T) { - uavi := NewUAVersionInfo("foo", "bar") - if !uavi.isValid() { - t.Fatalf("UAVersionInfo should be valid") - } - uavi = NewUAVersionInfo("", "bar") - if uavi.isValid() { - t.Fatalf("Expected UAVersionInfo to be invalid") - } - uavi = NewUAVersionInfo("foo", "") - if uavi.isValid() { - t.Fatalf("Expected UAVersionInfo to be invalid") - } -} - -func TestUserAgentDecorator(t *testing.T) { - httpVersion := make([]UAVersionInfo, 2) - httpVersion = append(httpVersion, NewUAVersionInfo("testname", "testversion")) - httpVersion = append(httpVersion, NewUAVersionInfo("name", "version")) - uad := &UserAgentDecorator{ - Versions: httpVersion, - } - - req, err := http.NewRequest("GET", "/something", strings.NewReader("test")) - if err != nil { - t.Fatal(err) - } - reqDecorated, err := uad.ChangeRequest(req) - if err != nil { - t.Fatal(err) - } - - if reqDecorated.Header.Get("User-Agent") != "testname/testversion name/version" { - t.Fatalf("Request should have User-Agent 'testname/testversion name/version'") - } -} - -func TestUserAgentDecoratorErr(t *testing.T) { - httpVersion := make([]UAVersionInfo, 0) - uad := &UserAgentDecorator{ - Versions: httpVersion, - } - - var req *http.Request - _, err := uad.ChangeRequest(req) - if err == nil { - t.Fatalf("Expected to get ErrNilRequest instead no error was returned") - } -} - -func TestMetaHeadersDecorator(t *testing.T) { - var headers = map[string][]string{ - "key1": {"value1"}, - "key2": {"value2"}, - } - mhd := &MetaHeadersDecorator{ - Headers: headers, - } - - req, err := http.NewRequest("GET", "/something", strings.NewReader("test")) - if err != nil { - t.Fatal(err) - } - reqDecorated, err := mhd.ChangeRequest(req) - if err != nil { - t.Fatal(err) - } - - v, ok := reqDecorated.Header["key1"] - if !ok { - t.Fatalf("Expected to have header key1") - } - if v[0] != "value1" { - t.Fatalf("Expected value for key1 isn't value1") - } - - v, ok = reqDecorated.Header["key2"] - if !ok { - t.Fatalf("Expected to have header key2") - } - if v[0] != "value2" { - t.Fatalf("Expected value for key2 isn't value2") - } -} - -func TestMetaHeadersDecoratorErr(t *testing.T) { - mhd := &MetaHeadersDecorator{} - - var req *http.Request - _, err := mhd.ChangeRequest(req) - if err == nil { - t.Fatalf("Expected to get ErrNilRequest instead no error was returned") - } -} - -func TestAuthDecorator(t *testing.T) { - ad := NewAuthDecorator("test", "password") - - req, err := http.NewRequest("GET", "/something", strings.NewReader("test")) - if err != nil { - t.Fatal(err) - } - reqDecorated, err := ad.ChangeRequest(req) - if err != nil { - t.Fatal(err) - } - - username, password, ok := reqDecorated.BasicAuth() - if !ok { - t.Fatalf("Cannot retrieve basic auth info from request") - } - if username != "test" { - t.Fatalf("Expected username to be test, got %s", username) - } - if password != "password" { - t.Fatalf("Expected password to be password, got %s", password) - } -} - -func TestAuthDecoratorErr(t *testing.T) { - ad := &AuthDecorator{} - - var req *http.Request - _, err := ad.ChangeRequest(req) - if err == nil { - t.Fatalf("Expected to get ErrNilRequest instead no error was returned") - } -} - -func TestRequestFactory(t *testing.T) { - ad := NewAuthDecorator("test", "password") - httpVersion := make([]UAVersionInfo, 2) - httpVersion = append(httpVersion, NewUAVersionInfo("testname", "testversion")) - httpVersion = append(httpVersion, NewUAVersionInfo("name", "version")) - uad := &UserAgentDecorator{ - Versions: httpVersion, - } - - requestFactory := NewRequestFactory(ad, uad) - - if l := len(requestFactory.GetDecorators()); l != 2 { - t.Fatalf("Expected to have two decorators, got %d", l) - } - - req, err := requestFactory.NewRequest("GET", "/test", strings.NewReader("test")) - if err != nil { - t.Fatal(err) - } - - username, password, ok := req.BasicAuth() - if !ok { - t.Fatalf("Cannot retrieve basic auth info from request") - } - if username != "test" { - t.Fatalf("Expected username to be test, got %s", username) - } - if password != "password" { - t.Fatalf("Expected password to be password, got %s", password) - } - if req.Header.Get("User-Agent") != "testname/testversion name/version" { - t.Fatalf("Request should have User-Agent 'testname/testversion name/version'") - } -} - -func TestRequestFactoryNewRequestWithDecorators(t *testing.T) { - ad := NewAuthDecorator("test", "password") - - requestFactory := NewRequestFactory(ad) - - if l := len(requestFactory.GetDecorators()); l != 1 { - t.Fatalf("Expected to have one decorators, got %d", l) - } - - ad2 := NewAuthDecorator("test2", "password2") - - req, err := requestFactory.NewRequest("GET", "/test", strings.NewReader("test"), ad2) - if err != nil { - t.Fatal(err) - } - - username, password, ok := req.BasicAuth() - if !ok { - t.Fatalf("Cannot retrieve basic auth info from request") - } - if username != "test2" { - t.Fatalf("Expected username to be test, got %s", username) - } - if password != "password2" { - t.Fatalf("Expected password to be password, got %s", password) - } -} - -func TestRequestFactoryAddDecorator(t *testing.T) { - requestFactory := NewRequestFactory() - - if l := len(requestFactory.GetDecorators()); l != 0 { - t.Fatalf("Expected to have zero decorators, got %d", l) - } - - ad := NewAuthDecorator("test", "password") - requestFactory.AddDecorator(ad) - - if l := len(requestFactory.GetDecorators()); l != 1 { - t.Fatalf("Expected to have one decorators, got %d", l) - } -} - -func TestRequestFactoryNil(t *testing.T) { - var requestFactory RequestFactory - _, err := requestFactory.NewRequest("GET", "/test", strings.NewReader("test")) - if err != nil { - t.Fatalf("Expected not to get and error, got %s", err) - } -} diff --git a/pkg/transport/LICENSE b/pkg/transport/LICENSE new file mode 100644 index 000000000..d02f24fd5 --- /dev/null +++ b/pkg/transport/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The oauth2 Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pkg/transport/transport.go b/pkg/transport/transport.go new file mode 100644 index 000000000..510d8b4bc --- /dev/null +++ b/pkg/transport/transport.go @@ -0,0 +1,148 @@ +package transport + +import ( + "io" + "net/http" + "sync" +) + +type RequestModifier interface { + ModifyRequest(*http.Request) error +} + +type headerModifier http.Header + +// NewHeaderRequestModifier returns a RequestModifier that merges the HTTP headers +// passed as an argument, with the HTTP headers of a request. +// +// If the same key is present in both, the modifying header values for that key, +// are appended to the values for that same key in the request header. +func NewHeaderRequestModifier(header http.Header) RequestModifier { + return headerModifier(header) +} + +func (h headerModifier) ModifyRequest(req *http.Request) error { + for k, s := range http.Header(h) { + req.Header[k] = append(req.Header[k], s...) + } + + return nil +} + +// NewTransport returns an http.RoundTripper that modifies requests according to +// the RequestModifiers passed in the arguments, before sending the requests to +// the base http.RoundTripper (which, if nil, defaults to http.DefaultTransport). +func NewTransport(base http.RoundTripper, modifiers ...RequestModifier) http.RoundTripper { + return &transport{ + Modifiers: modifiers, + Base: base, + } +} + +// transport is an http.RoundTripper that makes HTTP requests after +// copying and modifying the request +type transport struct { + Modifiers []RequestModifier + Base http.RoundTripper + + mu sync.Mutex // guards modReq + modReq map[*http.Request]*http.Request // original -> modified +} + +func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) { + req2 := CloneRequest(req) + for _, modifier := range t.Modifiers { + if err := modifier.ModifyRequest(req2); err != nil { + return nil, err + } + } + + t.setModReq(req, req2) + res, err := t.base().RoundTrip(req2) + if err != nil { + t.setModReq(req, nil) + return nil, err + } + res.Body = &OnEOFReader{ + Rc: res.Body, + Fn: func() { t.setModReq(req, nil) }, + } + return res, nil +} + +// CancelRequest cancels an in-flight request by closing its connection. +func (t *transport) CancelRequest(req *http.Request) { + type canceler interface { + CancelRequest(*http.Request) + } + if cr, ok := t.base().(canceler); ok { + t.mu.Lock() + modReq := t.modReq[req] + delete(t.modReq, req) + t.mu.Unlock() + cr.CancelRequest(modReq) + } +} + +func (t *transport) base() http.RoundTripper { + if t.Base != nil { + return t.Base + } + return http.DefaultTransport +} + +func (t *transport) setModReq(orig, mod *http.Request) { + t.mu.Lock() + defer t.mu.Unlock() + if t.modReq == nil { + t.modReq = make(map[*http.Request]*http.Request) + } + if mod == nil { + delete(t.modReq, orig) + } else { + t.modReq[orig] = mod + } +} + +// CloneRequest returns a clone of the provided *http.Request. +// The clone is a shallow copy of the struct and its Header map. +func CloneRequest(r *http.Request) *http.Request { + // shallow copy of the struct + r2 := new(http.Request) + *r2 = *r + // deep copy of the Header + r2.Header = make(http.Header, len(r.Header)) + for k, s := range r.Header { + r2.Header[k] = append([]string(nil), s...) + } + + return r2 +} + +// OnEOFReader ensures a callback function is called +// on Close() and when the underlying Reader returns an io.EOF error +type OnEOFReader struct { + Rc io.ReadCloser + Fn func() +} + +func (r *OnEOFReader) Read(p []byte) (n int, err error) { + n, err = r.Rc.Read(p) + if err == io.EOF { + r.runFunc() + } + return +} + +func (r *OnEOFReader) Close() error { + err := r.Rc.Close() + r.runFunc() + return err +} + +func (r *OnEOFReader) runFunc() { + if fn := r.Fn; fn != nil { + fn() + r.Fn = nil + } +} diff --git a/pkg/useragent/README.md b/pkg/useragent/README.md new file mode 100644 index 000000000..d9cb367d1 --- /dev/null +++ b/pkg/useragent/README.md @@ -0,0 +1 @@ +This package provides helper functions to pack version information into a single User-Agent header. diff --git a/pkg/useragent/useragent.go b/pkg/useragent/useragent.go new file mode 100644 index 000000000..9e35d1c70 --- /dev/null +++ b/pkg/useragent/useragent.go @@ -0,0 +1,60 @@ +// Package useragent provides helper functions to pack +// version information into a single User-Agent header. +package useragent + +import ( + "errors" + "strings" +) + +var ( + ErrNilRequest = errors.New("request cannot be nil") +) + +// VersionInfo is used to model UserAgent versions. +type VersionInfo struct { + Name string + Version string +} + +func (vi *VersionInfo) isValid() bool { + const stopChars = " \t\r\n/" + name := vi.Name + vers := vi.Version + if len(name) == 0 || strings.ContainsAny(name, stopChars) { + return false + } + if len(vers) == 0 || strings.ContainsAny(vers, stopChars) { + return false + } + return true +} + +// Convert versions to a string and append the string to the string base. +// +// Each VersionInfo will be converted to a string in the format of +// "product/version", where the "product" is get from the name field, while +// version is get from the version field. Several pieces of verson information +// will be concatinated and separated by space. +// +// Example: +// AppendVersions("base", VersionInfo{"foo", "1.0"}, VersionInfo{"bar", "2.0"}) +// results in "base foo/1.0 bar/2.0". +func AppendVersions(base string, versions ...VersionInfo) string { + if len(versions) == 0 { + return base + } + + verstrs := make([]string, 0, 1+len(versions)) + if len(base) > 0 { + verstrs = append(verstrs, base) + } + + for _, v := range versions { + if !v.isValid() { + continue + } + verstrs = append(verstrs, v.Name+"/"+v.Version) + } + return strings.Join(verstrs, " ") +} diff --git a/pkg/useragent/useragent_test.go b/pkg/useragent/useragent_test.go new file mode 100644 index 000000000..0ad7243a6 --- /dev/null +++ b/pkg/useragent/useragent_test.go @@ -0,0 +1,31 @@ +package useragent + +import "testing" + +func TestVersionInfo(t *testing.T) { + vi := VersionInfo{"foo", "bar"} + if !vi.isValid() { + t.Fatalf("VersionInfo should be valid") + } + vi = VersionInfo{"", "bar"} + if vi.isValid() { + t.Fatalf("Expected VersionInfo to be invalid") + } + vi = VersionInfo{"foo", ""} + if vi.isValid() { + t.Fatalf("Expected VersionInfo to be invalid") + } +} + +func TestAppendVersions(t *testing.T) { + vis := []VersionInfo{ + {"foo", "1.0"}, + {"bar", "0.1"}, + {"pi", "3.1.4"}, + } + v := AppendVersions("base", vis...) + expect := "base foo/1.0 bar/0.1 pi/3.1.4" + if v != expect { + t.Fatalf("expected %q, got %q", expect, v) + } +} diff --git a/registry/auth.go b/registry/auth.go index 1ac1ca984..33f8fa068 100644 --- a/registry/auth.go +++ b/registry/auth.go @@ -11,7 +11,6 @@ import ( "github.com/Sirupsen/logrus" "github.com/docker/docker/cliconfig" - "github.com/docker/docker/pkg/requestdecorator" ) type RequestAuthorization struct { @@ -45,9 +44,6 @@ func (auth *RequestAuthorization) getToken() (string, error) { return auth.tokenCache, nil } - client := auth.registryEndpoint.HTTPClient() - factory := HTTPRequestFactory(nil) - for _, challenge := range auth.registryEndpoint.AuthChallenges { switch strings.ToLower(challenge.Scheme) { case "basic": @@ -59,7 +55,7 @@ func (auth *RequestAuthorization) getToken() (string, error) { params[k] = v } params["scope"] = fmt.Sprintf("%s:%s:%s", auth.resource, auth.scope, strings.Join(auth.actions, ",")) - token, err := getToken(auth.authConfig.Username, auth.authConfig.Password, params, auth.registryEndpoint, client, factory) + token, err := getToken(auth.authConfig.Username, auth.authConfig.Password, params, auth.registryEndpoint) if err != nil { return "", err } @@ -92,21 +88,20 @@ func (auth *RequestAuthorization) Authorize(req *http.Request) error { } // Login tries to register/login to the registry server. -func Login(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint, factory *requestdecorator.RequestFactory) (string, error) { +func Login(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (string, error) { // Separates the v2 registry login logic from the v1 logic. if registryEndpoint.Version == APIVersion2 { - return loginV2(authConfig, registryEndpoint, factory) + return loginV2(authConfig, registryEndpoint) } - return loginV1(authConfig, registryEndpoint, factory) + return loginV1(authConfig, registryEndpoint) } // loginV1 tries to register/login to the v1 registry server. -func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint, factory *requestdecorator.RequestFactory) (string, error) { +func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (string, error) { var ( status string reqBody []byte err error - client = registryEndpoint.HTTPClient() reqStatusCode = 0 serverAddress = authConfig.ServerAddress ) @@ -130,7 +125,7 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint, facto // using `bytes.NewReader(jsonBody)` here causes the server to respond with a 411 status. b := strings.NewReader(string(jsonBody)) - req1, err := client.Post(serverAddress+"users/", "application/json; charset=utf-8", b) + req1, err := registryEndpoint.client.Post(serverAddress+"users/", "application/json; charset=utf-8", b) if err != nil { return "", fmt.Errorf("Server Error: %s", err) } @@ -151,9 +146,9 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint, facto } } else if reqStatusCode == 400 { if string(reqBody) == "\"Username or email already exists\"" { - req, err := factory.NewRequest("GET", serverAddress+"users/", nil) + req, err := http.NewRequest("GET", serverAddress+"users/", nil) req.SetBasicAuth(authConfig.Username, authConfig.Password) - resp, err := client.Do(req) + resp, err := registryEndpoint.client.Do(req) if err != nil { return "", err } @@ -180,9 +175,9 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint, facto } else if reqStatusCode == 401 { // This case would happen with private registries where /v1/users is // protected, so people can use `docker login` as an auth check. - req, err := factory.NewRequest("GET", serverAddress+"users/", nil) + req, err := http.NewRequest("GET", serverAddress+"users/", nil) req.SetBasicAuth(authConfig.Username, authConfig.Password) - resp, err := client.Do(req) + resp, err := registryEndpoint.client.Do(req) if err != nil { return "", err } @@ -214,12 +209,11 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint, facto // now, users should create their account through other means like directly from a web page // served by the v2 registry service provider. Whether this will be supported in the future // is to be determined. -func loginV2(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint, factory *requestdecorator.RequestFactory) (string, error) { +func loginV2(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (string, error) { logrus.Debugf("attempting v2 login to registry endpoint %s", registryEndpoint) var ( err error allErrors []error - client = registryEndpoint.HTTPClient() ) for _, challenge := range registryEndpoint.AuthChallenges { @@ -227,9 +221,9 @@ func loginV2(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint, facto switch strings.ToLower(challenge.Scheme) { case "basic": - err = tryV2BasicAuthLogin(authConfig, challenge.Parameters, registryEndpoint, client, factory) + err = tryV2BasicAuthLogin(authConfig, challenge.Parameters, registryEndpoint) case "bearer": - err = tryV2TokenAuthLogin(authConfig, challenge.Parameters, registryEndpoint, client, factory) + err = tryV2TokenAuthLogin(authConfig, challenge.Parameters, registryEndpoint) default: // Unsupported challenge types are explicitly skipped. err = fmt.Errorf("unsupported auth scheme: %q", challenge.Scheme) @@ -247,15 +241,15 @@ func loginV2(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint, facto return "", fmt.Errorf("no successful auth challenge for %s - errors: %s", registryEndpoint, allErrors) } -func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint, client *http.Client, factory *requestdecorator.RequestFactory) error { - req, err := factory.NewRequest("GET", registryEndpoint.Path(""), nil) +func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint) error { + req, err := http.NewRequest("GET", registryEndpoint.Path(""), nil) if err != nil { return err } req.SetBasicAuth(authConfig.Username, authConfig.Password) - resp, err := client.Do(req) + resp, err := registryEndpoint.client.Do(req) if err != nil { return err } @@ -268,20 +262,20 @@ func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]str return nil } -func tryV2TokenAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint, client *http.Client, factory *requestdecorator.RequestFactory) error { - token, err := getToken(authConfig.Username, authConfig.Password, params, registryEndpoint, client, factory) +func tryV2TokenAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint) error { + token, err := getToken(authConfig.Username, authConfig.Password, params, registryEndpoint) if err != nil { return err } - req, err := factory.NewRequest("GET", registryEndpoint.Path(""), nil) + req, err := http.NewRequest("GET", registryEndpoint.Path(""), nil) if err != nil { return err } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - resp, err := client.Do(req) + resp, err := registryEndpoint.client.Do(req) if err != nil { return err } diff --git a/registry/endpoint.go b/registry/endpoint.go index 84b11a987..ce92668f4 100644 --- a/registry/endpoint.go +++ b/registry/endpoint.go @@ -1,7 +1,6 @@ package registry import ( - "crypto/tls" "encoding/json" "fmt" "io/ioutil" @@ -12,7 +11,7 @@ import ( "github.com/Sirupsen/logrus" "github.com/docker/distribution/registry/api/v2" - "github.com/docker/docker/pkg/requestdecorator" + "github.com/docker/docker/pkg/transport" ) // for mocking in unit tests @@ -43,9 +42,9 @@ func scanForAPIVersion(address string) (string, APIVersion) { } // NewEndpoint parses the given address to return a registry endpoint. -func NewEndpoint(index *IndexInfo) (*Endpoint, error) { +func NewEndpoint(index *IndexInfo, metaHeaders http.Header) (*Endpoint, error) { // *TODO: Allow per-registry configuration of endpoints. - endpoint, err := newEndpoint(index.GetAuthConfigKey(), index.Secure) + endpoint, err := newEndpoint(index.GetAuthConfigKey(), index.Secure, metaHeaders) if err != nil { return nil, err } @@ -83,7 +82,7 @@ func validateEndpoint(endpoint *Endpoint) error { return nil } -func newEndpoint(address string, secure bool) (*Endpoint, error) { +func newEndpoint(address string, secure bool, metaHeaders http.Header) (*Endpoint, error) { var ( endpoint = new(Endpoint) trimmedAddress string @@ -100,15 +99,18 @@ func newEndpoint(address string, secure bool) (*Endpoint, error) { return nil, err } endpoint.IsSecure = secure + tr := NewTransport(ConnectTimeout, endpoint.IsSecure) + endpoint.client = HTTPClient(transport.NewTransport(tr, DockerHeaders(metaHeaders)...)) return endpoint, nil } -func (repoInfo *RepositoryInfo) GetEndpoint() (*Endpoint, error) { - return NewEndpoint(repoInfo.Index) +func (repoInfo *RepositoryInfo) GetEndpoint(metaHeaders http.Header) (*Endpoint, error) { + return NewEndpoint(repoInfo.Index, metaHeaders) } // Endpoint stores basic information about a registry endpoint. type Endpoint struct { + client *http.Client URL *url.URL Version APIVersion IsSecure bool @@ -135,25 +137,24 @@ func (e *Endpoint) Path(path string) string { func (e *Endpoint) Ping() (RegistryInfo, error) { // The ping logic to use is determined by the registry endpoint version. - factory := HTTPRequestFactory(nil) switch e.Version { case APIVersion1: - return e.pingV1(factory) + return e.pingV1() case APIVersion2: - return e.pingV2(factory) + return e.pingV2() } // APIVersionUnknown // We should try v2 first... e.Version = APIVersion2 - regInfo, errV2 := e.pingV2(factory) + regInfo, errV2 := e.pingV2() if errV2 == nil { return regInfo, nil } // ... then fallback to v1. e.Version = APIVersion1 - regInfo, errV1 := e.pingV1(factory) + regInfo, errV1 := e.pingV1() if errV1 == nil { return regInfo, nil } @@ -162,7 +163,7 @@ func (e *Endpoint) Ping() (RegistryInfo, error) { return RegistryInfo{}, fmt.Errorf("unable to ping registry endpoint %s\nv2 ping attempt failed with error: %s\n v1 ping attempt failed with error: %s", e, errV2, errV1) } -func (e *Endpoint) pingV1(factory *requestdecorator.RequestFactory) (RegistryInfo, error) { +func (e *Endpoint) pingV1() (RegistryInfo, error) { logrus.Debugf("attempting v1 ping for registry endpoint %s", e) if e.String() == IndexServerAddress() { @@ -171,12 +172,12 @@ func (e *Endpoint) pingV1(factory *requestdecorator.RequestFactory) (RegistryInf return RegistryInfo{Standalone: false}, nil } - req, err := factory.NewRequest("GET", e.Path("_ping"), nil) + req, err := http.NewRequest("GET", e.Path("_ping"), nil) if err != nil { return RegistryInfo{Standalone: false}, err } - resp, _, err := doRequest(req, nil, ConnectTimeout, e.IsSecure) + resp, err := e.client.Do(req) if err != nil { return RegistryInfo{Standalone: false}, err } @@ -216,15 +217,15 @@ func (e *Endpoint) pingV1(factory *requestdecorator.RequestFactory) (RegistryInf return info, nil } -func (e *Endpoint) pingV2(factory *requestdecorator.RequestFactory) (RegistryInfo, error) { +func (e *Endpoint) pingV2() (RegistryInfo, error) { logrus.Debugf("attempting v2 ping for registry endpoint %s", e) - req, err := factory.NewRequest("GET", e.Path(""), nil) + req, err := http.NewRequest("GET", e.Path(""), nil) if err != nil { return RegistryInfo{}, err } - resp, _, err := doRequest(req, nil, ConnectTimeout, e.IsSecure) + resp, err := e.client.Do(req) if err != nil { return RegistryInfo{}, err } @@ -263,20 +264,3 @@ HeaderLoop: return RegistryInfo{}, fmt.Errorf("v2 registry endpoint returned status %d: %q", resp.StatusCode, http.StatusText(resp.StatusCode)) } - -func (e *Endpoint) HTTPClient() *http.Client { - tlsConfig := tls.Config{ - MinVersion: tls.VersionTLS10, - } - if !e.IsSecure { - tlsConfig.InsecureSkipVerify = true - } - return &http.Client{ - Transport: &http.Transport{ - DisableKeepAlives: true, - Proxy: http.ProxyFromEnvironment, - TLSClientConfig: &tlsConfig, - }, - CheckRedirect: AddRequiredHeadersToRedirectedRequests, - } -} diff --git a/registry/endpoint_test.go b/registry/endpoint_test.go index 9567ba235..6f67867bb 100644 --- a/registry/endpoint_test.go +++ b/registry/endpoint_test.go @@ -19,7 +19,7 @@ func TestEndpointParse(t *testing.T) { {"0.0.0.0:5000", "https://0.0.0.0:5000/v0/"}, } for _, td := range testData { - e, err := newEndpoint(td.str, false) + e, err := newEndpoint(td.str, false, nil) if err != nil { t.Errorf("%q: %s", td.str, err) } @@ -60,6 +60,7 @@ func TestValidateEndpointAmbiguousAPIVersion(t *testing.T) { testEndpoint := Endpoint{ URL: testServerURL, Version: APIVersionUnknown, + client: HTTPClient(NewTransport(ConnectTimeout, false)), } if err = validateEndpoint(&testEndpoint); err != nil { diff --git a/registry/httpfactory.go b/registry/httpfactory.go deleted file mode 100644 index f1b89e582..000000000 --- a/registry/httpfactory.go +++ /dev/null @@ -1,30 +0,0 @@ -package registry - -import ( - "runtime" - - "github.com/docker/docker/autogen/dockerversion" - "github.com/docker/docker/pkg/parsers/kernel" - "github.com/docker/docker/pkg/requestdecorator" -) - -func HTTPRequestFactory(metaHeaders map[string][]string) *requestdecorator.RequestFactory { - // FIXME: this replicates the 'info' job. - httpVersion := make([]requestdecorator.UAVersionInfo, 0, 4) - httpVersion = append(httpVersion, requestdecorator.NewUAVersionInfo("docker", dockerversion.VERSION)) - httpVersion = append(httpVersion, requestdecorator.NewUAVersionInfo("go", runtime.Version())) - httpVersion = append(httpVersion, requestdecorator.NewUAVersionInfo("git-commit", dockerversion.GITCOMMIT)) - if kernelVersion, err := kernel.GetKernelVersion(); err == nil { - httpVersion = append(httpVersion, requestdecorator.NewUAVersionInfo("kernel", kernelVersion.String())) - } - httpVersion = append(httpVersion, requestdecorator.NewUAVersionInfo("os", runtime.GOOS)) - httpVersion = append(httpVersion, requestdecorator.NewUAVersionInfo("arch", runtime.GOARCH)) - uad := &requestdecorator.UserAgentDecorator{ - Versions: httpVersion, - } - mhd := &requestdecorator.MetaHeadersDecorator{ - Headers: metaHeaders, - } - factory := requestdecorator.NewRequestFactory(uad, mhd) - return factory -} diff --git a/registry/registry.go b/registry/registry.go index aff28eaa4..b0706e348 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -8,13 +8,19 @@ import ( "io/ioutil" "net" "net/http" + "net/http/httputil" "os" "path" + "runtime" "strings" "time" "github.com/Sirupsen/logrus" + "github.com/docker/docker/autogen/dockerversion" + "github.com/docker/docker/pkg/parsers/kernel" "github.com/docker/docker/pkg/timeoutconn" + "github.com/docker/docker/pkg/transport" + "github.com/docker/docker/pkg/useragent" ) var ( @@ -31,66 +37,38 @@ const ( ConnectTimeout ) -func newClient(jar http.CookieJar, roots *x509.CertPool, certs []tls.Certificate, timeout TimeoutType, secure bool) *http.Client { - tlsConfig := tls.Config{ - RootCAs: roots, - // Avoid fallback to SSL protocols < TLS1.0 - MinVersion: tls.VersionTLS10, - Certificates: certs, +// dockerUserAgent is the User-Agent the Docker client uses to identify itself. +// It is populated on init(), comprising version information of different components. +var dockerUserAgent string + +func init() { + httpVersion := make([]useragent.VersionInfo, 0, 6) + httpVersion = append(httpVersion, useragent.VersionInfo{"docker", dockerversion.VERSION}) + httpVersion = append(httpVersion, useragent.VersionInfo{"go", runtime.Version()}) + httpVersion = append(httpVersion, useragent.VersionInfo{"git-commit", dockerversion.GITCOMMIT}) + if kernelVersion, err := kernel.GetKernelVersion(); err == nil { + httpVersion = append(httpVersion, useragent.VersionInfo{"kernel", kernelVersion.String()}) } + httpVersion = append(httpVersion, useragent.VersionInfo{"os", runtime.GOOS}) + httpVersion = append(httpVersion, useragent.VersionInfo{"arch", runtime.GOARCH}) - if !secure { - tlsConfig.InsecureSkipVerify = true - } - - httpTransport := &http.Transport{ - DisableKeepAlives: true, - Proxy: http.ProxyFromEnvironment, - TLSClientConfig: &tlsConfig, - } - - switch timeout { - case ConnectTimeout: - httpTransport.Dial = func(proto string, addr string) (net.Conn, error) { - // Set the connect timeout to 30 seconds to allow for slower connection - // times... - d := net.Dialer{Timeout: 30 * time.Second, DualStack: true} - - conn, err := d.Dial(proto, addr) - if err != nil { - return nil, err - } - // Set the recv timeout to 10 seconds - conn.SetDeadline(time.Now().Add(10 * time.Second)) - return conn, nil - } - case ReceiveTimeout: - httpTransport.Dial = func(proto string, addr string) (net.Conn, error) { - d := net.Dialer{DualStack: true} - - conn, err := d.Dial(proto, addr) - if err != nil { - return nil, err - } - conn = timeoutconn.New(conn, 1*time.Minute) - return conn, nil - } - } - - return &http.Client{ - Transport: httpTransport, - CheckRedirect: AddRequiredHeadersToRedirectedRequests, - Jar: jar, - } + dockerUserAgent = useragent.AppendVersions("", httpVersion...) } -func doRequest(req *http.Request, jar http.CookieJar, timeout TimeoutType, secure bool) (*http.Response, *http.Client, error) { +type httpsRequestModifier struct{ tlsConfig *tls.Config } + +// DRAGONS(tiborvass): If someone wonders why do we set tlsconfig in a roundtrip, +// it's because it's so as to match the current behavior in master: we generate the +// certpool on every-goddam-request. It's not great, but it allows people to just put +// the certs in /etc/docker/certs.d/.../ and let docker "pick it up" immediately. Would +// prefer an fsnotify implementation, but that was out of scope of my refactoring. +func (m *httpsRequestModifier) ModifyRequest(req *http.Request) error { var ( - pool *x509.CertPool + roots *x509.CertPool certs []tls.Certificate ) - if secure && req.URL.Scheme == "https" { + if req.URL.Scheme == "https" { hasFile := func(files []os.FileInfo, name string) bool { for _, f := range files { if f.Name() == name { @@ -104,31 +82,31 @@ func doRequest(req *http.Request, jar http.CookieJar, timeout TimeoutType, secur logrus.Debugf("hostDir: %s", hostDir) fs, err := ioutil.ReadDir(hostDir) if err != nil && !os.IsNotExist(err) { - return nil, nil, err + return nil } for _, f := range fs { if strings.HasSuffix(f.Name(), ".crt") { - if pool == nil { - pool = x509.NewCertPool() + if roots == nil { + roots = x509.NewCertPool() } logrus.Debugf("crt: %s", hostDir+"/"+f.Name()) data, err := ioutil.ReadFile(path.Join(hostDir, f.Name())) if err != nil { - return nil, nil, err + return err } - pool.AppendCertsFromPEM(data) + roots.AppendCertsFromPEM(data) } if strings.HasSuffix(f.Name(), ".cert") { certName := f.Name() keyName := certName[:len(certName)-5] + ".key" logrus.Debugf("cert: %s", hostDir+"/"+f.Name()) if !hasFile(fs, keyName) { - return nil, nil, fmt.Errorf("Missing key %s for certificate %s", keyName, certName) + return fmt.Errorf("Missing key %s for certificate %s", keyName, certName) } cert, err := tls.LoadX509KeyPair(path.Join(hostDir, certName), path.Join(hostDir, keyName)) if err != nil { - return nil, nil, err + return err } certs = append(certs, cert) } @@ -137,24 +115,108 @@ func doRequest(req *http.Request, jar http.CookieJar, timeout TimeoutType, secur certName := keyName[:len(keyName)-4] + ".cert" logrus.Debugf("key: %s", hostDir+"/"+f.Name()) if !hasFile(fs, certName) { - return nil, nil, fmt.Errorf("Missing certificate %s for key %s", certName, keyName) + return fmt.Errorf("Missing certificate %s for key %s", certName, keyName) } } } + m.tlsConfig.RootCAs = roots + m.tlsConfig.Certificates = certs + } + return nil +} + +func NewTransport(timeout TimeoutType, secure bool) http.RoundTripper { + tlsConfig := &tls.Config{ + // Avoid fallback to SSL protocols < TLS1.0 + MinVersion: tls.VersionTLS10, + InsecureSkipVerify: !secure, } - if len(certs) == 0 { - client := newClient(jar, pool, nil, timeout, secure) - res, err := client.Do(req) - if err != nil { - return nil, nil, err + tr := &http.Transport{ + DisableKeepAlives: true, + Proxy: http.ProxyFromEnvironment, + TLSClientConfig: tlsConfig, + } + + switch timeout { + case ConnectTimeout: + tr.Dial = func(proto string, addr string) (net.Conn, error) { + // Set the connect timeout to 30 seconds to allow for slower connection + // times... + d := net.Dialer{Timeout: 30 * time.Second, DualStack: true} + + conn, err := d.Dial(proto, addr) + if err != nil { + return nil, err + } + // Set the recv timeout to 10 seconds + conn.SetDeadline(time.Now().Add(10 * time.Second)) + return conn, nil + } + case ReceiveTimeout: + tr.Dial = func(proto string, addr string) (net.Conn, error) { + d := net.Dialer{DualStack: true} + + conn, err := d.Dial(proto, addr) + if err != nil { + return nil, err + } + conn = timeoutconn.New(conn, 1*time.Minute) + return conn, nil } - return res, client, nil } - client := newClient(jar, pool, certs, timeout, secure) - res, err := client.Do(req) - return res, client, err + if secure { + // note: httpsTransport also handles http transport + // but for HTTPS, it sets up the certs + return transport.NewTransport(tr, &httpsRequestModifier{tlsConfig}) + } + + return tr +} + +// DockerHeaders returns request modifiers that ensure requests have +// the User-Agent header set to dockerUserAgent and that metaHeaders +// are added. +func DockerHeaders(metaHeaders http.Header) []transport.RequestModifier { + modifiers := []transport.RequestModifier{ + transport.NewHeaderRequestModifier(http.Header{"User-Agent": []string{dockerUserAgent}}), + } + if metaHeaders != nil { + modifiers = append(modifiers, transport.NewHeaderRequestModifier(metaHeaders)) + } + return modifiers +} + +type debugTransport struct{ http.RoundTripper } + +func (tr debugTransport) RoundTrip(req *http.Request) (*http.Response, error) { + dump, err := httputil.DumpRequestOut(req, false) + if err != nil { + fmt.Println("could not dump request") + } + fmt.Println(string(dump)) + resp, err := tr.RoundTripper.RoundTrip(req) + if err != nil { + return nil, err + } + dump, err = httputil.DumpResponse(resp, false) + if err != nil { + fmt.Println("could not dump response") + } + fmt.Println(string(dump)) + return resp, err +} + +func HTTPClient(transport http.RoundTripper) *http.Client { + if transport == nil { + transport = NewTransport(ConnectTimeout, true) + } + + return &http.Client{ + Transport: transport, + CheckRedirect: AddRequiredHeadersToRedirectedRequests, + } } func trustedLocation(req *http.Request) bool { diff --git a/registry/registry_test.go b/registry/registry_test.go index 799d080ed..33e86ff43 100644 --- a/registry/registry_test.go +++ b/registry/registry_test.go @@ -8,7 +8,7 @@ import ( "testing" "github.com/docker/docker/cliconfig" - "github.com/docker/docker/pkg/requestdecorator" + "github.com/docker/docker/pkg/transport" ) var ( @@ -22,45 +22,34 @@ const ( func spawnTestRegistrySession(t *testing.T) *Session { authConfig := &cliconfig.AuthConfig{} - endpoint, err := NewEndpoint(makeIndex("/v1/")) + endpoint, err := NewEndpoint(makeIndex("/v1/"), nil) if err != nil { t.Fatal(err) } - r, err := NewSession(authConfig, requestdecorator.NewRequestFactory(), endpoint, true) + var tr http.RoundTripper = debugTransport{NewTransport(ReceiveTimeout, endpoint.IsSecure)} + tr = transport.NewTransport(AuthTransport(tr, authConfig, false), DockerHeaders(nil)...) + client := HTTPClient(tr) + r, err := NewSession(client, authConfig, endpoint) if err != nil { t.Fatal(err) } + // In a normal scenario for the v1 registry, the client should send a `X-Docker-Token: true` + // header while authenticating, in order to retrieve a token that can be later used to + // perform authenticated actions. + // + // The mock v1 registry does not support that, (TODO(tiborvass): support it), instead, + // it will consider authenticated any request with the header `X-Docker-Token: fake-token`. + // + // Because we know that the client's transport is an `*authTransport` we simply cast it, + // in order to set the internal cached token to the fake token, and thus send that fake token + // upon every subsequent requests. + r.client.Transport.(*authTransport).token = token return r } -func TestPublicSession(t *testing.T) { - authConfig := &cliconfig.AuthConfig{} - - getSessionDecorators := func(index *IndexInfo) int { - endpoint, err := NewEndpoint(index) - if err != nil { - t.Fatal(err) - } - r, err := NewSession(authConfig, requestdecorator.NewRequestFactory(), endpoint, true) - if err != nil { - t.Fatal(err) - } - return len(r.reqFactory.GetDecorators()) - } - - decorators := getSessionDecorators(makeIndex("/v1/")) - assertEqual(t, decorators, 0, "Expected no decorator on http session") - - decorators = getSessionDecorators(makeHttpsIndex("/v1/")) - assertNotEqual(t, decorators, 0, "Expected decorator on https session") - - decorators = getSessionDecorators(makePublicIndex()) - assertEqual(t, decorators, 0, "Expected no decorator on public session") -} - func TestPingRegistryEndpoint(t *testing.T) { testPing := func(index *IndexInfo, expectedStandalone bool, assertMessage string) { - ep, err := NewEndpoint(index) + ep, err := NewEndpoint(index, nil) if err != nil { t.Fatal(err) } @@ -80,7 +69,7 @@ func TestPingRegistryEndpoint(t *testing.T) { func TestEndpoint(t *testing.T) { // Simple wrapper to fail test if err != nil expandEndpoint := func(index *IndexInfo) *Endpoint { - endpoint, err := NewEndpoint(index) + endpoint, err := NewEndpoint(index, nil) if err != nil { t.Fatal(err) } @@ -89,7 +78,7 @@ func TestEndpoint(t *testing.T) { assertInsecureIndex := func(index *IndexInfo) { index.Secure = true - _, err := NewEndpoint(index) + _, err := NewEndpoint(index, nil) assertNotEqual(t, err, nil, index.Name+": Expected error for insecure index") assertEqual(t, strings.Contains(err.Error(), "insecure-registry"), true, index.Name+": Expected insecure-registry error for insecure index") index.Secure = false @@ -97,7 +86,7 @@ func TestEndpoint(t *testing.T) { assertSecureIndex := func(index *IndexInfo) { index.Secure = true - _, err := NewEndpoint(index) + _, err := NewEndpoint(index, nil) assertNotEqual(t, err, nil, index.Name+": Expected cert error for secure index") assertEqual(t, strings.Contains(err.Error(), "certificate signed by unknown authority"), true, index.Name+": Expected cert error for secure index") index.Secure = false @@ -163,14 +152,14 @@ func TestEndpoint(t *testing.T) { } for _, address := range badEndpoints { index.Name = address - _, err := NewEndpoint(index) + _, err := NewEndpoint(index, nil) checkNotEqual(t, err, nil, "Expected error while expanding bad endpoint") } } func TestGetRemoteHistory(t *testing.T) { r := spawnTestRegistrySession(t) - hist, err := r.GetRemoteHistory(imageID, makeURL("/v1/"), token) + hist, err := r.GetRemoteHistory(imageID, makeURL("/v1/")) if err != nil { t.Fatal(err) } @@ -182,16 +171,16 @@ func TestGetRemoteHistory(t *testing.T) { func TestLookupRemoteImage(t *testing.T) { r := spawnTestRegistrySession(t) - err := r.LookupRemoteImage(imageID, makeURL("/v1/"), token) + err := r.LookupRemoteImage(imageID, makeURL("/v1/")) assertEqual(t, err, nil, "Expected error of remote lookup to nil") - if err := r.LookupRemoteImage("abcdef", makeURL("/v1/"), token); err == nil { + if err := r.LookupRemoteImage("abcdef", makeURL("/v1/")); err == nil { t.Fatal("Expected error of remote lookup to not nil") } } func TestGetRemoteImageJSON(t *testing.T) { r := spawnTestRegistrySession(t) - json, size, err := r.GetRemoteImageJSON(imageID, makeURL("/v1/"), token) + json, size, err := r.GetRemoteImageJSON(imageID, makeURL("/v1/")) if err != nil { t.Fatal(err) } @@ -200,7 +189,7 @@ func TestGetRemoteImageJSON(t *testing.T) { t.Fatal("Expected non-empty json") } - _, _, err = r.GetRemoteImageJSON("abcdef", makeURL("/v1/"), token) + _, _, err = r.GetRemoteImageJSON("abcdef", makeURL("/v1/")) if err == nil { t.Fatal("Expected image not found error") } @@ -208,7 +197,7 @@ func TestGetRemoteImageJSON(t *testing.T) { func TestGetRemoteImageLayer(t *testing.T) { r := spawnTestRegistrySession(t) - data, err := r.GetRemoteImageLayer(imageID, makeURL("/v1/"), token, 0) + data, err := r.GetRemoteImageLayer(imageID, makeURL("/v1/"), 0) if err != nil { t.Fatal(err) } @@ -216,7 +205,7 @@ func TestGetRemoteImageLayer(t *testing.T) { t.Fatal("Expected non-nil data result") } - _, err = r.GetRemoteImageLayer("abcdef", makeURL("/v1/"), token, 0) + _, err = r.GetRemoteImageLayer("abcdef", makeURL("/v1/"), 0) if err == nil { t.Fatal("Expected image not found error") } @@ -224,14 +213,14 @@ func TestGetRemoteImageLayer(t *testing.T) { func TestGetRemoteTags(t *testing.T) { r := spawnTestRegistrySession(t) - tags, err := r.GetRemoteTags([]string{makeURL("/v1/")}, REPO, token) + tags, err := r.GetRemoteTags([]string{makeURL("/v1/")}, REPO) if err != nil { t.Fatal(err) } assertEqual(t, len(tags), 1, "Expected one tag") assertEqual(t, tags["latest"], imageID, "Expected tag latest to map to "+imageID) - _, err = r.GetRemoteTags([]string{makeURL("/v1/")}, "foo42/baz", token) + _, err = r.GetRemoteTags([]string{makeURL("/v1/")}, "foo42/baz") if err == nil { t.Fatal("Expected error when fetching tags for bogus repo") } @@ -265,7 +254,7 @@ func TestPushImageJSONRegistry(t *testing.T) { Checksum: "sha256:1ac330d56e05eef6d438586545ceff7550d3bdcb6b19961f12c5ba714ee1bb37", } - err := r.PushImageJSONRegistry(imgData, []byte{0x42, 0xdf, 0x0}, makeURL("/v1/"), token) + err := r.PushImageJSONRegistry(imgData, []byte{0x42, 0xdf, 0x0}, makeURL("/v1/")) if err != nil { t.Fatal(err) } @@ -274,7 +263,7 @@ func TestPushImageJSONRegistry(t *testing.T) { func TestPushImageLayerRegistry(t *testing.T) { r := spawnTestRegistrySession(t) layer := strings.NewReader("") - _, _, err := r.PushImageLayerRegistry(imageID, layer, makeURL("/v1/"), token, []byte{}) + _, _, err := r.PushImageLayerRegistry(imageID, layer, makeURL("/v1/"), []byte{}) if err != nil { t.Fatal(err) } @@ -694,7 +683,7 @@ func TestNewIndexInfo(t *testing.T) { func TestPushRegistryTag(t *testing.T) { r := spawnTestRegistrySession(t) - err := r.PushRegistryTag("foo42/bar", imageID, "stable", makeURL("/v1/"), token) + err := r.PushRegistryTag("foo42/bar", imageID, "stable", makeURL("/v1/")) if err != nil { t.Fatal(err) } diff --git a/registry/service.go b/registry/service.go index 87fc1d076..681174927 100644 --- a/registry/service.go +++ b/registry/service.go @@ -1,6 +1,10 @@ package registry -import "github.com/docker/docker/cliconfig" +import ( + "net/http" + + "github.com/docker/docker/cliconfig" +) type Service struct { Config *ServiceConfig @@ -27,12 +31,12 @@ func (s *Service) Auth(authConfig *cliconfig.AuthConfig) (string, error) { if err != nil { return "", err } - endpoint, err := NewEndpoint(index) + endpoint, err := NewEndpoint(index, nil) if err != nil { return "", err } authConfig.ServerAddress = endpoint.String() - return Login(authConfig, endpoint, HTTPRequestFactory(nil)) + return Login(authConfig, endpoint) } // Search queries the public registry for images matching the specified @@ -42,12 +46,13 @@ func (s *Service) Search(term string, authConfig *cliconfig.AuthConfig, headers if err != nil { return nil, err } + // *TODO: Search multiple indexes. - endpoint, err := repoInfo.GetEndpoint() + endpoint, err := repoInfo.GetEndpoint(http.Header(headers)) if err != nil { return nil, err } - r, err := NewSession(authConfig, HTTPRequestFactory(headers), endpoint, true) + r, err := NewSession(endpoint.client, authConfig, endpoint) if err != nil { return nil, err } diff --git a/registry/session.go b/registry/session.go index e65f82cd6..8e54bc821 100644 --- a/registry/session.go +++ b/registry/session.go @@ -3,6 +3,8 @@ package registry import ( "bytes" "crypto/sha256" + "errors" + "sync" // this is required for some certificates _ "crypto/sha512" "encoding/hex" @@ -20,64 +22,143 @@ import ( "github.com/Sirupsen/logrus" "github.com/docker/docker/cliconfig" "github.com/docker/docker/pkg/httputils" - "github.com/docker/docker/pkg/requestdecorator" "github.com/docker/docker/pkg/tarsum" + "github.com/docker/docker/pkg/transport" ) type Session struct { - authConfig *cliconfig.AuthConfig - reqFactory *requestdecorator.RequestFactory indexEndpoint *Endpoint - jar *cookiejar.Jar - timeout TimeoutType + client *http.Client + // TODO(tiborvass): remove authConfig + authConfig *cliconfig.AuthConfig } -func NewSession(authConfig *cliconfig.AuthConfig, factory *requestdecorator.RequestFactory, endpoint *Endpoint, timeout bool) (r *Session, err error) { +type authTransport struct { + http.RoundTripper + *cliconfig.AuthConfig + + alwaysSetBasicAuth bool + token []string + + mu sync.Mutex // guards modReq + modReq map[*http.Request]*http.Request // original -> modified +} + +// AuthTransport handles the auth layer when communicating with a v1 registry (private or official) +// +// For private v1 registries, set alwaysSetBasicAuth to true. +// +// For the official v1 registry, if there isn't already an Authorization header in the request, +// but there is an X-Docker-Token header set to true, then Basic Auth will be used to set the Authorization header. +// After sending the request with the provided base http.RoundTripper, if an X-Docker-Token header, representing +// a token, is present in the response, then it gets cached and sent in the Authorization header of all subsequent +// requests. +// +// If the server sends a token without the client having requested it, it is ignored. +// +// This RoundTripper also has a CancelRequest method important for correct timeout handling. +func AuthTransport(base http.RoundTripper, authConfig *cliconfig.AuthConfig, alwaysSetBasicAuth bool) http.RoundTripper { + if base == nil { + base = http.DefaultTransport + } + return &authTransport{ + RoundTripper: base, + AuthConfig: authConfig, + alwaysSetBasicAuth: alwaysSetBasicAuth, + modReq: make(map[*http.Request]*http.Request), + } +} + +func (tr *authTransport) RoundTrip(orig *http.Request) (*http.Response, error) { + req := transport.CloneRequest(orig) + tr.mu.Lock() + tr.modReq[orig] = req + tr.mu.Unlock() + + if tr.alwaysSetBasicAuth { + req.SetBasicAuth(tr.Username, tr.Password) + return tr.RoundTripper.RoundTrip(req) + } + + var askedForToken bool + + // Don't override + if req.Header.Get("Authorization") == "" { + if req.Header.Get("X-Docker-Token") == "true" { + req.SetBasicAuth(tr.Username, tr.Password) + askedForToken = true + } else if len(tr.token) > 0 { + req.Header.Set("Authorization", "Token "+strings.Join(tr.token, ",")) + } + } + resp, err := tr.RoundTripper.RoundTrip(req) + if err != nil { + delete(tr.modReq, orig) + return nil, err + } + if askedForToken && len(resp.Header["X-Docker-Token"]) > 0 { + tr.token = resp.Header["X-Docker-Token"] + } + resp.Body = &transport.OnEOFReader{ + Rc: resp.Body, + Fn: func() { delete(tr.modReq, orig) }, + } + return resp, nil +} + +// CancelRequest cancels an in-flight request by closing its connection. +func (tr *authTransport) CancelRequest(req *http.Request) { + type canceler interface { + CancelRequest(*http.Request) + } + if cr, ok := tr.RoundTripper.(canceler); ok { + tr.mu.Lock() + modReq := tr.modReq[req] + delete(tr.modReq, req) + tr.mu.Unlock() + cr.CancelRequest(modReq) + } +} + +// TODO(tiborvass): remove authConfig param once registry client v2 is vendored +func NewSession(client *http.Client, authConfig *cliconfig.AuthConfig, endpoint *Endpoint) (r *Session, err error) { r = &Session{ authConfig: authConfig, + client: client, indexEndpoint: endpoint, } - if timeout { - r.timeout = ReceiveTimeout - } - - r.jar, err = cookiejar.New(nil) - if err != nil { - return nil, err - } + var alwaysSetBasicAuth bool // If we're working with a standalone private registry over HTTPS, send Basic Auth headers - // alongside our requests. - if r.indexEndpoint.VersionString(1) != IndexServerAddress() && r.indexEndpoint.URL.Scheme == "https" { - info, err := r.indexEndpoint.Ping() + // alongside all our requests. + if endpoint.VersionString(1) != IndexServerAddress() && endpoint.URL.Scheme == "https" { + info, err := endpoint.Ping() if err != nil { return nil, err } - if info.Standalone && authConfig != nil && factory != nil { - logrus.Debugf("Endpoint %s is eligible for private registry. Enabling decorator.", r.indexEndpoint.String()) - dec := requestdecorator.NewAuthDecorator(authConfig.Username, authConfig.Password) - factory.AddDecorator(dec) + + if info.Standalone && authConfig != nil { + logrus.Debugf("Endpoint %s is eligible for private registry. Enabling decorator.", endpoint.String()) + alwaysSetBasicAuth = true } } - r.reqFactory = factory - return r, nil -} + client.Transport = AuthTransport(client.Transport, authConfig, alwaysSetBasicAuth) -func (r *Session) doRequest(req *http.Request) (*http.Response, *http.Client, error) { - return doRequest(req, r.jar, r.timeout, r.indexEndpoint.IsSecure) + jar, err := cookiejar.New(nil) + if err != nil { + return nil, errors.New("cookiejar.New is not supposed to return an error") + } + client.Jar = jar + + return r, nil } // Retrieve the history of a given image from the Registry. // Return a list of the parent's json (requested image included) -func (r *Session) GetRemoteHistory(imgID, registry string, token []string) ([]string, error) { - req, err := r.reqFactory.NewRequest("GET", registry+"images/"+imgID+"/ancestry", nil) - if err != nil { - return nil, err - } - setTokenAuth(req, token) - res, _, err := r.doRequest(req) +func (r *Session) GetRemoteHistory(imgID, registry string) ([]string, error) { + res, err := r.client.Get(registry + "images/" + imgID + "/ancestry") if err != nil { return nil, err } @@ -89,27 +170,18 @@ func (r *Session) GetRemoteHistory(imgID, registry string, token []string) ([]st return nil, httputils.NewHTTPRequestError(fmt.Sprintf("Server error: %d trying to fetch remote history for %s", res.StatusCode, imgID), res) } - jsonString, err := ioutil.ReadAll(res.Body) - if err != nil { - return nil, fmt.Errorf("Error while reading the http response: %s", err) + var history []string + if err := json.NewDecoder(res.Body).Decode(&history); err != nil { + return nil, fmt.Errorf("Error while reading the http response: %v", err) } - logrus.Debugf("Ancestry: %s", jsonString) - history := new([]string) - if err := json.Unmarshal(jsonString, history); err != nil { - return nil, err - } - return *history, nil + logrus.Debugf("Ancestry: %v", history) + return history, nil } // Check if an image exists in the Registry -func (r *Session) LookupRemoteImage(imgID, registry string, token []string) error { - req, err := r.reqFactory.NewRequest("GET", registry+"images/"+imgID+"/json", nil) - if err != nil { - return err - } - setTokenAuth(req, token) - res, _, err := r.doRequest(req) +func (r *Session) LookupRemoteImage(imgID, registry string) error { + res, err := r.client.Get(registry + "images/" + imgID + "/json") if err != nil { return err } @@ -121,14 +193,8 @@ func (r *Session) LookupRemoteImage(imgID, registry string, token []string) erro } // Retrieve an image from the Registry. -func (r *Session) GetRemoteImageJSON(imgID, registry string, token []string) ([]byte, int, error) { - // Get the JSON - req, err := r.reqFactory.NewRequest("GET", registry+"images/"+imgID+"/json", nil) - if err != nil { - return nil, -1, fmt.Errorf("Failed to download json: %s", err) - } - setTokenAuth(req, token) - res, _, err := r.doRequest(req) +func (r *Session) GetRemoteImageJSON(imgID, registry string) ([]byte, int, error) { + res, err := r.client.Get(registry + "images/" + imgID + "/json") if err != nil { return nil, -1, fmt.Errorf("Failed to download json: %s", err) } @@ -147,44 +213,44 @@ func (r *Session) GetRemoteImageJSON(imgID, registry string, token []string) ([] jsonString, err := ioutil.ReadAll(res.Body) if err != nil { - return nil, -1, fmt.Errorf("Failed to parse downloaded json: %s (%s)", err, jsonString) + return nil, -1, fmt.Errorf("Failed to parse downloaded json: %v (%s)", err, jsonString) } return jsonString, imageSize, nil } -func (r *Session) GetRemoteImageLayer(imgID, registry string, token []string, imgSize int64) (io.ReadCloser, error) { +func (r *Session) GetRemoteImageLayer(imgID, registry string, imgSize int64) (io.ReadCloser, error) { var ( retries = 5 statusCode = 0 - client *http.Client res *http.Response + err error imageURL = fmt.Sprintf("%simages/%s/layer", registry, imgID) ) - req, err := r.reqFactory.NewRequest("GET", imageURL, nil) + req, err := http.NewRequest("GET", imageURL, nil) if err != nil { - return nil, fmt.Errorf("Error while getting from the server: %s\n", err) + return nil, fmt.Errorf("Error while getting from the server: %v", err) } - setTokenAuth(req, token) + // TODO: why are we doing retries at this level? + // These retries should be generic to both v1 and v2 for i := 1; i <= retries; i++ { statusCode = 0 - res, client, err = r.doRequest(req) - if err != nil { - logrus.Debugf("Error contacting registry: %s", err) - if res != nil { - if res.Body != nil { - res.Body.Close() - } - statusCode = res.StatusCode - } - if i == retries { - return nil, fmt.Errorf("Server error: Status %d while fetching image layer (%s)", - statusCode, imgID) - } - time.Sleep(time.Duration(i) * 5 * time.Second) - continue + res, err = r.client.Do(req) + if err == nil { + break } - break + logrus.Debugf("Error contacting registry %s: %v", registry, err) + if res != nil { + if res.Body != nil { + res.Body.Close() + } + statusCode = res.StatusCode + } + if i == retries { + return nil, fmt.Errorf("Server error: Status %d while fetching image layer (%s)", + statusCode, imgID) + } + time.Sleep(time.Duration(i) * 5 * time.Second) } if res.StatusCode != 200 { @@ -195,13 +261,13 @@ func (r *Session) GetRemoteImageLayer(imgID, registry string, token []string, im if res.Header.Get("Accept-Ranges") == "bytes" && imgSize > 0 { logrus.Debugf("server supports resume") - return httputils.ResumableRequestReaderWithInitialResponse(client, req, 5, imgSize, res), nil + return httputils.ResumableRequestReaderWithInitialResponse(r.client, req, 5, imgSize, res), nil } logrus.Debugf("server doesn't support resume") return res.Body, nil } -func (r *Session) GetRemoteTags(registries []string, repository string, token []string) (map[string]string, error) { +func (r *Session) GetRemoteTags(registries []string, repository string) (map[string]string, error) { if strings.Count(repository, "/") == 0 { // This will be removed once the Registry supports auto-resolution on // the "library" namespace @@ -209,13 +275,7 @@ func (r *Session) GetRemoteTags(registries []string, repository string, token [] } for _, host := range registries { endpoint := fmt.Sprintf("%srepositories/%s/tags", host, repository) - req, err := r.reqFactory.NewRequest("GET", endpoint, nil) - - if err != nil { - return nil, err - } - setTokenAuth(req, token) - res, _, err := r.doRequest(req) + res, err := r.client.Get(endpoint) if err != nil { return nil, err } @@ -263,16 +323,13 @@ func (r *Session) GetRepositoryData(remote string) (*RepositoryData, error) { logrus.Debugf("[registry] Calling GET %s", repositoryTarget) - req, err := r.reqFactory.NewRequest("GET", repositoryTarget, nil) + req, err := http.NewRequest("GET", repositoryTarget, nil) if err != nil { return nil, err } - if r.authConfig != nil && len(r.authConfig.Username) > 0 { - req.SetBasicAuth(r.authConfig.Username, r.authConfig.Password) - } + // this will set basic auth in r.client.Transport and send cached X-Docker-Token headers for all subsequent requests req.Header.Set("X-Docker-Token", "true") - - res, _, err := r.doRequest(req) + res, err := r.client.Do(req) if err != nil { return nil, err } @@ -292,11 +349,6 @@ func (r *Session) GetRepositoryData(remote string) (*RepositoryData, error) { return nil, httputils.NewHTTPRequestError(fmt.Sprintf("Error: Status %d trying to pull repository %s: %q", res.StatusCode, remote, errBody), res) } - var tokens []string - if res.Header.Get("X-Docker-Token") != "" { - tokens = res.Header["X-Docker-Token"] - } - var endpoints []string if res.Header.Get("X-Docker-Endpoints") != "" { endpoints, err = buildEndpointsList(res.Header["X-Docker-Endpoints"], r.indexEndpoint.VersionString(1)) @@ -322,29 +374,29 @@ func (r *Session) GetRepositoryData(remote string) (*RepositoryData, error) { return &RepositoryData{ ImgList: imgsData, Endpoints: endpoints, - Tokens: tokens, }, nil } -func (r *Session) PushImageChecksumRegistry(imgData *ImgData, registry string, token []string) error { +func (r *Session) PushImageChecksumRegistry(imgData *ImgData, registry string) error { - logrus.Debugf("[registry] Calling PUT %s", registry+"images/"+imgData.ID+"/checksum") + u := registry + "images/" + imgData.ID + "/checksum" - req, err := r.reqFactory.NewRequest("PUT", registry+"images/"+imgData.ID+"/checksum", nil) + logrus.Debugf("[registry] Calling PUT %s", u) + + req, err := http.NewRequest("PUT", u, nil) if err != nil { return err } - setTokenAuth(req, token) req.Header.Set("X-Docker-Checksum", imgData.Checksum) req.Header.Set("X-Docker-Checksum-Payload", imgData.ChecksumPayload) - res, _, err := r.doRequest(req) + res, err := r.client.Do(req) if err != nil { - return fmt.Errorf("Failed to upload metadata: %s", err) + return fmt.Errorf("Failed to upload metadata: %v", err) } defer res.Body.Close() if len(res.Cookies()) > 0 { - r.jar.SetCookies(req.URL, res.Cookies()) + r.client.Jar.SetCookies(req.URL, res.Cookies()) } if res.StatusCode != 200 { errBody, err := ioutil.ReadAll(res.Body) @@ -363,18 +415,19 @@ func (r *Session) PushImageChecksumRegistry(imgData *ImgData, registry string, t } // Push a local image to the registry -func (r *Session) PushImageJSONRegistry(imgData *ImgData, jsonRaw []byte, registry string, token []string) error { +func (r *Session) PushImageJSONRegistry(imgData *ImgData, jsonRaw []byte, registry string) error { - logrus.Debugf("[registry] Calling PUT %s", registry+"images/"+imgData.ID+"/json") + u := registry + "images/" + imgData.ID + "/json" - req, err := r.reqFactory.NewRequest("PUT", registry+"images/"+imgData.ID+"/json", bytes.NewReader(jsonRaw)) + logrus.Debugf("[registry] Calling PUT %s", u) + + req, err := http.NewRequest("PUT", u, bytes.NewReader(jsonRaw)) if err != nil { return err } req.Header.Add("Content-type", "application/json") - setTokenAuth(req, token) - res, _, err := r.doRequest(req) + res, err := r.client.Do(req) if err != nil { return fmt.Errorf("Failed to upload metadata: %s", err) } @@ -398,9 +451,11 @@ func (r *Session) PushImageJSONRegistry(imgData *ImgData, jsonRaw []byte, regist return nil } -func (r *Session) PushImageLayerRegistry(imgID string, layer io.Reader, registry string, token []string, jsonRaw []byte) (checksum string, checksumPayload string, err error) { +func (r *Session) PushImageLayerRegistry(imgID string, layer io.Reader, registry string, jsonRaw []byte) (checksum string, checksumPayload string, err error) { - logrus.Debugf("[registry] Calling PUT %s", registry+"images/"+imgID+"/layer") + u := registry + "images/" + imgID + "/layer" + + logrus.Debugf("[registry] Calling PUT %s", u) tarsumLayer, err := tarsum.NewTarSum(layer, false, tarsum.Version0) if err != nil { @@ -411,17 +466,16 @@ func (r *Session) PushImageLayerRegistry(imgID string, layer io.Reader, registry h.Write([]byte{'\n'}) checksumLayer := io.TeeReader(tarsumLayer, h) - req, err := r.reqFactory.NewRequest("PUT", registry+"images/"+imgID+"/layer", checksumLayer) + req, err := http.NewRequest("PUT", u, checksumLayer) if err != nil { return "", "", err } req.Header.Add("Content-Type", "application/octet-stream") req.ContentLength = -1 req.TransferEncoding = []string{"chunked"} - setTokenAuth(req, token) - res, _, err := r.doRequest(req) + res, err := r.client.Do(req) if err != nil { - return "", "", fmt.Errorf("Failed to upload layer: %s", err) + return "", "", fmt.Errorf("Failed to upload layer: %v", err) } if rc, ok := layer.(io.Closer); ok { if err := rc.Close(); err != nil { @@ -444,19 +498,18 @@ func (r *Session) PushImageLayerRegistry(imgID string, layer io.Reader, registry // push a tag on the registry. // Remote has the format '/ -func (r *Session) PushRegistryTag(remote, revision, tag, registry string, token []string) error { +func (r *Session) PushRegistryTag(remote, revision, tag, registry string) error { // "jsonify" the string revision = "\"" + revision + "\"" path := fmt.Sprintf("repositories/%s/tags/%s", remote, tag) - req, err := r.reqFactory.NewRequest("PUT", registry+path, strings.NewReader(revision)) + req, err := http.NewRequest("PUT", registry+path, strings.NewReader(revision)) if err != nil { return err } req.Header.Add("Content-type", "application/json") - setTokenAuth(req, token) req.ContentLength = int64(len(revision)) - res, _, err := r.doRequest(req) + res, err := r.client.Do(req) if err != nil { return err } @@ -491,7 +544,8 @@ func (r *Session) PushImageJSONIndex(remote string, imgList []*ImgData, validate logrus.Debugf("[registry] PUT %s", u) logrus.Debugf("Image list pushed to index:\n%s", imgListJSON) headers := map[string][]string{ - "Content-type": {"application/json"}, + "Content-type": {"application/json"}, + // this will set basic auth in r.client.Transport and send cached X-Docker-Token headers for all subsequent requests "X-Docker-Token": {"true"}, } if validate { @@ -526,9 +580,6 @@ func (r *Session) PushImageJSONIndex(remote string, imgList []*ImgData, validate } return nil, httputils.NewHTTPRequestError(fmt.Sprintf("Error: Status %d trying to push repository %s: %q", res.StatusCode, remote, errBody), res) } - if res.Header.Get("X-Docker-Token") == "" { - return nil, fmt.Errorf("Index response didn't contain an access token") - } tokens = res.Header["X-Docker-Token"] logrus.Debugf("Auth token: %v", tokens) @@ -539,8 +590,7 @@ func (r *Session) PushImageJSONIndex(remote string, imgList []*ImgData, validate if err != nil { return nil, err } - } - if validate { + } else { if res.StatusCode != 204 { errBody, err := ioutil.ReadAll(res.Body) if err != nil { @@ -551,22 +601,20 @@ func (r *Session) PushImageJSONIndex(remote string, imgList []*ImgData, validate } return &RepositoryData{ - Tokens: tokens, Endpoints: endpoints, }, nil } func (r *Session) putImageRequest(u string, headers map[string][]string, body []byte) (*http.Response, error) { - req, err := r.reqFactory.NewRequest("PUT", u, bytes.NewReader(body)) + req, err := http.NewRequest("PUT", u, bytes.NewReader(body)) if err != nil { return nil, err } - req.SetBasicAuth(r.authConfig.Username, r.authConfig.Password) req.ContentLength = int64(len(body)) for k, v := range headers { req.Header[k] = v } - response, _, err := r.doRequest(req) + response, err := r.client.Do(req) if err != nil { return nil, err } @@ -580,15 +628,7 @@ func shouldRedirect(response *http.Response) bool { func (r *Session) SearchRepositories(term string) (*SearchResults, error) { logrus.Debugf("Index server: %s", r.indexEndpoint) u := r.indexEndpoint.VersionString(1) + "search?q=" + url.QueryEscape(term) - req, err := r.reqFactory.NewRequest("GET", u, nil) - if err != nil { - return nil, err - } - if r.authConfig != nil && len(r.authConfig.Username) > 0 { - req.SetBasicAuth(r.authConfig.Username, r.authConfig.Password) - } - req.Header.Set("X-Docker-Token", "true") - res, _, err := r.doRequest(req) + res, err := r.client.Get(u) if err != nil { return nil, err } @@ -600,6 +640,7 @@ func (r *Session) SearchRepositories(term string) (*SearchResults, error) { return result, json.NewDecoder(res.Body).Decode(result) } +// TODO(tiborvass): remove this once registry client v2 is vendored func (r *Session) GetAuthConfig(withPasswd bool) *cliconfig.AuthConfig { password := "" if withPasswd { @@ -611,9 +652,3 @@ func (r *Session) GetAuthConfig(withPasswd bool) *cliconfig.AuthConfig { Email: r.authConfig.Email, } } - -func setTokenAuth(req *http.Request, token []string) { - if req.Header.Get("Authorization") == "" { // Don't override - req.Header.Set("Authorization", "Token "+strings.Join(token, ",")) - } -} diff --git a/registry/session_v2.go b/registry/session_v2.go index 4188e505b..b66017289 100644 --- a/registry/session_v2.go +++ b/registry/session_v2.go @@ -27,7 +27,7 @@ func getV2Builder(e *Endpoint) *v2.URLBuilder { func (r *Session) V2RegistryEndpoint(index *IndexInfo) (ep *Endpoint, err error) { // TODO check if should use Mirror if index.Official { - ep, err = newEndpoint(REGISTRYSERVER, true) + ep, err = newEndpoint(REGISTRYSERVER, true, nil) if err != nil { return } @@ -38,7 +38,7 @@ func (r *Session) V2RegistryEndpoint(index *IndexInfo) (ep *Endpoint, err error) } else if r.indexEndpoint.String() == index.GetAuthConfigKey() { ep = r.indexEndpoint } else { - ep, err = NewEndpoint(index) + ep, err = NewEndpoint(index, nil) if err != nil { return } @@ -77,14 +77,14 @@ func (r *Session) GetV2ImageManifest(ep *Endpoint, imageName, tagName string, au method := "GET" logrus.Debugf("[registry] Calling %q %s", method, routeURL) - req, err := r.reqFactory.NewRequest(method, routeURL, nil) + req, err := http.NewRequest(method, routeURL, nil) if err != nil { return nil, "", err } if err := auth.Authorize(req); err != nil { return nil, "", err } - res, _, err := r.doRequest(req) + res, err := r.client.Do(req) if err != nil { return nil, "", err } @@ -118,14 +118,14 @@ func (r *Session) HeadV2ImageBlob(ep *Endpoint, imageName string, dgst digest.Di method := "HEAD" logrus.Debugf("[registry] Calling %q %s", method, routeURL) - req, err := r.reqFactory.NewRequest(method, routeURL, nil) + req, err := http.NewRequest(method, routeURL, nil) if err != nil { return false, err } if err := auth.Authorize(req); err != nil { return false, err } - res, _, err := r.doRequest(req) + res, err := r.client.Do(req) if err != nil { return false, err } @@ -152,14 +152,14 @@ func (r *Session) GetV2ImageBlob(ep *Endpoint, imageName string, dgst digest.Dig method := "GET" logrus.Debugf("[registry] Calling %q %s", method, routeURL) - req, err := r.reqFactory.NewRequest(method, routeURL, nil) + req, err := http.NewRequest(method, routeURL, nil) if err != nil { return err } if err := auth.Authorize(req); err != nil { return err } - res, _, err := r.doRequest(req) + res, err := r.client.Do(req) if err != nil { return err } @@ -183,14 +183,14 @@ func (r *Session) GetV2ImageBlobReader(ep *Endpoint, imageName string, dgst dige method := "GET" logrus.Debugf("[registry] Calling %q %s", method, routeURL) - req, err := r.reqFactory.NewRequest(method, routeURL, nil) + req, err := http.NewRequest(method, routeURL, nil) if err != nil { return nil, 0, err } if err := auth.Authorize(req); err != nil { return nil, 0, err } - res, _, err := r.doRequest(req) + res, err := r.client.Do(req) if err != nil { return nil, 0, err } @@ -220,7 +220,7 @@ func (r *Session) PutV2ImageBlob(ep *Endpoint, imageName string, dgst digest.Dig method := "PUT" logrus.Debugf("[registry] Calling %q %s", method, location) - req, err := r.reqFactory.NewRequest(method, location, ioutil.NopCloser(blobRdr)) + req, err := http.NewRequest(method, location, ioutil.NopCloser(blobRdr)) if err != nil { return err } @@ -230,7 +230,7 @@ func (r *Session) PutV2ImageBlob(ep *Endpoint, imageName string, dgst digest.Dig if err := auth.Authorize(req); err != nil { return err } - res, _, err := r.doRequest(req) + res, err := r.client.Do(req) if err != nil { return err } @@ -259,7 +259,7 @@ func (r *Session) initiateBlobUpload(ep *Endpoint, imageName string, auth *Reque } logrus.Debugf("[registry] Calling %q %s", "POST", routeURL) - req, err := r.reqFactory.NewRequest("POST", routeURL, nil) + req, err := http.NewRequest("POST", routeURL, nil) if err != nil { return "", err } @@ -267,7 +267,7 @@ func (r *Session) initiateBlobUpload(ep *Endpoint, imageName string, auth *Reque if err := auth.Authorize(req); err != nil { return "", err } - res, _, err := r.doRequest(req) + res, err := r.client.Do(req) if err != nil { return "", err } @@ -305,14 +305,14 @@ func (r *Session) PutV2ImageManifest(ep *Endpoint, imageName, tagName string, si method := "PUT" logrus.Debugf("[registry] Calling %q %s", method, routeURL) - req, err := r.reqFactory.NewRequest(method, routeURL, bytes.NewReader(signedManifest)) + req, err := http.NewRequest(method, routeURL, bytes.NewReader(signedManifest)) if err != nil { return "", err } if err := auth.Authorize(req); err != nil { return "", err } - res, _, err := r.doRequest(req) + res, err := r.client.Do(req) if err != nil { return "", err } @@ -366,14 +366,14 @@ func (r *Session) GetV2RemoteTags(ep *Endpoint, imageName string, auth *RequestA method := "GET" logrus.Debugf("[registry] Calling %q %s", method, routeURL) - req, err := r.reqFactory.NewRequest(method, routeURL, nil) + req, err := http.NewRequest(method, routeURL, nil) if err != nil { return nil, err } if err := auth.Authorize(req); err != nil { return nil, err } - res, _, err := r.doRequest(req) + res, err := r.client.Do(req) if err != nil { return nil, err } diff --git a/registry/token.go b/registry/token.go index b03bd891b..e27cb6f52 100644 --- a/registry/token.go +++ b/registry/token.go @@ -7,15 +7,13 @@ import ( "net/http" "net/url" "strings" - - "github.com/docker/docker/pkg/requestdecorator" ) type tokenResponse struct { Token string `json:"token"` } -func getToken(username, password string, params map[string]string, registryEndpoint *Endpoint, client *http.Client, factory *requestdecorator.RequestFactory) (token string, err error) { +func getToken(username, password string, params map[string]string, registryEndpoint *Endpoint) (token string, err error) { realm, ok := params["realm"] if !ok { return "", errors.New("no realm specified for token auth challenge") @@ -34,7 +32,7 @@ func getToken(username, password string, params map[string]string, registryEndpo } } - req, err := factory.NewRequest("GET", realmURL.String(), nil) + req, err := http.NewRequest("GET", realmURL.String(), nil) if err != nil { return "", err } @@ -58,7 +56,7 @@ func getToken(username, password string, params map[string]string, registryEndpo req.URL.RawQuery = reqParams.Encode() - resp, err := client.Do(req) + resp, err := registryEndpoint.client.Do(req) if err != nil { return "", err }