diff --git a/pkg/docker/creds/credentials_test.go b/pkg/docker/creds/credentials_test.go index 0a09a4e5..ff3f31dc 100644 --- a/pkg/docker/creds/credentials_test.go +++ b/pkg/docker/creds/credentials_test.go @@ -92,7 +92,7 @@ func TestCheckAuth(t *testing.T) { incorrectPwd = "badpwd" ) - localhost, localhostTLS := startServer(t, uname, pwd) + localhost, localhostTLS, cert := startServer(t, uname, pwd) _, portTLS, err := net.SplitHostPort(localhostTLS) if err != nil { @@ -132,7 +132,6 @@ func TestCheckAuth(t *testing.T) { }, wantErr: false, }, - { name: "correct credentials non-localhost", args: args{ @@ -170,7 +169,30 @@ func TestCheckAuth(t *testing.T) { Username: tt.args.username, Password: tt.args.password, } - if err := creds.CheckAuth(tt.args.ctx, tt.args.registry+"/someorg/someimage:sometag", c, http.DefaultTransport); (err != nil) != tt.wantErr { + // create trusted certificates pool and add our certificate + certPool := x509.NewCertPool() + certPool.AddCert(cert) + + // client transport with the certificate + transport := &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: certPool, + }, + } + + dialer := &net.Dialer{} + + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + h, p, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + if h == "test.io" { + h = "localhost" + } + return dialer.DialContext(ctx, network, net.JoinHostPort(h, p)) + } + if err := creds.CheckAuth(tt.args.ctx, tt.args.registry+"/someorg/someimage:sometag", c, transport); (err != nil) != tt.wantErr { t.Errorf("CheckAuth() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -179,52 +201,15 @@ func TestCheckAuth(t *testing.T) { func TestCheckAuthEmptyCreds(t *testing.T) { - localhost, _ := startServer(t, "", "") + localhost, _, _ := startServer(t, "", "") err := creds.CheckAuth(context.Background(), localhost+"/someorg/someimage:sometag", docker.Credentials{}, http.DefaultTransport) if err != nil { t.Error(err) } } -func startServer(t *testing.T, uname, pwd string) (addr, addrTLS string) { - // TODO: this should be refactored to use OS-chosen ports so as not to - // fail when a user is running a function on the default port.) - listener, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatal(err) - } - addr = listener.Addr().String() - - listenerTLS, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatal(err) - } - addrTLS = listenerTLS.Addr().String() - - handler := http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - if uname == "" || pwd == "" { - if req.Method == http.MethodPost { - resp.WriteHeader(http.StatusCreated) - } else { - resp.WriteHeader(http.StatusOK) - } - return - } - // TODO add also test for token based auth - resp.Header().Add("WWW-Authenticate", "basic") - if u, p, ok := req.BasicAuth(); ok { - if u == uname && p == pwd { - if req.Method == http.MethodPost { - resp.WriteHeader(http.StatusCreated) - } else { - resp.WriteHeader(http.StatusOK) - } - return - } - } - resp.WriteHeader(http.StatusUnauthorized) - }) - +// generate Certificates +func generateCert(t *testing.T) (tls.Certificate, *x509.Certificate) { var randReader io.Reader = rand.Reader caPublicKey, caPrivateKey, err := ed25519.GenerateKey(randReader) @@ -232,15 +217,13 @@ func startServer(t *testing.T, uname, pwd string) (addr, addrTLS string) { t.Fatal(err) } - ca := &x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{ - CommonName: "localhost", - }, + caTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "localhost"}, IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback}, DNSNames: []string{"localhost", "test.io"}, NotBefore: time.Now(), - NotAfter: time.Now().AddDate(10, 0, 0), + NotAfter: time.Now().AddDate(1, 0, 0), IsCA: true, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, ExtraExtensions: []pkix.Extension{}, @@ -248,72 +231,103 @@ func startServer(t *testing.T, uname, pwd string) (addr, addrTLS string) { BasicConstraintsValid: true, } - caBytes, err := x509.CreateCertificate(randReader, ca, ca, caPublicKey, caPrivateKey) + caBytes, err := x509.CreateCertificate(randReader, caTemplate, caTemplate, caPublicKey, caPrivateKey) if err != nil { t.Fatal(err) } - ca, err = x509.ParseCertificate(caBytes) + ca, err := x509.ParseCertificate(caBytes) if err != nil { t.Fatal(err) } - cert := tls.Certificate{ + tls := tls.Certificate{ Certificate: [][]byte{caBytes}, PrivateKey: caPrivateKey, Leaf: ca, } + return tls, ca +} +func startServer(t *testing.T, uname, pwd string) (addr, addrTLS string, ca *x509.Certificate) { + // create a custom handler function + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // no authentication required, empty creds + if uname == "" || pwd == "" { + if r.Method == http.MethodPost { + w.WriteHeader(http.StatusCreated) + } else { + w.WriteHeader(http.StatusOK) + } + return + } + + w.Header().Add("WWW-Authenticate", "basic") + if u, p, ok := r.BasicAuth(); ok { + if u == uname && p == pwd { + if r.Method == http.MethodPost { + w.WriteHeader(http.StatusCreated) + } else { + w.WriteHeader(http.StatusOK) + } + return + } + } + w.WriteHeader(http.StatusUnauthorized) + }) + + // Setup certificates + // tls Cert for the TLS server (has ca as Leaf) + // x509 certificate which is its own CA for client + tlsCert, ca := generateCert(t) + + // create Server config server := http.Server{ Handler: handler, TLSConfig: &tls.Config{ - ServerName: "localhost", - Certificates: []tls.Certificate{cert}, + ServerName: "localhost", + // with the TLS certificate + Certificates: []tls.Certificate{tlsCert}, }, } + // non-TLS listener + listener, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + + // TLS listener + listenerTLS, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + addr = listener.Addr().String() + addrTLS = listenerTLS.Addr().String() + + // listen for requests go func() { err := server.ServeTLS(listenerTLS, "", "") - if err != nil && !strings.Contains(err.Error(), "Server closed") { + if err != nil && err != http.ErrServerClosed { panic(err) } }() go func() { err := server.Serve(listener) - if err != nil && !strings.Contains(err.Error(), "Server closed") { + if err != nil && err != http.ErrServerClosed { panic(err) } }() - // make the testing CA trusted by default HTTP transport/client - oldDefaultTransport := http.DefaultTransport - newDefaultTransport := http.DefaultTransport.(*http.Transport).Clone() - http.DefaultTransport = newDefaultTransport - caPool := x509.NewCertPool() - caPool.AddCert(ca) - newDefaultTransport.TLSClientConfig.RootCAs = caPool - dc := newDefaultTransport.DialContext - newDefaultTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - h, p, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - if h == "test.io" { - h = "localhost" - } - addr = net.JoinHostPort(h, p) - return dc(ctx, network, addr) - } - + // shutdown servers at cleanup t.Cleanup(func() { err := server.Shutdown(context.Background()) if err != nil { t.Fatal(err) } - http.DefaultTransport = oldDefaultTransport }) - return addr, addrTLS + return } const (