Move all this config gubbins out of the way

There are two distinct configuration layers in this program:  the 'raw'
types provided by the flag library, and the 'validated' types we present
to the rest of the program.  This commit makes that distinction clear,
and internalises some pointer muck from the flag lib.
This commit is contained in:
Saj Goonatilleke 2019-05-15 19:03:53 +10:00
parent bdc39cee65
commit c9b7e27f76
3 changed files with 164 additions and 91 deletions

View File

@ -2,7 +2,7 @@ FROM golang:alpine as builder
RUN apk add git
WORKDIR /go/src/github.com/discourse/discourse-auth-proxy
COPY internal ./internal/
COPY main.go .
COPY *.go ./
RUN go get && go build
FROM alpine:latest

133
config.go Normal file
View File

@ -0,0 +1,133 @@
package main
import (
"fmt"
"net/url"
"os"
"time"
"github.com/namsral/flag"
"github.com/pborman/uuid"
)
type Config struct {
OriginURL *url.URL
ProxyURL *url.URL
ProxyURLString string // memoised - derives from ProxyURL
ListenAddr string
SSOURL *url.URL
SSOURLString string // memoised - derives from SSOURL
SSOSecret string
CookieSecret string
AllowAll bool
BasicAuth string
Whitelist string
UsernameHeader string
GroupsHeader string
Timeout time.Duration
}
func ParseConfig() (*Config, error) {
missing := func(name string) error {
return fmt.Errorf("missing mandatory flag: %s", name)
}
rc := parseRawConfig()
if *rc.OriginURL == "" {
return nil, missing("origin-url")
}
if *rc.ProxyURL == "" {
return nil, missing("proxy-url")
}
if *rc.SSOURL == "" {
return nil, missing("sso-url")
}
if *rc.SSOSecret == "" {
return nil, missing("sso-secret")
}
c := &Config{}
{
u, err := url.Parse(*rc.OriginURL)
if err != nil {
return nil, fmt.Errorf("invalid origin URL: %s", rc.OriginURL)
}
c.OriginURL = u
}
{
u, err := url.Parse(*rc.ProxyURL)
if err != nil {
return nil, fmt.Errorf("invalid proxy URL: %s", rc.ProxyURL)
}
c.ProxyURL = u
c.ProxyURLString = u.String()
}
{
u, err := url.Parse(*rc.SSOURL)
if err != nil {
return nil, fmt.Errorf("invalid SSO URL - should point at Discourse site with enable sso: %s", rc.SSOURL)
}
c.SSOURL = u
c.SSOURLString = u.String()
}
if *rc.ListenURL == "" {
c.ListenAddr = c.ProxyURL.Host
} else {
c.ListenAddr = *rc.ListenURL
}
c.SSOSecret = *rc.SSOSecret
c.AllowAll = *rc.AllowAll
c.BasicAuth = *rc.BasicAuth
c.Whitelist = *rc.Whitelist
c.UsernameHeader = *rc.UsernameHeader
c.GroupsHeader = *rc.GroupsHeader
c.Timeout = time.Duration(*rc.Timeout) * time.Second
c.CookieSecret = uuid.New()
return c, nil
}
type rawConfig struct {
OriginURL *string
ProxyURL *string
ListenURL *string // Not actually a URL. This is a TCP socket address, i.e.: 'host:port'.
SSOURL *string
SSOSecret *string
AllowAll *bool
BasicAuth *string
Whitelist *string
UsernameHeader *string
GroupsHeader *string
Timeout *int
}
func parseRawConfig() *rawConfig {
c := &rawConfig{
OriginURL: flag.String("origin-url", "", "origin to proxy eg: http://localhost:2002"),
ProxyURL: flag.String("proxy-url", "", "outer url of this host eg: http://secrets.example.com"),
ListenURL: flag.String("listen-url", "", "url to listen on eg: localhost:2001. leave blank to set equal to proxy-url"),
SSOURL: flag.String("sso-url", "", "SSO endpoint eg: http://discourse.forum.com"),
SSOSecret: flag.String("sso-secret", "", "SSO secret for origin"),
AllowAll: flag.Bool("allow-all", false, "allow all discourse users (default: admin users only)"),
BasicAuth: flag.String("basic-auth", "", "HTTP Basic authentication credentials to let through directly"),
Whitelist: flag.String("whitelist", "", "Path which does not require authorization"),
UsernameHeader: flag.String("username-header", "Discourse-User-Name", "Request header to pass authenticated username into"),
GroupsHeader: flag.String("groups-header", "Discourse-User-Groups", "Request header to pass authenticated groups into"),
Timeout: flag.Int("timeout", 10, "Read/write timeout"),
}
flag.Parse()
return c
}
func usage(err error) {
flag.Usage()
if err != nil {
fmt.Fprintln(os.Stderr, "")
fmt.Fprintln(os.Stderr, err)
}
os.Exit(2)
}

