Merge 319fb64c9a
into 8fe4a0b7c0
This commit is contained in:
commit
8622c7238f
|
@ -26,6 +26,7 @@ type Config struct {
|
||||||
WhitelistPrefix string
|
WhitelistPrefix string
|
||||||
UsernameHeader string
|
UsernameHeader string
|
||||||
GroupsHeader string
|
GroupsHeader string
|
||||||
|
UserIDHeader string
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
SRVAbandonAfter time.Duration
|
SRVAbandonAfter time.Duration
|
||||||
LogRequests bool
|
LogRequests bool
|
||||||
|
@ -93,6 +94,7 @@ func ParseConfig() (*Config, error) {
|
||||||
c.WhitelistPrefix = *rc.WhitelistPrefix
|
c.WhitelistPrefix = *rc.WhitelistPrefix
|
||||||
c.UsernameHeader = *rc.UsernameHeader
|
c.UsernameHeader = *rc.UsernameHeader
|
||||||
c.GroupsHeader = *rc.GroupsHeader
|
c.GroupsHeader = *rc.GroupsHeader
|
||||||
|
c.UserIDHeader = *rc.UserIDHeader
|
||||||
c.Timeout = time.Duration(*rc.Timeout) * time.Second
|
c.Timeout = time.Duration(*rc.Timeout) * time.Second
|
||||||
if *rc.SRVAbandonAfter < 1 {
|
if *rc.SRVAbandonAfter < 1 {
|
||||||
c.SRVAbandonAfter = 0
|
c.SRVAbandonAfter = 0
|
||||||
|
@ -119,6 +121,7 @@ type rawConfig struct {
|
||||||
WhitelistPrefix *string
|
WhitelistPrefix *string
|
||||||
UsernameHeader *string
|
UsernameHeader *string
|
||||||
GroupsHeader *string
|
GroupsHeader *string
|
||||||
|
UserIDHeader *string
|
||||||
Timeout *int
|
Timeout *int
|
||||||
SRVAbandonAfter *int
|
SRVAbandonAfter *int
|
||||||
LogRequests *bool
|
LogRequests *bool
|
||||||
|
@ -138,6 +141,7 @@ func parseRawConfig() *rawConfig {
|
||||||
WhitelistPrefix: flag.String("whitelist-prefix", "", "Prefix for paths which do not require authorization"),
|
WhitelistPrefix: flag.String("whitelist-prefix", "", "Prefix for paths which do not require authorization"),
|
||||||
UsernameHeader: flag.String("username-header", "Discourse-User-Name", "Request header to pass authenticated username into"),
|
UsernameHeader: flag.String("username-header", "Discourse-User-Name", "Request header to pass authenticated username into"),
|
||||||
GroupsHeader: flag.String("groups-header", "Discourse-User-Groups", "Request header to pass authenticated groups into"),
|
GroupsHeader: flag.String("groups-header", "Discourse-User-Groups", "Request header to pass authenticated groups into"),
|
||||||
|
UserIDHeader: flag.String("user-id-header", "", "Request header to pass authenticated user id into (default: disabled)"),
|
||||||
Timeout: flag.Int("timeout", 10, "Read/write timeout (seconds)"),
|
Timeout: flag.Int("timeout", 10, "Read/write timeout (seconds)"),
|
||||||
SRVAbandonAfter: flag.Int("dns-srv-abandon-after", 600, "Abandon DNS SRV discovery if origin RRs do not appear within this time (seconds). When negative, attempt SRV lookups indefinitely."),
|
SRVAbandonAfter: flag.Int("dns-srv-abandon-after", 600, "Abandon DNS SRV discovery if origin RRs do not appear within this time (seconds). When negative, attempt SRV lookups indefinitely."),
|
||||||
LogRequests: flag.Bool("log-requests", false, "Log all requests to standard error"),
|
LogRequests: flag.Bool("log-requests", false, "Log all requests to standard error"),
|
||||||
|
|
28
main.go
28
main.go
|
@ -161,15 +161,20 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr
|
||||||
}
|
}
|
||||||
|
|
||||||
cookie, err := r.Cookie(cookieName)
|
cookie, err := r.Cookie(cookieName)
|
||||||
var username, groups string
|
var username, groups, user_id string
|
||||||
|
|
||||||
if err == nil && cookie != nil {
|
if err == nil && cookie != nil {
|
||||||
username, groups, err = parseCookie(cookie.Value, config.CookieSecret)
|
username, groups, user_id, err = parseCookie(cookie.Value, config.CookieSecret)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
r.Header.Set(config.UsernameHeader, username)
|
r.Header.Set(config.UsernameHeader, username)
|
||||||
r.Header.Set(config.GroupsHeader, groups)
|
r.Header.Set(config.GroupsHeader, groups)
|
||||||
|
|
||||||
|
if config.UserIDHeader != "" {
|
||||||
|
r.Header.Set(config.UserIDHeader, user_id)
|
||||||
|
}
|
||||||
|
|
||||||
handler.ServeHTTP(w, r)
|
handler.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -199,6 +204,7 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr
|
||||||
admin = parsedQuery.Get("admin")
|
admin = parsedQuery.Get("admin")
|
||||||
nonce = parsedQuery.Get("nonce")
|
nonce = parsedQuery.Get("nonce")
|
||||||
groups = NewStringSet(parsedQuery.Get("groups"))
|
groups = NewStringSet(parsedQuery.Get("groups"))
|
||||||
|
user_id = parsedQuery.Get("external_id")
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(nonce) == 0 {
|
if len(nonce) == 0 {
|
||||||
|
@ -231,7 +237,8 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr
|
||||||
// we have a valid auth
|
// we have a valid auth
|
||||||
expiration := time.Now().Add(reauthorizeInterval)
|
expiration := time.Now().Add(reauthorizeInterval)
|
||||||
|
|
||||||
cookieData := strings.Join([]string{username, strings.Join(groups, "|")}, ",")
|
cookieData := strings.Join([]string{username, strings.Join(groups, "|"), user_id}, ",")
|
||||||
|
cookieData = url.QueryEscape(cookieData)
|
||||||
http.SetCookie(w, &http.Cookie{
|
http.SetCookie(w, &http.Cookie{
|
||||||
Name: cookieName,
|
Name: cookieName,
|
||||||
Value: signCookie(cookieData, config.CookieSecret),
|
Value: signCookie(cookieData, config.CookieSecret),
|
||||||
|
@ -270,10 +277,11 @@ func signCookie(data, secret string) string {
|
||||||
return data + "," + computeHMAC(data, secret)
|
return data + "," + computeHMAC(data, secret)
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseCookie(data, secret string) (username string, groups string, err error) {
|
func parseCookie(data, secret string) (username string, groups string, user_id string, err error) {
|
||||||
err = nil
|
err = nil
|
||||||
username = ""
|
username = ""
|
||||||
groups = ""
|
groups = ""
|
||||||
|
user_id = ""
|
||||||
|
|
||||||
split := strings.Split(data, ",")
|
split := strings.Split(data, ",")
|
||||||
|
|
||||||
|
@ -291,8 +299,16 @@ func parseCookie(data, secret string) (username string, groups string, err error
|
||||||
err = fmt.Errorf("Expecting signature to match")
|
err = fmt.Errorf("Expecting signature to match")
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
username = strings.Split(parsed, ",")[0]
|
parsed, err = url.QueryUnescape(parsed)
|
||||||
groups = strings.Split(parsed, ",")[1]
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
splitted := strings.Split(parsed, ",")
|
||||||
|
username = splitted[0]
|
||||||
|
groups = splitted[1]
|
||||||
|
if len(splitted) >= 3 {
|
||||||
|
user_id = splitted[2]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
29
main_test.go
29
main_test.go
|
@ -182,25 +182,46 @@ func TestAllowedAnon(t *testing.T) {
|
||||||
|
|
||||||
func TestInvalidSecretFails(t *testing.T) {
|
func TestInvalidSecretFails(t *testing.T) {
|
||||||
signed := signCookie("user,group", "secretfoo")
|
signed := signCookie("user,group", "secretfoo")
|
||||||
_, _, parseError := parseCookie(signed, "secretbar")
|
_, _, _, parseError := parseCookie(signed, "secretbar")
|
||||||
|
|
||||||
assert.Error(t, parseError)
|
assert.Error(t, parseError)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInvalidPayloadFails(t *testing.T) {
|
func TestInvalidPayloadFails(t *testing.T) {
|
||||||
signed := signCookie("user,group", "secretfoo") + "garbage"
|
signed := signCookie("user,group", "secretfoo") + "garbage"
|
||||||
_, _, parseError := parseCookie(signed, "secretfoo")
|
_, _, _, parseError := parseCookie(signed, "secretfoo")
|
||||||
|
|
||||||
assert.Error(t, parseError)
|
assert.Error(t, parseError)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidPayload(t *testing.T) {
|
func TestValidPayload(t *testing.T) {
|
||||||
signed := signCookie("user,group", "secretfoo")
|
signed := signCookie("user,group,1", "secretfoo")
|
||||||
username, group, parseError := parseCookie(signed, "secretfoo")
|
username, group, user_id, parseError := parseCookie(signed, "secretfoo")
|
||||||
|
|
||||||
assert.NoError(t, parseError)
|
assert.NoError(t, parseError)
|
||||||
assert.Equal(t, username, "user")
|
assert.Equal(t, username, "user")
|
||||||
assert.Equal(t, group, "group")
|
assert.Equal(t, group, "group")
|
||||||
|
assert.Equal(t, user_id, "1")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidPayloadWithoutUserID(t *testing.T) {
|
||||||
|
signed := signCookie("user,group", "secretfoo")
|
||||||
|
username, group, user_id, parseError := parseCookie(signed, "secretfoo")
|
||||||
|
|
||||||
|
assert.NoError(t, parseError)
|
||||||
|
assert.Equal(t, username, "user")
|
||||||
|
assert.Equal(t, group, "group")
|
||||||
|
assert.Equal(t, user_id, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidPayloadWithUnicode(t *testing.T) {
|
||||||
|
signed := signCookie("用户名,群组,2", "secretfoo")
|
||||||
|
username, group, user_id, parseError := parseCookie(signed, "secretfoo")
|
||||||
|
|
||||||
|
assert.NoError(t, parseError)
|
||||||
|
assert.Equal(t, username, "用户名")
|
||||||
|
assert.Equal(t, group, "群组")
|
||||||
|
assert.Equal(t, user_id, "2")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNotWhitelistedPath(t *testing.T) {
|
func TestNotWhitelistedPath(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue