mirror of https://github.com/grpc/grpc-go.git
proxy: support basic authentication (#2426)
Proxy-Authorization: https://tools.ietf.org/html/rfc7235#section-4.4 The 'Basic' HTTP Authentication Scheme: https://tools.ietf.org/html/rfc7617 updates #2422
This commit is contained in:
parent
04ea82009c
commit
63ae68c968
41
proxy.go
41
proxy.go
|
@ -21,6 +21,7 @@ package grpc
|
|||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -30,6 +31,8 @@ import (
|
|||
"net/url"
|
||||
)
|
||||
|
||||
const proxyAuthHeaderKey = "Proxy-Authorization"
|
||||
|
||||
var (
|
||||
// errDisabled indicates that proxy is disabled for the address.
|
||||
errDisabled = errors.New("proxy is disabled for the address")
|
||||
|
@ -37,7 +40,7 @@ var (
|
|||
httpProxyFromEnvironment = http.ProxyFromEnvironment
|
||||
)
|
||||
|
||||
func mapAddress(ctx context.Context, address string) (string, error) {
|
||||
func mapAddress(ctx context.Context, address string) (*url.URL, error) {
|
||||
req := &http.Request{
|
||||
URL: &url.URL{
|
||||
Scheme: "https",
|
||||
|
@ -46,12 +49,12 @@ func mapAddress(ctx context.Context, address string) (string, error) {
|
|||
}
|
||||
url, err := httpProxyFromEnvironment(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
if url == nil {
|
||||
return "", errDisabled
|
||||
return nil, errDisabled
|
||||
}
|
||||
return url.Host, nil
|
||||
return url, nil
|
||||
}
|
||||
|
||||
// To read a response from a net.Conn, http.ReadResponse() takes a bufio.Reader.
|
||||
|
@ -68,18 +71,28 @@ func (c *bufConn) Read(b []byte) (int, error) {
|
|||
return c.r.Read(b)
|
||||
}
|
||||
|
||||
func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, addr string) (_ net.Conn, err error) {
|
||||
func basicAuth(username, password string) string {
|
||||
auth := username + ":" + password
|
||||
return base64.StdEncoding.EncodeToString([]byte(auth))
|
||||
}
|
||||
|
||||
func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr string, proxyURL *url.URL) (_ net.Conn, err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
req := (&http.Request{
|
||||
req := &http.Request{
|
||||
Method: http.MethodConnect,
|
||||
URL: &url.URL{Host: addr},
|
||||
URL: &url.URL{Host: backendAddr},
|
||||
Header: map[string][]string{"User-Agent": {grpcUA}},
|
||||
})
|
||||
}
|
||||
if t := proxyURL.User; t != nil {
|
||||
u := t.Username()
|
||||
p, _ := t.Password()
|
||||
req.Header.Add(proxyAuthHeaderKey, "Basic "+basicAuth(u, p))
|
||||
}
|
||||
|
||||
if err := sendHTTPRequest(ctx, req, conn); err != nil {
|
||||
return nil, fmt.Errorf("failed to write the HTTP request: %v", err)
|
||||
|
@ -107,22 +120,24 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, addr string) (_
|
|||
// provided dialer, does HTTP CONNECT handshake and returns the connection.
|
||||
func newProxyDialer(dialer func(context.Context, string) (net.Conn, error)) func(context.Context, string) (net.Conn, error) {
|
||||
return func(ctx context.Context, addr string) (conn net.Conn, err error) {
|
||||
var skipHandshake bool
|
||||
newAddr, err := mapAddress(ctx, addr)
|
||||
var newAddr string
|
||||
proxyURL, err := mapAddress(ctx, addr)
|
||||
if err != nil {
|
||||
if err != errDisabled {
|
||||
return nil, err
|
||||
}
|
||||
skipHandshake = true
|
||||
newAddr = addr
|
||||
} else {
|
||||
newAddr = proxyURL.Host
|
||||
}
|
||||
|
||||
conn, err = dialer(ctx, newAddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if !skipHandshake {
|
||||
conn, err = doHTTPConnectHandshake(ctx, conn, addr)
|
||||
if proxyURL != nil {
|
||||
// proxy is disabled if proxyURL is nil.
|
||||
conn, err = doHTTPConnectHandshake(ctx, conn, addr, proxyURL)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -23,6 +23,8 @@ package grpc
|
|||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -53,6 +55,8 @@ type proxyServer struct {
|
|||
lis net.Listener
|
||||
in net.Conn
|
||||
out net.Conn
|
||||
|
||||
requestCheck func(*http.Request) error
|
||||
}
|
||||
|
||||
func (p *proxyServer) run() {
|
||||
|
@ -67,11 +71,11 @@ func (p *proxyServer) run() {
|
|||
p.t.Errorf("failed to read CONNECT req: %v", err)
|
||||
return
|
||||
}
|
||||
if req.Method != http.MethodConnect || req.UserAgent() != grpcUA {
|
||||
if err := p.requestCheck(req); err != nil {
|
||||
resp := http.Response{StatusCode: http.StatusMethodNotAllowed}
|
||||
resp.Write(p.in)
|
||||
p.in.Close()
|
||||
p.t.Errorf("get wrong CONNECT req: %+v", req)
|
||||
p.t.Errorf("get wrong CONNECT req: %+v, error: %v", req, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -97,13 +101,17 @@ func (p *proxyServer) stop() {
|
|||
}
|
||||
}
|
||||
|
||||
func TestHTTPConnect(t *testing.T) {
|
||||
func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxyReqCheck func(*http.Request) error) {
|
||||
defer leakcheck.Check(t)
|
||||
plis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
p := &proxyServer{t: t, lis: plis}
|
||||
p := &proxyServer{
|
||||
t: t,
|
||||
lis: plis,
|
||||
requestCheck: proxyReqCheck,
|
||||
}
|
||||
go p.run()
|
||||
defer p.stop()
|
||||
|
||||
|
@ -128,7 +136,7 @@ func TestHTTPConnect(t *testing.T) {
|
|||
|
||||
// Overwrite the function in the test and restore them in defer.
|
||||
hpfe := func(req *http.Request) (*url.URL, error) {
|
||||
return &url.URL{Host: plis.Addr().String()}, nil
|
||||
return proxyURLModify(&url.URL{Host: plis.Addr().String()}), nil
|
||||
}
|
||||
defer overwrite(hpfe)()
|
||||
|
||||
|
@ -157,6 +165,51 @@ func TestHTTPConnect(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestHTTPConnect(t *testing.T) {
|
||||
testHTTPConnect(t,
|
||||
func(in *url.URL) *url.URL {
|
||||
return in
|
||||
},
|
||||
func(req *http.Request) error {
|
||||
if req.Method != http.MethodConnect {
|
||||
return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
|
||||
}
|
||||
if req.UserAgent() != grpcUA {
|
||||
return fmt.Errorf("unexpect user agent %q, want %q", req.UserAgent(), grpcUA)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func TestHTTPConnectBasicAuth(t *testing.T) {
|
||||
const (
|
||||
user = "notAUser"
|
||||
password = "notAPassword"
|
||||
)
|
||||
testHTTPConnect(t,
|
||||
func(in *url.URL) *url.URL {
|
||||
in.User = url.UserPassword(user, password)
|
||||
return in
|
||||
},
|
||||
func(req *http.Request) error {
|
||||
if req.Method != http.MethodConnect {
|
||||
return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
|
||||
}
|
||||
if req.UserAgent() != grpcUA {
|
||||
return fmt.Errorf("unexpect user agent %q, want %q", req.UserAgent(), grpcUA)
|
||||
}
|
||||
wantProxyAuthStr := "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+password))
|
||||
if got := req.Header.Get(proxyAuthHeaderKey); got != wantProxyAuthStr {
|
||||
gotDecoded, _ := base64.StdEncoding.DecodeString(got)
|
||||
wantDecoded, _ := base64.StdEncoding.DecodeString(wantProxyAuthStr)
|
||||
return fmt.Errorf("unexpected auth %q (%q), want %q (%q)", got, gotDecoded, wantProxyAuthStr, wantDecoded)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func TestMapAddressEnv(t *testing.T) {
|
||||
defer leakcheck.Check(t)
|
||||
// Overwrite the function in the test and restore them in defer.
|
||||
|
@ -176,7 +229,7 @@ func TestMapAddressEnv(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if got != envProxyAddr {
|
||||
if got.Host != envProxyAddr {
|
||||
t.Errorf("want %v, got %v", envProxyAddr, got)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue