FIX: Properly query-escape URLs, several other cleanups (#7)
Use Form.Get() instead of [0]. Move important constants to constants. Document some functions. Eliminate a single-use one-line closure. Avoid bare return when reasonable.
This commit is contained in:
parent
f81d3bb030
commit
1cb59fc2ce
47
main.go
47
main.go
|
@ -33,6 +33,11 @@ var (
|
||||||
nonceMutex = &sync.Mutex{}
|
nonceMutex = &sync.Mutex{}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
cookieName = "__discourse_proxy"
|
||||||
|
reauthorizeInterval = 365 * 24 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
{
|
{
|
||||||
var err error
|
var err error
|
||||||
|
@ -139,15 +144,12 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr
|
||||||
writeHttpError := func(code int) {
|
writeHttpError := func(code int) {
|
||||||
http.Error(w, http.StatusText(code), code)
|
http.Error(w, http.StatusText(code), code)
|
||||||
}
|
}
|
||||||
writeClientError := func() {
|
|
||||||
writeHttpError(http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
fail := func(format string, v ...interface{}) {
|
fail := func(format string, v ...interface{}) {
|
||||||
logger.Printf(format, v...)
|
logger.Printf(format, v...)
|
||||||
writeClientError()
|
writeHttpError(http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
cookie, err := r.Cookie("__discourse_proxy")
|
cookie, err := r.Cookie(cookieName)
|
||||||
var username, groups string
|
var username, groups string
|
||||||
|
|
||||||
if err == nil && cookie != nil {
|
if err == nil && cookie != nil {
|
||||||
|
@ -182,11 +184,10 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
username = parsedQuery["username"]
|
username = parsedQuery.Get("username")
|
||||||
groups = parsedQuery["groups"]
|
admin = parsedQuery.Get("admin")
|
||||||
admin = parsedQuery["admin"]
|
nonce = parsedQuery.Get("nonce")
|
||||||
nonce = parsedQuery["nonce"]
|
groupsArray = strings.Split(parsedQuery.Get("groups"), ",")
|
||||||
groupsArray = strings.Split(groups[0], ",")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(nonce) == 0 {
|
if len(nonce) == 0 {
|
||||||
|
@ -201,23 +202,28 @@ func redirectIfNoCookie(handler http.Handler, r *http.Request, w http.ResponseWr
|
||||||
fail("incomplete payload from sso provider: missing admin")
|
fail("incomplete payload from sso provider: missing admin")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !(config.AllowAll || admin[0] == "true") {
|
if !(config.AllowAll || admin == "true") {
|
||||||
writeHttpError(http.StatusForbidden)
|
writeHttpError(http.StatusForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
returnUrl, err := getReturnUrl(config.SSOSecret, sso, sig, nonce[0])
|
returnUrl, err := getReturnUrl(config.SSOSecret, sso, sig, nonce)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fail("failed to build return URL: %s", err)
|
fail("failed to build return URL: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// we have a valid auth
|
// we have a valid auth
|
||||||
expiration := time.Now().Add(365 * 24 * time.Hour)
|
expiration := time.Now().Add(reauthorizeInterval)
|
||||||
|
|
||||||
cookieData := strings.Join([]string{username[0], strings.Join(groupsArray, "|")}, ",")
|
cookieData := strings.Join([]string{username, strings.Join(groupsArray, "|")}, ",")
|
||||||
cookie := http.Cookie{Name: "__discourse_proxy", Value: signCookie(cookieData, config.CookieSecret), Expires: expiration, HttpOnly: true, Path: "/"}
|
http.SetCookie(w, &http.Cookie{
|
||||||
http.SetCookie(w, &cookie)
|
Name: cookieName,
|
||||||
|
Value: signCookie(cookieData, config.CookieSecret),
|
||||||
|
Expires: expiration,
|
||||||
|
HttpOnly: true,
|
||||||
|
Path: "/",
|
||||||
|
})
|
||||||
|
|
||||||
// works around weird safari stuff
|
// works around weird safari stuff
|
||||||
fmt.Fprintf(w, "<html><head></head><body><script>window.location = '%v'</script></body>", returnUrl)
|
fmt.Fprintf(w, "<html><head></head><body><script>window.location = '%v'</script></body>", returnUrl)
|
||||||
|
@ -241,7 +247,7 @@ func getReturnUrl(secret string, payload string, sig string, nonce string) (retu
|
||||||
if computeHMAC(payload, secret) != sig {
|
if computeHMAC(payload, secret) != sig {
|
||||||
err = errors.New("signature is invalid")
|
err = errors.New("signature is invalid")
|
||||||
}
|
}
|
||||||
return
|
return returnUrl, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func sameHost(handler http.Handler) http.Handler {
|
func sameHost(handler http.Handler) http.Handler {
|
||||||
|
@ -283,13 +289,17 @@ func parseCookie(data, secret string) (username string, groups string, err error
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sso_payload takes the SSO secret and the two redirection URLs, stores the
|
||||||
|
// returnUrl in the nonce cache, and returns a partial URL querystring.
|
||||||
func sso_payload(secret string, return_sso_url string, returnUrl string) string {
|
func sso_payload(secret string, return_sso_url string, returnUrl string) string {
|
||||||
result := "return_sso_url=" + return_sso_url + returnUrl + "&nonce=" + addNonce(returnUrl)
|
result := "return_sso_url=" + url.QueryEscape(return_sso_url) + url.QueryEscape(returnUrl) + "&nonce=" + url.QueryEscape(addNonce(returnUrl))
|
||||||
payload := base64.StdEncoding.EncodeToString([]byte(result))
|
payload := base64.StdEncoding.EncodeToString([]byte(result))
|
||||||
|
|
||||||
|
// payload, computeHMAC already query-safe
|
||||||
return "sso=" + payload + "&sig=" + computeHMAC(payload, secret)
|
return "sso=" + payload + "&sig=" + computeHMAC(payload, secret)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// addNonce takes a return URL and returns a nonce associated to that URL.
|
||||||
func addNonce(returnUrl string) string {
|
func addNonce(returnUrl string) string {
|
||||||
guid := uuid.New()
|
guid := uuid.New()
|
||||||
nonceMutex.Lock()
|
nonceMutex.Lock()
|
||||||
|
@ -298,6 +308,7 @@ func addNonce(returnUrl string) string {
|
||||||
return guid
|
return guid
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// computeHMAC implements the Discourse SSO protocol, returning a hex string.
|
||||||
func computeHMAC(message string, secret string) string {
|
func computeHMAC(message string, secret string) string {
|
||||||
key := []byte(secret)
|
key := []byte(secret)
|
||||||
h := hmac.New(sha256.New, key)
|
h := hmac.New(sha256.New, key)
|
||||||
|
|
Loading…
Reference in New Issue