From 75da3911c0205df6c1b8acb74662d96a6e37b305 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-R=C3=A9my=20Bancel?= Date: Thu, 5 Sep 2019 07:36:05 -0700 Subject: [PATCH] Spoof at the Resolver layer, not the Request layer. (#630) --- test/clients.go | 4 +- test/request.go | 29 ++++++++++--- test/spoof/spoof.go | 99 ++++++++++++++++++++++++++------------------- 3 files changed, 83 insertions(+), 49 deletions(-) diff --git a/test/clients.go b/test/clients.go index c3f43bdab..6f9347f2a 100644 --- a/test/clients.go +++ b/test/clients.go @@ -38,8 +38,8 @@ type KubeClient struct { } // NewSpoofingClient returns a spoofing client to make requests -func NewSpoofingClient(client *KubeClient, logf logging.FormatLogger, domain string, resolvable bool) (*spoof.SpoofingClient, error) { - return spoof.New(client.Kube, logf, domain, resolvable, Flags.IngressEndpoint) +func NewSpoofingClient(client *KubeClient, logf logging.FormatLogger, domain string, resolvable bool, opts ...spoof.TransportOption) (*spoof.SpoofingClient, error) { + return spoof.New(client.Kube, logf, domain, resolvable, Flags.IngressEndpoint, opts...) } // NewKubeClient instantiates and returns several clientsets required for making request to the diff --git a/test/request.go b/test/request.go index d6d84f356..8b38cb29a 100644 --- a/test/request.go +++ b/test/request.go @@ -134,7 +134,14 @@ func MatchesAllOf(checkers ...spoof.ResponseChecker) spoof.ResponseChecker { // the domain in the request headers, otherwise it will make the request directly to domain. // desc will be used to name the metric that is emitted to track how long it took for the // domain to get into the state checked by inState. Commas in `desc` must be escaped. -func WaitForEndpointState(kubeClient *KubeClient, logf logging.FormatLogger, theURL string, inState spoof.ResponseChecker, desc string, resolvable bool, opts ...RequestOption) (*spoof.Response, error) { +func WaitForEndpointState( + kubeClient *KubeClient, + logf logging.FormatLogger, + theURL string, + inState spoof.ResponseChecker, + desc string, + resolvable bool, + opts ...interface{}) (*spoof.Response, error) { return WaitForEndpointStateWithTimeout(kubeClient, logf, theURL, inState, desc, resolvable, spoof.RequestTimeout, opts...) } @@ -145,8 +152,14 @@ func WaitForEndpointState(kubeClient *KubeClient, logf logging.FormatLogger, the // desc will be used to name the metric that is emitted to track how long it took for the // domain to get into the state checked by inState. Commas in `desc` must be escaped. func WaitForEndpointStateWithTimeout( - kubeClient *KubeClient, logf logging.FormatLogger, theURL string, inState spoof.ResponseChecker, - desc string, resolvable bool, timeout time.Duration, opts ...RequestOption) (*spoof.Response, error) { + kubeClient *KubeClient, + logf logging.FormatLogger, + theURL string, + inState spoof.ResponseChecker, + desc string, + resolvable bool, + timeout time.Duration, + opts ...interface{}) (*spoof.Response, error) { defer logging.GetEmitableSpan(context.Background(), fmt.Sprintf("WaitForEndpointState/%s", desc)).End() // Try parsing the "theURL" with and without a scheme. @@ -163,11 +176,17 @@ func WaitForEndpointStateWithTimeout( return nil, err } + var tOpts []spoof.TransportOption for _, opt := range opts { - opt(req) + rOpt, ok := opt.(RequestOption) + if ok { + rOpt(req) + } else if tOpt, ok := opt.(spoof.TransportOption); ok { + tOpts = append(tOpts, tOpt) + } } - client, err := NewSpoofingClient(kubeClient, logf, asURL.Hostname(), resolvable) + client, err := NewSpoofingClient(kubeClient, logf, asURL.Hostname(), resolvable, tOpts...) if err != nil { return nil, err } diff --git a/test/spoof/spoof.go b/test/spoof/spoof.go index 125a37a9f..a102bd1c5 100644 --- a/test/spoof/spoof.go +++ b/test/spoof/spoof.go @@ -19,9 +19,12 @@ limitations under the License. package spoof import ( + "context" "fmt" "io/ioutil" + "net" "net/http" + "strings" "time" "github.com/pkg/errors" @@ -66,7 +69,10 @@ type Interface interface { } // https://medium.com/stupid-gopher-tricks/ensuring-go-interface-satisfaction-at-compile-time-1ed158e8fa17 -var _ Interface = (*SpoofingClient)(nil) +var ( + _ Interface = (*SpoofingClient)(nil) + dialContext = (&net.Dialer{}).DialContext +) // ResponseChecker is used to determine when SpoofinClient.Poll is done polling. // This allows you to predicate wait.PollImmediate on the request's http.Response. @@ -82,72 +88,81 @@ type SpoofingClient struct { RequestInterval time.Duration RequestTimeout time.Duration - endpoint string - domain string - logf logging.FormatLogger } -// New returns a SpoofingClient that rewrites requests if the target domain is not `resolveable`. +// TransportOption allows callers to customize the http.Transport used by a SpoofingClient +type TransportOption func(transport *http.Transport) *http.Transport + +// New returns a SpoofingClient that rewrites requests if the target domain is not `resolvable`. // It does this by looking up the ingress at construction time, so reusing a client will not // follow the ingress if it moves (or if there are multiple ingresses). // // If that's a problem, see test/request.go#WaitForEndpointState for oneshot spoofing. -func New(kubeClientset *kubernetes.Clientset, logf logging.FormatLogger, domain string, resolvable bool, endpointOverride string) (*SpoofingClient, error) { +func New( + kubeClientset *kubernetes.Clientset, + logf logging.FormatLogger, + domain string, + resolvable bool, + endpointOverride string, + opts ...TransportOption) (*SpoofingClient, error) { + endpoint, err := ResolveEndpoint(kubeClientset, domain, resolvable, endpointOverride) + if err != nil { + fmt.Errorf("failed get the cluster endpoint: %v", err) + } + + // Spoof the hostname at the resolver level + transport := &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (conn net.Conn, e error) { + spoofed := addr + if i := strings.LastIndex(addr, ":"); i != -1 && domain == addr[:i] { + // The original hostname:port is spoofed by replacing the hostname by the value + // returned by ResolveEndpoint. + spoofed = endpoint + ":" + addr[i+1:] + logf("Spoofing %s -> %s", addr, spoofed) + } + return dialContext(ctx, network, spoofed) + }, + } + + for _, opt := range opts { + transport = opt(transport) + } + + // Enable Zipkin tracing + roundTripper := &ochttp.Transport{ + Base: transport, + Propagation: &b3.HTTPFormat{}, + } + sc := SpoofingClient{ - Client: &http.Client{Transport: &ochttp.Transport{Propagation: &b3.HTTPFormat{}}}, // Using ochttp Transport required for zipkin-tracing + Client: &http.Client{Transport: roundTripper}, RequestInterval: requestInterval, RequestTimeout: RequestTimeout, logf: logf, } - - var err error - if sc.endpoint, err = ResolveEndpoint(kubeClientset, domain, resolvable, endpointOverride); err != nil { - return nil, err - } - - if !resolvable { - sc.domain = domain - } - return &sc, nil } // ResolveEndpoint resolves the endpoint address considering whether the domain is resolvable and taking into // account whether the user overrode the endpoint address externally func ResolveEndpoint(kubeClientset *kubernetes.Clientset, domain string, resolvable bool, endpointOverride string) (string, error) { - // If the domain is resolvable, we can use it directly when we make requests. - endpoint := domain - if !resolvable { - e := endpointOverride - if endpointOverride == "" { - var err error - // If the domain that the Route controller is configured to assign to Route.Status.Domain - // (the domainSuffix) is not resolvable, we need to retrieve the endpoint and spoof - // the Host in our requests. - if e, err = ingress.GetIngressEndpoint(kubeClientset); err != nil { - return "", err - } - } - endpoint = e + // If the domain is resolvable, it can be used directly + if resolvable { + return domain, nil } - return endpoint, nil + // If an override is provided, use it + if endpointOverride != "" { + return endpointOverride, nil + } + // Otherwise, use the actual cluster endpoint + return ingress.GetIngressEndpoint(kubeClientset) } // Do dispatches to the underlying http.Client.Do, spoofing domains as needed // and transforming the http.Response into a spoof.Response. // Each response is augmented with "ZipkinTraceID" header that identifies the zipkin trace corresponding to the request. func (sc *SpoofingClient) Do(req *http.Request) (*Response, error) { - // Controls the Host header, for spoofing. - if sc.domain != "" { - req.Host = sc.domain - } - - // Controls the actual resolution. - if sc.endpoint != "" { - req.URL.Host = sc.endpoint - } - // Starting span to capture zipkin trace. traceContext, span := trace.StartSpan(req.Context(), "SpoofingClient-Trace") defer span.End()