func/pkg/http/transport.go

205 lines
4.8 KiB
Go

package http
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"net"
"net/http"
"time"
"knative.dev/func/pkg/k8s"
)
type ContextDialer interface {
DialContext(ctx context.Context, network string, addr string) (net.Conn, error)
Close() error
}
type RoundTripCloser interface {
http.RoundTripper
io.Closer
}
type options struct {
selectCA func(ctx context.Context, serverName string) (*x509.Certificate, error)
inClusterDialer ContextDialer
insecureSkipVerify bool
}
type Option func(*options)
func WithSelectCA(selectCA func(ctx context.Context, serverName string) (*x509.Certificate, error)) Option {
return func(o *options) {
o.selectCA = selectCA
}
}
func WithInClusterDialer(inClusterDialer ContextDialer) Option {
return func(o *options) {
o.inClusterDialer = inClusterDialer
}
}
func WithInsecureSkipVerify(insecureSkipVerify bool) Option {
return func(o *options) {
o.insecureSkipVerify = insecureSkipVerify
}
}
// NewRoundTripper returns new closable RoundTripper that first tries to dial connection in standard way,
// if the dial operation fails due to hostname resolution the RoundTripper tries to dial from in cluster pod.
//
// This is useful for accessing cluster internal services (pushing a CloudEvent into Knative broker).
func NewRoundTripper(opts ...Option) RoundTripCloser {
o := options{
inClusterDialer: k8s.NewLazyInitInClusterDialer(k8s.GetClientConfig()),
insecureSkipVerify: false,
}
for _, option := range opts {
option(&o)
}
httpTransport := newHTTPTransport()
primaryDialer := dialContextFn(httpTransport.DialContext)
secondaryDialer := o.inClusterDialer
combinedDialer := newDialerWithFallback(primaryDialer, secondaryDialer)
httpTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: o.insecureSkipVerify}
httpTransport.DialContext = combinedDialer.DialContext
httpTransport.DialTLSContext = newDialTLSContext(combinedDialer, httpTransport.TLSClientConfig, o.selectCA)
return &roundTripCloser{
Transport: httpTransport,
dialer: combinedDialer,
}
}
func newHTTPTransport() *http.Transport {
if dt, ok := http.DefaultTransport.(*http.Transport); ok {
return dt.Clone()
} else {
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: time.Minute,
KeepAlive: time.Minute,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}
}
type roundTripCloser struct {
*http.Transport
dialer ContextDialer
}
func (r *roundTripCloser) Close() error {
return r.dialer.Close()
}
func newDialerWithFallback(primaryDialer ContextDialer, fallbackDialer ContextDialer) *dialerWithFallback {
return &dialerWithFallback{
primaryDialer: primaryDialer,
fallbackDialer: fallbackDialer,
}
}
type dialerWithFallback struct {
primaryDialer ContextDialer
fallbackDialer ContextDialer
}
func (d *dialerWithFallback) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := d.primaryDialer.DialContext(ctx, network, address)
if err == nil {
return conn, nil
}
var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
return nil, err
}
return d.fallbackDialer.DialContext(ctx, network, address)
}
func (d *dialerWithFallback) Close() error {
var err error
errs := make([]error, 0, 2)
err = d.primaryDialer.Close()
if err != nil {
errs = append(errs, err)
}
err = d.fallbackDialer.Close()
if err != nil {
errs = append(errs, err)
}
if len(errs) > 0 {
return fmt.Errorf("failed to Close(): %v", errs)
}
return nil
}
type dialContextFn func(ctx context.Context, network string, addr string) (net.Conn, error)
func (d dialContextFn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
return d(ctx, network, addr)
}
func (d dialContextFn) Close() error { return nil }
func newDialTLSContext(dialer ContextDialer, config *tls.Config, selectCA func(ctx context.Context, serverName string) (*x509.Certificate, error)) func(ctx context.Context, network, addr string) (net.Conn, error) {
if selectCA == nil {
return nil
}
return func(ctx context.Context, network, addr string) (net.Conn, error) {
conn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
var cfg *tls.Config
if config != nil {
cfg = config.Clone()
} else {
cfg = &tls.Config{}
}
serverName, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
if cfg.ServerName == "" {
cfg.ServerName = serverName
}
if ca, err := selectCA(ctx, serverName); ca != nil && err == nil {
caPool := x509.NewCertPool()
caPool.AddCert(ca)
cfg.RootCAs = caPool
}
tlsConn := tls.Client(conn, cfg)
return tlsConn, nil
}
}