mirror of https://github.com/knative/pkg.git
This reverts commit 3dd5d66573.
This commit is contained in:
parent
4749553105
commit
5bd67436f6
|
|
@ -20,7 +20,6 @@ package test
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
corev1 "k8s.io/api/core/v1"
|
||||
|
|
@ -39,8 +38,8 @@ type KubeClient struct {
|
|||
}
|
||||
|
||||
// NewSpoofingClient returns a spoofing client to make requests
|
||||
func NewSpoofingClient(client *KubeClient, transport *http.Transport, logf logging.FormatLogger, domain string, resolvable bool) (*spoof.SpoofingClient, error) {
|
||||
return spoof.New(client.Kube, transport, logf, domain, resolvable, Flags.IngressEndpoint)
|
||||
func NewSpoofingClient(client *KubeClient, logf logging.FormatLogger, domain string, resolvable bool) (*spoof.SpoofingClient, error) {
|
||||
return spoof.New(client.Kube, logf, domain, resolvable, Flags.IngressEndpoint)
|
||||
}
|
||||
|
||||
// NewKubeClient instantiates and returns several clientsets required for making request to the
|
||||
|
|
|
|||
|
|
@ -134,15 +134,8 @@ func MatchesAllOf(checkers ...spoof.ResponseChecker) spoof.ResponseChecker {
|
|||
// the domain in the request headers, otherwise it will make the request directly to domain.
|
||||
// desc will be used to name the metric that is emitted to track how long it took for the
|
||||
// domain to get into the state checked by inState. Commas in `desc` must be escaped.
|
||||
func WaitForEndpointState(
|
||||
kubeClient *KubeClient,
|
||||
logf logging.FormatLogger,
|
||||
theURL string,
|
||||
inState spoof.ResponseChecker,
|
||||
desc string,
|
||||
resolvable bool,
|
||||
opts ...RequestOption) (*spoof.Response, error) {
|
||||
return WaitForEndpointStateWithTimeout(kubeClient, http.DefaultTransport.(*http.Transport), logf, theURL, inState, desc, resolvable, spoof.RequestTimeout, opts...)
|
||||
func WaitForEndpointState(kubeClient *KubeClient, logf logging.FormatLogger, theURL string, inState spoof.ResponseChecker, desc string, resolvable bool, opts ...RequestOption) (*spoof.Response, error) {
|
||||
return WaitForEndpointStateWithTimeout(kubeClient, logf, theURL, inState, desc, resolvable, spoof.RequestTimeout, opts...)
|
||||
}
|
||||
|
||||
// WaitForEndpointStateWithTimeout will poll an endpoint until inState indicates the state is achieved
|
||||
|
|
@ -152,15 +145,8 @@ func WaitForEndpointState(
|
|||
// desc will be used to name the metric that is emitted to track how long it took for the
|
||||
// domain to get into the state checked by inState. Commas in `desc` must be escaped.
|
||||
func WaitForEndpointStateWithTimeout(
|
||||
kubeClient *KubeClient,
|
||||
transport *http.Transport,
|
||||
logf logging.FormatLogger,
|
||||
theURL string,
|
||||
inState spoof.ResponseChecker,
|
||||
desc string,
|
||||
resolvable bool,
|
||||
timeout time.Duration,
|
||||
opts ...RequestOption) (*spoof.Response, error) {
|
||||
kubeClient *KubeClient, logf logging.FormatLogger, theURL string, inState spoof.ResponseChecker,
|
||||
desc string, resolvable bool, timeout time.Duration, opts ...RequestOption) (*spoof.Response, error) {
|
||||
defer logging.GetEmitableSpan(context.Background(), fmt.Sprintf("WaitForEndpointState/%s", desc)).End()
|
||||
|
||||
// Try parsing the "theURL" with and without a scheme.
|
||||
|
|
@ -181,7 +167,7 @@ func WaitForEndpointStateWithTimeout(
|
|||
opt(req)
|
||||
}
|
||||
|
||||
client, err := NewSpoofingClient(kubeClient, transport, logf, asURL.Hostname(), resolvable)
|
||||
client, err := NewSpoofingClient(kubeClient, logf, asURL.Hostname(), resolvable)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,12 +19,9 @@ limitations under the License.
|
|||
package spoof
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
|
@ -85,75 +82,72 @@ type SpoofingClient struct {
|
|||
RequestInterval time.Duration
|
||||
RequestTimeout time.Duration
|
||||
|
||||
endpoint string
|
||||
domain string
|
||||
|
||||
logf logging.FormatLogger
|
||||
}
|
||||
|
||||
// New returns a SpoofingClient that rewrites requests if the target domain is not `resolvable`.
|
||||
// New returns a SpoofingClient that rewrites requests if the target domain is not `resolveable`.
|
||||
// It does this by looking up the ingress at construction time, so reusing a client will not
|
||||
// follow the ingress if it moves (or if there are multiple ingresses).
|
||||
//
|
||||
// If that's a problem, see test/request.go#WaitForEndpointState for oneshot spoofing.
|
||||
func New(
|
||||
kubeClientset *kubernetes.Clientset,
|
||||
transport *http.Transport,
|
||||
logf logging.FormatLogger,
|
||||
domain string,
|
||||
resolvable bool,
|
||||
endpointOverride string) (*SpoofingClient, error) {
|
||||
// Spoof the hostname at the resolver level
|
||||
endpoint, err := ResolveEndpoint(kubeClientset, domain, resolvable, endpointOverride)
|
||||
if err != nil {
|
||||
fmt.Errorf("failed get the cluster endpoint: %v", err)
|
||||
}
|
||||
oldDialContext := transport.DialContext
|
||||
if oldDialContext == nil {
|
||||
oldDialContext = (&net.Dialer{}).DialContext
|
||||
}
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (conn net.Conn, e error) {
|
||||
spoofed := addr
|
||||
if i := strings.LastIndex(addr, ":"); i != -1 && domain == addr[:i] {
|
||||
// The original hostname:port is spoofed by replacing the hostname by the value
|
||||
// returned by ResolveEndpoint.
|
||||
spoofed = endpoint + ":" + addr[i+1:]
|
||||
logf("Spoofing %s -> %s", addr, spoofed)
|
||||
}
|
||||
return oldDialContext(ctx, network, spoofed)
|
||||
}
|
||||
|
||||
// Enable Zipkin tracing
|
||||
roundTripper := &ochttp.Transport{
|
||||
Base: transport,
|
||||
Propagation: &b3.HTTPFormat{},
|
||||
}
|
||||
|
||||
func New(kubeClientset *kubernetes.Clientset, logf logging.FormatLogger, domain string, resolvable bool, endpointOverride string) (*SpoofingClient, error) {
|
||||
sc := SpoofingClient{
|
||||
Client: &http.Client{Transport: roundTripper},
|
||||
Client: &http.Client{Transport: &ochttp.Transport{Propagation: &b3.HTTPFormat{}}}, // Using ochttp Transport required for zipkin-tracing
|
||||
RequestInterval: requestInterval,
|
||||
RequestTimeout: RequestTimeout,
|
||||
logf: logf,
|
||||
}
|
||||
|
||||
var err error
|
||||
if sc.endpoint, err = ResolveEndpoint(kubeClientset, domain, resolvable, endpointOverride); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !resolvable {
|
||||
sc.domain = domain
|
||||
}
|
||||
|
||||
return &sc, nil
|
||||
}
|
||||
|
||||
// ResolveEndpoint resolves the endpoint address considering whether the domain is resolvable and taking into
|
||||
// account whether the user overrode the endpoint address externally
|
||||
func ResolveEndpoint(kubeClientset *kubernetes.Clientset, domain string, resolvable bool, endpointOverride string) (string, error) {
|
||||
// If the domain is resolvable, it can be used directly
|
||||
if resolvable {
|
||||
return domain, nil
|
||||
// If the domain is resolvable, we can use it directly when we make requests.
|
||||
endpoint := domain
|
||||
if !resolvable {
|
||||
e := endpointOverride
|
||||
if endpointOverride == "" {
|
||||
var err error
|
||||
// If the domain that the Route controller is configured to assign to Route.Status.Domain
|
||||
// (the domainSuffix) is not resolvable, we need to retrieve the endpoint and spoof
|
||||
// the Host in our requests.
|
||||
if e, err = ingress.GetIngressEndpoint(kubeClientset); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
endpoint = e
|
||||
}
|
||||
// If an override is provided, use it
|
||||
if endpointOverride != "" {
|
||||
return endpointOverride, nil
|
||||
}
|
||||
// Otherwise, use the actual cluster endpoint
|
||||
return ingress.GetIngressEndpoint(kubeClientset)
|
||||
return endpoint, nil
|
||||
}
|
||||
|
||||
// Do dispatches to the underlying http.Client.Do, spoofing domains as needed
|
||||
// and transforming the http.Response into a spoof.Response.
|
||||
// Each response is augmented with "ZipkinTraceID" header that identifies the zipkin trace corresponding to the request.
|
||||
func (sc *SpoofingClient) Do(req *http.Request) (*Response, error) {
|
||||
// Controls the Host header, for spoofing.
|
||||
if sc.domain != "" {
|
||||
req.Host = sc.domain
|
||||
}
|
||||
|
||||
// Controls the actual resolution.
|
||||
if sc.endpoint != "" {
|
||||
req.URL.Host = sc.endpoint
|
||||
}
|
||||
|
||||
// Starting span to capture zipkin trace.
|
||||
traceContext, span := trace.StartSpan(req.Context(), "SpoofingClient-Trace")
|
||||
defer span.End()
|
||||
|
|
|
|||
Loading…
Reference in New Issue