mirror of https://github.com/knative/pkg.git
Spoof at the Resolver layer, not the Request layer. (#630)
This commit is contained in:
parent
a5172fd917
commit
75da3911c0
|
|
@ -38,8 +38,8 @@ type KubeClient struct {
|
|||
}
|
||||
|
||||
// NewSpoofingClient returns a spoofing client to make requests
|
||||
func NewSpoofingClient(client *KubeClient, logf logging.FormatLogger, domain string, resolvable bool) (*spoof.SpoofingClient, error) {
|
||||
return spoof.New(client.Kube, logf, domain, resolvable, Flags.IngressEndpoint)
|
||||
func NewSpoofingClient(client *KubeClient, logf logging.FormatLogger, domain string, resolvable bool, opts ...spoof.TransportOption) (*spoof.SpoofingClient, error) {
|
||||
return spoof.New(client.Kube, logf, domain, resolvable, Flags.IngressEndpoint, opts...)
|
||||
}
|
||||
|
||||
// NewKubeClient instantiates and returns several clientsets required for making request to the
|
||||
|
|
|
|||
|
|
@ -134,7 +134,14 @@ func MatchesAllOf(checkers ...spoof.ResponseChecker) spoof.ResponseChecker {
|
|||
// the domain in the request headers, otherwise it will make the request directly to domain.
|
||||
// desc will be used to name the metric that is emitted to track how long it took for the
|
||||
// domain to get into the state checked by inState. Commas in `desc` must be escaped.
|
||||
func WaitForEndpointState(kubeClient *KubeClient, logf logging.FormatLogger, theURL string, inState spoof.ResponseChecker, desc string, resolvable bool, opts ...RequestOption) (*spoof.Response, error) {
|
||||
func WaitForEndpointState(
|
||||
kubeClient *KubeClient,
|
||||
logf logging.FormatLogger,
|
||||
theURL string,
|
||||
inState spoof.ResponseChecker,
|
||||
desc string,
|
||||
resolvable bool,
|
||||
opts ...interface{}) (*spoof.Response, error) {
|
||||
return WaitForEndpointStateWithTimeout(kubeClient, logf, theURL, inState, desc, resolvable, spoof.RequestTimeout, opts...)
|
||||
}
|
||||
|
||||
|
|
@ -145,8 +152,14 @@ func WaitForEndpointState(kubeClient *KubeClient, logf logging.FormatLogger, the
|
|||
// desc will be used to name the metric that is emitted to track how long it took for the
|
||||
// domain to get into the state checked by inState. Commas in `desc` must be escaped.
|
||||
func WaitForEndpointStateWithTimeout(
|
||||
kubeClient *KubeClient, logf logging.FormatLogger, theURL string, inState spoof.ResponseChecker,
|
||||
desc string, resolvable bool, timeout time.Duration, opts ...RequestOption) (*spoof.Response, error) {
|
||||
kubeClient *KubeClient,
|
||||
logf logging.FormatLogger,
|
||||
theURL string,
|
||||
inState spoof.ResponseChecker,
|
||||
desc string,
|
||||
resolvable bool,
|
||||
timeout time.Duration,
|
||||
opts ...interface{}) (*spoof.Response, error) {
|
||||
defer logging.GetEmitableSpan(context.Background(), fmt.Sprintf("WaitForEndpointState/%s", desc)).End()
|
||||
|
||||
// Try parsing the "theURL" with and without a scheme.
|
||||
|
|
@ -163,11 +176,17 @@ func WaitForEndpointStateWithTimeout(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
var tOpts []spoof.TransportOption
|
||||
for _, opt := range opts {
|
||||
opt(req)
|
||||
rOpt, ok := opt.(RequestOption)
|
||||
if ok {
|
||||
rOpt(req)
|
||||
} else if tOpt, ok := opt.(spoof.TransportOption); ok {
|
||||
tOpts = append(tOpts, tOpt)
|
||||
}
|
||||
}
|
||||
|
||||
client, err := NewSpoofingClient(kubeClient, logf, asURL.Hostname(), resolvable)
|
||||
client, err := NewSpoofingClient(kubeClient, logf, asURL.Hostname(), resolvable, tOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,9 +19,12 @@ limitations under the License.
|
|||
package spoof
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
|
@ -66,7 +69,10 @@ type Interface interface {
|
|||
}
|
||||
|
||||
// https://medium.com/stupid-gopher-tricks/ensuring-go-interface-satisfaction-at-compile-time-1ed158e8fa17
|
||||
var _ Interface = (*SpoofingClient)(nil)
|
||||
var (
|
||||
_ Interface = (*SpoofingClient)(nil)
|
||||
dialContext = (&net.Dialer{}).DialContext
|
||||
)
|
||||
|
||||
// ResponseChecker is used to determine when SpoofinClient.Poll is done polling.
|
||||
// This allows you to predicate wait.PollImmediate on the request's http.Response.
|
||||
|
|
@ -82,72 +88,81 @@ type SpoofingClient struct {
|
|||
RequestInterval time.Duration
|
||||
RequestTimeout time.Duration
|
||||
|
||||
endpoint string
|
||||
domain string
|
||||
|
||||
logf logging.FormatLogger
|
||||
}
|
||||
|
||||
// New returns a SpoofingClient that rewrites requests if the target domain is not `resolveable`.
|
||||
// TransportOption allows callers to customize the http.Transport used by a SpoofingClient
|
||||
type TransportOption func(transport *http.Transport) *http.Transport
|
||||
|
||||
// New returns a SpoofingClient that rewrites requests if the target domain is not `resolvable`.
|
||||
// It does this by looking up the ingress at construction time, so reusing a client will not
|
||||
// follow the ingress if it moves (or if there are multiple ingresses).
|
||||
//
|
||||
// If that's a problem, see test/request.go#WaitForEndpointState for oneshot spoofing.
|
||||
func New(kubeClientset *kubernetes.Clientset, logf logging.FormatLogger, domain string, resolvable bool, endpointOverride string) (*SpoofingClient, error) {
|
||||
func New(
|
||||
kubeClientset *kubernetes.Clientset,
|
||||
logf logging.FormatLogger,
|
||||
domain string,
|
||||
resolvable bool,
|
||||
endpointOverride string,
|
||||
opts ...TransportOption) (*SpoofingClient, error) {
|
||||
endpoint, err := ResolveEndpoint(kubeClientset, domain, resolvable, endpointOverride)
|
||||
if err != nil {
|
||||
fmt.Errorf("failed get the cluster endpoint: %v", err)
|
||||
}
|
||||
|
||||
// Spoof the hostname at the resolver level
|
||||
transport := &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (conn net.Conn, e error) {
|
||||
spoofed := addr
|
||||
if i := strings.LastIndex(addr, ":"); i != -1 && domain == addr[:i] {
|
||||
// The original hostname:port is spoofed by replacing the hostname by the value
|
||||
// returned by ResolveEndpoint.
|
||||
spoofed = endpoint + ":" + addr[i+1:]
|
||||
logf("Spoofing %s -> %s", addr, spoofed)
|
||||
}
|
||||
return dialContext(ctx, network, spoofed)
|
||||
},
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
transport = opt(transport)
|
||||
}
|
||||
|
||||
// Enable Zipkin tracing
|
||||
roundTripper := &ochttp.Transport{
|
||||
Base: transport,
|
||||
Propagation: &b3.HTTPFormat{},
|
||||
}
|
||||
|
||||
sc := SpoofingClient{
|
||||
Client: &http.Client{Transport: &ochttp.Transport{Propagation: &b3.HTTPFormat{}}}, // Using ochttp Transport required for zipkin-tracing
|
||||
Client: &http.Client{Transport: roundTripper},
|
||||
RequestInterval: requestInterval,
|
||||
RequestTimeout: RequestTimeout,
|
||||
logf: logf,
|
||||
}
|
||||
|
||||
var err error
|
||||
if sc.endpoint, err = ResolveEndpoint(kubeClientset, domain, resolvable, endpointOverride); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !resolvable {
|
||||
sc.domain = domain
|
||||
}
|
||||
|
||||
return &sc, nil
|
||||
}
|
||||
|
||||
// ResolveEndpoint resolves the endpoint address considering whether the domain is resolvable and taking into
|
||||
// account whether the user overrode the endpoint address externally
|
||||
func ResolveEndpoint(kubeClientset *kubernetes.Clientset, domain string, resolvable bool, endpointOverride string) (string, error) {
|
||||
// If the domain is resolvable, we can use it directly when we make requests.
|
||||
endpoint := domain
|
||||
if !resolvable {
|
||||
e := endpointOverride
|
||||
if endpointOverride == "" {
|
||||
var err error
|
||||
// If the domain that the Route controller is configured to assign to Route.Status.Domain
|
||||
// (the domainSuffix) is not resolvable, we need to retrieve the endpoint and spoof
|
||||
// the Host in our requests.
|
||||
if e, err = ingress.GetIngressEndpoint(kubeClientset); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
endpoint = e
|
||||
// If the domain is resolvable, it can be used directly
|
||||
if resolvable {
|
||||
return domain, nil
|
||||
}
|
||||
return endpoint, nil
|
||||
// If an override is provided, use it
|
||||
if endpointOverride != "" {
|
||||
return endpointOverride, nil
|
||||
}
|
||||
// Otherwise, use the actual cluster endpoint
|
||||
return ingress.GetIngressEndpoint(kubeClientset)
|
||||
}
|
||||
|
||||
// Do dispatches to the underlying http.Client.Do, spoofing domains as needed
|
||||
// and transforming the http.Response into a spoof.Response.
|
||||
// Each response is augmented with "ZipkinTraceID" header that identifies the zipkin trace corresponding to the request.
|
||||
func (sc *SpoofingClient) Do(req *http.Request) (*Response, error) {
|
||||
// Controls the Host header, for spoofing.
|
||||
if sc.domain != "" {
|
||||
req.Host = sc.domain
|
||||
}
|
||||
|
||||
// Controls the actual resolution.
|
||||
if sc.endpoint != "" {
|
||||
req.URL.Host = sc.endpoint
|
||||
}
|
||||
|
||||
// Starting span to capture zipkin trace.
|
||||
traceContext, span := trace.StartSpan(req.Context(), "SpoofingClient-Trace")
|
||||
defer span.End()
|
||||
|
|
|
|||
Loading…
Reference in New Issue