fix flaky unit test (#2749)

Signed-off-by: David Fridrich <fridrich.david19@gmail.com>
This commit is contained in:
David Fridrich 2025-03-18 15:56:14 +01:00 committed by GitHub
parent 525761a199
commit 817c77bbec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 92 additions and 78 deletions

View File

@ -92,7 +92,7 @@ func TestCheckAuth(t *testing.T) {
incorrectPwd = "badpwd"
)
localhost, localhostTLS := startServer(t, uname, pwd)
localhost, localhostTLS, cert := startServer(t, uname, pwd)
_, portTLS, err := net.SplitHostPort(localhostTLS)
if err != nil {
@ -132,7 +132,6 @@ func TestCheckAuth(t *testing.T) {
},
wantErr: false,
},
{
name: "correct credentials non-localhost",
args: args{
@ -170,7 +169,30 @@ func TestCheckAuth(t *testing.T) {
Username: tt.args.username,
Password: tt.args.password,
}
if err := creds.CheckAuth(tt.args.ctx, tt.args.registry+"/someorg/someimage:sometag", c, http.DefaultTransport); (err != nil) != tt.wantErr {
// create trusted certificates pool and add our certificate
certPool := x509.NewCertPool()
certPool.AddCert(cert)
// client transport with the certificate
transport := &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: certPool,
},
}
dialer := &net.Dialer{}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
h, p, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
if h == "test.io" {
h = "localhost"
}
return dialer.DialContext(ctx, network, net.JoinHostPort(h, p))
}
if err := creds.CheckAuth(tt.args.ctx, tt.args.registry+"/someorg/someimage:sometag", c, transport); (err != nil) != tt.wantErr {
t.Errorf("CheckAuth() error = %v, wantErr %v", err, tt.wantErr)
}
})
@ -179,52 +201,15 @@ func TestCheckAuth(t *testing.T) {
func TestCheckAuthEmptyCreds(t *testing.T) {
localhost, _ := startServer(t, "", "")
localhost, _, _ := startServer(t, "", "")
err := creds.CheckAuth(context.Background(), localhost+"/someorg/someimage:sometag", docker.Credentials{}, http.DefaultTransport)
if err != nil {
t.Error(err)
}
}
func startServer(t *testing.T, uname, pwd string) (addr, addrTLS string) {
// TODO: this should be refactored to use OS-chosen ports so as not to
// fail when a user is running a function on the default port.)
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatal(err)
}
addr = listener.Addr().String()
listenerTLS, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatal(err)
}
addrTLS = listenerTLS.Addr().String()
handler := http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
if uname == "" || pwd == "" {
if req.Method == http.MethodPost {
resp.WriteHeader(http.StatusCreated)
} else {
resp.WriteHeader(http.StatusOK)
}
return
}
// TODO add also test for token based auth
resp.Header().Add("WWW-Authenticate", "basic")
if u, p, ok := req.BasicAuth(); ok {
if u == uname && p == pwd {
if req.Method == http.MethodPost {
resp.WriteHeader(http.StatusCreated)
} else {
resp.WriteHeader(http.StatusOK)
}
return
}
}
resp.WriteHeader(http.StatusUnauthorized)
})
// generate Certificates
func generateCert(t *testing.T) (tls.Certificate, *x509.Certificate) {
var randReader io.Reader = rand.Reader
caPublicKey, caPrivateKey, err := ed25519.GenerateKey(randReader)
@ -232,15 +217,13 @@ func startServer(t *testing.T, uname, pwd string) (addr, addrTLS string) {
t.Fatal(err)
}
ca := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: "localhost",
},
caTemplate := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "localhost"},
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
DNSNames: []string{"localhost", "test.io"},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
NotAfter: time.Now().AddDate(1, 0, 0),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
ExtraExtensions: []pkix.Extension{},
@ -248,72 +231,103 @@ func startServer(t *testing.T, uname, pwd string) (addr, addrTLS string) {
BasicConstraintsValid: true,
}
caBytes, err := x509.CreateCertificate(randReader, ca, ca, caPublicKey, caPrivateKey)
caBytes, err := x509.CreateCertificate(randReader, caTemplate, caTemplate, caPublicKey, caPrivateKey)
if err != nil {
t.Fatal(err)
}
ca, err = x509.ParseCertificate(caBytes)
ca, err := x509.ParseCertificate(caBytes)
if err != nil {
t.Fatal(err)
}
cert := tls.Certificate{
tls := tls.Certificate{
Certificate: [][]byte{caBytes},
PrivateKey: caPrivateKey,
Leaf: ca,
}
return tls, ca
}
func startServer(t *testing.T, uname, pwd string) (addr, addrTLS string, ca *x509.Certificate) {
// create a custom handler function
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// no authentication required, empty creds
if uname == "" || pwd == "" {
if r.Method == http.MethodPost {
w.WriteHeader(http.StatusCreated)
} else {
w.WriteHeader(http.StatusOK)
}
return
}
w.Header().Add("WWW-Authenticate", "basic")
if u, p, ok := r.BasicAuth(); ok {
if u == uname && p == pwd {
if r.Method == http.MethodPost {
w.WriteHeader(http.StatusCreated)
} else {
w.WriteHeader(http.StatusOK)
}
return
}
}
w.WriteHeader(http.StatusUnauthorized)
})
// Setup certificates
// tls Cert for the TLS server (has ca as Leaf)
// x509 certificate which is its own CA for client
tlsCert, ca := generateCert(t)
// create Server config
server := http.Server{
Handler: handler,
TLSConfig: &tls.Config{
ServerName: "localhost",
Certificates: []tls.Certificate{cert},
ServerName: "localhost",
// with the TLS certificate
Certificates: []tls.Certificate{tlsCert},
},
}
// non-TLS listener
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatal(err)
}
// TLS listener
listenerTLS, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatal(err)
}
addr = listener.Addr().String()
addrTLS = listenerTLS.Addr().String()
// listen for requests
go func() {
err := server.ServeTLS(listenerTLS, "", "")
if err != nil && !strings.Contains(err.Error(), "Server closed") {
if err != nil && err != http.ErrServerClosed {
panic(err)
}
}()
go func() {
err := server.Serve(listener)
if err != nil && !strings.Contains(err.Error(), "Server closed") {
if err != nil && err != http.ErrServerClosed {
panic(err)
}
}()
// make the testing CA trusted by default HTTP transport/client
oldDefaultTransport := http.DefaultTransport
newDefaultTransport := http.DefaultTransport.(*http.Transport).Clone()
http.DefaultTransport = newDefaultTransport
caPool := x509.NewCertPool()
caPool.AddCert(ca)
newDefaultTransport.TLSClientConfig.RootCAs = caPool
dc := newDefaultTransport.DialContext
newDefaultTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
h, p, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
if h == "test.io" {
h = "localhost"
}
addr = net.JoinHostPort(h, p)
return dc(ctx, network, addr)
}
// shutdown servers at cleanup
t.Cleanup(func() {
err := server.Shutdown(context.Background())
if err != nil {
t.Fatal(err)
}
http.DefaultTransport = oldDefaultTransport
})
return addr, addrTLS
return
}
const (