Enable custom CA for specified server names (#770)

* src: refactor

Signed-off-by: Matej Vasek <mvasek@redhat.com>

* src: Allow usage of custom CA for

Allows TLS verify against custom CA for chosen server names.

Signed-off-by: Matej Vasek <mvasek@redhat.com>

* fixup: style

Signed-off-by: Matej Vasek <mvasek@redhat.com>

* fixup: lint

Signed-off-by: Matej Vasek <mvasek@redhat.com>

* fixup: cleanup

Signed-off-by: Matej Vasek <mvasek@redhat.com>
This commit is contained in:
Matej Vasek 2022-01-21 15:50:04 +01:00 committed by GitHub
parent ce938122d8
commit 8ceb325142
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 315 additions and 53 deletions

View File

@ -2,45 +2,86 @@ package http
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"net"
"net/http"
"sync"
"time"
"knative.dev/kn-plugin-func/k8s"
)
type ContextDialer interface {
DialContext(ctx context.Context, network string, addr string) (net.Conn, error)
Close() error
}
type RoundTripCloser interface {
http.RoundTripper
io.Closer
}
type options struct {
selectCA func(ctx context.Context, serverName string) (*x509.Certificate, error)
inClusterDialer ContextDialer
}
type Option func(*options)
func WithSelectCA(selectCA func(ctx context.Context, serverName string) (*x509.Certificate, error)) Option {
return func(o *options) {
o.selectCA = selectCA
}
}
func WithInClusterDialer(inClusterDialer ContextDialer) Option {
return func(o *options) {
o.inClusterDialer = inClusterDialer
}
}
// NewRoundTripper returns new closable RoundTripper that first tries to dial connection in standard way,
// if the dial operation fails due to hostname resolution the RoundTripper tries to dial from in cluster pod.
//
// This is useful for accessing cluster internal services (pushing a CloudEvent into Knative broker).
func NewRoundTripper() RoundTripCloser {
result := &roundTripCloser{}
func NewRoundTripper(opts ...Option) RoundTripCloser {
o := options{
inClusterDialer: k8s.NewLazyInitInClusterDialer(),
}
for _, option := range opts {
option(&o)
}
httpTransport := newHTTPTransport()
primaryDialer := dialContextFn(httpTransport.DialContext)
secondaryDialer := o.inClusterDialer
combinedDialer := newDialerWithFallback(primaryDialer, secondaryDialer)
httpTransport.DialContext = combinedDialer.DialContext
httpTransport.DialTLSContext = newDialTLSContext(combinedDialer, httpTransport.TLSClientConfig, o.selectCA)
return &roundTripCloser{
Transport: httpTransport,
dialer: combinedDialer,
}
}
func newHTTPTransport() *http.Transport {
if dt, ok := http.DefaultTransport.(*http.Transport); ok {
d := &dialer{
defaultDialContext: dt.DialContext,
}
result.d = d
result.Transport = dt.Clone()
result.Transport.DialContext = d.DialContext
return dt.Clone()
} else {
d := &dialer{
defaultDialContext: (&net.Dialer{
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
}
result.d = d
result.Transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: d.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
@ -48,27 +89,31 @@ func NewRoundTripper() RoundTripCloser {
ExpectContinueTimeout: 1 * time.Second,
}
}
return result
}
type roundTripCloser struct {
*http.Transport
d *dialer
dialer ContextDialer
}
func (r *roundTripCloser) Close() error {
return r.d.Close()
return r.dialer.Close()
}
type dialer struct {
o sync.Once
defaultDialContext func(ctx context.Context, network, address string) (net.Conn, error)
inClusterDialer k8s.ContextDialer
func newDialerWithFallback(primaryDialer ContextDialer, fallbackDialer ContextDialer) *dialerWithFallback {
return &dialerWithFallback{
primaryDialer: primaryDialer,
fallbackDialer: fallbackDialer,
}
}
func (d *dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := d.defaultDialContext(ctx, network, address)
type dialerWithFallback struct {
primaryDialer ContextDialer
fallbackDialer ContextDialer
}
func (d *dialerWithFallback) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := d.primaryDialer.DialContext(ctx, network, address)
if err == nil {
return conn, nil
}
@ -77,26 +122,73 @@ func (d *dialer) DialContext(ctx context.Context, network, address string) (net.
if !(errors.As(err, &dnsErr) && dnsErr.IsNotFound) {
return nil, err
}
err = nil
d.o.Do(func() {
d.inClusterDialer, err = k8s.NewInClusterDialer(ctx)
})
if err != nil {
return nil, err
}
if d.inClusterDialer == nil {
return nil, errors.New("failed to init in cluster dialer")
}
return d.inClusterDialer.DialContext(ctx, network, address)
return d.fallbackDialer.DialContext(ctx, network, address)
}
func (d *dialer) Close() error {
if d.inClusterDialer != nil {
return d.inClusterDialer.Close()
func (d *dialerWithFallback) Close() error {
var err error
errs := make([]error, 0, 2)
err = d.primaryDialer.Close()
if err != nil {
errs = append(errs, err)
}
err = d.fallbackDialer.Close()
if err != nil {
errs = append(errs, err)
}
if len(errs) > 0 {
return fmt.Errorf("failed to Close(): %v", errs)
}
return nil
}
type dialContextFn func(ctx context.Context, network string, addr string) (net.Conn, error)
func (d dialContextFn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
return d(ctx, network, addr)
}
func (d dialContextFn) Close() error { return nil }
func newDialTLSContext(dialer ContextDialer, config *tls.Config, selectCA func(ctx context.Context, serverName string) (*x509.Certificate, error)) func(ctx context.Context, network, addr string) (net.Conn, error) {
if selectCA == nil {
return nil
}
return func(ctx context.Context, network, addr string) (net.Conn, error) {
conn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
var cfg *tls.Config
if config != nil {
cfg = config.Clone()
} else {
cfg = &tls.Config{}
}
serverName, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
if cfg.ServerName == "" {
cfg.ServerName = serverName
}
if ca, err := selectCA(ctx, serverName); ca != nil && err == nil {
caPool := x509.NewCertPool()
caPool.AddCert(ca)
cfg.RootCAs = caPool
}
tlsConn := tls.Client(conn, cfg)
return tlsConn, nil
}
}

150
http/transport_test.go Normal file
View File

@ -0,0 +1,150 @@
package http_test
import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"math/big"
"net"
"net/http"
"testing"
"time"
fnhttp "knative.dev/kn-plugin-func/http"
)
const inClusterHostName = "image-registry.openshift-image-registry.svc"
func TestCustomCA(t *testing.T) {
var err error
inClusterAddr, inClusterCA := startServer(t, inClusterHostName)
localhostAddr, localhostCA := startServer(t, "localhost")
mockSelectCA := func(ctx context.Context, serverName string) (*x509.Certificate, error) {
if serverName == inClusterHostName {
return inClusterCA, nil
}
if serverName == "localhost" {
return localhostCA, nil
}
return nil, nil
}
mockInCusterDialer := mockInClusterDialer{
backingAddr: inClusterAddr,
}
tr := fnhttp.NewRoundTripper(
fnhttp.WithSelectCA(mockSelectCA),
fnhttp.WithInClusterDialer(mockInCusterDialer))
defer tr.Close()
client := http.Client{Transport: tr}
_, p, err := net.SplitHostPort(localhostAddr)
if err != nil {
t.Fatal(err)
}
resp, err := client.Get(fmt.Sprintf("https://localhost:%s", p))
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
resp, err = client.Get(fmt.Sprintf("https://%s:5000/v2/", inClusterHostName))
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
}
type mockInClusterDialer struct {
backingAddr string
}
func (m mockInClusterDialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
hostname, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
if hostname == inClusterHostName {
return net.Dial(network, m.backingAddr)
}
return net.Dial(network, addr)
}
func (m mockInClusterDialer) Close() error {
return nil
}
func startServer(t *testing.T, hostname string) (addr string, ca *x509.Certificate) {
randReader := rand.Reader
caPublicKey, caPrivateKey, err := ed25519.GenerateKey(randReader)
if err != nil {
t.Fatal(err)
}
ca = &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: hostname,
},
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
DNSNames: []string{"localhost", hostname},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
ExtraExtensions: []pkix.Extension{},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
caBytes, err := x509.CreateCertificate(randReader, ca, ca, caPublicKey, caPrivateKey)
if err != nil {
t.Fatal()
}
ca, err = x509.ParseCertificate(caBytes)
if err != nil {
t.Fatal(err)
}
cert := tls.Certificate{
Certificate: [][]byte{caBytes},
PrivateKey: caPrivateKey,
Leaf: ca,
}
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatal(err)
}
addr = listener.Addr().String()
handler := http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
})
server := http.Server{
Handler: handler,
TLSConfig: &tls.Config{
ServerName: hostname,
Certificates: []tls.Certificate{cert},
},
}
t.Cleanup(func() {
server.Close()
})
go func() {
_ = server.ServeTLS(listener, "", "")
}()
return
}

View File

@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net"
"sync"
"time"
coreV1 "k8s.io/api/core/v1"
@ -25,11 +26,6 @@ const (
socatImage = "alpine/socat:1.7.4.2-r0"
)
type ContextDialer interface {
DialContext(ctx context.Context, network string, addr string) (net.Conn, error)
Close() error
}
// NewInClusterDialer creates context dialer that will dial TCP connections via POD running in k8s cluster.
// This is useful when accessing k8s services that are not exposed outside cluster (e.g. openshift image registry).
//
@ -48,7 +44,7 @@ type ContextDialer interface {
// var client = http.Client{
// Transport: transport,
// }
func NewInClusterDialer(ctx context.Context) (ContextDialer, error) {
func NewInClusterDialer(ctx context.Context) (*contextDialer, error) {
c := &contextDialer{
detachChan: make(chan struct{}),
}
@ -372,3 +368,30 @@ func newConn(execDone <-chan struct{}) (*io.PipeReader, *io.PipeWriter, conn) {
rwc := conn{pr: pr0, pw: pw1, execDone: execDone}
return pr1, pw0, rwc
}
func NewLazyInitInClusterDialer() *lazyInitInClusterDialer {
return &lazyInitInClusterDialer{}
}
type lazyInitInClusterDialer struct {
contextDialer *contextDialer
initErr error
o sync.Once
}
func (l *lazyInitInClusterDialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
l.o.Do(func() {
l.contextDialer, l.initErr = NewInClusterDialer(ctx)
})
if l.initErr != nil {
return nil, l.initErr
}
return l.contextDialer.DialContext(ctx, network, addr)
}
func (l *lazyInitInClusterDialer) Close() error {
if l.contextDialer != nil {
return l.contextDialer.Close()
}
return nil
}

View File

@ -109,10 +109,7 @@ func TestDialInClusterService(t *testing.T) {
// wait for service to start
time.Sleep(time.Second * 10)
dialer, err := k8s.NewInClusterDialer(ctx)
if err != nil {
t.Fatal(err)
}
dialer := k8s.NewLazyInitInClusterDialer()
t.Cleanup(func() {
dialer.Close()
})