From 5bd67436f64368eef87e48e164a4dfd8e66049c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-R=C3=A9my=20Bancel?= Date: Tue, 3 Sep 2019 14:00:02 -0700 Subject: [PATCH] Revert "Spoof at the Resolver layer, not the Request layer. (#620)" (#629) This reverts commit 3dd5d66573f6dc922b3c9357d8dc97dbba3c9f3b. --- test/clients.go | 5 ++- test/request.go | 24 +++---------- test/spoof/spoof.go | 88 +++++++++++++++++++++------------------------ 3 files changed, 48 insertions(+), 69 deletions(-) diff --git a/test/clients.go b/test/clients.go index 0bc864c5e..c3f43bdab 100644 --- a/test/clients.go +++ b/test/clients.go @@ -20,7 +20,6 @@ package test import ( "fmt" - "net/http" "strings" corev1 "k8s.io/api/core/v1" @@ -39,8 +38,8 @@ type KubeClient struct { } // NewSpoofingClient returns a spoofing client to make requests -func NewSpoofingClient(client *KubeClient, transport *http.Transport, logf logging.FormatLogger, domain string, resolvable bool) (*spoof.SpoofingClient, error) { - return spoof.New(client.Kube, transport, logf, domain, resolvable, Flags.IngressEndpoint) +func NewSpoofingClient(client *KubeClient, logf logging.FormatLogger, domain string, resolvable bool) (*spoof.SpoofingClient, error) { + return spoof.New(client.Kube, logf, domain, resolvable, Flags.IngressEndpoint) } // NewKubeClient instantiates and returns several clientsets required for making request to the diff --git a/test/request.go b/test/request.go index 501208ff1..d6d84f356 100644 --- a/test/request.go +++ b/test/request.go @@ -134,15 +134,8 @@ func MatchesAllOf(checkers ...spoof.ResponseChecker) spoof.ResponseChecker { // the domain in the request headers, otherwise it will make the request directly to domain. // desc will be used to name the metric that is emitted to track how long it took for the // domain to get into the state checked by inState. Commas in `desc` must be escaped. -func WaitForEndpointState( - kubeClient *KubeClient, - logf logging.FormatLogger, - theURL string, - inState spoof.ResponseChecker, - desc string, - resolvable bool, - opts ...RequestOption) (*spoof.Response, error) { - return WaitForEndpointStateWithTimeout(kubeClient, http.DefaultTransport.(*http.Transport), logf, theURL, inState, desc, resolvable, spoof.RequestTimeout, opts...) +func WaitForEndpointState(kubeClient *KubeClient, logf logging.FormatLogger, theURL string, inState spoof.ResponseChecker, desc string, resolvable bool, opts ...RequestOption) (*spoof.Response, error) { + return WaitForEndpointStateWithTimeout(kubeClient, logf, theURL, inState, desc, resolvable, spoof.RequestTimeout, opts...) } // WaitForEndpointStateWithTimeout will poll an endpoint until inState indicates the state is achieved @@ -152,15 +145,8 @@ func WaitForEndpointState( // desc will be used to name the metric that is emitted to track how long it took for the // domain to get into the state checked by inState. Commas in `desc` must be escaped. func WaitForEndpointStateWithTimeout( - kubeClient *KubeClient, - transport *http.Transport, - logf logging.FormatLogger, - theURL string, - inState spoof.ResponseChecker, - desc string, - resolvable bool, - timeout time.Duration, - opts ...RequestOption) (*spoof.Response, error) { + kubeClient *KubeClient, logf logging.FormatLogger, theURL string, inState spoof.ResponseChecker, + desc string, resolvable bool, timeout time.Duration, opts ...RequestOption) (*spoof.Response, error) { defer logging.GetEmitableSpan(context.Background(), fmt.Sprintf("WaitForEndpointState/%s", desc)).End() // Try parsing the "theURL" with and without a scheme. @@ -181,7 +167,7 @@ func WaitForEndpointStateWithTimeout( opt(req) } - client, err := NewSpoofingClient(kubeClient, transport, logf, asURL.Hostname(), resolvable) + client, err := NewSpoofingClient(kubeClient, logf, asURL.Hostname(), resolvable) if err != nil { return nil, err } diff --git a/test/spoof/spoof.go b/test/spoof/spoof.go index cd8618c6e..125a37a9f 100644 --- a/test/spoof/spoof.go +++ b/test/spoof/spoof.go @@ -19,12 +19,9 @@ limitations under the License. package spoof import ( - "context" "fmt" "io/ioutil" - "net" "net/http" - "strings" "time" "github.com/pkg/errors" @@ -85,75 +82,72 @@ type SpoofingClient struct { RequestInterval time.Duration RequestTimeout time.Duration + endpoint string + domain string + logf logging.FormatLogger } -// New returns a SpoofingClient that rewrites requests if the target domain is not `resolvable`. +// New returns a SpoofingClient that rewrites requests if the target domain is not `resolveable`. // It does this by looking up the ingress at construction time, so reusing a client will not // follow the ingress if it moves (or if there are multiple ingresses). // // If that's a problem, see test/request.go#WaitForEndpointState for oneshot spoofing. -func New( - kubeClientset *kubernetes.Clientset, - transport *http.Transport, - logf logging.FormatLogger, - domain string, - resolvable bool, - endpointOverride string) (*SpoofingClient, error) { - // Spoof the hostname at the resolver level - endpoint, err := ResolveEndpoint(kubeClientset, domain, resolvable, endpointOverride) - if err != nil { - fmt.Errorf("failed get the cluster endpoint: %v", err) - } - oldDialContext := transport.DialContext - if oldDialContext == nil { - oldDialContext = (&net.Dialer{}).DialContext - } - transport.DialContext = func(ctx context.Context, network, addr string) (conn net.Conn, e error) { - spoofed := addr - if i := strings.LastIndex(addr, ":"); i != -1 && domain == addr[:i] { - // The original hostname:port is spoofed by replacing the hostname by the value - // returned by ResolveEndpoint. - spoofed = endpoint + ":" + addr[i+1:] - logf("Spoofing %s -> %s", addr, spoofed) - } - return oldDialContext(ctx, network, spoofed) - } - - // Enable Zipkin tracing - roundTripper := &ochttp.Transport{ - Base: transport, - Propagation: &b3.HTTPFormat{}, - } - +func New(kubeClientset *kubernetes.Clientset, logf logging.FormatLogger, domain string, resolvable bool, endpointOverride string) (*SpoofingClient, error) { sc := SpoofingClient{ - Client: &http.Client{Transport: roundTripper}, + Client: &http.Client{Transport: &ochttp.Transport{Propagation: &b3.HTTPFormat{}}}, // Using ochttp Transport required for zipkin-tracing RequestInterval: requestInterval, RequestTimeout: RequestTimeout, logf: logf, } + + var err error + if sc.endpoint, err = ResolveEndpoint(kubeClientset, domain, resolvable, endpointOverride); err != nil { + return nil, err + } + + if !resolvable { + sc.domain = domain + } + return &sc, nil } // ResolveEndpoint resolves the endpoint address considering whether the domain is resolvable and taking into // account whether the user overrode the endpoint address externally func ResolveEndpoint(kubeClientset *kubernetes.Clientset, domain string, resolvable bool, endpointOverride string) (string, error) { - // If the domain is resolvable, it can be used directly - if resolvable { - return domain, nil + // If the domain is resolvable, we can use it directly when we make requests. + endpoint := domain + if !resolvable { + e := endpointOverride + if endpointOverride == "" { + var err error + // If the domain that the Route controller is configured to assign to Route.Status.Domain + // (the domainSuffix) is not resolvable, we need to retrieve the endpoint and spoof + // the Host in our requests. + if e, err = ingress.GetIngressEndpoint(kubeClientset); err != nil { + return "", err + } + } + endpoint = e } - // If an override is provided, use it - if endpointOverride != "" { - return endpointOverride, nil - } - // Otherwise, use the actual cluster endpoint - return ingress.GetIngressEndpoint(kubeClientset) + return endpoint, nil } // Do dispatches to the underlying http.Client.Do, spoofing domains as needed // and transforming the http.Response into a spoof.Response. // Each response is augmented with "ZipkinTraceID" header that identifies the zipkin trace corresponding to the request. func (sc *SpoofingClient) Do(req *http.Request) (*Response, error) { + // Controls the Host header, for spoofing. + if sc.domain != "" { + req.Host = sc.domain + } + + // Controls the actual resolution. + if sc.endpoint != "" { + req.URL.Host = sc.endpoint + } + // Starting span to capture zipkin trace. traceContext, span := trace.StartSpan(req.Context(), "SpoofingClient-Trace") defer span.End()