From a9aa54d567c68fc18b839f0234205802ac80a4d6 Mon Sep 17 00:00:00 2001 From: OsamaSayegh Date: Tue, 16 Mar 2021 03:16:58 +0300 Subject: [PATCH] Commit 1 --- README.md | 7 +++ config.go | 8 +++ go.mod | 3 +- go.sum | 79 +++++++++++++++++++++++++++ main.go | 83 +++++++++++++++++++--------- main_test.go | 149 ++++++++++++++++++++++++++++++++++++++++++++++++++- stores.go | 118 ++++++++++++++++++++++++++++++++++++++++ 7 files changed, 419 insertions(+), 28 deletions(-) create mode 100644 stores.go diff --git a/README.md b/README.md index a0d1914..64909de 100644 --- a/README.md +++ b/README.md @@ -54,3 +54,10 @@ docker run discourse/auth-proxy ``` Running will display configuration instructions + +Development +=== + +1. Install Golang: https://golang.org/doc/install. +2. You need a Redis server running locally and listening on `127.0.0.1:6379`. If you're on Ubuntu, you can install Redis with this command: `sudo apt-get install redis-server`. +3. Run `go test` to run the tests suite. diff --git a/config.go b/config.go index eac066e..944a9b9 100644 --- a/config.go +++ b/config.go @@ -28,6 +28,8 @@ type Config struct { Timeout time.Duration SRVAbandonAfter time.Duration LogRequests bool + RedisAddress string + RedisPassword string } func ParseConfig() (*Config, error) { @@ -100,6 +102,8 @@ func ParseConfig() (*Config, error) { c.LogRequests = *rc.LogRequests c.CookieSecret = uuid.New() + c.RedisAddress = *rc.RedisAddress + c.RedisPassword = *rc.RedisPassword return c, nil } @@ -119,6 +123,8 @@ type rawConfig struct { Timeout *int SRVAbandonAfter *int LogRequests *bool + RedisAddress *string + RedisPassword *string } func parseRawConfig() *rawConfig { @@ -137,6 +143,8 @@ func parseRawConfig() *rawConfig { Timeout: flag.Int("timeout", 10, "Read/write timeout (seconds)"), SRVAbandonAfter: flag.Int("dns-srv-abandon-after", 600, "Abandon DNS SRV discovery if origin RRs do not appear within this time (seconds). When negative, attempt SRV lookups indefinitely."), LogRequests: flag.Bool("log-requests", false, "Log all requests to standard error"), + RedisAddress: flag.String("redis-address", "", "Address of a Redis server which auth-proxy will use to store nonces. e.g.: 127.0.0.1:6379. Optional."), + RedisPassword: flag.String("redis-password", "", "Password of the Redis server. Optional."), } flag.Parse() return c diff --git a/go.mod b/go.mod index 147fafa..71f6cf0 100644 --- a/go.mod +++ b/go.mod @@ -3,9 +3,10 @@ module github.com/discourse/discourse-auth-proxy go 1.15 require ( + github.com/go-redis/redis/v8 v8.7.1 github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e github.com/namsral/flag v1.7.4-pre github.com/pborman/uuid v1.2.1 - github.com/stretchr/testify v1.6.1 + github.com/stretchr/testify v1.7.0 golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 ) diff --git a/go.sum b/go.sum index f2a0f15..cc70cbf 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,39 @@ +github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/go-redis/redis v6.15.9+incompatible h1:K0pv1D7EQUjfyoMql+r/jZqCLizCGKFlFgcHWWmHQjg= +github.com/go-redis/redis/v8 v8.7.1 h1:8IYi6RO83fNcG5amcUUYTN/qH2h4OjZHlim3KWGFSsA= +github.com/go-redis/redis/v8 v8.7.1/go.mod h1:BRxHBWn3pO3CfjyX6vAoyeRmCquvxr6QG+2onGV2gYs= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/uuid v1.0.0 h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/namsral/flag v1.7.4-pre h1:b2ScHhoCUkbsq0d2C15Mv+VU8bl8hAXV8arnWiOHNZs= github.com/namsral/flag v1.7.4-pre/go.mod h1:OXldTctbM6SWH1K899kPZcf65KxJiD7MsceFUpB5yDo= +github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= +github.com/onsi/ginkgo v1.15.0/go.mod h1:hF8qUzuuC8DJGygJH3726JnCZX4MYbRB8yFfISqnKUg= +github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= +github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= +github.com/onsi/gomega v1.10.5/go.mod h1:gza4q3jKQJijlu05nKWRCW/GavJumGt8aNRxWg7mt48= github.com/pborman/uuid v1.2.1 h1:+ZZIw58t/ozdjRaXh/3awHfmWRbzYxJoAdNJxe/3pvw= github.com/pborman/uuid v1.2.1/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -13,8 +41,59 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +go.opentelemetry.io/otel v0.18.0 h1:d5Of7+Zw4ANFOJB+TIn2K3QWsgS2Ht7OU9DqZHI6qu8= +go.opentelemetry.io/otel v0.18.0/go.mod h1:PT5zQj4lTsR1YeARt8YNKcFb88/c2IKoSABK9mX0r78= +go.opentelemetry.io/otel/metric v0.18.0 h1:yuZCmY9e1ZTaMlZXLrrbAPmYW6tW1A5ozOZeOYGaTaY= +go.opentelemetry.io/otel/metric v0.18.0/go.mod h1:kEH2QtzAyBy3xDVQfGZKIcok4ZZFvd5xyKPfPcuK6pE= +go.opentelemetry.io/otel/oteltest v0.18.0/go.mod h1:NyierCU3/G8DLTva7KRzGii2fdxdR89zXKH1bNWY7Bo= +go.opentelemetry.io/otel/trace v0.18.0 h1:ilCfc/fptVKaDMK1vWk0elxpolurJbEgey9J6g6s+wk= +go.opentelemetry.io/otel/trace v0.18.0/go.mod h1:FzdUu3BPwZSZebfQ1vl5/tAa8LyMLXSJN57AXIt/iDk= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go index 1b1dc28..32ff8f0 100644 --- a/main.go +++ b/main.go @@ -21,6 +21,7 @@ import ( "github.com/golang/groupcache/lru" "github.com/pborman/uuid" + "github.com/go-redis/redis/v8" "github.com/discourse/discourse-auth-proxy/internal/httpproxy" ) @@ -30,8 +31,7 @@ var ( config *Config - nonceCache = lru.New(20) - nonceMutex = &sync.Mutex{} + storageInstance CacheStore ) const ( @@ -48,6 +48,9 @@ func main() { } } + setupStorage(config) + GetSetNXCookieSecretIfRedis(config) + dnssrv := httpproxy.NewDNSSRVBackend(config.OriginURL) go dnssrv.Lookup(context.Background(), 50*time.Second, 10*time.Second, config.SRVAbandonAfter) proxy := &httputil.ReverseProxy{Director: dnssrv.Director} @@ -169,7 +172,12 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr sig := query.Get("sig") if len(sso) == 0 { - url := config.SSOURLString + "/session/sso_provider?" + sso_payload(config.SSOSecret, config.ProxyURLString, r.URL.String()).Encode() + payload, err := ssoPayload(config.SSOSecret, config.ProxyURLString, r.URL.String()) + if err != nil { + fail("An error occurred when generating SSO payload: %s", err) + return + } + url := config.SSOURLString + "/session/sso_provider?" + payload.Encode() http.Redirect(w, r, url, 302) } else { decoded, err := base64.StdEncoding.DecodeString(sso) @@ -236,20 +244,10 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr } func getReturnUrl(secret string, payload string, sig string, nonce string) (returnUrl string, err error) { - - nonceMutex.Lock() - value, ok := nonceCache.Get(nonce) - nonceMutex.Unlock() - if !ok { - err = fmt.Errorf("nonce not found: %s", nonce) - return + returnUrl, err = storageInstance.GetAndDeleteNonce(nonce) + if err != nil { + return "", err } - - returnUrl = value.(string) - nonceMutex.Lock() - nonceCache.Remove(nonce) - nonceMutex.Unlock() - if computeHMAC(payload, secret) != sig { err = errors.New("signature is invalid") } @@ -288,25 +286,30 @@ func parseCookie(data, secret string) (username string, groups string, err error return } -// sso_payload takes the SSO secret and the two redirection URLs, stores the +// ssoPayload takes the SSO secret and the two redirection URLs, stores the // returnUrl in the nonce cache, and returns a partial URL querystring. -func sso_payload(secret string, return_sso_url string, returnUrl string) url.Values { - result := "return_sso_url=" + url.QueryEscape(return_sso_url) + url.QueryEscape(returnUrl) + "&nonce=" + url.QueryEscape(addNonce(returnUrl)) +func ssoPayload(secret string, return_sso_url string, returnUrl string) (url.Values, error) { + guid, err := addNonce(returnUrl) + if err != nil { + return url.Values{}, err + } + result := "return_sso_url=" + url.QueryEscape(return_sso_url) + url.QueryEscape(returnUrl) + "&nonce=" + url.QueryEscape(guid) payload := base64.StdEncoding.EncodeToString([]byte(result)) return url.Values{ "sso": []string{payload}, "sig": []string{computeHMAC(payload, secret)}, - } + }, nil } // addNonce takes a return URL and returns a nonce associated to that URL. -func addNonce(returnUrl string) string { +func addNonce(returnUrl string) (string, error) { guid := uuid.New() - nonceMutex.Lock() - nonceCache.Add(guid, returnUrl) - nonceMutex.Unlock() - return guid + err := storageInstance.AddNonce(guid, returnUrl) + if err != nil { + return "", err + } + return guid, nil } // computeHMAC implements the Discourse SSO protocol, returning a hex string. @@ -316,3 +319,33 @@ func computeHMAC(message string, secret string) string { h.Write([]byte(message)) return hex.EncodeToString(h.Sum(nil)) } + +func setupStorage(config *Config) { + if config.RedisAddress != "" { + client := redis.NewClient(&redis.Options{ + Addr: config.RedisAddress, + Password: config.RedisPassword, + }) + storageInstance = &RedisStore{ + Redis: client, + Namespace: "_discourse-auth-proxy_", + } + } else { + storageInstance = &MemoryStore{ + Mutex: &sync.Mutex{}, + Cache: lru.New(20), + } + } +} + +func GetSetNXCookieSecretIfRedis(config *Config) { + redisStore, ok := storageInstance.(*RedisStore) + if ok { + secret, err := redisStore.GetSetNXCookieSecret() + if err != nil { + fmt.Printf("Failed to get cookie secret from redis. Error: %s\n", err) + } else { + config.CookieSecret = secret + } + } +} diff --git a/main_test.go b/main_test.go index 32d3ebd..fdb3fb1 100644 --- a/main_test.go +++ b/main_test.go @@ -8,8 +8,12 @@ import ( "net/url" "strconv" "testing" + "sync" + "os" "github.com/stretchr/testify/assert" + "github.com/golang/groupcache/lru" + "github.com/go-redis/redis/v8" ) type SSOOptions struct { @@ -23,6 +27,26 @@ type SSOOptions struct { type SSOOverrideFunc func(*SSOOptions) type ConfigOverrideFunc func(*Config) +func NewRedisStore() *RedisStore { + return &RedisStore{ + Redis: redis.NewClient(&redis.Options{ + Addr: "127.0.0.1:6379", + }), + Namespace: "_discourse-auth-proxy-test_", + } +} + +func NewMemoryStore() *MemoryStore { + return &MemoryStore{ + Mutex: &sync.Mutex{}, + Cache: lru.New(20), + } +} + +var redisStore = NewRedisStore() +var memoryStore = NewMemoryStore() +var stores = [2]CacheStore{ memoryStore, redisStore } + func mustParseURL(s string) *url.URL { u, err := url.Parse(s) if err != nil { @@ -46,6 +70,8 @@ func NewTestConfig() Config { Timeout: 10, SRVAbandonAfter: 600, LogRequests: false, + RedisAddress: "", + RedisPassword: "", } } @@ -61,9 +87,11 @@ func RegisterTestNonce(t *testing.T, options SSOOptions) SSOOptions { if options.Nonce != "" { return options } - options.Nonce = addNonce("http://some.url/") + nonce, err := addNonce("http://some.url/") + assert.NoError(t, err) + options.Nonce = nonce t.Cleanup(func() { - nonceCache.Clear() + storageInstance.Clear() }) return options } @@ -94,6 +122,7 @@ func GetTestResult(t *testing.T, configOverride ConfigOverrideFunc, ssoOverride configOverride(&newConfig) config = &newConfig + setupStorage(config) proxy := authProxyHandler(handler) ts := httptest.NewServer(proxy) @@ -179,6 +208,21 @@ func TestAllowedAnon(t *testing.T) { assert.Equal(t, 200, res.StatusCode) } +func TestExpiredNonce(t *testing.T) { + res := GetTestResult( + t, + func(config *Config) { + config.AllowGroups = NewStringSet("") + config.AllowAll = true + }, + func(options *SSOOptions) { + options.Admin = false + options.Nonce = "somenonexistentnonce" + }, + ) + assert.Equal(t, 400, res.StatusCode) +} + func TestInvalidSecretFails(t *testing.T) { signed := signCookie("user,group", "secretfoo") _, _, parseError := parseCookie(signed, "secretbar") @@ -201,3 +245,104 @@ func TestValidPayload(t *testing.T) { assert.Equal(t, username, "user") assert.Equal(t, group, "group") } + +func TestStoresAddNonceMethod(t *testing.T) { + for _, store := range stores { + nonce := "this-is-a-test-nonce" + err := store.AddNonce(nonce, "auth proxy hello world") + assert.NoError(t, err) + val, err := store.GetAndDeleteNonce(nonce) + assert.NoError(t, err) + assert.Equal(t, "auth proxy hello world", val) + } +} + +func TestStoresGetAndDeleteMethod(t *testing.T) { + for _, store := range stores { + nonce := "this-is-a-test-nonce" + err := store.AddNonce(nonce, "auth proxy hello world") + assert.NoError(t, err) + val, err := store.GetAndDeleteNonce(nonce) + assert.NoError(t, err) + assert.Equal(t, "auth proxy hello world", val) + val, err = store.GetAndDeleteNonce(nonce) + assert.Error(t, err) + assert.Equal( + t, + fmt.Sprintf("[%T] nonce not found: this-is-a-test-nonce", store), + fmt.Sprintf("%s", err), + ) + assert.Equal(t, "", val) + } +} + +func TestRedisStorePrefix(t *testing.T) { + assert.Equal(t, "_discourse-auth-proxy-test_osama", redisStore.Prefix("osama")) + assert.Equal(t, "_discourse-auth-proxy-test_DiSCoUrsE", redisStore.Prefix("DiSCoUrsE")) +} + +func TestRedisGetSetNXCookieSecret(t *testing.T) { + secret, err := redisStore.GetSetNXCookieSecret() + assert.NoError(t, err) + assert.Equal(t, 36, len(secret)) + secret2, err := redisStore.GetSetNXCookieSecret() + assert.NoError(t, err) + assert.Equal(t, secret, secret2) +} + +func TestSetupStorage(t *testing.T) { + c := NewTestConfig() + + c.RedisAddress = "127.0.0.1:6379" + setupStorage(&c) + _, ok := storageInstance.(*RedisStore) + assert.Equal(t, true, ok) + + c.RedisPassword = "somesecretpa$$word" + setupStorage(&c) + _, ok = storageInstance.(*RedisStore) + assert.Equal(t, true, ok) + + c.RedisPassword = "" + c.RedisAddress = "" + setupStorage(&c) + _, ok = storageInstance.(*MemoryStore) + assert.Equal(t, true, ok) +} + +func TestGetSetNXCookieSecretIfRedis(t *testing.T) { + c := NewTestConfig() + c.CookieSecret = "secret1" + + c.RedisAddress = "127.0.0.1:6379" + setupStorage(&c) + GetSetNXCookieSecretIfRedis(&c) + assert.NotEqual(t, "secret1", c.CookieSecret) + assert.Equal(t, 36, len(c.CookieSecret)) + secret2 := c.CookieSecret + + c = NewTestConfig() + c.CookieSecret = "secret3" + setupStorage(&c) + GetSetNXCookieSecretIfRedis(&c) + assert.Equal(t, "secret3", c.CookieSecret) + + c = NewTestConfig() + c.CookieSecret = "secret4" + c.RedisAddress = "127.0.0.1:6379" + setupStorage(&c) + GetSetNXCookieSecretIfRedis(&c) + assert.Equal(t, secret2, c.CookieSecret) +} + +func TestMain(m *testing.M) { + for _, s := range stores { + err := s.Clear() + if err != nil { + fmt.Printf("%s", err) + os.Exit(1) + } + } + code := m.Run() + os.Exit(code) +} diff --git a/stores.go b/stores.go new file mode 100644 index 0000000..fb7e93d --- /dev/null +++ b/stores.go @@ -0,0 +1,118 @@ +package main + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/golang/groupcache/lru" + "github.com/go-redis/redis/v8" + "github.com/pborman/uuid" +) + +var bgCtx = context.Background() + +type CacheStore interface { + AddNonce(nonce string, val string) error + GetAndDeleteNonce(nonce string) (string, error) + Clear() error // used in tests only +} + +type MemoryStore struct { + Mutex *sync.Mutex + Cache *lru.Cache +} + +func (store *MemoryStore) AddNonce(nonce string, val string) error { + store.Mutex.Lock() + store.Cache.Add(nonce, val) + store.Mutex.Unlock() + return nil +} + +func (store *MemoryStore) GetAndDeleteNonce(nonce string) (val string, err error) { + store.Mutex.Lock() + _val, ok := store.Cache.Get(nonce) + store.Mutex.Unlock() + if !ok { + err = fmt.Errorf("[%T] nonce not found: %s", store, nonce) + return "", err + } + val = _val.(string) + store.Mutex.Lock() + store.Cache.Remove(nonce) + store.Mutex.Unlock() + return val, nil +} + +// used in tests only +func (store *MemoryStore) Clear() error { + store.Mutex.Lock() + store.Cache.Clear() + store.Mutex.Unlock() + return nil +} + +type RedisStore struct { + Redis *redis.Client + Namespace string +} + +func (store *RedisStore) AddNonce(nonce string, val string) error { + err := store.Redis.SetEX(bgCtx, store.Prefix(nonce), val, 600 * time.Second).Err() + if err != nil { + return err + } + return nil +} + +func (store *RedisStore) GetAndDeleteNonce(nonce string) (val string, err error) { + prefixedKey := store.Prefix(nonce) + val, err = store.Redis.Get(bgCtx, prefixedKey).Result() + if err != nil { + return "", fmt.Errorf("[%T] nonce not found: %s", store, nonce) + } + err = store.Redis.Del(bgCtx, prefixedKey).Err() + if err != nil { + return "", err + } + return val, nil +} + +// used in tests only +func (store *RedisStore) Clear() error { + keys, err := store.Redis.Keys(bgCtx, store.Namespace + "*").Result() + if err != nil { + return err + } + for _, key := range keys { + err := store.Redis.Del(bgCtx, key).Err() + if err != nil { + return err + } + } + return nil +} + +func (store *RedisStore) Prefix(in string) string { + return store.Namespace + in +} + +func (store *RedisStore) GetSetNXCookieSecret() (string, error) { + prefixedKey := store.Prefix("cookie-secret-uuid") + secret := uuid.New() + ourSecretSet, err := store.Redis.SetNX(bgCtx, prefixedKey, secret, 0).Result() + if err != nil { + return "", err + } + if ourSecretSet { + return secret, nil + } + + secret, err = store.Redis.Get(bgCtx, prefixedKey).Result() + if err != nil { + return "", err + } + return secret, nil +}