Merge 319fb64c9a
into 8fe4a0b7c0
This commit is contained in:
commit
8622c7238f
|
@ -26,6 +26,7 @@ type Config struct {
|
|||
WhitelistPrefix string
|
||||
UsernameHeader string
|
||||
GroupsHeader string
|
||||
UserIDHeader string
|
||||
Timeout time.Duration
|
||||
SRVAbandonAfter time.Duration
|
||||
LogRequests bool
|
||||
|
@ -93,6 +94,7 @@ func ParseConfig() (*Config, error) {
|
|||
c.WhitelistPrefix = *rc.WhitelistPrefix
|
||||
c.UsernameHeader = *rc.UsernameHeader
|
||||
c.GroupsHeader = *rc.GroupsHeader
|
||||
c.UserIDHeader = *rc.UserIDHeader
|
||||
c.Timeout = time.Duration(*rc.Timeout) * time.Second
|
||||
if *rc.SRVAbandonAfter < 1 {
|
||||
c.SRVAbandonAfter = 0
|
||||
|
@ -119,6 +121,7 @@ type rawConfig struct {
|
|||
WhitelistPrefix *string
|
||||
UsernameHeader *string
|
||||
GroupsHeader *string
|
||||
UserIDHeader *string
|
||||
Timeout *int
|
||||
SRVAbandonAfter *int
|
||||
LogRequests *bool
|
||||
|
@ -138,6 +141,7 @@ func parseRawConfig() *rawConfig {
|
|||
WhitelistPrefix: flag.String("whitelist-prefix", "", "Prefix for paths which do not require authorization"),
|
||||
UsernameHeader: flag.String("username-header", "Discourse-User-Name", "Request header to pass authenticated username into"),
|
||||
GroupsHeader: flag.String("groups-header", "Discourse-User-Groups", "Request header to pass authenticated groups into"),
|
||||
UserIDHeader: flag.String("user-id-header", "", "Request header to pass authenticated user id into (default: disabled)"),
|
||||
Timeout: flag.Int("timeout", 10, "Read/write timeout (seconds)"),
|
||||
SRVAbandonAfter: flag.Int("dns-srv-abandon-after", 600, "Abandon DNS SRV discovery if origin RRs do not appear within this time (seconds). When negative, attempt SRV lookups indefinitely."),
|
||||
LogRequests: flag.Bool("log-requests", false, "Log all requests to standard error"),
|
||||
|
|
28
main.go
28
main.go
|
@ -161,15 +161,20 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr
|
|||
}
|
||||
|
||||
cookie, err := r.Cookie(cookieName)
|
||||
var username, groups string
|
||||
var username, groups, user_id string
|
||||
|
||||
if err == nil && cookie != nil {
|
||||
username, groups, err = parseCookie(cookie.Value, config.CookieSecret)
|
||||
username, groups, user_id, err = parseCookie(cookie.Value, config.CookieSecret)
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
r.Header.Set(config.UsernameHeader, username)
|
||||
r.Header.Set(config.GroupsHeader, groups)
|
||||
|
||||
if config.UserIDHeader != "" {
|
||||
r.Header.Set(config.UserIDHeader, user_id)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
@ -199,6 +204,7 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr
|
|||
admin = parsedQuery.Get("admin")
|
||||
nonce = parsedQuery.Get("nonce")
|
||||
groups = NewStringSet(parsedQuery.Get("groups"))
|
||||
user_id = parsedQuery.Get("external_id")
|
||||
)
|
||||
|
||||
if len(nonce) == 0 {
|
||||
|
@ -231,7 +237,8 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr
|
|||
// we have a valid auth
|
||||
expiration := time.Now().Add(reauthorizeInterval)
|
||||
|
||||
cookieData := strings.Join([]string{username, strings.Join(groups, "|")}, ",")
|
||||
cookieData := strings.Join([]string{username, strings.Join(groups, "|"), user_id}, ",")
|
||||
cookieData = url.QueryEscape(cookieData)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: cookieName,
|
||||
Value: signCookie(cookieData, config.CookieSecret),
|
||||
|
@ -270,10 +277,11 @@ func signCookie(data, secret string) string {
|
|||
return data + "," + computeHMAC(data, secret)
|
||||
}
|
||||
|
||||
func parseCookie(data, secret string) (username string, groups string, err error) {
|
||||
func parseCookie(data, secret string) (username string, groups string, user_id string, err error) {
|
||||
err = nil
|
||||
username = ""
|
||||
groups = ""
|
||||
user_id = ""
|
||||
|
||||
split := strings.Split(data, ",")
|
||||
|
||||
|
@ -291,8 +299,16 @@ func parseCookie(data, secret string) (username string, groups string, err error
|
|||
err = fmt.Errorf("Expecting signature to match")
|
||||
return
|
||||
} else {
|
||||
username = strings.Split(parsed, ",")[0]
|
||||
groups = strings.Split(parsed, ",")[1]
|
||||
parsed, err = url.QueryUnescape(parsed)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
splitted := strings.Split(parsed, ",")
|
||||
username = splitted[0]
|
||||
groups = splitted[1]
|
||||
if len(splitted) >= 3 {
|
||||
user_id = splitted[2]
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
|
|
29
main_test.go
29
main_test.go
|
@ -182,25 +182,46 @@ func TestAllowedAnon(t *testing.T) {
|
|||
|
||||
func TestInvalidSecretFails(t *testing.T) {
|
||||
signed := signCookie("user,group", "secretfoo")
|
||||
_, _, parseError := parseCookie(signed, "secretbar")
|
||||
_, _, _, parseError := parseCookie(signed, "secretbar")
|
||||
|
||||
assert.Error(t, parseError)
|
||||
}
|
||||
|
||||
func TestInvalidPayloadFails(t *testing.T) {
|
||||
signed := signCookie("user,group", "secretfoo") + "garbage"
|
||||
_, _, parseError := parseCookie(signed, "secretfoo")
|
||||
_, _, _, parseError := parseCookie(signed, "secretfoo")
|
||||
|
||||
assert.Error(t, parseError)
|
||||
}
|
||||
|
||||
func TestValidPayload(t *testing.T) {
|
||||
signed := signCookie("user,group", "secretfoo")
|
||||
username, group, parseError := parseCookie(signed, "secretfoo")
|
||||
signed := signCookie("user,group,1", "secretfoo")
|
||||
username, group, user_id, parseError := parseCookie(signed, "secretfoo")
|
||||
|
||||
assert.NoError(t, parseError)
|
||||
assert.Equal(t, username, "user")
|
||||
assert.Equal(t, group, "group")
|
||||
assert.Equal(t, user_id, "1")
|
||||
}
|
||||
|
||||
func TestValidPayloadWithoutUserID(t *testing.T) {
|
||||
signed := signCookie("user,group", "secretfoo")
|
||||
username, group, user_id, parseError := parseCookie(signed, "secretfoo")
|
||||
|
||||
assert.NoError(t, parseError)
|
||||
assert.Equal(t, username, "user")
|
||||
assert.Equal(t, group, "group")
|
||||
assert.Equal(t, user_id, "")
|
||||
}
|
||||
|
||||
func TestValidPayloadWithUnicode(t *testing.T) {
|
||||
signed := signCookie("用户名,群组,2", "secretfoo")
|
||||
username, group, user_id, parseError := parseCookie(signed, "secretfoo")
|
||||
|
||||
assert.NoError(t, parseError)
|
||||
assert.Equal(t, username, "用户名")
|
||||
assert.Equal(t, group, "群组")
|
||||
assert.Equal(t, user_id, "2")
|
||||
}
|
||||
|
||||
func TestNotWhitelistedPath(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue