267 lines
6.9 KiB
Go
267 lines
6.9 KiB
Go
package jose
|
|
|
|
import (
|
|
"crypto"
|
|
"crypto/ecdsa"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/sha256"
|
|
"crypto/sha512"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"hash"
|
|
"math/big"
|
|
"strings"
|
|
)
|
|
|
|
// JWS
|
|
|
|
type JwsHeader struct {
|
|
Algorithm JoseAlgorithm `json:"alg,omitempty"`
|
|
Nonce string `json:"nonce,omitempty"`
|
|
Key JsonWebKey `json:"jwk,omitempty"`
|
|
}
|
|
|
|
// rawJsonWebSignature and JsonWebSignature are the same.
|
|
// We just use rawJsonWebSignature for the basic parse,
|
|
// and JsonWebSignature for the full parse
|
|
type rawJsonWebSignature struct {
|
|
signed bool
|
|
Header JwsHeader `json:"header,omitempty"`
|
|
Protected JsonBuffer `json:"protected,omitempty"`
|
|
Payload JsonBuffer `json:"payload,omitempty"`
|
|
Signature JsonBuffer `json:"signature,omitempty"`
|
|
}
|
|
|
|
type JsonWebSignature rawJsonWebSignature
|
|
|
|
// No need for special MarshalJSON handling; it's OK for
|
|
// elements to remain in the unprotected header, since they'll
|
|
// just be overwritten.
|
|
// func (jwk JsonWebKey) MarshalJSON() ([]byte, error) {}
|
|
|
|
// On unmarshal, copy protected header fields to protected
|
|
func (jws *JsonWebSignature) UnmarshalJSON(data []byte) error {
|
|
var raw rawJsonWebSignature
|
|
err := json.Unmarshal(data, &raw)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Copy over simple fields
|
|
jws.Header = raw.Header
|
|
jws.Protected = raw.Protected
|
|
jws.Payload = raw.Payload
|
|
jws.Signature = raw.Signature
|
|
|
|
if len(jws.Protected) > 0 {
|
|
// This overwrites fields in jwk.Header if there is a conflict
|
|
err = json.Unmarshal(jws.Protected, &jws.Header)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Check that required fields are present
|
|
if len(jws.Signature) == 0 || len(jws.Payload) == 0 {
|
|
return errors.New("JWS missing required fields")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (jws JsonWebSignature) MarshalCompact() ([]byte, error) {
|
|
if !jws.signed {
|
|
return []byte{}, errors.New("Cannot marshal unsigned JWS")
|
|
}
|
|
|
|
return []byte(B64enc(jws.Protected) + "." + B64enc(jws.Payload) + "." + B64enc(jws.Signature)), nil
|
|
}
|
|
|
|
func UnmarshalCompact(data []byte) (JsonWebSignature, error) {
|
|
jws := JsonWebSignature{}
|
|
parts := strings.Split(string(data), ".")
|
|
if len(parts) != 3 {
|
|
return jws, errors.New("Mal-formed compact JWS")
|
|
}
|
|
|
|
// Decode simple fields
|
|
var err error
|
|
jws.Protected, err = B64dec(parts[0])
|
|
if err != nil {
|
|
return jws, err
|
|
}
|
|
jws.Payload, err = B64dec(parts[1])
|
|
if err != nil {
|
|
return jws, err
|
|
}
|
|
jws.Signature, err = B64dec(parts[2])
|
|
if err != nil {
|
|
return jws, err
|
|
}
|
|
|
|
// Populate header from protected
|
|
err = json.Unmarshal(jws.Protected, &jws.Header)
|
|
if err != nil {
|
|
return jws, err
|
|
}
|
|
|
|
jws.signed = true
|
|
return jws, nil
|
|
}
|
|
|
|
func prepareInput(jws JsonWebSignature) (crypto.Hash, []byte, error) {
|
|
input := []byte(B64enc(jws.Protected) + "." + B64enc(jws.Payload))
|
|
zeroh := crypto.Hash(0)
|
|
zerob := []byte{}
|
|
|
|
// TODO: Check for valid algorithm
|
|
|
|
// Hash the payload
|
|
hashAlg := string(jws.Header.Algorithm[2:])
|
|
var hashID crypto.Hash
|
|
var hash hash.Hash
|
|
switch hashAlg {
|
|
case "256":
|
|
hashID = crypto.SHA256
|
|
hash = sha256.New()
|
|
case "384":
|
|
hashID = crypto.SHA384
|
|
hash = sha512.New384()
|
|
case "512":
|
|
hashID = crypto.SHA512
|
|
hash = sha512.New()
|
|
default:
|
|
return zeroh, zerob, errors.New("Invalid hash length " + hashAlg)
|
|
}
|
|
hash.Write(input)
|
|
inputHash := hash.Sum(nil)
|
|
|
|
return hashID, inputHash, nil
|
|
}
|
|
|
|
func Sign(alg JoseAlgorithm, privateKey interface{}, payload []byte) (JsonWebSignature, error) {
|
|
zero := JsonWebSignature{}
|
|
|
|
// Create a working JWS
|
|
jws := JsonWebSignature{Payload: payload}
|
|
jws.Header.Algorithm = alg
|
|
|
|
// Cast the private key to the appropriate type, and
|
|
// add the corresponding public key to the header
|
|
var rsaPriv *rsa.PrivateKey
|
|
var ecPriv *ecdsa.PrivateKey
|
|
switch privateKey := privateKey.(type) {
|
|
case rsa.PrivateKey:
|
|
rsaPriv = &privateKey
|
|
jws.Header.Key = JsonWebKey{KeyType: KeyTypeRSA, Rsa: &rsaPriv.PublicKey}
|
|
case ecdsa.PrivateKey:
|
|
ecPriv = &privateKey
|
|
jws.Header.Key = JsonWebKey{KeyType: KeyTypeEC, Ec: &ecPriv.PublicKey}
|
|
default:
|
|
return zero, errors.New(fmt.Sprintf("Unsupported key type for %+v\n", privateKey))
|
|
}
|
|
|
|
// Base64-encode the header -> protected
|
|
// NOTE: This implies that unprotected headers are not supported
|
|
protected, err := json.Marshal(jws.Header)
|
|
if err != nil {
|
|
return zero, err
|
|
}
|
|
jws.Protected = protected
|
|
|
|
// Compute the signature input
|
|
hashID, inputHash, err := prepareInput(jws)
|
|
if err != nil {
|
|
return zero, err
|
|
}
|
|
|
|
// Sign
|
|
// TODO: Check that key type is compatible
|
|
var sig []byte
|
|
switch jws.Header.Algorithm[:1] {
|
|
case "R":
|
|
if rsaPriv == nil {
|
|
return zero, errors.New(fmt.Sprintf("Algorithm %s requres RSA private key", jws.Header.Algorithm))
|
|
}
|
|
sig, err = rsa.SignPKCS1v15(rand.Reader, rsaPriv, hashID, inputHash)
|
|
case "P":
|
|
if rsaPriv == nil {
|
|
return zero, errors.New(fmt.Sprintf("Algorithm %s requres RSA private key", jws.Header.Algorithm))
|
|
}
|
|
// Contrary to docs, you can't pass a nil instead of the PSSOptions; You'll
|
|
// get a nil dereference.
|
|
sig, err = rsa.SignPSS(rand.Reader, rsaPriv, hashID, inputHash, &rsa.PSSOptions{})
|
|
case "E":
|
|
if ecPriv == nil {
|
|
return zero, errors.New(fmt.Sprintf("Algorithm %s requres EC private key", jws.Header.Algorithm))
|
|
}
|
|
r, s, err := ecdsa.Sign(rand.Reader, ecPriv, inputHash)
|
|
if err == nil {
|
|
sig = concatRS(r, s)
|
|
}
|
|
default:
|
|
return zero, errors.New("Invalid signature algorithm " + string(jws.Header.Algorithm[:1]))
|
|
}
|
|
|
|
if err != nil {
|
|
return zero, err
|
|
}
|
|
jws.Signature = sig
|
|
jws.signed = true
|
|
|
|
return jws, nil
|
|
}
|
|
|
|
func concatRS(r, s *big.Int) []byte {
|
|
rb, sb := r.Bytes(), s.Bytes()
|
|
|
|
if padSize := len(rb) - len(sb); padSize > 0 {
|
|
sb = append(make([]byte, padSize), sb...)
|
|
} else if padSize < 0 {
|
|
rb = append(make([]byte, -padSize), rb...)
|
|
}
|
|
|
|
return append(rb, sb...)
|
|
}
|
|
|
|
func (jws *JsonWebSignature) Verify() error {
|
|
hashID, inputHash, err := prepareInput(*jws)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
sig := jws.Signature
|
|
|
|
// Check the signature, branching from the first character in the alg value
|
|
// For example: "RS256" => "R" => PKCS1v15
|
|
switch jws.Header.Algorithm[:1] {
|
|
case "R":
|
|
if jws.Header.Key.Rsa == nil {
|
|
return errors.New(fmt.Sprintf("Algorithm %s requires RSA key", jws.Header.Algorithm))
|
|
}
|
|
return rsa.VerifyPKCS1v15(jws.Header.Key.Rsa, hashID, inputHash, sig)
|
|
case "P":
|
|
if jws.Header.Key.Rsa == nil {
|
|
return errors.New(fmt.Sprintf("Algorithm %s requires RSA key", jws.Header.Algorithm))
|
|
}
|
|
return rsa.VerifyPSS(jws.Header.Key.Rsa, hashID, inputHash, sig, nil)
|
|
case "E":
|
|
if jws.Header.Key.Ec == nil {
|
|
return errors.New(fmt.Sprintf("Algorithm %s requires EC key", jws.Header.Algorithm))
|
|
}
|
|
intlen := len(sig) / 2
|
|
rBytes, sBytes := sig[:intlen], sig[intlen:]
|
|
r, s := big.NewInt(0), big.NewInt(0)
|
|
r.SetBytes(rBytes)
|
|
s.SetBytes(sBytes)
|
|
if ecdsa.Verify(jws.Header.Key.Ec, inputHash, r, s) {
|
|
return nil
|
|
} else {
|
|
return errors.New("ECDSA signature validation failed")
|
|
}
|
|
default:
|
|
return errors.New("Invalid signature algorithm " + string(jws.Header.Algorithm[:1]))
|
|
}
|
|
}
|