diff --git a/Dockerfile b/Dockerfile index 1112fb7..1eebaf0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,7 +2,7 @@ FROM golang:alpine as builder RUN apk add git WORKDIR /go/src/github.com/discourse/discourse-auth-proxy COPY internal ./internal/ -COPY main.go . +COPY *.go ./ RUN go get && go build FROM alpine:latest diff --git a/config.go b/config.go new file mode 100644 index 0000000..498e183 --- /dev/null +++ b/config.go @@ -0,0 +1,133 @@ +package main + +import ( + "fmt" + "net/url" + "os" + "time" + + "github.com/namsral/flag" + "github.com/pborman/uuid" +) + +type Config struct { + OriginURL *url.URL + ProxyURL *url.URL + ProxyURLString string // memoised - derives from ProxyURL + ListenAddr string + SSOURL *url.URL + SSOURLString string // memoised - derives from SSOURL + SSOSecret string + CookieSecret string + AllowAll bool + BasicAuth string + Whitelist string + UsernameHeader string + GroupsHeader string + Timeout time.Duration +} + +func ParseConfig() (*Config, error) { + missing := func(name string) error { + return fmt.Errorf("missing mandatory flag: %s", name) + } + + rc := parseRawConfig() + + if *rc.OriginURL == "" { + return nil, missing("origin-url") + } + if *rc.ProxyURL == "" { + return nil, missing("proxy-url") + } + if *rc.SSOURL == "" { + return nil, missing("sso-url") + } + if *rc.SSOSecret == "" { + return nil, missing("sso-secret") + } + + c := &Config{} + { + u, err := url.Parse(*rc.OriginURL) + if err != nil { + return nil, fmt.Errorf("invalid origin URL: %s", rc.OriginURL) + } + c.OriginURL = u + } + { + u, err := url.Parse(*rc.ProxyURL) + if err != nil { + return nil, fmt.Errorf("invalid proxy URL: %s", rc.ProxyURL) + } + c.ProxyURL = u + c.ProxyURLString = u.String() + } + { + u, err := url.Parse(*rc.SSOURL) + if err != nil { + return nil, fmt.Errorf("invalid SSO URL - should point at Discourse site with enable sso: %s", rc.SSOURL) + } + c.SSOURL = u + c.SSOURLString = u.String() + } + + if *rc.ListenURL == "" { + c.ListenAddr = c.ProxyURL.Host + } else { + c.ListenAddr = *rc.ListenURL + } + + c.SSOSecret = *rc.SSOSecret + c.AllowAll = *rc.AllowAll + c.BasicAuth = *rc.BasicAuth + c.Whitelist = *rc.Whitelist + c.UsernameHeader = *rc.UsernameHeader + c.GroupsHeader = *rc.GroupsHeader + c.Timeout = time.Duration(*rc.Timeout) * time.Second + + c.CookieSecret = uuid.New() + + return c, nil +} + +type rawConfig struct { + OriginURL *string + ProxyURL *string + ListenURL *string // Not actually a URL. This is a TCP socket address, i.e.: 'host:port'. + SSOURL *string + SSOSecret *string + AllowAll *bool + BasicAuth *string + Whitelist *string + UsernameHeader *string + GroupsHeader *string + Timeout *int +} + +func parseRawConfig() *rawConfig { + c := &rawConfig{ + OriginURL: flag.String("origin-url", "", "origin to proxy eg: http://localhost:2002"), + ProxyURL: flag.String("proxy-url", "", "outer url of this host eg: http://secrets.example.com"), + ListenURL: flag.String("listen-url", "", "url to listen on eg: localhost:2001. leave blank to set equal to proxy-url"), + SSOURL: flag.String("sso-url", "", "SSO endpoint eg: http://discourse.forum.com"), + SSOSecret: flag.String("sso-secret", "", "SSO secret for origin"), + AllowAll: flag.Bool("allow-all", false, "allow all discourse users (default: admin users only)"), + BasicAuth: flag.String("basic-auth", "", "HTTP Basic authentication credentials to let through directly"), + Whitelist: flag.String("whitelist", "", "Path which does not require authorization"), + UsernameHeader: flag.String("username-header", "Discourse-User-Name", "Request header to pass authenticated username into"), + GroupsHeader: flag.String("groups-header", "Discourse-User-Groups", "Request header to pass authenticated groups into"), + Timeout: flag.Int("timeout", 10, "Read/write timeout"), + } + flag.Parse() + return c +} + +func usage(err error) { + flag.Usage() + if err != nil { + fmt.Fprintln(os.Stderr, "") + fmt.Fprintln(os.Stderr, err) + } + os.Exit(2) +} diff --git a/main.go b/main.go index 4e1b19f..b5826c6 100644 --- a/main.go +++ b/main.go @@ -11,102 +11,43 @@ import ( "net/http" "net/http/httputil" "net/url" - "os" "strings" "sync" "time" "github.com/golang/groupcache/lru" - "github.com/namsral/flag" "github.com/pborman/uuid" "github.com/discourse/discourse-auth-proxy/internal/httpproxy" ) -var nonceCache = lru.New(20) -var nonceMutex = &sync.Mutex{} +var ( + config *Config -type Config struct { - ListenUri *string - ProxyUri *string - OriginUri *string - SsoSecret *string - SsoUri *string - BasicAuth *string - UsernameHeader *string - GroupsHeader *string - Timeout *int - CookieSecret string - AllowAll *bool - Whitelist *string -} + nonceCache = lru.New(20) + nonceMutex = &sync.Mutex{} +) func main() { - config := new(Config) - - config.ListenUri = flag.String("listen-url", "", "uri to listen on eg: localhost:2001. leave blank to set equal to proxy-url") - config.ProxyUri = flag.String("proxy-url", "", "outer url of this host eg: http://secrets.example.com") - config.OriginUri = flag.String("origin-url", "", "origin to proxy eg: http://localhost:2002") - config.SsoSecret = flag.String("sso-secret", "", "SSO secret for origin") - config.SsoUri = flag.String("sso-url", "", "SSO endpoint eg: http://discourse.forum.com") - config.AllowAll = flag.Bool("allow-all", false, "allow all discourse users (default: admin users only)") - config.BasicAuth = flag.String("basic-auth", "", "HTTP Basic authentication credentials to let through directly") - config.UsernameHeader = flag.String("username-header", "Discourse-User-Name", "Request header to pass authenticated username into") - config.GroupsHeader = flag.String("groups-header", "Discourse-User-Groups", "Request header to pass authenticated groups into") - config.Timeout = flag.Int("timeout", 10, "Read/write timeout") - config.Whitelist = flag.String("whitelist", "", "Path which does not require authorization") - - flag.Parse() - - originUrl, err := url.Parse(*config.OriginUri) - - if err != nil { - flag.Usage() - log.Fatal("invalid origin url") + { + var err error + config, err = ParseConfig() + if err != nil { + usage(err) + } } - _, err = url.Parse(*config.SsoUri) - - if err != nil { - flag.Usage() - log.Fatal("invalid sso url, should point at Discourse site with enable sso") - } - - proxyUrl, err2 := url.Parse(*config.ProxyUri) - - if err2 != nil { - flag.Usage() - log.Fatal("invalid proxy uri") - } - - if *config.ListenUri == "" { - log.Println("Defaulting to listening on the proxy url") - *config.ListenUri = proxyUrl.Host - } - - if *config.ProxyUri == "" || *config.OriginUri == "" || *config.SsoSecret == "" || *config.SsoUri == "" || *config.ListenUri == "" { - flag.Usage() - os.Exit(1) - return - } - - if *config.BasicAuth != "" { - log.Println("Enabling basic auth support") - } - - config.CookieSecret = uuid.New() - - dnssrv := httpproxy.NewDNSSRVBackend(originUrl) + dnssrv := httpproxy.NewDNSSRVBackend(config.OriginURL) go dnssrv.Lookup(context.Background(), 50*time.Second, 10*time.Second, 5*time.Minute) proxy := &httputil.ReverseProxy{Director: dnssrv.Director} handler := authProxyHandler(proxy, config) server := &http.Server{ - Addr: *config.ListenUri, + Addr: config.ListenAddr, Handler: handler, - ReadTimeout: time.Duration(*config.Timeout) * time.Second, - WriteTimeout: time.Duration(*config.Timeout) * time.Second, + ReadTimeout: config.Timeout, + WriteTimeout: config.Timeout, MaxHeaderBytes: 1 << 20, } @@ -115,19 +56,18 @@ func main() { func authProxyHandler(handler http.Handler, config *Config) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if checkWhitelist(handler, r, w, config) { + if checkWhitelist(handler, r, w) { return } - if checkAuthorizationHeader(handler, r, w, config) { + if checkAuthorizationHeader(handler, r, w) { return } - redirectIfNoCookie(handler, r, w, config) + redirectIfNoCookie(handler, r, w) }) } -func checkAuthorizationHeader(handler http.Handler, r *http.Request, w http.ResponseWriter, config *Config) bool { - if *config.BasicAuth == "" { - // Can't auth if we don't have anything to auth against +func checkAuthorizationHeader(handler http.Handler, r *http.Request, w http.ResponseWriter) bool { + if config.BasicAuth == "" { return false } @@ -140,13 +80,13 @@ func checkAuthorizationHeader(handler http.Handler, r *http.Request, w http.Resp log.Println("Received request with basic auth creds") b_creds, _ := base64.StdEncoding.DecodeString(auth_header[6:]) creds := string(b_creds) - if creds == *config.BasicAuth { + if creds == config.BasicAuth { colon_idx := strings.Index(creds, ":") if colon_idx == -1 { return false } username := creds[0:colon_idx] - r.Header.Set(*config.UsernameHeader, username) + r.Header.Set(config.UsernameHeader, username) r.Header.Del("Authorization") log.Printf("Accepted basic auth creds for %s\n", username) handler.ServeHTTP(w, r) @@ -159,8 +99,8 @@ func checkAuthorizationHeader(handler http.Handler, r *http.Request, w http.Resp return false } -func checkWhitelist(handler http.Handler, r *http.Request, w http.ResponseWriter, config *Config) bool { - if r.URL.Path == *(config.Whitelist) { +func checkWhitelist(handler http.Handler, r *http.Request, w http.ResponseWriter) bool { + if r.URL.Path == config.Whitelist { handler.ServeHTTP(w, r) return true } @@ -168,7 +108,7 @@ func checkWhitelist(handler http.Handler, r *http.Request, w http.ResponseWriter return false } -func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWriter, config *Config) { +func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWriter) { cookie, err := r.Cookie("__discourse_proxy") var username, groups string @@ -177,8 +117,8 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr } if err == nil { - r.Header.Set(*config.UsernameHeader, username) - r.Header.Set(*config.GroupsHeader, groups) + r.Header.Set(config.UsernameHeader, username) + r.Header.Set(config.GroupsHeader, groups) handler.ServeHTTP(w, r) return } @@ -188,7 +128,7 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr sig := query.Get("sig") if len(sso) == 0 { - url := *config.SsoUri + "/session/sso_provider?" + sso_payload(*config.SsoSecret, *config.ProxyUri, r.URL.String()) + url := config.SSOURLString + "/session/sso_provider?" + sso_payload(config.SSOSecret, config.ProxyURLString, r.URL.String()) http.Redirect(w, r, url, 302) } else { decoded, _ := base64.StdEncoding.DecodeString(sso) @@ -202,8 +142,8 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr groupsArray := strings.Split(groups[0], ",") - if len(nonce) > 0 && len(admin) > 0 && len(username) > 0 && (admin[0] == "true" || *config.AllowAll) { - returnUrl, err := getReturnUrl(*config.SsoSecret, sso, sig, nonce[0]) + if len(nonce) > 0 && len(admin) > 0 && len(username) > 0 && (admin[0] == "true" || config.AllowAll) { + returnUrl, err := getReturnUrl(config.SSOSecret, sso, sig, nonce[0]) if err != nil { fmt.Fprintf(w, "Invalid request")