From fcb0c7139f11f01c5db4b5fb49eb06b39b2ed9c4 Mon Sep 17 00:00:00 2001 From: Matej Vasek Date: Tue, 7 Dec 2021 16:53:58 +0100 Subject: [PATCH] src: CheckAuth() calls registry directly (#704) CheckAuth() calls registry directly not using docker daemon as a middle-man Signed-off-by: Matej Vasek --- docker/pusher.go | 51 ++++++++--- docker/pusher_test.go | 204 ++++++++++++++++++++++++++++++++++++++++++ go.mod | 1 + vendor/modules.txt | 1 + 4 files changed, 246 insertions(+), 11 deletions(-) diff --git a/docker/pusher.go b/docker/pusher.go index 73e0fff3..b892d2f1 100644 --- a/docker/pusher.go +++ b/docker/pusher.go @@ -8,11 +8,16 @@ import ( "errors" "fmt" "io" + "net/http" "os" "path/filepath" "regexp" "strings" + "github.com/google/go-containerregistry/pkg/authn" + "github.com/google/go-containerregistry/pkg/name" + "github.com/google/go-containerregistry/pkg/v1/remote/transport" + "github.com/docker/docker/client" fn "knative.dev/kn-plugin-func" @@ -20,7 +25,6 @@ import ( "github.com/containers/image/v5/pkg/docker/config" containersTypes "github.com/containers/image/v5/types" "github.com/docker/docker/api/types" - "github.com/docker/docker/errdefs" ) type Opt func(*Pusher) error @@ -41,24 +45,49 @@ var ErrUnauthorized = errors.New("bad credentials") type VerifyCredentialsCallback func(ctx context.Context, username, password, registry string) error func CheckAuth(ctx context.Context, username, password, registry string) error { - cli, _, err := NewClient(client.DefaultDockerHost) + serverAddress := registry + if !strings.HasPrefix(serverAddress, "https://") && !strings.HasPrefix(serverAddress, "http://") { + serverAddress = "https://" + serverAddress + } + + url := fmt.Sprintf("%s/v2", serverAddress) + + authenticator := &authn.Basic{ + Username: username, + Password: password, + } + + reg, err := name.NewRegistry(registry) if err != nil { return err } - defer cli.Close() - _, err = cli.RegistryLogin(ctx, types.AuthConfig{Username: username, Password: password, ServerAddress: registry}) - if err != nil && strings.Contains(err.Error(), "401 Unauthorized") { - return ErrUnauthorized + tr, err := transport.NewWithContext(ctx, reg, authenticator, http.DefaultTransport, nil) + if err != nil { + return err } - // podman hack until https://github.com/containers/podman/pull/11595 is merged - // podman returns 400 (instead of 500) and body in unexpected shape - if errdefs.IsInvalidParameter(err) { - return ErrUnauthorized + cli := http.Client{Transport: tr} + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return err } - return err + resp, err := cli.Do(req) + if err != nil { + return fmt.Errorf("failed to verify credentials: %w", err) + } + defer resp.Body.Close() + + switch { + case resp.StatusCode == http.StatusUnauthorized: + return ErrUnauthorized + case resp.StatusCode != http.StatusOK: + return fmt.Errorf("failed to verify credentials: status code: %d", resp.StatusCode) + default: + return nil + } } type ChooseCredentialHelperCallback func(available []string) (string, error) diff --git a/docker/pusher_test.go b/docker/pusher_test.go index ecda1c1f..d504cc21 100644 --- a/docker/pusher_test.go +++ b/docker/pusher_test.go @@ -5,11 +5,18 @@ package docker import ( "context" + "crypto/ed25519" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" "encoding/base64" "encoding/json" "errors" "fmt" + "io" "io/ioutil" + "math/big" "net" "net/http" "os" @@ -654,3 +661,200 @@ func (i *inMemoryHelper) Delete(serverURL string) error { return credentials.NewErrCredentialsNotFound() } + +func TestCheckAuth(t *testing.T) { + localhost, localhostTLS, stopServer := startServer(t) + defer stopServer() + + _, portTLS, err := net.SplitHostPort(localhostTLS) + if err != nil { + t.Fatal(err) + } + + nonLocalhostTLS := "test.io:" + portTLS + + type args struct { + ctx context.Context + username string + password string + registry string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "correct credentials localhost no-TLS", + args: args{ + ctx: context.Background(), + username: "testuser", + password: "testpwd", + registry: localhost, + }, + wantErr: false, + }, + { + name: "correct credentials localhost", + args: args{ + ctx: context.Background(), + username: "testuser", + password: "testpwd", + registry: localhostTLS, + }, + wantErr: false, + }, + + { + name: "correct credentials non-localhost", + args: args{ + ctx: context.Background(), + username: "testuser", + password: "testpwd", + registry: nonLocalhostTLS, + }, + wantErr: false, + }, + { + name: "incorrect credentials localhost no-TLS", + args: args{ + ctx: context.Background(), + username: "testuser", + password: "badpwd", + registry: localhost, + }, + wantErr: true, + }, + { + name: "incorrect credentials localhost", + args: args{ + ctx: context.Background(), + username: "testuser", + password: "badpwd", + registry: localhostTLS, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := CheckAuth(tt.args.ctx, tt.args.username, tt.args.password, tt.args.registry); (err != nil) != tt.wantErr { + t.Errorf("CheckAuth() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func startServer(t *testing.T) (addr, addrTLS string, stopServer func()) { + listener, err := net.Listen("tcp", "localhost:8080") + if err != nil { + t.Fatal(err) + } + addr = listener.Addr().String() + + listenerTLS, err := net.Listen("tcp", "localhost:4433") + if err != nil { + t.Fatal(err) + } + addrTLS = listenerTLS.Addr().String() + + handler := http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + resp.Header().Add("WWW-Authenticate", "basic") + if user, pwd, ok := req.BasicAuth(); ok { + if user == "testuser" && pwd == "testpwd" { + resp.WriteHeader(http.StatusOK) + return + } + } + resp.WriteHeader(http.StatusUnauthorized) + }) + + var randReader io.Reader = rand.Reader + + caPublicKey, caPrivateKey, err := ed25519.GenerateKey(randReader) + if err != nil { + t.Fatal(err) + } + + ca := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "localhost", + }, + IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback}, + DNSNames: []string{"localhost", "test.io"}, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + ExtraExtensions: []pkix.Extension{}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + + caBytes, err := x509.CreateCertificate(randReader, ca, ca, caPublicKey, caPrivateKey) + if err != nil { + t.Fatal(err) + } + + ca, err = x509.ParseCertificate(caBytes) + if err != nil { + t.Fatal(err) + } + + cert := tls.Certificate{ + Certificate: [][]byte{caBytes}, + PrivateKey: caPrivateKey, + Leaf: ca, + } + + server := http.Server{ + Handler: handler, + TLSConfig: &tls.Config{ + ServerName: "localhost", + Certificates: []tls.Certificate{cert}, + }, + } + + go func() { + err := server.ServeTLS(listenerTLS, "", "") + if err != nil && !strings.Contains(err.Error(), "Server closed") { + panic(err) + } + }() + + go func() { + err := server.Serve(listener) + if err != nil && !strings.Contains(err.Error(), "Server closed") { + panic(err) + } + }() + + // make the testing CA trusted by default HTTP transport/client + oldDefaultTransport := http.DefaultTransport + newDefaultTransport := http.DefaultTransport.(*http.Transport).Clone() + http.DefaultTransport = newDefaultTransport + caPool := x509.NewCertPool() + caPool.AddCert(ca) + newDefaultTransport.TLSClientConfig.RootCAs = caPool + dc := newDefaultTransport.DialContext + newDefaultTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + h, p, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + if h == "test.io" { + h = "localhost" + } + addr = net.JoinHostPort(h, p) + return dc(ctx, network, addr) + } + + return addr, addrTLS, func() { + err := server.Shutdown(context.Background()) + if err != nil { + t.Fatal(err) + } + http.DefaultTransport = oldDefaultTransport + } +} diff --git a/go.mod b/go.mod index 9bb36b0b..85b828b6 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/go-git/go-billy/v5 v5.3.1 github.com/go-git/go-git/v5 v5.4.2 github.com/google/go-cmp v0.5.6 + github.com/google/go-containerregistry v0.6.0 github.com/google/uuid v1.3.0 github.com/hinshun/vt10x v0.0.0-20180809195222-d55458df857c github.com/markbates/pkger v0.17.1 diff --git a/vendor/modules.txt b/vendor/modules.txt index 50db1d68..99cb90e9 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -396,6 +396,7 @@ github.com/google/go-cmp/cmp/internal/flags github.com/google/go-cmp/cmp/internal/function github.com/google/go-cmp/cmp/internal/value # github.com/google/go-containerregistry v0.6.0 +## explicit github.com/google/go-containerregistry/internal/and github.com/google/go-containerregistry/internal/estargz github.com/google/go-containerregistry/internal/gzip