diff --git a/config.go b/config.go index b703acb..eac066e 100644 --- a/config.go +++ b/config.go @@ -20,6 +20,7 @@ type Config struct { SSOSecret string CookieSecret string AllowAll bool + AllowGroups StringSet BasicAuth string Whitelist string UsernameHeader string @@ -85,6 +86,7 @@ func ParseConfig() (*Config, error) { c.SSOSecret = *rc.SSOSecret c.AllowAll = *rc.AllowAll + c.AllowGroups = NewStringSet(*rc.AllowGroups) c.BasicAuth = *rc.BasicAuth c.Whitelist = *rc.Whitelist c.UsernameHeader = *rc.UsernameHeader @@ -109,6 +111,7 @@ type rawConfig struct { SSOURL *string SSOSecret *string AllowAll *bool + AllowGroups *string BasicAuth *string Whitelist *string UsernameHeader *string @@ -126,6 +129,7 @@ func parseRawConfig() *rawConfig { SSOURL: flag.String("sso-url", "", "SSO endpoint. e.g.: http://discourse.forum.com"), SSOSecret: flag.String("sso-secret", "", "SSO secret for origin"), AllowAll: flag.Bool("allow-all", false, "Allow all discourse users (default: admin users only)"), + AllowGroups: flag.String("allow-groups", "", "Allow users belonging to the specified groups, comma delimited (default: no groups are allowed)"), BasicAuth: flag.String("basic-auth", "", "HTTP Basic authentication credentials to let through directly"), Whitelist: flag.String("whitelist", "", "Path which does not require authorization"), UsernameHeader: flag.String("username-header", "Discourse-User-Name", "Request header to pass authenticated username into"), diff --git a/go.mod b/go.mod index 67863b9..147fafa 100644 --- a/go.mod +++ b/go.mod @@ -5,8 +5,7 @@ go 1.15 require ( github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e github.com/namsral/flag v1.7.4-pre - github.com/onsi/ginkgo v1.14.2 - github.com/onsi/gomega v1.10.3 github.com/pborman/uuid v1.2.1 - golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e + github.com/stretchr/testify v1.6.1 + golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 ) diff --git a/go.sum b/go.sum index d2fc4d8..f2a0f15 100644 --- a/go.sum +++ b/go.sum @@ -1,76 +1,20 @@ -github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= -github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= -github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= -github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= -github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= -github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= -github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= -github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/uuid v1.0.0 h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/namsral/flag v1.7.4-pre h1:b2ScHhoCUkbsq0d2C15Mv+VU8bl8hAXV8arnWiOHNZs= github.com/namsral/flag v1.7.4-pre/go.mod h1:OXldTctbM6SWH1K899kPZcf65KxJiD7MsceFUpB5yDo= -github.com/nxadm/tail v1.4.4 h1:DQuhQpB1tVlglWS2hLQ5OV6B5r8aGxSrPc5Qo6uTN78= -github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= -github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= -github.com/onsi/ginkgo v1.14.2 h1:8mVmC9kjFFmA8H4pKMUhcblgifdkOIXPvbhN1T36q1M= -github.com/onsi/ginkgo v1.14.2/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= -github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= -github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= -github.com/onsi/gomega v1.10.3 h1:gph6h/qe9GSUw1NhH1gp+qb+h8rXD8Cy60Z32Qw3ELA= -github.com/onsi/gomega v1.10.3/go.mod h1:V9xEwhxec5O8UDM77eCW8vLymOMltsqPVYWrpDsH8xc= github.com/pborman/uuid v1.2.1 h1:+ZZIw58t/ozdjRaXh/3awHfmWRbzYxJoAdNJxe/3pvw= github.com/pborman/uuid v1.2.1/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0 h1:wBouT66WTYFXdxfVdz9sVWARVd/2vfGcmI45D2gj45M= -golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f h1:+Nyd8tzPX9R7BWHguqsrbFdRx3WQ/1ib8I44HXV5yTA= -golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e h1:EHBhcS0mlXEAVwNyO2dLfjToGsyY4j24pTs2ScHnX7s= -golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= -google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= -google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= -google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= -google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= -google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM= -google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= +golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= -gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= -gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go index 3882ef6..1b1dc28 100644 --- a/main.go +++ b/main.go @@ -52,7 +52,7 @@ func main() { go dnssrv.Lookup(context.Background(), 50*time.Second, 10*time.Second, config.SRVAbandonAfter) proxy := &httputil.ReverseProxy{Director: dnssrv.Director} - handler := authProxyHandler(proxy, config) + handler := authProxyHandler(proxy) if config.LogRequests { handler = logHandler(handler) @@ -85,7 +85,7 @@ func main() { log.Fatal(server.Serve(listener)) } -func authProxyHandler(handler http.Handler, config *Config) http.Handler { +func authProxyHandler(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if checkWhitelist(handler, r, w) { return @@ -185,10 +185,10 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr } var ( - username = parsedQuery.Get("username") - admin = parsedQuery.Get("admin") - nonce = parsedQuery.Get("nonce") - groupsArray = strings.Split(parsedQuery.Get("groups"), ",") + username = parsedQuery.Get("username") + admin = parsedQuery.Get("admin") + nonce = parsedQuery.Get("nonce") + groups = NewStringSet(parsedQuery.Get("groups")) ) if len(nonce) == 0 { @@ -204,8 +204,12 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr return } if !(config.AllowAll || admin == "true") { - writeHttpError(http.StatusForbidden) - return + allowed := config.AllowGroups.ContainsAny(groups) + + if !allowed { + writeHttpError(http.StatusForbidden) + return + } } returnUrl, err := getReturnUrl(config.SSOSecret, sso, sig, nonce) @@ -217,7 +221,7 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr // we have a valid auth expiration := time.Now().Add(reauthorizeInterval) - cookieData := strings.Join([]string{username, strings.Join(groupsArray, "|")}, ",") + cookieData := strings.Join([]string{username, strings.Join(groups, "|")}, ",") http.SetCookie(w, &http.Cookie{ Name: cookieName, Value: signCookie(cookieData, config.CookieSecret), @@ -232,6 +236,7 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr } func getReturnUrl(secret string, payload string, sig string, nonce string) (returnUrl string, err error) { + nonceMutex.Lock() value, ok := nonceCache.Get(nonce) nonceMutex.Unlock() diff --git a/main_test.go b/main_test.go index b29ef5c..32d3ebd 100644 --- a/main_test.go +++ b/main_test.go @@ -1,62 +1,203 @@ package main import ( + "encoding/base64" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strconv" "testing" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" + "github.com/stretchr/testify/assert" ) -func TestMain(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "main") +type SSOOptions struct { + URL string + Secret string + Nonce string + Groups string + Admin bool } -var _ = Describe("parseCookie", func() { - var ( - signed, secret string +type SSOOverrideFunc func(*SSOOptions) +type ConfigOverrideFunc func(*Config) - parsedUsername string - parsedGroup string - parseError error +func mustParseURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil { + panic(err) + } + return u +} + +func NewTestConfig() Config { + return Config{ + OriginURL: mustParseURL("http://origin.url"), + ProxyURL: mustParseURL("http://proxy.url"), + SSOURL: mustParseURL("http://sso.url"), + SSOSecret: "secret", + AllowAll: false, + AllowGroups: NewStringSet(""), + BasicAuth: "", + Whitelist: "", + UsernameHeader: "username-header", + GroupsHeader: "groups-header", + Timeout: 10, + SRVAbandonAfter: 600, + LogRequests: false, + } +} + +func NewSSOOptions(url string, secret string) SSOOptions { + return SSOOptions{ + URL: url, + Secret: secret, + Admin: false, + } +} + +func RegisterTestNonce(t *testing.T, options SSOOptions) SSOOptions { + if options.Nonce != "" { + return options + } + options.Nonce = addNonce("http://some.url/") + t.Cleanup(func() { + nonceCache.Clear() + }) + return options +} + +func BuildTestSSOURL(options SSOOptions) string { + innerqs := url.Values{ + "username": []string{"sam"}, + "groups": []string{options.Groups}, + "admin": []string{strconv.FormatBool(options.Admin)}, + "nonce": []string{options.Nonce}, + } + inner := base64.StdEncoding.EncodeToString([]byte(innerqs.Encode())) + + u := mustParseURL(options.URL) + outerqs := u.Query() + outerqs.Set("sso", inner) + outerqs.Set("sig", computeHMAC(inner, options.Secret)) + u.RawQuery = outerqs.Encode() + return u.String() +} + +func GetTestResult(t *testing.T, configOverride ConfigOverrideFunc, ssoOverride SSOOverrideFunc) *http.Response { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "") + }) + + newConfig := NewTestConfig() + + configOverride(&newConfig) + config = &newConfig + + proxy := authProxyHandler(handler) + ts := httptest.NewServer(proxy) + defer ts.Close() + + options := NewSSOOptions(ts.URL, config.SSOSecret) + ssoOverride(&options) + options = RegisterTestNonce(t, options) + + res, _ := http.Get(BuildTestSSOURL(options)) + return res +} + +func TestBadSecret(t *testing.T) { + res := GetTestResult( + t, + func(config *Config) { + config.AllowAll = true + }, + func(options *SSOOptions) { + options.Secret = "BAD SECRET" + }, ) - JustBeforeEach(func() { - parsedUsername, parsedGroup, parseError = parseCookie(signed, secret) - }) + assert.Equal(t, 400, res.StatusCode) +} - Context("when verifying with an invalid secret", func() { - BeforeEach(func() { - secret = "secretbar" - signed = signCookie("user,group", "secretfoo") - }) +func TestForbiddenGroup(t *testing.T) { + res := GetTestResult( + t, + func(config *Config) { + config.AllowGroups = NewStringSet("group_a,group_b") + }, + func(options *SSOOptions) { + options.Groups = "group_c,group_d" + }, + ) - It("fails", func() { - Expect(parseError).To(HaveOccurred()) - }) - }) + assert.Equal(t, 403, res.StatusCode) +} - Context("when verifying with an invalid payload", func() { - BeforeEach(func() { - secret = "mysecret" - signed = signCookie("user,group", secret) + "garbage" - }) +func TestAllowedGroup(t *testing.T) { + res := GetTestResult( + t, + func(config *Config) { + config.AllowGroups = NewStringSet("group_a,group_b") + }, + func(options *SSOOptions) { + options.Groups = "group_c,group_a" + }, + ) - It("fails", func() { - Expect(parseError).To(HaveOccurred()) - }) - }) + assert.Equal(t, 200, res.StatusCode) +} - Context("when verifying with a valid payload and secret", func() { - BeforeEach(func() { - secret = "mysecret" - signed = signCookie("user,group", secret) - }) +func TestForbiddenAnon(t *testing.T) { + res := GetTestResult( + t, + func(config *Config) { + config.AllowGroups = NewStringSet("") + config.AllowAll = false + }, + func(options *SSOOptions) { + options.Admin = false + }, + ) - It("returns a user and group", func() { - Expect(parseError).To(Succeed()) - Expect(parsedUsername).To(Equal("user")) - Expect(parsedGroup).To(Equal("group")) - }) - }) -}) + assert.Equal(t, 403, res.StatusCode) +} + +func TestAllowedAnon(t *testing.T) { + res := GetTestResult( + t, + func(config *Config) { + config.AllowGroups = NewStringSet("") + config.AllowAll = true + }, + func(options *SSOOptions) { + options.Admin = false + }, + ) + + assert.Equal(t, 200, res.StatusCode) +} + +func TestInvalidSecretFails(t *testing.T) { + signed := signCookie("user,group", "secretfoo") + _, _, parseError := parseCookie(signed, "secretbar") + + assert.Error(t, parseError) +} + +func TestInvalidPayloadFails(t *testing.T) { + signed := signCookie("user,group", "secretfoo") + "garbage" + _, _, parseError := parseCookie(signed, "secretfoo") + + assert.Error(t, parseError) +} + +func TestValidPayload(t *testing.T) { + signed := signCookie("user,group", "secretfoo") + username, group, parseError := parseCookie(signed, "secretfoo") + + assert.NoError(t, parseError) + assert.Equal(t, username, "user") + assert.Equal(t, group, "group") +} diff --git a/string_set.go b/string_set.go new file mode 100644 index 0000000..a17c72d --- /dev/null +++ b/string_set.go @@ -0,0 +1,32 @@ +package main + +import ( + "strings" +) + +type StringSet []string + +func NewStringSet(s string) StringSet { + if len(s) == 0 { + return []string{} + } + return strings.Split(s, ",") +} + +func (ss StringSet) Contains(needle string) bool { + for _, s := range ss { + if s == needle { + return true + } + } + return false +} + +func (ss StringSet) ContainsAny(needles StringSet) bool { + for _, n := range needles { + if ss.Contains(n) { + return true + } + } + return false +}