proxy: support basic authentication (#2426)

Proxy-Authorization: https://tools.ietf.org/html/rfc7235#section-4.4
The 'Basic' HTTP Authentication Scheme: https://tools.ietf.org/html/rfc7617

updates #2422
This commit is contained in:
Menghan Li 2018-11-13 14:59:16 -08:00 committed by GitHub
parent 04ea82009c
commit 63ae68c968
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 87 additions and 19 deletions

View File

@ -21,6 +21,7 @@ package grpc
import (
"bufio"
"context"
"encoding/base64"
"errors"
"fmt"
"io"
@ -30,6 +31,8 @@ import (
"net/url"
)
const proxyAuthHeaderKey = "Proxy-Authorization"
var (
// errDisabled indicates that proxy is disabled for the address.
errDisabled = errors.New("proxy is disabled for the address")
@ -37,7 +40,7 @@ var (
httpProxyFromEnvironment = http.ProxyFromEnvironment
)
func mapAddress(ctx context.Context, address string) (string, error) {
func mapAddress(ctx context.Context, address string) (*url.URL, error) {
req := &http.Request{
URL: &url.URL{
Scheme: "https",
@ -46,12 +49,12 @@ func mapAddress(ctx context.Context, address string) (string, error) {
}
url, err := httpProxyFromEnvironment(req)
if err != nil {
return "", err
return nil, err
}
if url == nil {
return "", errDisabled
return nil, errDisabled
}
return url.Host, nil
return url, nil
}
// To read a response from a net.Conn, http.ReadResponse() takes a bufio.Reader.
@ -68,18 +71,28 @@ func (c *bufConn) Read(b []byte) (int, error) {
return c.r.Read(b)
}
func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, addr string) (_ net.Conn, err error) {
func basicAuth(username, password string) string {
auth := username + ":" + password
return base64.StdEncoding.EncodeToString([]byte(auth))
}
func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr string, proxyURL *url.URL) (_ net.Conn, err error) {
defer func() {
if err != nil {
conn.Close()
}
}()
req := (&http.Request{
req := &http.Request{
Method: http.MethodConnect,
URL: &url.URL{Host: addr},
URL: &url.URL{Host: backendAddr},
Header: map[string][]string{"User-Agent": {grpcUA}},
})
}
if t := proxyURL.User; t != nil {
u := t.Username()
p, _ := t.Password()
req.Header.Add(proxyAuthHeaderKey, "Basic "+basicAuth(u, p))
}
if err := sendHTTPRequest(ctx, req, conn); err != nil {
return nil, fmt.Errorf("failed to write the HTTP request: %v", err)
@ -107,22 +120,24 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, addr string) (_
// provided dialer, does HTTP CONNECT handshake and returns the connection.
func newProxyDialer(dialer func(context.Context, string) (net.Conn, error)) func(context.Context, string) (net.Conn, error) {
return func(ctx context.Context, addr string) (conn net.Conn, err error) {
var skipHandshake bool
newAddr, err := mapAddress(ctx, addr)
var newAddr string
proxyURL, err := mapAddress(ctx, addr)
if err != nil {
if err != errDisabled {
return nil, err
}
skipHandshake = true
newAddr = addr
} else {
newAddr = proxyURL.Host
}
conn, err = dialer(ctx, newAddr)
if err != nil {
return
}
if !skipHandshake {
conn, err = doHTTPConnectHandshake(ctx, conn, addr)
if proxyURL != nil {
// proxy is disabled if proxyURL is nil.
conn, err = doHTTPConnectHandshake(ctx, conn, addr, proxyURL)
}
return
}

View File

@ -23,6 +23,8 @@ package grpc
import (
"bufio"
"context"
"encoding/base64"
"fmt"
"io"
"net"
"net/http"
@ -53,6 +55,8 @@ type proxyServer struct {
lis net.Listener
in net.Conn
out net.Conn
requestCheck func(*http.Request) error
}
func (p *proxyServer) run() {
@ -67,11 +71,11 @@ func (p *proxyServer) run() {
p.t.Errorf("failed to read CONNECT req: %v", err)
return
}
if req.Method != http.MethodConnect || req.UserAgent() != grpcUA {
if err := p.requestCheck(req); err != nil {
resp := http.Response{StatusCode: http.StatusMethodNotAllowed}
resp.Write(p.in)
p.in.Close()
p.t.Errorf("get wrong CONNECT req: %+v", req)
p.t.Errorf("get wrong CONNECT req: %+v, error: %v", req, err)
return
}
@ -97,13 +101,17 @@ func (p *proxyServer) stop() {
}
}
func TestHTTPConnect(t *testing.T) {
func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxyReqCheck func(*http.Request) error) {
defer leakcheck.Check(t)
plis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
p := &proxyServer{t: t, lis: plis}
p := &proxyServer{
t: t,
lis: plis,
requestCheck: proxyReqCheck,
}
go p.run()
defer p.stop()
@ -128,7 +136,7 @@ func TestHTTPConnect(t *testing.T) {
// Overwrite the function in the test and restore them in defer.
hpfe := func(req *http.Request) (*url.URL, error) {
return &url.URL{Host: plis.Addr().String()}, nil
return proxyURLModify(&url.URL{Host: plis.Addr().String()}), nil
}
defer overwrite(hpfe)()
@ -157,6 +165,51 @@ func TestHTTPConnect(t *testing.T) {
}
}
func TestHTTPConnect(t *testing.T) {
testHTTPConnect(t,
func(in *url.URL) *url.URL {
return in
},
func(req *http.Request) error {
if req.Method != http.MethodConnect {
return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
}
if req.UserAgent() != grpcUA {
return fmt.Errorf("unexpect user agent %q, want %q", req.UserAgent(), grpcUA)
}
return nil
},
)
}
func TestHTTPConnectBasicAuth(t *testing.T) {
const (
user = "notAUser"
password = "notAPassword"
)
testHTTPConnect(t,
func(in *url.URL) *url.URL {
in.User = url.UserPassword(user, password)
return in
},
func(req *http.Request) error {
if req.Method != http.MethodConnect {
return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
}
if req.UserAgent() != grpcUA {
return fmt.Errorf("unexpect user agent %q, want %q", req.UserAgent(), grpcUA)
}
wantProxyAuthStr := "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+password))
if got := req.Header.Get(proxyAuthHeaderKey); got != wantProxyAuthStr {
gotDecoded, _ := base64.StdEncoding.DecodeString(got)
wantDecoded, _ := base64.StdEncoding.DecodeString(wantProxyAuthStr)
return fmt.Errorf("unexpected auth %q (%q), want %q (%q)", got, gotDecoded, wantProxyAuthStr, wantDecoded)
}
return nil
},
)
}
func TestMapAddressEnv(t *testing.T) {
defer leakcheck.Check(t)
// Overwrite the function in the test and restore them in defer.
@ -176,7 +229,7 @@ func TestMapAddressEnv(t *testing.T) {
if err != nil {
t.Error(err)
}
if got != envProxyAddr {
if got.Host != envProxyAddr {
t.Errorf("want %v, got %v", envProxyAddr, got)
}
}