transport,grpc: Integrate delegating resolver and introduce dial options for target host resolution (#7881)

* Change proxy behaviour
This commit is contained in:
eshitachandwani 2025-01-24 12:10:11 +05:30 committed by GitHub
parent 66f64719c5
commit 2fd426d091
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 750 additions and 276 deletions

View File

@ -225,7 +225,12 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) {
// At the end of this method, we kick the channel out of idle, rather than
// waiting for the first rpc.
opts = append([]DialOption{withDefaultScheme("passthrough")}, opts...)
//
// WithLocalDNSResolution dial option in `grpc.Dial` ensures that it
// preserves behavior: when default scheme passthrough is used, skip
// hostname resolution, when "dns" is used for resolution, perform
// resolution on the client.
opts = append([]DialOption{withDefaultScheme("passthrough"), WithLocalDNSResolution()}, opts...)
cc, err := NewClient(target, opts...)
if err != nil {
return nil, err

View File

@ -94,6 +94,8 @@ type dialOptions struct {
idleTimeout time.Duration
defaultScheme string
maxCallAttempts int
enableLocalDNSResolution bool // Specifies if target hostnames should be resolved when proxying is enabled.
useProxy bool // Specifies if a server should be connected via proxy.
}
// DialOption configures how we set up the connection.
@ -377,7 +379,22 @@ func WithInsecure() DialOption {
// later release.
func WithNoProxy() DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.copts.UseProxy = false
o.useProxy = false
})
}
// WithLocalDNSResolution forces local DNS name resolution even when a proxy is
// specified in the environment. By default, the server name is provided
// directly to the proxy as part of the CONNECT handshake. This is ignored if
// WithNoProxy is used.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func WithLocalDNSResolution() DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.enableLocalDNSResolution = true
})
}
@ -667,14 +684,15 @@ func defaultDialOptions() dialOptions {
copts: transport.ConnectOptions{
ReadBufferSize: defaultReadBufSize,
WriteBufferSize: defaultWriteBufSize,
UseProxy: true,
UserAgent: grpcUA,
BufferPool: mem.DefaultBufferPool(),
},
bs: internalbackoff.DefaultExponential,
idleTimeout: 30 * time.Minute,
defaultScheme: "dns",
maxCallAttempts: defaultMaxCallAttempts,
bs: internalbackoff.DefaultExponential,
idleTimeout: 30 * time.Minute,
defaultScheme: "dns",
maxCallAttempts: defaultMaxCallAttempts,
useProxy: true,
enableLocalDNSResolution: false,
}
}

View File

@ -33,7 +33,7 @@ const proxyOptionsKey = keyType("grpc.resolver.delegatingresolver.proxyOptions")
// Options holds the proxy connection details needed during the CONNECT
// handshake.
type Options struct {
User url.Userinfo
User *url.Userinfo
ConnectAddr string
}
@ -44,7 +44,8 @@ func Set(addr resolver.Address, opts Options) resolver.Address {
}
// Get returns the Options for the proxy [resolver.Address] and a boolean
// value representing if the attribute is present or not.
// value representing if the attribute is present or not. The returned data
// should not be mutated.
func Get(addr resolver.Address) (Options, bool) {
if a := addr.Attributes.Value(proxyOptionsKey); a != nil {
return a.(Options), true

View File

@ -42,7 +42,7 @@ func (s) TestGet(t *testing.T) {
name string
addr resolver.Address
wantConnectAddr string
wantUser url.Userinfo
wantUser *url.Userinfo
wantAttrPresent bool
}{
{
@ -61,10 +61,10 @@ func (s) TestGet(t *testing.T) {
addr: resolver.Address{
Addr: "test-address",
Attributes: attributes.New(proxyOptionsKey, Options{
User: *user,
User: user,
}),
},
wantUser: *user,
wantUser: user,
wantAttrPresent: true,
},
{
@ -97,7 +97,7 @@ func (s) TestGet(t *testing.T) {
func (s) TestSet(t *testing.T) {
addr := resolver.Address{Addr: "test-address"}
pOpts := Options{
User: *url.UserPassword("username", "password"),
User: url.UserPassword("username", "password"),
ConnectAddr: "proxy-address",
}
@ -108,7 +108,7 @@ func (s) TestSet(t *testing.T) {
t.Errorf("Get(%v) = %v, want %v ", populatedAddr, attrPresent, true)
}
if got, want := gotOption.ConnectAddr, pOpts.ConnectAddr; got != want {
t.Errorf("Unexpected ConnectAddr proxy atrribute = %v, want %v", got, want)
t.Errorf("unexpected ConnectAddr proxy atrribute = %v, want %v", got, want)
}
if got, want := gotOption.User, pOpts.User; got != want {
t.Errorf("unexpected User proxy attribute = %v, want %v", got, want)

View File

@ -205,13 +205,9 @@ func (r *delegatingResolver) updateClientConnStateLocked() error {
proxyAddr = resolver.Address{Addr: r.proxyURL.Host}
}
var addresses []resolver.Address
var user url.Userinfo
if r.proxyURL.User != nil {
user = *r.proxyURL.User
}
for _, targetAddr := range (*r.targetResolverState).Addresses {
addresses = append(addresses, proxyattributes.Set(proxyAddr, proxyattributes.Options{
User: user,
User: r.proxyURL.User,
ConnectAddr: targetAddr.Addr,
}))
}
@ -229,7 +225,7 @@ func (r *delegatingResolver) updateClientConnStateLocked() error {
for _, proxyAddr := range r.proxyAddrs {
for _, targetAddr := range endpt.Addresses {
addrs = append(addrs, proxyattributes.Set(proxyAddr, proxyattributes.Options{
User: user,
User: r.proxyURL.User,
ConnectAddr: targetAddr.Addr,
}))
}

View File

@ -0,0 +1,134 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package proxyserver provides an implementation of a proxy server for testing purposes.
// The server supports only a single incoming connection at a time and is not concurrent.
// It handles only HTTP CONNECT requests; other HTTP methods are not supported.
package proxyserver
import (
"bufio"
"bytes"
"io"
"net"
"net/http"
"testing"
"time"
"google.golang.org/grpc/internal/testutils"
)
// ProxyServer represents a test proxy server.
type ProxyServer struct {
lis net.Listener
in net.Conn // Connection from the client to the proxy.
out net.Conn // Connection from the proxy to the backend.
onRequest func(*http.Request) // Function to check the request sent to proxy.
Addr string // Address of the proxy
}
const defaultTestTimeout = 10 * time.Second
// Stop closes the ProxyServer and its connections to client and server.
func (p *ProxyServer) stop() {
p.lis.Close()
if p.in != nil {
p.in.Close()
}
if p.out != nil {
p.out.Close()
}
}
func (p *ProxyServer) handleRequest(t *testing.T, in net.Conn, waitForServerHello bool) {
req, err := http.ReadRequest(bufio.NewReader(in))
if err != nil {
t.Errorf("failed to read CONNECT req: %v", err)
return
}
if req.Method != http.MethodConnect {
t.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
}
p.onRequest(req)
t.Logf("Dialing to %s", req.URL.Host)
out, err := net.Dial("tcp", req.URL.Host)
if err != nil {
in.Close()
t.Logf("failed to dial to server: %v", err)
return
}
out.SetDeadline(time.Now().Add(defaultTestTimeout))
resp := http.Response{StatusCode: http.StatusOK, Proto: "HTTP/1.0"}
var buf bytes.Buffer
resp.Write(&buf)
if waitForServerHello {
// Batch the first message from the server with the http connect
// response. This is done to test the cases in which the grpc client has
// the response to the connect request and proxied packets from the
// destination server when it reads the transport.
b := make([]byte, 50)
bytesRead, err := out.Read(b)
if err != nil {
t.Errorf("Got error while reading server hello: %v", err)
in.Close()
out.Close()
return
}
buf.Write(b[0:bytesRead])
}
p.in = in
p.in.Write(buf.Bytes())
p.out = out
go io.Copy(p.in, p.out)
go io.Copy(p.out, p.in)
}
// New initializes and starts a proxy server, registers a cleanup to
// stop it, and returns a ProxyServer.
func New(t *testing.T, reqCheck func(*http.Request), waitForServerHello bool) *ProxyServer {
t.Helper()
pLis, err := testutils.LocalTCPListener()
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
p := &ProxyServer{
lis: pLis,
onRequest: reqCheck,
Addr: pLis.Addr().String(),
}
// Start the proxy server.
go func() {
for {
in, err := p.lis.Accept()
if err != nil {
return
}
// p.handleRequest is not invoked in a goroutine because the test
// proxy currently supports handling only one connection at a time.
p.handleRequest(t, in, waitForServerHello)
}
}()
t.Logf("Started proxy at: %q", pLis.Addr().String())
t.Cleanup(p.stop)
return p
}

View File

@ -43,6 +43,7 @@ import (
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/grpcutil"
imetadata "google.golang.org/grpc/internal/metadata"
"google.golang.org/grpc/internal/proxyattributes"
istatus "google.golang.org/grpc/internal/status"
isyscall "google.golang.org/grpc/internal/syscall"
"google.golang.org/grpc/internal/transport/networktype"
@ -153,7 +154,7 @@ type http2Client struct {
logger *grpclog.PrefixLogger
}
func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr resolver.Address, useProxy bool, grpcUA string) (net.Conn, error) {
func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr resolver.Address, grpcUA string) (net.Conn, error) {
address := addr.Addr
networkType, ok := networktype.Get(addr)
if fn != nil {
@ -177,8 +178,8 @@ func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error
if !ok {
networkType, address = parseDialTarget(address)
}
if networkType == "tcp" && useProxy {
return proxyDial(ctx, address, grpcUA)
if opts, present := proxyattributes.Get(addr); present {
return proxyDial(ctx, addr, grpcUA, opts)
}
return internal.NetDialerWithTCPKeepalive().DialContext(ctx, networkType, address)
}
@ -217,7 +218,7 @@ func NewHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
// address specific arbitrary data to reach custom dialers and credential handshakers.
connectCtx = icredentials.NewClientHandshakeInfoContext(connectCtx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
conn, err := dial(connectCtx, opts.Dialer, addr, opts.UseProxy, opts.UserAgent)
conn, err := dial(connectCtx, opts.Dialer, addr, opts.UserAgent)
if err != nil {
if opts.FailOnNonTempDialError {
return nil, connectionErrorf(isTemporary(err), err, "transport: error while dialing: %v", err)

View File

@ -30,34 +30,16 @@ import (
"net/url"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/proxyattributes"
"google.golang.org/grpc/resolver"
)
const proxyAuthHeaderKey = "Proxy-Authorization"
var (
// The following variable will be overwritten in the tests.
httpProxyFromEnvironment = http.ProxyFromEnvironment
)
func mapAddress(address string) (*url.URL, error) {
req := &http.Request{
URL: &url.URL{
Scheme: "https",
Host: address,
},
}
url, err := httpProxyFromEnvironment(req)
if err != nil {
return nil, err
}
return url, nil
}
// To read a response from a net.Conn, http.ReadResponse() takes a bufio.Reader.
// It's possible that this reader reads more than what's need for the response and stores
// those bytes in the buffer.
// bufConn wraps the original net.Conn and the bufio.Reader to make sure we don't lose the
// bytes in the buffer.
// It's possible that this reader reads more than what's need for the response
// and stores those bytes in the buffer. bufConn wraps the original net.Conn
// and the bufio.Reader to make sure we don't lose the bytes in the buffer.
type bufConn struct {
net.Conn
r io.Reader
@ -72,7 +54,7 @@ func basicAuth(username, password string) string {
return base64.StdEncoding.EncodeToString([]byte(auth))
}
func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr string, proxyURL *url.URL, grpcUA string) (_ net.Conn, err error) {
func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, grpcUA string, opts proxyattributes.Options) (_ net.Conn, err error) {
defer func() {
if err != nil {
conn.Close()
@ -81,15 +63,14 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr stri
req := &http.Request{
Method: http.MethodConnect,
URL: &url.URL{Host: backendAddr},
URL: &url.URL{Host: opts.ConnectAddr},
Header: map[string][]string{"User-Agent": {grpcUA}},
}
if t := proxyURL.User; t != nil {
u := t.Username()
p, _ := t.Password()
if user := opts.User; user != nil {
u := user.Username()
p, _ := user.Password()
req.Header.Add(proxyAuthHeaderKey, "Basic "+basicAuth(u, p))
}
if err := sendHTTPRequest(ctx, req, conn); err != nil {
return nil, fmt.Errorf("failed to write the HTTP request: %v", err)
}
@ -117,28 +98,13 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr stri
return conn, nil
}
// proxyDial dials, connecting to a proxy first if necessary. Checks if a proxy
// is necessary, dials, does the HTTP CONNECT handshake, and returns the
// connection.
func proxyDial(ctx context.Context, addr string, grpcUA string) (net.Conn, error) {
newAddr := addr
proxyURL, err := mapAddress(addr)
// proxyDial establishes a TCP connection to the specified address and performs an HTTP CONNECT handshake.
func proxyDial(ctx context.Context, addr resolver.Address, grpcUA string, opts proxyattributes.Options) (net.Conn, error) {
conn, err := internal.NetDialerWithTCPKeepalive().DialContext(ctx, "tcp", addr.Addr)
if err != nil {
return nil, err
}
if proxyURL != nil {
newAddr = proxyURL.Host
}
conn, err := internal.NetDialerWithTCPKeepalive().DialContext(ctx, "tcp", newAddr)
if err != nil {
return nil, err
}
if proxyURL == nil {
// proxy is disabled if proxyURL is nil.
return conn, err
}
return doHTTPConnectHandshake(ctx, conn, addr, proxyURL, grpcUA)
return doHTTPConnectHandshake(ctx, conn, grpcUA, opts)
}
func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) error {

View File

@ -0,0 +1,518 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package transport_test
import (
"context"
"encoding/base64"
"fmt"
"net"
"net/http"
"net/netip"
"net/url"
"testing"
"time"
"golang.org/x/net/http/httpproxy"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/resolver/delegatingresolver"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/internal/testutils/proxyserver"
testgrpc "google.golang.org/grpc/interop/grpc_testing"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
)
const defaultTestTimeout = 10 * time.Second
type s struct {
grpctest.Tester
}
func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}
func startBackendServer(t *testing.T) *stubserver.StubServer {
t.Helper()
backend := &stubserver.StubServer{
EmptyCallF: func(context.Context, *testgrpc.Empty) (*testgrpc.Empty, error) { return &testgrpc.Empty{}, nil },
}
if err := backend.StartServer(); err != nil {
t.Fatalf("failed to start backend: %v", err)
}
t.Logf("Started TestService backend at: %q", backend.Address)
t.Cleanup(backend.Stop)
return backend
}
func isIPAddr(addr string) bool {
_, err := netip.ParseAddr(addr)
return err == nil
}
// Tests the scenario where grpc.Dial is performed using a proxy with the
// default resolver in the target URI. The test verifies that the connection is
// established to the proxy server, sends the unresolved target URI in the HTTP
// CONNECT request and is successfully connected to the backend server.
func (s) TestGRPCDialWithProxy(t *testing.T) {
backend := startBackendServer(t)
unresolvedTargetURI := fmt.Sprintf("localhost:%d", testutils.ParsePort(t, backend.Address))
proxyCalled := false
reqCheck := func(req *http.Request) {
proxyCalled = true
host, _, err := net.SplitHostPort(req.URL.Host)
if err != nil {
t.Error(err)
}
if got, want := host, "localhost"; got != want {
t.Errorf(" Unexpected request host: %s , want = %s ", got, want)
}
}
pServer := proxyserver.New(t, reqCheck, false)
// Use "localhost:<port>" to verify the proxy address is handled
// correctly by the delegating resolver and connects to the proxy server
// correctly even when unresolved.
pAddr := fmt.Sprintf("localhost:%d", testutils.ParsePort(t, pServer.Addr))
// Overwrite the function in the test and restore them in defer.
hpfe := func(req *http.Request) (*url.URL, error) {
if req.URL.Host == unresolvedTargetURI {
return &url.URL{
Scheme: "https",
Host: pAddr,
}, nil
}
t.Errorf("Unexpected request host to proxy: %s want %s", req.URL.Host, unresolvedTargetURI)
return nil, nil
}
orighpfe := delegatingresolver.HTTPSProxyFromEnvironment
delegatingresolver.HTTPSProxyFromEnvironment = hpfe
defer func() { delegatingresolver.HTTPSProxyFromEnvironment = orighpfe }()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
conn, err := grpc.Dial(unresolvedTargetURI, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.Dial(%s) failed: %v", unresolvedTargetURI, err)
}
defer conn.Close()
// Send an empty RPC to the backend through the proxy.
client := testgrpc.NewTestServiceClient(conn)
if _, err := client.EmptyCall(ctx, &testgrpc.Empty{}); err != nil {
t.Fatalf("EmptyCall failed: %v", err)
}
if !proxyCalled {
t.Fatalf("Proxy not connected")
}
}
// Tests the scenario where `grpc.Dial` is performed with a proxy and the "dns"
// scheme for the target. The test verifies that the proxy URI is correctly
// resolved and that the target URI resolution on the client preserves the
// original behavior of `grpc.Dial`. It also ensures that a connection is
// established to the proxy server, with the resolved target URI sent in the
// HTTP CONNECT request, successfully connecting to the backend server.
func (s) TestGRPCDialWithDNSAndProxy(t *testing.T) {
backend := startBackendServer(t)
unresolvedTargetURI := fmt.Sprintf("localhost:%d", testutils.ParsePort(t, backend.Address))
proxyCalled := false
reqCheck := func(req *http.Request) {
proxyCalled = true
host, _, err := net.SplitHostPort(req.URL.Host)
if err != nil {
t.Error(err)
}
if got, want := isIPAddr(host), true; got != want {
t.Errorf("isIPAddr(%q) = %t, want = %t", host, got, want)
}
}
pServer := proxyserver.New(t, reqCheck, false)
// Overwrite the function in the test and restore them in defer.
hpfe := func(req *http.Request) (*url.URL, error) {
if req.URL.Host == unresolvedTargetURI {
return &url.URL{
Scheme: "https",
Host: pServer.Addr,
}, nil
}
t.Errorf("Unexpected request host to proxy: %s want %s", req.URL.Host, unresolvedTargetURI)
return nil, nil
}
orighpfe := delegatingresolver.HTTPSProxyFromEnvironment
delegatingresolver.HTTPSProxyFromEnvironment = hpfe
defer func() { delegatingresolver.HTTPSProxyFromEnvironment = orighpfe }()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
conn, err := grpc.Dial("dns:///"+unresolvedTargetURI, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.Dial(%s) failed: %v", "dns:///"+unresolvedTargetURI, err)
}
defer conn.Close()
// Send an empty RPC to the backend through the proxy.
client := testgrpc.NewTestServiceClient(conn)
if _, err := client.EmptyCall(ctx, &testgrpc.Empty{}); err != nil {
t.Fatalf("EmptyCall failed: %v", err)
}
if !proxyCalled {
t.Fatalf("Proxy not connected")
}
}
// Tests the scenario where `grpc.NewClient` is used with the default DNS
// resolver for the target URI and a proxy is configured. The test verifies
// that the client resolves proxy URI, connects to the proxy server, sends the
// unresolved target URI in the HTTP CONNECT request, and successfully
// establishes a connection to the backend server.
func (s) TestNewClientWithProxy(t *testing.T) {
backend := startBackendServer(t)
unresolvedTargetURI := fmt.Sprintf("localhost:%d", testutils.ParsePort(t, backend.Address))
proxyCalled := false
reqCheck := func(req *http.Request) {
proxyCalled = true
host, _, err := net.SplitHostPort(req.URL.Host)
if err != nil {
t.Error(err)
}
if got, want := host, "localhost"; got != want {
t.Errorf(" Unexpected request host: %s , want = %s ", got, want)
}
}
pServer := proxyserver.New(t, reqCheck, false)
// Use "localhost:<port>" to verify the proxy address is handled
// correctly by the delegating resolver and connects to the proxy server
// correctly even when unresolved.
pAddr := fmt.Sprintf("localhost:%d", testutils.ParsePort(t, pServer.Addr))
// Overwrite the function in the test and restore them in defer.
hpfe := func(req *http.Request) (*url.URL, error) {
if req.URL.Host == unresolvedTargetURI {
return &url.URL{
Scheme: "https",
Host: pAddr,
}, nil
}
t.Errorf("Unexpected request host to proxy: %s want %s", req.URL.Host, unresolvedTargetURI)
return nil, nil
}
orighpfe := delegatingresolver.HTTPSProxyFromEnvironment
delegatingresolver.HTTPSProxyFromEnvironment = hpfe
defer func() { delegatingresolver.HTTPSProxyFromEnvironment = orighpfe }()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
conn, err := grpc.NewClient(unresolvedTargetURI, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient(%s) failed: %v", unresolvedTargetURI, err)
}
defer conn.Close()
// Send an empty RPC to the backend through the proxy.
client := testgrpc.NewTestServiceClient(conn)
if _, err := client.EmptyCall(ctx, &testgrpc.Empty{}); err != nil {
t.Fatalf("EmptyCall failed: %v", err)
}
if !proxyCalled {
t.Fatalf("Proxy not connected")
}
}
// Tests the scenario where grpc.NewClient is used with a custom target URI
// scheme and a proxy is configured. The test verifies that the client
// successfully connects to the proxy server, resolves the proxy URI correctly,
// includes the resolved target URI in the HTTP CONNECT request, and
// establishes a connection to the backend server.
func (s) TestNewClientWithProxyAndCustomResolver(t *testing.T) {
backend := startBackendServer(t)
unresolvedTargetURI := fmt.Sprintf("localhost:%d", testutils.ParsePort(t, backend.Address))
proxyCalled := false
reqCheck := func(req *http.Request) {
proxyCalled = true
host, _, err := net.SplitHostPort(req.URL.Host)
if err != nil {
t.Error(err)
}
if got, want := isIPAddr(host), true; got != want {
t.Errorf("isIPAddr(%q) = %t, want = %t", host, got, want)
}
}
pServer := proxyserver.New(t, reqCheck, false)
// Overwrite the function in the test and restore them in defer.
hpfe := func(req *http.Request) (*url.URL, error) {
if req.URL.Host == unresolvedTargetURI {
return &url.URL{
Scheme: "https",
Host: pServer.Addr,
}, nil
}
t.Errorf("Unexpected request host to proxy: %s want %s", req.URL.Host, unresolvedTargetURI)
return nil, nil
}
orighpfe := delegatingresolver.HTTPSProxyFromEnvironment
delegatingresolver.HTTPSProxyFromEnvironment = hpfe
defer func() { delegatingresolver.HTTPSProxyFromEnvironment = orighpfe }()
// Create and update a custom resolver for target URI.
targetResolver := manual.NewBuilderWithScheme("test")
resolver.Register(targetResolver)
targetResolver.InitialState(resolver.State{Endpoints: []resolver.Endpoint{{Addresses: []resolver.Address{{Addr: backend.Address}}}}})
// Dial to the proxy server.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
conn, err := grpc.NewClient(targetResolver.Scheme()+":///"+unresolvedTargetURI, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient(%s) failed: %v", targetResolver.Scheme()+":///"+unresolvedTargetURI, err)
}
defer conn.Close()
// Send an empty RPC to the backend through the proxy.
client := testgrpc.NewTestServiceClient(conn)
if _, err := client.EmptyCall(ctx, &testgrpc.Empty{}); err != nil {
t.Fatalf("EmptyCall() failed: %v", err)
}
if !proxyCalled {
t.Fatalf("Proxy not connected")
}
}
// Tests the scenario where grpc.NewClient is used with the default "dns"
// resolver and the dial option grpc.WithLocalDNSResolution() is set,
// enabling target resolution on the client. The test verifies that target
// resolution happens on the client by sending resolved target URI in HTTP
// CONNECT request, the proxy URI is resolved correctly, and the connection is
// successfully established with the backend server through the proxy.
func (s) TestNewClientWithProxyAndTargetResolutionEnabled(t *testing.T) {
backend := startBackendServer(t)
unresolvedTargetURI := fmt.Sprintf("localhost:%d", testutils.ParsePort(t, backend.Address))
proxyCalled := false
reqCheck := func(req *http.Request) {
proxyCalled = true
host, _, err := net.SplitHostPort(req.URL.Host)
if err != nil {
t.Error(err)
}
if got, want := isIPAddr(host), true; got != want {
t.Errorf("isIPAddr(%q) = %t, want = %t", host, got, want)
}
}
pServer := proxyserver.New(t, reqCheck, false)
// Overwrite the function in the test and restore them in defer.
hpfe := func(req *http.Request) (*url.URL, error) {
if req.URL.Host == unresolvedTargetURI {
return &url.URL{
Scheme: "https",
Host: pServer.Addr,
}, nil
}
t.Errorf("Unexpected request host to proxy: %s want %s", req.URL.Host, unresolvedTargetURI)
return nil, nil
}
orighpfe := delegatingresolver.HTTPSProxyFromEnvironment
delegatingresolver.HTTPSProxyFromEnvironment = hpfe
defer func() { delegatingresolver.HTTPSProxyFromEnvironment = orighpfe }()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
conn, err := grpc.NewClient(unresolvedTargetURI, grpc.WithLocalDNSResolution(), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient(%s) failed: %v", unresolvedTargetURI, err)
}
defer conn.Close()
// Send an empty RPC to the backend through the proxy.
client := testgrpc.NewTestServiceClient(conn)
if _, err := client.EmptyCall(ctx, &testgrpc.Empty{}); err != nil {
t.Fatalf("EmptyCall failed: %v", err)
}
if !proxyCalled {
t.Fatalf("Proxy not connected")
}
}
// Tests the scenario where grpc.NewClient is used with grpc.WithNoProxy() set,
// explicitly disabling proxy usage. The test verifies that the client does not
// dial the proxy but directly connects to the backend server. It also checks
// that the proxy resolution function is not called and that the proxy server
// never receives a connection request.
func (s) TestNewClientWithNoProxy(t *testing.T) {
backend := startBackendServer(t)
unresolvedTargetURI := fmt.Sprintf("localhost:%d", testutils.ParsePort(t, backend.Address))
reqCheck := func(_ *http.Request) { t.Error("proxy server should not have received a Connect request") }
pServer := proxyserver.New(t, reqCheck, false)
// Overwrite the function in the test and restore them in defer.
hpfe := func(req *http.Request) (*url.URL, error) {
if req.URL.Host == unresolvedTargetURI {
return &url.URL{
Scheme: "https",
Host: pServer.Addr,
}, nil
}
t.Errorf("Unexpected request host to proxy: %s want %s", req.URL.Host, unresolvedTargetURI)
return nil, nil
}
orighpfe := delegatingresolver.HTTPSProxyFromEnvironment
delegatingresolver.HTTPSProxyFromEnvironment = hpfe
defer func() { delegatingresolver.HTTPSProxyFromEnvironment = orighpfe }()
dopts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithNoProxy(), // Disable proxy.
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
conn, err := grpc.NewClient(unresolvedTargetURI, dopts...)
if err != nil {
t.Fatalf("grpc.NewClient(%s) failed: %v", unresolvedTargetURI, err)
}
defer conn.Close()
// Create a test service client and make an RPC call.
client := testgrpc.NewTestServiceClient(conn)
if _, err := client.EmptyCall(ctx, &testgrpc.Empty{}); err != nil {
t.Fatalf("EmptyCall() failed: %v", err)
}
}
// Tests the scenario where grpc.NewClient is used with grpc.WithContextDialer()
// set. The test verifies that the client bypasses proxy dialing and uses the
// custom dialer instead. It ensures that the proxy server is never dialed, the
// proxy resolution function is not triggered, and the custom dialer is invoked
// as expected.
func (s) TestNewClientWithContextDialer(t *testing.T) {
backend := startBackendServer(t)
unresolvedTargetURI := fmt.Sprintf("localhost:%d", testutils.ParsePort(t, backend.Address))
reqCheck := func(_ *http.Request) { t.Error("proxy server should not have received a Connect request") }
pServer := proxyserver.New(t, reqCheck, false)
// Overwrite the function in the test and restore them in defer.
hpfe := func(req *http.Request) (*url.URL, error) {
if req.URL.Host == unresolvedTargetURI {
return &url.URL{
Scheme: "https",
Host: pServer.Addr,
}, nil
}
t.Errorf("Unexpected request host to proxy: %s want %s", req.URL.Host, unresolvedTargetURI)
return nil, nil
}
orighpfe := delegatingresolver.HTTPSProxyFromEnvironment
delegatingresolver.HTTPSProxyFromEnvironment = hpfe
defer func() { delegatingresolver.HTTPSProxyFromEnvironment = orighpfe }()
// Create a custom dialer that directly dials the backend.
customDialer := func(_ context.Context, unresolvedTargetURI string) (net.Conn, error) {
return net.Dial("tcp", unresolvedTargetURI)
}
dopts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithContextDialer(customDialer), // Use a custom dialer.
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
conn, err := grpc.NewClient(unresolvedTargetURI, dopts...)
if err != nil {
t.Fatalf("grpc.NewClient(%s) failed: %v", unresolvedTargetURI, err)
}
defer conn.Close()
client := testgrpc.NewTestServiceClient(conn)
if _, err := client.EmptyCall(ctx, &testgrpc.Empty{}); err != nil {
t.Fatalf("EmptyCall() failed: %v", err)
}
}
// Tests the scenario where grpc.NewClient is used with the default DNS resolver
// for targetURI and a proxy. The test verifies that the client connects to the
// proxy server, sends the unresolved target URI in the HTTP CONNECT request,
// and successfully connects to the backend. Additionally, it checks that the
// correct user information is included in the Proxy-Authorization header of
// the CONNECT request. The test also ensures that target resolution does not
// happen on the client.
func (s) TestBasicAuthInNewClientWithProxy(t *testing.T) {
unresolvedTargetURI := "example.test"
const (
user = "notAUser"
password = "notAPassword"
)
proxyCalled := false
reqCheck := func(req *http.Request) {
proxyCalled = true
if got, want := req.URL.Host, "example.test"; got != want {
t.Errorf(" Unexpected request host: %s , want = %s ", got, want)
}
wantProxyAuthStr := "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+password))
if got := req.Header.Get("Proxy-Authorization"); got != wantProxyAuthStr {
gotDecoded, err := base64.StdEncoding.DecodeString(got)
if err != nil {
t.Errorf("failed to decode Proxy-Authorization header: %v", err)
}
wantDecoded, _ := base64.StdEncoding.DecodeString(wantProxyAuthStr)
t.Errorf("unexpected auth %q (%q), want %q (%q)", got, gotDecoded, wantProxyAuthStr, wantDecoded)
}
}
pServer := proxyserver.New(t, reqCheck, false)
t.Setenv("HTTPS_PROXY", user+":"+password+"@"+pServer.Addr)
// Use the httpproxy package functions instead of `http.ProxyFromEnvironment`
// because the latter reads proxy-related environment variables only once at
// initialization. This behavior causes issues when running test multiple
// times, as changes to environment variables during tests would be ignored.
// By using `httpproxy.FromEnvironment()`, we ensure proxy settings are read dynamically.
origHTTPSProxyFromEnvironment := delegatingresolver.HTTPSProxyFromEnvironment
delegatingresolver.HTTPSProxyFromEnvironment = func(req *http.Request) (*url.URL, error) {
return httpproxy.FromEnvironment().ProxyFunc()(req.URL)
}
defer func() {
delegatingresolver.HTTPSProxyFromEnvironment = origHTTPSProxyFromEnvironment
}()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
conn, err := grpc.NewClient(unresolvedTargetURI, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient(%s) failed: %v", unresolvedTargetURI, err)
}
defer conn.Close()
// Send an empty RPC to the backend through the proxy.
client := testgrpc.NewTestServiceClient(conn)
client.EmptyCall(ctx, &testgrpc.Empty{})
if !proxyCalled {
t.Fatalf("Proxy not connected")
}
}

View File

@ -22,126 +22,39 @@
package transport
import (
"bufio"
"bytes"
"context"
"encoding/base64"
"fmt"
"io"
"net"
"net/http"
"net/url"
"net/netip"
"testing"
"time"
"google.golang.org/grpc/internal/proxyattributes"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/internal/testutils/proxyserver"
"google.golang.org/grpc/resolver"
)
const (
envTestAddr = "1.2.3.4:8080"
envProxyAddr = "2.3.4.5:7687"
)
// overwriteAndRestore overwrite function httpProxyFromEnvironment and
// returns a function to restore the default values.
func overwrite(hpfe func(req *http.Request) (*url.URL, error)) func() {
backHPFE := httpProxyFromEnvironment
httpProxyFromEnvironment = hpfe
return func() {
httpProxyFromEnvironment = backHPFE
}
}
type proxyServer struct {
t *testing.T
lis net.Listener
in net.Conn
out net.Conn
requestCheck func(*http.Request) error
}
func (p *proxyServer) run(waitForServerHello bool) {
in, err := p.lis.Accept()
func (s) TestHTTPConnectWithServerHello(t *testing.T) {
serverMessage := []byte("server-hello")
blis, err := testutils.LocalTCPListener()
if err != nil {
return
t.Fatalf("failed to listen: %v", err)
}
p.in = in
req, err := http.ReadRequest(bufio.NewReader(in))
if err != nil {
p.t.Errorf("failed to read CONNECT req: %v", err)
return
}
if err := p.requestCheck(req); err != nil {
resp := http.Response{StatusCode: http.StatusMethodNotAllowed}
resp.Write(p.in)
p.in.Close()
p.t.Errorf("get wrong CONNECT req: %+v, error: %v", req, err)
return
}
out, err := net.Dial("tcp", req.URL.Host)
if err != nil {
p.t.Errorf("failed to dial to server: %v", err)
return
}
out.SetDeadline(time.Now().Add(defaultTestTimeout))
resp := http.Response{StatusCode: http.StatusOK, Proto: "HTTP/1.0"}
var buf bytes.Buffer
resp.Write(&buf)
if waitForServerHello {
// Batch the first message from the server with the http connect
// response. This is done to test the cases in which the grpc client has
// the response to the connect request and proxied packets from the
// destination server when it reads the transport.
b := make([]byte, 50)
bytesRead, err := out.Read(b)
if err != nil {
p.t.Errorf("Got error while reading server hello: %v", err)
in.Close()
out.Close()
return
reqCheck := func(req *http.Request) {
if req.Method != http.MethodConnect {
t.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
}
host, _, err := net.SplitHostPort(req.URL.Host)
if err != nil {
t.Error(err)
}
_, err = netip.ParseAddr(host)
if err != nil {
t.Error(err)
}
buf.Write(b[0:bytesRead])
}
p.in.Write(buf.Bytes())
p.out = out
go io.Copy(p.in, p.out)
go io.Copy(p.out, p.in)
}
func (p *proxyServer) stop() {
p.lis.Close()
if p.in != nil {
p.in.Close()
}
if p.out != nil {
p.out.Close()
}
}
type testArgs struct {
proxyURLModify func(*url.URL) *url.URL
proxyReqCheck func(*http.Request) error
serverMessage []byte
}
func testHTTPConnect(t *testing.T, args testArgs) {
plis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
p := &proxyServer{
t: t,
lis: plis,
requestCheck: args.proxyReqCheck,
}
go p.run(len(args.serverMessage) > 0)
defer p.stop()
blis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
pServer := proxyserver.New(t, reqCheck, true)
msg := []byte{4, 3, 5, 2}
recvBuf := make([]byte, len(msg))
@ -153,21 +66,15 @@ func testHTTPConnect(t *testing.T, args testArgs) {
return
}
defer in.Close()
in.Write(args.serverMessage)
in.Write(serverMessage)
in.Read(recvBuf)
done <- nil
}()
// Overwrite the function in the test and restore them in defer.
hpfe := func(*http.Request) (*url.URL, error) {
return args.proxyURLModify(&url.URL{Host: plis.Addr().String()}), nil
}
defer overwrite(hpfe)()
// Dial to proxy server.
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
c, err := proxyDial(ctx, blis.Addr().String(), "test")
c, err := proxyDial(ctx, resolver.Address{Addr: pServer.Addr}, "test", proxyattributes.Options{ConnectAddr: blis.Addr().String()})
if err != nil {
t.Fatalf("HTTP connect Dial failed: %v", err)
}
@ -185,94 +92,14 @@ func testHTTPConnect(t *testing.T, args testArgs) {
t.Fatalf("Received msg: %v, want %v", recvBuf, msg)
}
if len(args.serverMessage) > 0 {
gotServerMessage := make([]byte, len(args.serverMessage))
if len(serverMessage) > 0 {
gotServerMessage := make([]byte, len(serverMessage))
if _, err := c.Read(gotServerMessage); err != nil {
t.Errorf("Got error while reading message from server: %v", err)
return
}
if string(gotServerMessage) != string(args.serverMessage) {
t.Errorf("Message from server: %v, want %v", gotServerMessage, args.serverMessage)
if string(gotServerMessage) != string(serverMessage) {
t.Errorf("Message from server: %v, want %v", gotServerMessage, serverMessage)
}
}
}
func (s) TestHTTPConnect(t *testing.T) {
args := testArgs{
proxyURLModify: func(in *url.URL) *url.URL {
return in
},
proxyReqCheck: func(req *http.Request) error {
if req.Method != http.MethodConnect {
return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
}
return nil
},
}
testHTTPConnect(t, args)
}
func (s) TestHTTPConnectWithServerHello(t *testing.T) {
args := testArgs{
proxyURLModify: func(in *url.URL) *url.URL {
return in
},
proxyReqCheck: func(req *http.Request) error {
if req.Method != http.MethodConnect {
return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
}
return nil
},
serverMessage: []byte("server-hello"),
}
testHTTPConnect(t, args)
}
func (s) TestHTTPConnectBasicAuth(t *testing.T) {
const (
user = "notAUser"
password = "notAPassword"
)
args := testArgs{
proxyURLModify: func(in *url.URL) *url.URL {
in.User = url.UserPassword(user, password)
return in
},
proxyReqCheck: func(req *http.Request) error {
if req.Method != http.MethodConnect {
return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
}
wantProxyAuthStr := "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+password))
if got := req.Header.Get(proxyAuthHeaderKey); got != wantProxyAuthStr {
gotDecoded, _ := base64.StdEncoding.DecodeString(got)
wantDecoded, _ := base64.StdEncoding.DecodeString(wantProxyAuthStr)
return fmt.Errorf("unexpected auth %q (%q), want %q (%q)", got, gotDecoded, wantProxyAuthStr, wantDecoded)
}
return nil
},
}
testHTTPConnect(t, args)
}
func (s) TestMapAddressEnv(t *testing.T) {
// Overwrite the function in the test and restore them in defer.
hpfe := func(req *http.Request) (*url.URL, error) {
if req.URL.Host == envTestAddr {
return &url.URL{
Scheme: "https",
Host: envProxyAddr,
}, nil
}
return nil, nil
}
defer overwrite(hpfe)()
// envTestAddr should be handled by ProxyFromEnvironment.
got, err := mapAddress(envTestAddr)
if err != nil {
t.Error(err)
}
if got.Host != envProxyAddr {
t.Errorf("want %v, got %v", envProxyAddr, got)
}
}

View File

@ -502,8 +502,6 @@ type ConnectOptions struct {
ChannelzParent *channelz.SubChannel
// MaxHeaderListSize sets the max (uncompressed) size of header list that is prepared to be received.
MaxHeaderListSize *uint32
// UseProxy specifies if a proxy should be used.
UseProxy bool
// The mem.BufferPool to use when reading/writing to the wire.
BufferPool mem.BufferPool
}

View File

@ -26,6 +26,7 @@ import (
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/pretty"
"google.golang.org/grpc/internal/resolver/delegatingresolver"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"
)
@ -78,7 +79,16 @@ func (ccr *ccResolverWrapper) start() error {
Authority: ccr.cc.authority,
}
var err error
ccr.resolver, err = ccr.cc.resolverBuilder.Build(ccr.cc.parsedTarget, ccr, opts)
// The delegating resolver is used unless:
// - A custom dialer is provided via WithContextDialer dialoption or
// - Proxy usage is disabled through WithNoProxy dialoption.
// In these cases, the resolver is built based on the scheme of target,
// using the appropriate resolver builder.
if ccr.cc.dopts.copts.Dialer != nil || !ccr.cc.dopts.useProxy {
ccr.resolver, err = ccr.cc.resolverBuilder.Build(ccr.cc.parsedTarget, ccr, opts)
} else {
ccr.resolver, err = delegatingresolver.New(ccr.cc.parsedTarget, ccr, opts, ccr.cc.resolverBuilder, ccr.cc.dopts.enableLocalDNSResolution)
}
errCh <- err
})
return <-errCh