diff --git a/config.go b/config.go index dd7811e..cb2a096 100644 --- a/config.go +++ b/config.go @@ -26,6 +26,7 @@ type Config struct { WhitelistPrefix string UsernameHeader string GroupsHeader string + UserIDHeader string Timeout time.Duration SRVAbandonAfter time.Duration LogRequests bool @@ -93,6 +94,7 @@ func ParseConfig() (*Config, error) { c.WhitelistPrefix = *rc.WhitelistPrefix c.UsernameHeader = *rc.UsernameHeader c.GroupsHeader = *rc.GroupsHeader + c.UserIDHeader = *rc.UserIDHeader c.Timeout = time.Duration(*rc.Timeout) * time.Second if *rc.SRVAbandonAfter < 1 { c.SRVAbandonAfter = 0 @@ -119,6 +121,7 @@ type rawConfig struct { WhitelistPrefix *string UsernameHeader *string GroupsHeader *string + UserIDHeader *string Timeout *int SRVAbandonAfter *int LogRequests *bool @@ -138,6 +141,7 @@ func parseRawConfig() *rawConfig { WhitelistPrefix: flag.String("whitelist-prefix", "", "Prefix for paths which do not require authorization"), UsernameHeader: flag.String("username-header", "Discourse-User-Name", "Request header to pass authenticated username into"), GroupsHeader: flag.String("groups-header", "Discourse-User-Groups", "Request header to pass authenticated groups into"), + UserIDHeader: flag.String("user-id-header", "", "Request header to pass authenticated user id into (default: disabled)"), Timeout: flag.Int("timeout", 10, "Read/write timeout (seconds)"), SRVAbandonAfter: flag.Int("dns-srv-abandon-after", 600, "Abandon DNS SRV discovery if origin RRs do not appear within this time (seconds). When negative, attempt SRV lookups indefinitely."), LogRequests: flag.Bool("log-requests", false, "Log all requests to standard error"), diff --git a/main.go b/main.go index d625d14..0680027 100644 --- a/main.go +++ b/main.go @@ -161,15 +161,20 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr } cookie, err := r.Cookie(cookieName) - var username, groups string + var username, groups, user_id string if err == nil && cookie != nil { - username, groups, err = parseCookie(cookie.Value, config.CookieSecret) + username, groups, user_id, err = parseCookie(cookie.Value, config.CookieSecret) } if err == nil { r.Header.Set(config.UsernameHeader, username) r.Header.Set(config.GroupsHeader, groups) + + if config.UserIDHeader != "" { + r.Header.Set(config.UserIDHeader, user_id) + } + handler.ServeHTTP(w, r) return } @@ -199,6 +204,7 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr admin = parsedQuery.Get("admin") nonce = parsedQuery.Get("nonce") groups = NewStringSet(parsedQuery.Get("groups")) + user_id = parsedQuery.Get("external_id") ) if len(nonce) == 0 { @@ -231,7 +237,7 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr // we have a valid auth expiration := time.Now().Add(reauthorizeInterval) - cookieData := strings.Join([]string{username, strings.Join(groups, "|")}, ",") + cookieData := strings.Join([]string{username, strings.Join(groups, "|"), user_id}, ",") http.SetCookie(w, &http.Cookie{ Name: cookieName, Value: signCookie(cookieData, config.CookieSecret), @@ -270,10 +276,11 @@ func signCookie(data, secret string) string { return data + "," + computeHMAC(data, secret) } -func parseCookie(data, secret string) (username string, groups string, err error) { +func parseCookie(data, secret string) (username string, groups string, user_id string, err error) { err = nil username = "" groups = "" + user_id = "" split := strings.Split(data, ",") @@ -291,8 +298,12 @@ func parseCookie(data, secret string) (username string, groups string, err error err = fmt.Errorf("Expecting signature to match") return } else { - username = strings.Split(parsed, ",")[0] - groups = strings.Split(parsed, ",")[1] + splitted := strings.Split(parsed, ",") + username = splitted[0] + groups = splitted[1] + if len(splitted) >= 3 { + user_id = splitted[2] + } } return diff --git a/main_test.go b/main_test.go index 86d0794..4e90122 100644 --- a/main_test.go +++ b/main_test.go @@ -182,25 +182,36 @@ func TestAllowedAnon(t *testing.T) { func TestInvalidSecretFails(t *testing.T) { signed := signCookie("user,group", "secretfoo") - _, _, parseError := parseCookie(signed, "secretbar") + _, _, _, parseError := parseCookie(signed, "secretbar") assert.Error(t, parseError) } func TestInvalidPayloadFails(t *testing.T) { signed := signCookie("user,group", "secretfoo") + "garbage" - _, _, parseError := parseCookie(signed, "secretfoo") + _, _, _, parseError := parseCookie(signed, "secretfoo") assert.Error(t, parseError) } func TestValidPayload(t *testing.T) { - signed := signCookie("user,group", "secretfoo") - username, group, parseError := parseCookie(signed, "secretfoo") + signed := signCookie("user,group,1", "secretfoo") + username, group, user_id, parseError := parseCookie(signed, "secretfoo") assert.NoError(t, parseError) assert.Equal(t, username, "user") assert.Equal(t, group, "group") + assert.Equal(t, user_id, "1") +} + +func TestValidPayloadWithoutUserID(t *testing.T) { + signed := signCookie("user,group", "secretfoo") + username, group, user_id, parseError := parseCookie(signed, "secretfoo") + + assert.NoError(t, parseError) + assert.Equal(t, username, "user") + assert.Equal(t, group, "group") + assert.Equal(t, user_id, "") } func TestNotWhitelistedPath(t *testing.T) {