This commit is contained in:
Misaka 0x4e21 2025-03-08 06:43:39 +08:00 committed by GitHub
commit 8622c7238f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 51 additions and 10 deletions

View File

@ -26,6 +26,7 @@ type Config struct {
WhitelistPrefix string WhitelistPrefix string
UsernameHeader string UsernameHeader string
GroupsHeader string GroupsHeader string
UserIDHeader string
Timeout time.Duration Timeout time.Duration
SRVAbandonAfter time.Duration SRVAbandonAfter time.Duration
LogRequests bool LogRequests bool
@ -93,6 +94,7 @@ func ParseConfig() (*Config, error) {
c.WhitelistPrefix = *rc.WhitelistPrefix c.WhitelistPrefix = *rc.WhitelistPrefix
c.UsernameHeader = *rc.UsernameHeader c.UsernameHeader = *rc.UsernameHeader
c.GroupsHeader = *rc.GroupsHeader c.GroupsHeader = *rc.GroupsHeader
c.UserIDHeader = *rc.UserIDHeader
c.Timeout = time.Duration(*rc.Timeout) * time.Second c.Timeout = time.Duration(*rc.Timeout) * time.Second
if *rc.SRVAbandonAfter < 1 { if *rc.SRVAbandonAfter < 1 {
c.SRVAbandonAfter = 0 c.SRVAbandonAfter = 0
@ -119,6 +121,7 @@ type rawConfig struct {
WhitelistPrefix *string WhitelistPrefix *string
UsernameHeader *string UsernameHeader *string
GroupsHeader *string GroupsHeader *string
UserIDHeader *string
Timeout *int Timeout *int
SRVAbandonAfter *int SRVAbandonAfter *int
LogRequests *bool LogRequests *bool
@ -138,6 +141,7 @@ func parseRawConfig() *rawConfig {
WhitelistPrefix: flag.String("whitelist-prefix", "", "Prefix for paths which do not require authorization"), WhitelistPrefix: flag.String("whitelist-prefix", "", "Prefix for paths which do not require authorization"),
UsernameHeader: flag.String("username-header", "Discourse-User-Name", "Request header to pass authenticated username into"), UsernameHeader: flag.String("username-header", "Discourse-User-Name", "Request header to pass authenticated username into"),
GroupsHeader: flag.String("groups-header", "Discourse-User-Groups", "Request header to pass authenticated groups into"), GroupsHeader: flag.String("groups-header", "Discourse-User-Groups", "Request header to pass authenticated groups into"),
UserIDHeader: flag.String("user-id-header", "", "Request header to pass authenticated user id into (default: disabled)"),
Timeout: flag.Int("timeout", 10, "Read/write timeout (seconds)"), Timeout: flag.Int("timeout", 10, "Read/write timeout (seconds)"),
SRVAbandonAfter: flag.Int("dns-srv-abandon-after", 600, "Abandon DNS SRV discovery if origin RRs do not appear within this time (seconds). When negative, attempt SRV lookups indefinitely."), SRVAbandonAfter: flag.Int("dns-srv-abandon-after", 600, "Abandon DNS SRV discovery if origin RRs do not appear within this time (seconds). When negative, attempt SRV lookups indefinitely."),
LogRequests: flag.Bool("log-requests", false, "Log all requests to standard error"), LogRequests: flag.Bool("log-requests", false, "Log all requests to standard error"),

28
main.go
View File

@ -161,15 +161,20 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr
} }
cookie, err := r.Cookie(cookieName) cookie, err := r.Cookie(cookieName)
var username, groups string var username, groups, user_id string
if err == nil && cookie != nil { if err == nil && cookie != nil {
username, groups, err = parseCookie(cookie.Value, config.CookieSecret) username, groups, user_id, err = parseCookie(cookie.Value, config.CookieSecret)
} }
if err == nil { if err == nil {
r.Header.Set(config.UsernameHeader, username) r.Header.Set(config.UsernameHeader, username)
r.Header.Set(config.GroupsHeader, groups) r.Header.Set(config.GroupsHeader, groups)
if config.UserIDHeader != "" {
r.Header.Set(config.UserIDHeader, user_id)
}
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
return return
} }
@ -199,6 +204,7 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr
admin = parsedQuery.Get("admin") admin = parsedQuery.Get("admin")
nonce = parsedQuery.Get("nonce") nonce = parsedQuery.Get("nonce")
groups = NewStringSet(parsedQuery.Get("groups")) groups = NewStringSet(parsedQuery.Get("groups"))
user_id = parsedQuery.Get("external_id")
) )
if len(nonce) == 0 { if len(nonce) == 0 {
@ -231,7 +237,8 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr
// we have a valid auth // we have a valid auth
expiration := time.Now().Add(reauthorizeInterval) expiration := time.Now().Add(reauthorizeInterval)
cookieData := strings.Join([]string{username, strings.Join(groups, "|")}, ",") cookieData := strings.Join([]string{username, strings.Join(groups, "|"), user_id}, ",")
cookieData = url.QueryEscape(cookieData)
http.SetCookie(w, &http.Cookie{ http.SetCookie(w, &http.Cookie{
Name: cookieName, Name: cookieName,
Value: signCookie(cookieData, config.CookieSecret), Value: signCookie(cookieData, config.CookieSecret),
@ -270,10 +277,11 @@ func signCookie(data, secret string) string {
return data + "," + computeHMAC(data, secret) return data + "," + computeHMAC(data, secret)
} }
func parseCookie(data, secret string) (username string, groups string, err error) { func parseCookie(data, secret string) (username string, groups string, user_id string, err error) {
err = nil err = nil
username = "" username = ""
groups = "" groups = ""
user_id = ""
split := strings.Split(data, ",") split := strings.Split(data, ",")
@ -291,8 +299,16 @@ func parseCookie(data, secret string) (username string, groups string, err error
err = fmt.Errorf("Expecting signature to match") err = fmt.Errorf("Expecting signature to match")
return return
} else { } else {
username = strings.Split(parsed, ",")[0] parsed, err = url.QueryUnescape(parsed)
groups = strings.Split(parsed, ",")[1] if err != nil {
return
}
splitted := strings.Split(parsed, ",")
username = splitted[0]
groups = splitted[1]
if len(splitted) >= 3 {
user_id = splitted[2]
}
} }
return return

View File

@ -182,25 +182,46 @@ func TestAllowedAnon(t *testing.T) {
func TestInvalidSecretFails(t *testing.T) { func TestInvalidSecretFails(t *testing.T) {
signed := signCookie("user,group", "secretfoo") signed := signCookie("user,group", "secretfoo")
_, _, parseError := parseCookie(signed, "secretbar") _, _, _, parseError := parseCookie(signed, "secretbar")
assert.Error(t, parseError) assert.Error(t, parseError)
} }
func TestInvalidPayloadFails(t *testing.T) { func TestInvalidPayloadFails(t *testing.T) {
signed := signCookie("user,group", "secretfoo") + "garbage" signed := signCookie("user,group", "secretfoo") + "garbage"
_, _, parseError := parseCookie(signed, "secretfoo") _, _, _, parseError := parseCookie(signed, "secretfoo")
assert.Error(t, parseError) assert.Error(t, parseError)
} }
func TestValidPayload(t *testing.T) { func TestValidPayload(t *testing.T) {
signed := signCookie("user,group", "secretfoo") signed := signCookie("user,group,1", "secretfoo")
username, group, parseError := parseCookie(signed, "secretfoo") username, group, user_id, parseError := parseCookie(signed, "secretfoo")
assert.NoError(t, parseError) assert.NoError(t, parseError)
assert.Equal(t, username, "user") assert.Equal(t, username, "user")
assert.Equal(t, group, "group") assert.Equal(t, group, "group")
assert.Equal(t, user_id, "1")
}
func TestValidPayloadWithoutUserID(t *testing.T) {
signed := signCookie("user,group", "secretfoo")
username, group, user_id, parseError := parseCookie(signed, "secretfoo")
assert.NoError(t, parseError)
assert.Equal(t, username, "user")
assert.Equal(t, group, "group")
assert.Equal(t, user_id, "")
}
func TestValidPayloadWithUnicode(t *testing.T) {
signed := signCookie("用户名,群组,2", "secretfoo")
username, group, user_id, parseError := parseCookie(signed, "secretfoo")
assert.NoError(t, parseError)
assert.Equal(t, username, "用户名")
assert.Equal(t, group, "群组")
assert.Equal(t, user_id, "2")
} }
func TestNotWhitelistedPath(t *testing.T) { func TestNotWhitelistedPath(t *testing.T) {