diff --git a/config.go b/config.go index eac066e..dd7811e 100644 --- a/config.go +++ b/config.go @@ -23,6 +23,7 @@ type Config struct { AllowGroups StringSet BasicAuth string Whitelist string + WhitelistPrefix string UsernameHeader string GroupsHeader string Timeout time.Duration @@ -89,6 +90,7 @@ func ParseConfig() (*Config, error) { c.AllowGroups = NewStringSet(*rc.AllowGroups) c.BasicAuth = *rc.BasicAuth c.Whitelist = *rc.Whitelist + c.WhitelistPrefix = *rc.WhitelistPrefix c.UsernameHeader = *rc.UsernameHeader c.GroupsHeader = *rc.GroupsHeader c.Timeout = time.Duration(*rc.Timeout) * time.Second @@ -114,6 +116,7 @@ type rawConfig struct { AllowGroups *string BasicAuth *string Whitelist *string + WhitelistPrefix *string UsernameHeader *string GroupsHeader *string Timeout *int @@ -132,6 +135,7 @@ func parseRawConfig() *rawConfig { AllowGroups: flag.String("allow-groups", "", "Allow users belonging to the specified groups, comma delimited (default: no groups are allowed)"), BasicAuth: flag.String("basic-auth", "", "HTTP Basic authentication credentials to let through directly"), Whitelist: flag.String("whitelist", "", "Path which does not require authorization"), + WhitelistPrefix: flag.String("whitelist-prefix", "", "Prefix for paths which do not require authorization"), UsernameHeader: flag.String("username-header", "Discourse-User-Name", "Request header to pass authenticated username into"), GroupsHeader: flag.String("groups-header", "Discourse-User-Groups", "Request header to pass authenticated groups into"), Timeout: flag.Int("timeout", 10, "Read/write timeout (seconds)"), diff --git a/main.go b/main.go index 1b1dc28..d625d14 100644 --- a/main.go +++ b/main.go @@ -128,12 +128,22 @@ func checkAuthorizationHeader(handler http.Handler, r *http.Request, w http.Resp return false } -func checkWhitelist(handler http.Handler, r *http.Request, w http.ResponseWriter) bool { - if config.Whitelist == "" { +func allowedByWhiteList(c *Config, p string) bool { + if c.Whitelist == "" && c.WhitelistPrefix == "" { return false } - if r.URL.Path == config.Whitelist { + prefixAllowed := len(c.WhitelistPrefix) > 0 && strings.HasPrefix(p, c.WhitelistPrefix) + + if p == c.Whitelist || prefixAllowed { + return true + } + + return false +} + +func checkWhitelist(handler http.Handler, r *http.Request, w http.ResponseWriter) bool { + if allowedByWhiteList(config, r.URL.Path) { handler.ServeHTTP(w, r) return true } diff --git a/main_test.go b/main_test.go index 32d3ebd..86d0794 100644 --- a/main_test.go +++ b/main_test.go @@ -41,6 +41,7 @@ func NewTestConfig() Config { AllowGroups: NewStringSet(""), BasicAuth: "", Whitelist: "", + WhitelistPrefix: "", UsernameHeader: "username-header", GroupsHeader: "groups-header", Timeout: 10, @@ -201,3 +202,27 @@ func TestValidPayload(t *testing.T) { assert.Equal(t, username, "user") assert.Equal(t, group, "group") } + +func TestNotWhitelistedPath(t *testing.T) { + c := NewTestConfig() + c.Whitelist = "" + res := allowedByWhiteList(&c, "/some_path") + + assert.Equal(t, false, res) +} + +func TestWhitelistedPath(t *testing.T) { + c := NewTestConfig() + c.Whitelist = "/some_path" + res := allowedByWhiteList(&c, "/some_path") + + assert.Equal(t, true, res) +} + +func TestWhitelistedPrefixPath(t *testing.T) { + c := NewTestConfig() + c.WhitelistPrefix = "/prefix/" + res := allowedByWhiteList(&c, "/prefix/some_path") + + assert.Equal(t, true, res) +}