diff --git a/backends/forward.go b/backends/forward.go index 6cbf23c431..1c1a8041d2 100644 --- a/backends/forward.go +++ b/backends/forward.go @@ -229,27 +229,36 @@ func newClient(peer, version string) (*client, error) { if err != nil { return nil, err } + protoAddrParts := strings.SplitN(peer, "://", 2) c := &client{ URL: u, + proto: protoAddrParts[0], + addr: protoAddrParts[1], version: version, } c.URL.Scheme = "http" return c, nil } +func (c *client) dial() (net.Conn, error) { + return net.Dial(c.proto, c.addr) +} + func (c *client) call(method, path, body string) (*http.Response, error) { path = fmt.Sprintf("/%s%s", c.version, path) u, err := url.Parse(path) if err != nil { return nil, err } - u.Host = c.URL.Host + u.Host = "dummy.host" u.Scheme = c.URL.Scheme req, err := http.NewRequest(method, u.String(), strings.NewReader(body)) if err != nil { return nil, err } - resp, err := http.DefaultClient.Do(req) + tr := &http.Transport{Dial: func(_, _ string) (net.Conn, error) { return c.dial() }} + client := &http.Client{Transport: tr} + resp, err := client.Do(req) if err != nil { return nil, err } @@ -258,7 +267,7 @@ func (c *client) call(method, path, body string) (*http.Response, error) { func (c *client) hijack(method, path string, in io.ReadCloser, stdout, stderr io.Writer) error { path = fmt.Sprintf("/%s%s", c.version, path) - dial, err := net.Dial("tcp", c.URL.Host) + dial, err := c.dial() if err != nil { return err }