Move all this config gubbins out of the way
There are two distinct configuration layers in this program: the 'raw' types provided by the flag library, and the 'validated' types we present to the rest of the program. This commit makes that distinction clear, and internalises some pointer muck from the flag lib.
This commit is contained in:
parent
bdc39cee65
commit
c9b7e27f76
|
@ -2,7 +2,7 @@ FROM golang:alpine as builder
|
|||
RUN apk add git
|
||||
WORKDIR /go/src/github.com/discourse/discourse-auth-proxy
|
||||
COPY internal ./internal/
|
||||
COPY main.go .
|
||||
COPY *.go ./
|
||||
RUN go get && go build
|
||||
|
||||
FROM alpine:latest
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/namsral/flag"
|
||||
"github.com/pborman/uuid"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
OriginURL *url.URL
|
||||
ProxyURL *url.URL
|
||||
ProxyURLString string // memoised - derives from ProxyURL
|
||||
ListenAddr string
|
||||
SSOURL *url.URL
|
||||
SSOURLString string // memoised - derives from SSOURL
|
||||
SSOSecret string
|
||||
CookieSecret string
|
||||
AllowAll bool
|
||||
BasicAuth string
|
||||
Whitelist string
|
||||
UsernameHeader string
|
||||
GroupsHeader string
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
func ParseConfig() (*Config, error) {
|
||||
missing := func(name string) error {
|
||||
return fmt.Errorf("missing mandatory flag: %s", name)
|
||||
}
|
||||
|
||||
rc := parseRawConfig()
|
||||
|
||||
if *rc.OriginURL == "" {
|
||||
return nil, missing("origin-url")
|
||||
}
|
||||
if *rc.ProxyURL == "" {
|
||||
return nil, missing("proxy-url")
|
||||
}
|
||||
if *rc.SSOURL == "" {
|
||||
return nil, missing("sso-url")
|
||||
}
|
||||
if *rc.SSOSecret == "" {
|
||||
return nil, missing("sso-secret")
|
||||
}
|
||||
|
||||
c := &Config{}
|
||||
{
|
||||
u, err := url.Parse(*rc.OriginURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid origin URL: %s", rc.OriginURL)
|
||||
}
|
||||
c.OriginURL = u
|
||||
}
|
||||
{
|
||||
u, err := url.Parse(*rc.ProxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid proxy URL: %s", rc.ProxyURL)
|
||||
}
|
||||
c.ProxyURL = u
|
||||
c.ProxyURLString = u.String()
|
||||
}
|
||||
{
|
||||
u, err := url.Parse(*rc.SSOURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid SSO URL - should point at Discourse site with enable sso: %s", rc.SSOURL)
|
||||
}
|
||||
c.SSOURL = u
|
||||
c.SSOURLString = u.String()
|
||||
}
|
||||
|
||||
if *rc.ListenURL == "" {
|
||||
c.ListenAddr = c.ProxyURL.Host
|
||||
} else {
|
||||
c.ListenAddr = *rc.ListenURL
|
||||
}
|
||||
|
||||
c.SSOSecret = *rc.SSOSecret
|
||||
c.AllowAll = *rc.AllowAll
|
||||
c.BasicAuth = *rc.BasicAuth
|
||||
c.Whitelist = *rc.Whitelist
|
||||
c.UsernameHeader = *rc.UsernameHeader
|
||||
c.GroupsHeader = *rc.GroupsHeader
|
||||
c.Timeout = time.Duration(*rc.Timeout) * time.Second
|
||||
|
||||
c.CookieSecret = uuid.New()
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
type rawConfig struct {
|
||||
OriginURL *string
|
||||
ProxyURL *string
|
||||
ListenURL *string // Not actually a URL. This is a TCP socket address, i.e.: 'host:port'.
|
||||
SSOURL *string
|
||||
SSOSecret *string
|
||||
AllowAll *bool
|
||||
BasicAuth *string
|
||||
Whitelist *string
|
||||
UsernameHeader *string
|
||||
GroupsHeader *string
|
||||
Timeout *int
|
||||
}
|
||||
|
||||
func parseRawConfig() *rawConfig {
|
||||
c := &rawConfig{
|
||||
OriginURL: flag.String("origin-url", "", "origin to proxy eg: http://localhost:2002"),
|
||||
ProxyURL: flag.String("proxy-url", "", "outer url of this host eg: http://secrets.example.com"),
|
||||
ListenURL: flag.String("listen-url", "", "url to listen on eg: localhost:2001. leave blank to set equal to proxy-url"),
|
||||
SSOURL: flag.String("sso-url", "", "SSO endpoint eg: http://discourse.forum.com"),
|
||||
SSOSecret: flag.String("sso-secret", "", "SSO secret for origin"),
|
||||
AllowAll: flag.Bool("allow-all", false, "allow all discourse users (default: admin users only)"),
|
||||
BasicAuth: flag.String("basic-auth", "", "HTTP Basic authentication credentials to let through directly"),
|
||||
Whitelist: flag.String("whitelist", "", "Path which does not require authorization"),
|
||||
UsernameHeader: flag.String("username-header", "Discourse-User-Name", "Request header to pass authenticated username into"),
|
||||
GroupsHeader: flag.String("groups-header", "Discourse-User-Groups", "Request header to pass authenticated groups into"),
|
||||
Timeout: flag.Int("timeout", 10, "Read/write timeout"),
|
||||
}
|
||||
flag.Parse()
|
||||
return c
|
||||
}
|
||||
|
||||
func usage(err error) {
|
||||
flag.Usage()
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
}
|
||||
os.Exit(2)
|
||||
}
|
118
main.go
118
main.go
|
@ -11,102 +11,43 @@ import (
|
|||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang/groupcache/lru"
|
||||
"github.com/namsral/flag"
|
||||
"github.com/pborman/uuid"
|
||||
|
||||
"github.com/discourse/discourse-auth-proxy/internal/httpproxy"
|
||||
)
|
||||
|
||||
var nonceCache = lru.New(20)
|
||||
var nonceMutex = &sync.Mutex{}
|
||||
var (
|
||||
config *Config
|
||||
|
||||
type Config struct {
|
||||
ListenUri *string
|
||||
ProxyUri *string
|
||||
OriginUri *string
|
||||
SsoSecret *string
|
||||
SsoUri *string
|
||||
BasicAuth *string
|
||||
UsernameHeader *string
|
||||
GroupsHeader *string
|
||||
Timeout *int
|
||||
CookieSecret string
|
||||
AllowAll *bool
|
||||
Whitelist *string
|
||||
}
|
||||
nonceCache = lru.New(20)
|
||||
nonceMutex = &sync.Mutex{}
|
||||
)
|
||||
|
||||
func main() {
|
||||
config := new(Config)
|
||||
|
||||
config.ListenUri = flag.String("listen-url", "", "uri to listen on eg: localhost:2001. leave blank to set equal to proxy-url")
|
||||
config.ProxyUri = flag.String("proxy-url", "", "outer url of this host eg: http://secrets.example.com")
|
||||
config.OriginUri = flag.String("origin-url", "", "origin to proxy eg: http://localhost:2002")
|
||||
config.SsoSecret = flag.String("sso-secret", "", "SSO secret for origin")
|
||||
config.SsoUri = flag.String("sso-url", "", "SSO endpoint eg: http://discourse.forum.com")
|
||||
config.AllowAll = flag.Bool("allow-all", false, "allow all discourse users (default: admin users only)")
|
||||
config.BasicAuth = flag.String("basic-auth", "", "HTTP Basic authentication credentials to let through directly")
|
||||
config.UsernameHeader = flag.String("username-header", "Discourse-User-Name", "Request header to pass authenticated username into")
|
||||
config.GroupsHeader = flag.String("groups-header", "Discourse-User-Groups", "Request header to pass authenticated groups into")
|
||||
config.Timeout = flag.Int("timeout", 10, "Read/write timeout")
|
||||
config.Whitelist = flag.String("whitelist", "", "Path which does not require authorization")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
originUrl, err := url.Parse(*config.OriginUri)
|
||||
|
||||
{
|
||||
var err error
|
||||
config, err = ParseConfig()
|
||||
if err != nil {
|
||||
flag.Usage()
|
||||
log.Fatal("invalid origin url")
|
||||
usage(err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = url.Parse(*config.SsoUri)
|
||||
|
||||
if err != nil {
|
||||
flag.Usage()
|
||||
log.Fatal("invalid sso url, should point at Discourse site with enable sso")
|
||||
}
|
||||
|
||||
proxyUrl, err2 := url.Parse(*config.ProxyUri)
|
||||
|
||||
if err2 != nil {
|
||||
flag.Usage()
|
||||
log.Fatal("invalid proxy uri")
|
||||
}
|
||||
|
||||
if *config.ListenUri == "" {
|
||||
log.Println("Defaulting to listening on the proxy url")
|
||||
*config.ListenUri = proxyUrl.Host
|
||||
}
|
||||
|
||||
if *config.ProxyUri == "" || *config.OriginUri == "" || *config.SsoSecret == "" || *config.SsoUri == "" || *config.ListenUri == "" {
|
||||
flag.Usage()
|
||||
os.Exit(1)
|
||||
return
|
||||
}
|
||||
|
||||
if *config.BasicAuth != "" {
|
||||
log.Println("Enabling basic auth support")
|
||||
}
|
||||
|
||||
config.CookieSecret = uuid.New()
|
||||
|
||||
dnssrv := httpproxy.NewDNSSRVBackend(originUrl)
|
||||
dnssrv := httpproxy.NewDNSSRVBackend(config.OriginURL)
|
||||
go dnssrv.Lookup(context.Background(), 50*time.Second, 10*time.Second, 5*time.Minute)
|
||||
proxy := &httputil.ReverseProxy{Director: dnssrv.Director}
|
||||
|
||||
handler := authProxyHandler(proxy, config)
|
||||
|
||||
server := &http.Server{
|
||||
Addr: *config.ListenUri,
|
||||
Addr: config.ListenAddr,
|
||||
Handler: handler,
|
||||
ReadTimeout: time.Duration(*config.Timeout) * time.Second,
|
||||
WriteTimeout: time.Duration(*config.Timeout) * time.Second,
|
||||
ReadTimeout: config.Timeout,
|
||||
WriteTimeout: config.Timeout,
|
||||
MaxHeaderBytes: 1 << 20,
|
||||
}
|
||||
|
||||
|
@ -115,19 +56,18 @@ func main() {
|
|||
|
||||
func authProxyHandler(handler http.Handler, config *Config) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if checkWhitelist(handler, r, w, config) {
|
||||
if checkWhitelist(handler, r, w) {
|
||||
return
|
||||
}
|
||||
if checkAuthorizationHeader(handler, r, w, config) {
|
||||
if checkAuthorizationHeader(handler, r, w) {
|
||||
return
|
||||
}
|
||||
redirectIfNoCookie(handler, r, w, config)
|
||||
redirectIfNoCookie(handler, r, w)
|
||||
})
|
||||
}
|
||||
|
||||
func checkAuthorizationHeader(handler http.Handler, r *http.Request, w http.ResponseWriter, config *Config) bool {
|
||||
if *config.BasicAuth == "" {
|
||||
// Can't auth if we don't have anything to auth against
|
||||
func checkAuthorizationHeader(handler http.Handler, r *http.Request, w http.ResponseWriter) bool {
|
||||
if config.BasicAuth == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
|
@ -140,13 +80,13 @@ func checkAuthorizationHeader(handler http.Handler, r *http.Request, w http.Resp
|
|||
log.Println("Received request with basic auth creds")
|
||||
b_creds, _ := base64.StdEncoding.DecodeString(auth_header[6:])
|
||||
creds := string(b_creds)
|
||||
if creds == *config.BasicAuth {
|
||||
if creds == config.BasicAuth {
|
||||
colon_idx := strings.Index(creds, ":")
|
||||
if colon_idx == -1 {
|
||||
return false
|
||||
}
|
||||
username := creds[0:colon_idx]
|
||||
r.Header.Set(*config.UsernameHeader, username)
|
||||
r.Header.Set(config.UsernameHeader, username)
|
||||
r.Header.Del("Authorization")
|
||||
log.Printf("Accepted basic auth creds for %s\n", username)
|
||||
handler.ServeHTTP(w, r)
|
||||
|
@ -159,8 +99,8 @@ func checkAuthorizationHeader(handler http.Handler, r *http.Request, w http.Resp
|
|||
return false
|
||||
}
|
||||
|
||||
func checkWhitelist(handler http.Handler, r *http.Request, w http.ResponseWriter, config *Config) bool {
|
||||
if r.URL.Path == *(config.Whitelist) {
|
||||
func checkWhitelist(handler http.Handler, r *http.Request, w http.ResponseWriter) bool {
|
||||
if r.URL.Path == config.Whitelist {
|
||||
handler.ServeHTTP(w, r)
|
||||
return true
|
||||
}
|
||||
|
@ -168,7 +108,7 @@ func checkWhitelist(handler http.Handler, r *http.Request, w http.ResponseWriter
|
|||
return false
|
||||
}
|
||||
|
||||
func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWriter, config *Config) {
|
||||
func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWriter) {
|
||||
cookie, err := r.Cookie("__discourse_proxy")
|
||||
var username, groups string
|
||||
|
||||
|
@ -177,8 +117,8 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr
|
|||
}
|
||||
|
||||
if err == nil {
|
||||
r.Header.Set(*config.UsernameHeader, username)
|
||||
r.Header.Set(*config.GroupsHeader, groups)
|
||||
r.Header.Set(config.UsernameHeader, username)
|
||||
r.Header.Set(config.GroupsHeader, groups)
|
||||
handler.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
@ -188,7 +128,7 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr
|
|||
sig := query.Get("sig")
|
||||
|
||||
if len(sso) == 0 {
|
||||
url := *config.SsoUri + "/session/sso_provider?" + sso_payload(*config.SsoSecret, *config.ProxyUri, r.URL.String())
|
||||
url := config.SSOURLString + "/session/sso_provider?" + sso_payload(config.SSOSecret, config.ProxyURLString, r.URL.String())
|
||||
http.Redirect(w, r, url, 302)
|
||||
} else {
|
||||
decoded, _ := base64.StdEncoding.DecodeString(sso)
|
||||
|
@ -202,8 +142,8 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr
|
|||
|
||||
groupsArray := strings.Split(groups[0], ",")
|
||||
|
||||
if len(nonce) > 0 && len(admin) > 0 && len(username) > 0 && (admin[0] == "true" || *config.AllowAll) {
|
||||
returnUrl, err := getReturnUrl(*config.SsoSecret, sso, sig, nonce[0])
|
||||
if len(nonce) > 0 && len(admin) > 0 && len(username) > 0 && (admin[0] == "true" || config.AllowAll) {
|
||||
returnUrl, err := getReturnUrl(config.SSOSecret, sso, sig, nonce[0])
|
||||
|
||||
if err != nil {
|
||||
fmt.Fprintf(w, "Invalid request")
|
||||
|
|
Loading…
Reference in New Issue