From f2123554cfe0cd43105a80131a09ff99fd07dfcf Mon Sep 17 00:00:00 2001 From: Sam Saffron Date: Tue, 15 Dec 2020 16:18:24 +1100 Subject: [PATCH] FEATURE: add support for allowGroups setting This new setting allows you to add a list of comma delimited groups to the allow list. Previous to this change auth proxy was able to authenticate admins OR all users on the site, without any fidelity of allowing specific groups. This also refactors the tests somewhat and adds a bunch of integration tests to ensure the new setting is properly respected. Co-authored-by: Saj Goonatilleke --- config.go | 4 + go.mod | 5 +- go.sum | 78 +++-------------- main.go | 23 +++-- main_test.go | 229 ++++++++++++++++++++++++++++++++++++++++---------- string_set.go | 32 +++++++ 6 files changed, 248 insertions(+), 123 deletions(-) create mode 100644 string_set.go diff --git a/config.go b/config.go index b703acb..eac066e 100644 --- a/config.go +++ b/config.go @@ -20,6 +20,7 @@ type Config struct { SSOSecret string CookieSecret string AllowAll bool + AllowGroups StringSet BasicAuth string Whitelist string UsernameHeader string @@ -85,6 +86,7 @@ func ParseConfig() (*Config, error) { c.SSOSecret = *rc.SSOSecret c.AllowAll = *rc.AllowAll + c.AllowGroups = NewStringSet(*rc.AllowGroups) c.BasicAuth = *rc.BasicAuth c.Whitelist = *rc.Whitelist c.UsernameHeader = *rc.UsernameHeader @@ -109,6 +111,7 @@ type rawConfig struct { SSOURL *string SSOSecret *string AllowAll *bool + AllowGroups *string BasicAuth *string Whitelist *string UsernameHeader *string @@ -126,6 +129,7 @@ func parseRawConfig() *rawConfig { SSOURL: flag.String("sso-url", "", "SSO endpoint. e.g.: http://discourse.forum.com"), SSOSecret: flag.String("sso-secret", "", "SSO secret for origin"), AllowAll: flag.Bool("allow-all", false, "Allow all discourse users (default: admin users only)"), + AllowGroups: flag.String("allow-groups", "", "Allow users belonging to the specified groups, comma delimited (default: no groups are allowed)"), BasicAuth: flag.String("basic-auth", "", "HTTP Basic authentication credentials to let through directly"), Whitelist: flag.String("whitelist", "", "Path which does not require authorization"), UsernameHeader: flag.String("username-header", "Discourse-User-Name", "Request header to pass authenticated username into"), diff --git a/go.mod b/go.mod index 67863b9..147fafa 100644 --- a/go.mod +++ b/go.mod @@ -5,8 +5,7 @@ go 1.15 require ( github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e github.com/namsral/flag v1.7.4-pre - github.com/onsi/ginkgo v1.14.2 - github.com/onsi/gomega v1.10.3 github.com/pborman/uuid v1.2.1 - golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e + github.com/stretchr/testify v1.6.1 + golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 ) diff --git a/go.sum b/go.sum index d2fc4d8..f2a0f15 100644 --- a/go.sum +++ b/go.sum @@ -1,76 +1,20 @@ -github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= -github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= -github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= -github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= -github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= -github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= -github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= -github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/uuid v1.0.0 h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/namsral/flag v1.7.4-pre h1:b2ScHhoCUkbsq0d2C15Mv+VU8bl8hAXV8arnWiOHNZs= github.com/namsral/flag v1.7.4-pre/go.mod h1:OXldTctbM6SWH1K899kPZcf65KxJiD7MsceFUpB5yDo= -github.com/nxadm/tail v1.4.4 h1:DQuhQpB1tVlglWS2hLQ5OV6B5r8aGxSrPc5Qo6uTN78= -github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= -github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= -github.com/onsi/ginkgo v1.14.2 h1:8mVmC9kjFFmA8H4pKMUhcblgifdkOIXPvbhN1T36q1M= -github.com/onsi/ginkgo v1.14.2/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= -github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= -github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= -github.com/onsi/gomega v1.10.3 h1:gph6h/qe9GSUw1NhH1gp+qb+h8rXD8Cy60Z32Qw3ELA= -github.com/onsi/gomega v1.10.3/go.mod h1:V9xEwhxec5O8UDM77eCW8vLymOMltsqPVYWrpDsH8xc= github.com/pborman/uuid v1.2.1 h1:+ZZIw58t/ozdjRaXh/3awHfmWRbzYxJoAdNJxe/3pvw= github.com/pborman/uuid v1.2.1/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0 h1:wBouT66WTYFXdxfVdz9sVWARVd/2vfGcmI45D2gj45M= -golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f h1:+Nyd8tzPX9R7BWHguqsrbFdRx3WQ/1ib8I44HXV5yTA= -golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e h1:EHBhcS0mlXEAVwNyO2dLfjToGsyY4j24pTs2ScHnX7s= -golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= -google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= -google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= -google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= -google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= -google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM= -google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= +golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= -gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= -gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go index 3882ef6..1b1dc28 100644 --- a/main.go +++ b/main.go @@ -52,7 +52,7 @@ func main() { go dnssrv.Lookup(context.Background(), 50*time.Second, 10*time.Second, config.SRVAbandonAfter) proxy := &httputil.ReverseProxy{Director: dnssrv.Director} - handler := authProxyHandler(proxy, config) + handler := authProxyHandler(proxy) if config.LogRequests { handler = logHandler(handler) @@ -85,7 +85,7 @@ func main() { log.Fatal(server.Serve(listener)) } -func authProxyHandler(handler http.Handler, config *Config) http.Handler { +func authProxyHandler(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if checkWhitelist(handler, r, w) { return @@ -185,10 +185,10 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr } var ( - username = parsedQuery.Get("username") - admin = parsedQuery.Get("admin") - nonce = parsedQuery.Get("nonce") - groupsArray = strings.Split(parsedQuery.Get("groups"), ",") + username = parsedQuery.Get("username") + admin = parsedQuery.Get("admin") + nonce = parsedQuery.Get("nonce") + groups = NewStringSet(parsedQuery.Get("groups")) ) if len(nonce) == 0 { @@ -204,8 +204,12 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr return } if !(config.AllowAll || admin == "true") { - writeHttpError(http.StatusForbidden) - return + allowed := config.AllowGroups.ContainsAny(groups) + + if !allowed { + writeHttpError(http.StatusForbidden) + return + } } returnUrl, err := getReturnUrl(config.SSOSecret, sso, sig, nonce) @@ -217,7 +221,7 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr // we have a valid auth expiration := time.Now().Add(reauthorizeInterval) - cookieData := strings.Join([]string{username, strings.Join(groupsArray, "|")}, ",") + cookieData := strings.Join([]string{username, strings.Join(groups, "|")}, ",") http.SetCookie(w, &http.Cookie{ Name: cookieName, Value: signCookie(cookieData, config.CookieSecret), @@ -232,6 +236,7 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr } func getReturnUrl(secret string, payload string, sig string, nonce string) (returnUrl string, err error) { + nonceMutex.Lock() value, ok := nonceCache.Get(nonce) nonceMutex.Unlock() diff --git a/main_test.go b/main_test.go index b29ef5c..32d3ebd 100644 --- a/main_test.go +++ b/main_test.go @@ -1,62 +1,203 @@ package main import ( + "encoding/base64" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strconv" "testing" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" + "github.com/stretchr/testify/assert" ) -func TestMain(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "main") +type SSOOptions struct { + URL string + Secret string + Nonce string + Groups string + Admin bool } -var _ = Describe("parseCookie", func() { - var ( - signed, secret string +type SSOOverrideFunc func(*SSOOptions) +type ConfigOverrideFunc func(*Config) - parsedUsername string - parsedGroup string - parseError error +func mustParseURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil { + panic(err) + } + return u +} + +func NewTestConfig() Config { + return Config{ + OriginURL: mustParseURL("http://origin.url"), + ProxyURL: mustParseURL("http://proxy.url"), + SSOURL: mustParseURL("http://sso.url"), + SSOSecret: "secret", + AllowAll: false, + AllowGroups: NewStringSet(""), + BasicAuth: "", + Whitelist: "", + UsernameHeader: "username-header", + GroupsHeader: "groups-header", + Timeout: 10, + SRVAbandonAfter: 600, + LogRequests: false, + } +} + +func NewSSOOptions(url string, secret string) SSOOptions { + return SSOOptions{ + URL: url, + Secret: secret, + Admin: false, + } +} + +func RegisterTestNonce(t *testing.T, options SSOOptions) SSOOptions { + if options.Nonce != "" { + return options + } + options.Nonce = addNonce("http://some.url/") + t.Cleanup(func() { + nonceCache.Clear() + }) + return options +} + +func BuildTestSSOURL(options SSOOptions) string { + innerqs := url.Values{ + "username": []string{"sam"}, + "groups": []string{options.Groups}, + "admin": []string{strconv.FormatBool(options.Admin)}, + "nonce": []string{options.Nonce}, + } + inner := base64.StdEncoding.EncodeToString([]byte(innerqs.Encode())) + + u := mustParseURL(options.URL) + outerqs := u.Query() + outerqs.Set("sso", inner) + outerqs.Set("sig", computeHMAC(inner, options.Secret)) + u.RawQuery = outerqs.Encode() + return u.String() +} + +func GetTestResult(t *testing.T, configOverride ConfigOverrideFunc, ssoOverride SSOOverrideFunc) *http.Response { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "") + }) + + newConfig := NewTestConfig() + + configOverride(&newConfig) + config = &newConfig + + proxy := authProxyHandler(handler) + ts := httptest.NewServer(proxy) + defer ts.Close() + + options := NewSSOOptions(ts.URL, config.SSOSecret) + ssoOverride(&options) + options = RegisterTestNonce(t, options) + + res, _ := http.Get(BuildTestSSOURL(options)) + return res +} + +func TestBadSecret(t *testing.T) { + res := GetTestResult( + t, + func(config *Config) { + config.AllowAll = true + }, + func(options *SSOOptions) { + options.Secret = "BAD SECRET" + }, ) - JustBeforeEach(func() { - parsedUsername, parsedGroup, parseError = parseCookie(signed, secret) - }) + assert.Equal(t, 400, res.StatusCode) +} - Context("when verifying with an invalid secret", func() { - BeforeEach(func() { - secret = "secretbar" - signed = signCookie("user,group", "secretfoo") - }) +func TestForbiddenGroup(t *testing.T) { + res := GetTestResult( + t, + func(config *Config) { + config.AllowGroups = NewStringSet("group_a,group_b") + }, + func(options *SSOOptions) { + options.Groups = "group_c,group_d" + }, + ) - It("fails", func() { - Expect(parseError).To(HaveOccurred()) - }) - }) + assert.Equal(t, 403, res.StatusCode) +} - Context("when verifying with an invalid payload", func() { - BeforeEach(func() { - secret = "mysecret" - signed = signCookie("user,group", secret) + "garbage" - }) +func TestAllowedGroup(t *testing.T) { + res := GetTestResult( + t, + func(config *Config) { + config.AllowGroups = NewStringSet("group_a,group_b") + }, + func(options *SSOOptions) { + options.Groups = "group_c,group_a" + }, + ) - It("fails", func() { - Expect(parseError).To(HaveOccurred()) - }) - }) + assert.Equal(t, 200, res.StatusCode) +} - Context("when verifying with a valid payload and secret", func() { - BeforeEach(func() { - secret = "mysecret" - signed = signCookie("user,group", secret) - }) +func TestForbiddenAnon(t *testing.T) { + res := GetTestResult( + t, + func(config *Config) { + config.AllowGroups = NewStringSet("") + config.AllowAll = false + }, + func(options *SSOOptions) { + options.Admin = false + }, + ) - It("returns a user and group", func() { - Expect(parseError).To(Succeed()) - Expect(parsedUsername).To(Equal("user")) - Expect(parsedGroup).To(Equal("group")) - }) - }) -}) + assert.Equal(t, 403, res.StatusCode) +} + +func TestAllowedAnon(t *testing.T) { + res := GetTestResult( + t, + func(config *Config) { + config.AllowGroups = NewStringSet("") + config.AllowAll = true + }, + func(options *SSOOptions) { + options.Admin = false + }, + ) + + assert.Equal(t, 200, res.StatusCode) +} + +func TestInvalidSecretFails(t *testing.T) { + signed := signCookie("user,group", "secretfoo") + _, _, parseError := parseCookie(signed, "secretbar") + + assert.Error(t, parseError) +} + +func TestInvalidPayloadFails(t *testing.T) { + signed := signCookie("user,group", "secretfoo") + "garbage" + _, _, parseError := parseCookie(signed, "secretfoo") + + assert.Error(t, parseError) +} + +func TestValidPayload(t *testing.T) { + signed := signCookie("user,group", "secretfoo") + username, group, parseError := parseCookie(signed, "secretfoo") + + assert.NoError(t, parseError) + assert.Equal(t, username, "user") + assert.Equal(t, group, "group") +} diff --git a/string_set.go b/string_set.go new file mode 100644 index 0000000..a17c72d --- /dev/null +++ b/string_set.go @@ -0,0 +1,32 @@ +package main + +import ( + "strings" +) + +type StringSet []string + +func NewStringSet(s string) StringSet { + if len(s) == 0 { + return []string{} + } + return strings.Split(s, ",") +} + +func (ss StringSet) Contains(needle string) bool { + for _, s := range ss { + if s == needle { + return true + } + } + return false +} + +func (ss StringSet) ContainsAny(needles StringSet) bool { + for _, n := range needles { + if ss.Contains(n) { + return true + } + } + return false +}