120
main.go
View File

@ -11,102 +11,43 @@ import (
"net/http"
"net/http/httputil"
"net/url"
"os"
"strings"
"sync"
"time"
"github.com/golang/groupcache/lru"
"github.com/namsral/flag"
"github.com/pborman/uuid"
"github.com/discourse/discourse-auth-proxy/internal/httpproxy"
)
var nonceCache = lru.New(20)
var nonceMutex = &sync.Mutex{}
var (
config *Config
type Config struct {
ListenUri *string
ProxyUri *string
OriginUri *string
SsoSecret *string
SsoUri *string
BasicAuth *string
UsernameHeader *string
GroupsHeader *string
Timeout *int
CookieSecret string
AllowAll *bool
Whitelist *string
}
nonceCache = lru.New(20)
nonceMutex = &sync.Mutex{}
)
func main() {
config := new(Config)
config.ListenUri = flag.String("listen-url", "", "uri to listen on eg: localhost:2001. leave blank to set equal to proxy-url")
config.ProxyUri = flag.String("proxy-url", "", "outer url of this host eg: http://secrets.example.com")
config.OriginUri = flag.String("origin-url", "", "origin to proxy eg: http://localhost:2002")
config.SsoSecret = flag.String("sso-secret", "", "SSO secret for origin")
config.SsoUri = flag.String("sso-url", "", "SSO endpoint eg: http://discourse.forum.com")
config.AllowAll = flag.Bool("allow-all", false, "allow all discourse users (default: admin users only)")
config.BasicAuth = flag.String("basic-auth", "", "HTTP Basic authentication credentials to let through directly")
config.UsernameHeader = flag.String("username-header", "Discourse-User-Name", "Request header to pass authenticated username into")
config.GroupsHeader = flag.String("groups-header", "Discourse-User-Groups", "Request header to pass authenticated groups into")
config.Timeout = flag.Int("timeout", 10, "Read/write timeout")
config.Whitelist = flag.String("whitelist", "", "Path which does not require authorization")
flag.Parse()
originUrl, err := url.Parse(*config.OriginUri)
if err != nil {
flag.Usage()
log.Fatal("invalid origin url")
{
var err error
config, err = ParseConfig()
if err != nil {
usage(err)
}
}
_, err = url.Parse(*config.SsoUri)
if err != nil {
flag.Usage()
log.Fatal("invalid sso url, should point at Discourse site with enable sso")
}
proxyUrl, err2 := url.Parse(*config.ProxyUri)
if err2 != nil {
flag.Usage()
log.Fatal("invalid proxy uri")
}
if *config.ListenUri == "" {
log.Println("Defaulting to listening on the proxy url")
*config.ListenUri = proxyUrl.Host
}
if *config.ProxyUri == "" || *config.OriginUri == "" || *config.SsoSecret == "" || *config.SsoUri == "" || *config.ListenUri == "" {
flag.Usage()
os.Exit(1)
return
}
if *config.BasicAuth != "" {
log.Println("Enabling basic auth support")
}
config.CookieSecret = uuid.New()
dnssrv := httpproxy.NewDNSSRVBackend(originUrl)
dnssrv := httpproxy.NewDNSSRVBackend(config.OriginURL)
go dnssrv.Lookup(context.Background(), 50*time.Second, 10*time.Second, 5*time.Minute)
proxy := &httputil.ReverseProxy{Director: dnssrv.Director}
handler := authProxyHandler(proxy, config)
server := &http.Server{
Addr: *config.ListenUri,
Addr: config.ListenAddr,
Handler: handler,
ReadTimeout: time.Duration(*config.Timeout) * time.Second,
WriteTimeout: time.Duration(*config.Timeout) * time.Second,
ReadTimeout: config.Timeout,
WriteTimeout: config.Timeout,
MaxHeaderBytes: 1 << 20,
}
@ -115,19 +56,18 @@ func main() {
func authProxyHandler(handler http.Handler, config *Config) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if checkWhitelist(handler, r, w, config) {
if checkWhitelist(handler, r, w) {
return
}
if checkAuthorizationHeader(handler, r, w, config) {
if checkAuthorizationHeader(handler, r, w) {
return
}
redirectIfNoCookie(handler, r, w, config)
redirectIfNoCookie(handler, r, w)
})
}
func checkAuthorizationHeader(handler http.Handler, r *http.Request, w http.ResponseWriter, config *Config) bool {
if *config.BasicAuth == "" {
// Can't auth if we don't have anything to auth against
func checkAuthorizationHeader(handler http.Handler, r *http.Request, w http.ResponseWriter) bool {
if config.BasicAuth == "" {
return false
}
@ -140,13 +80,13 @@ func checkAuthorizationHeader(handler http.Handler, r *http.Request, w http.Resp
log.Println("Received request with basic auth creds")
b_creds, _ := base64.StdEncoding.DecodeString(auth_header[6:])
creds := string(b_creds)
if creds == *config.BasicAuth {
if creds == config.BasicAuth {
colon_idx := strings.Index(creds, ":")
if colon_idx == -1 {
return false
}
username := creds[0:colon_idx]
r.Header.Set(*config.UsernameHeader, username)
r.Header.Set(config.UsernameHeader, username)
r.Header.Del("Authorization")
log.Printf("Accepted basic auth creds for %s\n", username)
handler.ServeHTTP(w, r)
@ -159,8 +99,8 @@ func checkAuthorizationHeader(handler http.Handler, r *http.Request, w http.Resp
return false
}
func checkWhitelist(handler http.Handler, r *http.Request, w http.ResponseWriter, config *Config) bool {
if r.URL.Path == *(config.Whitelist) {
func checkWhitelist(handler http.Handler, r *http.Request, w http.ResponseWriter) bool {
if r.URL.Path == config.Whitelist {
handler.ServeHTTP(w, r)
return true
}
@ -168,7 +108,7 @@ func checkWhitelist(handler http.Handler, r *http.Request, w http.ResponseWriter
return false
}
func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWriter, config *Config) {
func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWriter) {
cookie, err := r.Cookie("__discourse_proxy")
var username, groups string
@ -177,8 +117,8 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr
}
if err == nil {
r.Header.Set(*config.UsernameHeader, username)
r.Header.Set(*config.GroupsHeader, groups)
r.Header.Set(config.UsernameHeader, username)
r.Header.Set(config.GroupsHeader, groups)
handler.ServeHTTP(w, r)
return
}
@ -188,7 +128,7 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr
sig := query.Get("sig")
if len(sso) == 0 {
url := *config.SsoUri + "/session/sso_provider?" + sso_payload(*config.SsoSecret, *config.ProxyUri, r.URL.String())
url := config.SSOURLString + "/session/sso_provider?" + sso_payload(config.SSOSecret, config.ProxyURLString, r.URL.String())
http.Redirect(w, r, url, 302)
} else {
decoded, _ := base64.StdEncoding.DecodeString(sso)
@ -202,8 +142,8 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr
groupsArray := strings.Split(groups[0], ",")
if len(nonce) > 0 && len(admin) > 0 && len(username) > 0 && (admin[0] == "true" || *config.AllowAll) {
returnUrl, err := getReturnUrl(*config.SsoSecret, sso, sig, nonce[0])
if len(nonce) > 0 && len(admin) > 0 && len(username) > 0 && (admin[0] == "true" || config.AllowAll) {
returnUrl, err := getReturnUrl(config.SSOSecret, sso, sig, nonce[0])
if err != nil {
fmt.Fprintf(w, "Invalid request")