This commit is contained in:
Misaka 0x4e21 2025-03-08 06:43:39 +08:00 committed by GitHub
commit 8622c7238f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 51 additions and 10 deletions

View File

@ -26,6 +26,7 @@ type Config struct {
WhitelistPrefix string
UsernameHeader string
GroupsHeader string
UserIDHeader string
Timeout time.Duration
SRVAbandonAfter time.Duration
LogRequests bool
@ -93,6 +94,7 @@ func ParseConfig() (*Config, error) {
c.WhitelistPrefix = *rc.WhitelistPrefix
c.UsernameHeader = *rc.UsernameHeader
c.GroupsHeader = *rc.GroupsHeader
c.UserIDHeader = *rc.UserIDHeader
c.Timeout = time.Duration(*rc.Timeout) * time.Second
if *rc.SRVAbandonAfter < 1 {
c.SRVAbandonAfter = 0
@ -119,6 +121,7 @@ type rawConfig struct {
WhitelistPrefix *string
UsernameHeader *string
GroupsHeader *string
UserIDHeader *string
Timeout *int
SRVAbandonAfter *int
LogRequests *bool
@ -138,6 +141,7 @@ func parseRawConfig() *rawConfig {
WhitelistPrefix: flag.String("whitelist-prefix", "", "Prefix for paths which do not require authorization"),
UsernameHeader: flag.String("username-header", "Discourse-User-Name", "Request header to pass authenticated username into"),
GroupsHeader: flag.String("groups-header", "Discourse-User-Groups", "Request header to pass authenticated groups into"),
UserIDHeader: flag.String("user-id-header", "", "Request header to pass authenticated user id into (default: disabled)"),
Timeout: flag.Int("timeout", 10, "Read/write timeout (seconds)"),
SRVAbandonAfter: flag.Int("dns-srv-abandon-after", 600, "Abandon DNS SRV discovery if origin RRs do not appear within this time (seconds). When negative, attempt SRV lookups indefinitely."),
LogRequests: flag.Bool("log-requests", false, "Log all requests to standard error"),

28
main.go
View File

@ -161,15 +161,20 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr
}
cookie, err := r.Cookie(cookieName)
var username, groups string
var username, groups, user_id string
if err == nil && cookie != nil {
username, groups, err = parseCookie(cookie.Value, config.CookieSecret)
username, groups, user_id, err = parseCookie(cookie.Value, config.CookieSecret)
}
if err == nil {
r.Header.Set(config.UsernameHeader, username)
r.Header.Set(config.GroupsHeader, groups)
if config.UserIDHeader != "" {
r.Header.Set(config.UserIDHeader, user_id)
}
handler.ServeHTTP(w, r)
return
}
@ -199,6 +204,7 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr
admin = parsedQuery.Get("admin")
nonce = parsedQuery.Get("nonce")
groups = NewStringSet(parsedQuery.Get("groups"))
user_id = parsedQuery.Get("external_id")
)
if len(nonce) == 0 {
@ -231,7 +237,8 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr
// we have a valid auth
expiration := time.Now().Add(reauthorizeInterval)
cookieData := strings.Join([]string{username, strings.Join(groups, "|")}, ",")
cookieData := strings.Join([]string{username, strings.Join(groups, "|"), user_id}, ",")
cookieData = url.QueryEscape(cookieData)
http.SetCookie(w, &http.Cookie{
Name: cookieName,
Value: signCookie(cookieData, config.CookieSecret),
@ -270,10 +277,11 @@ func signCookie(data, secret string) string {
return data + "," + computeHMAC(data, secret)
}
func parseCookie(data, secret string) (username string, groups string, err error) {
func parseCookie(data, secret string) (username string, groups string, user_id string, err error) {
err = nil
username = ""
groups = ""
user_id = ""
split := strings.Split(data, ",")
@ -291,8 +299,16 @@ func parseCookie(data, secret string) (username string, groups string, err error
err = fmt.Errorf("Expecting signature to match")
return
} else {
username = strings.Split(parsed, ",")[0]
groups = strings.Split(parsed, ",")[1]
parsed, err = url.QueryUnescape(parsed)
if err != nil {
return
}
splitted := strings.Split(parsed, ",")
username = splitted[0]
groups = splitted[1]
if len(splitted) >= 3 {
user_id = splitted[2]
}
}
return

View File

@ -182,25 +182,46 @@ func TestAllowedAnon(t *testing.T) {
func TestInvalidSecretFails(t *testing.T) {
signed := signCookie("user,group", "secretfoo")
_, _, parseError := parseCookie(signed, "secretbar")
_, _, _, parseError := parseCookie(signed, "secretbar")
assert.Error(t, parseError)
}
func TestInvalidPayloadFails(t *testing.T) {
signed := signCookie("user,group", "secretfoo") + "garbage"
_, _, parseError := parseCookie(signed, "secretfoo")
_, _, _, parseError := parseCookie(signed, "secretfoo")
assert.Error(t, parseError)
}
func TestValidPayload(t *testing.T) {
signed := signCookie("user,group", "secretfoo")
username, group, parseError := parseCookie(signed, "secretfoo")
signed := signCookie("user,group,1", "secretfoo")
username, group, user_id, parseError := parseCookie(signed, "secretfoo")
assert.NoError(t, parseError)
assert.Equal(t, username, "user")
assert.Equal(t, group, "group")
assert.Equal(t, user_id, "1")
}
func TestValidPayloadWithoutUserID(t *testing.T) {
signed := signCookie("user,group", "secretfoo")
username, group, user_id, parseError := parseCookie(signed, "secretfoo")
assert.NoError(t, parseError)
assert.Equal(t, username, "user")
assert.Equal(t, group, "group")
assert.Equal(t, user_id, "")
}
func TestValidPayloadWithUnicode(t *testing.T) {
signed := signCookie("用户名,群组,2", "secretfoo")
username, group, user_id, parseError := parseCookie(signed, "secretfoo")
assert.NoError(t, parseError)
assert.Equal(t, username, "用户名")
assert.Equal(t, group, "群组")
assert.Equal(t, user_id, "2")
}
func TestNotWhitelistedPath(t *testing.T) {