pkcs11helper: add a Session abstraction (#4989)

This commit is contained in:
Jacob Hoffman-Andrews 2020-07-29 12:38:45 -07:00 committed by GitHub
parent 09c060f3de
commit 0834ca4a19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 199 additions and 154 deletions

View File

@ -317,9 +317,7 @@ func (fr *failReader) Read([]byte) (int, error) {
// PKCS#11 ECDSA signature format and the RFC 5480 one which is required // PKCS#11 ECDSA signature format and the RFC 5480 one which is required
// for X.509 certificates // for X.509 certificates
type x509Signer struct { type x509Signer struct {
ctx pkcs11helpers.PKCtx session *pkcs11helpers.Session
session pkcs11.SessionHandle
objectHandle pkcs11.ObjectHandle objectHandle pkcs11.ObjectHandle
keyType pkcs11helpers.KeyType keyType pkcs11helpers.KeyType
@ -330,7 +328,7 @@ type x509Signer struct {
// is converted from the PKCS#11 format to the RFC 5480 format. For RSA keys a // is converted from the PKCS#11 format to the RFC 5480 format. For RSA keys a
// conversion step is not needed. // conversion step is not needed.
func (p *x509Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { func (p *x509Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
signature, err := pkcs11helpers.Sign(p.ctx, p.session, p.objectHandle, p.keyType, digest, opts.HashFunc()) signature, err := p.session.Sign(p.objectHandle, p.keyType, digest, opts.HashFunc())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -359,10 +357,10 @@ func (p *x509Signer) Public() crypto.PublicKey {
// having the actual public key object in order to retrieve the private key // having the actual public key object in order to retrieve the private key
// handle. This is because we already have the key pair object ID, and as such // handle. This is because we already have the key pair object ID, and as such
// do not need to query the HSM to retrieve it. // do not need to query the HSM to retrieve it.
func newSigner(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, label string, id []byte) (crypto.Signer, error) { func newSigner(session *pkcs11helpers.Session, label string, id []byte) (crypto.Signer, error) {
// Retrieve the private key handle that will later be used for the certificate // Retrieve the private key handle that will later be used for the certificate
// signing operation // signing operation
privateHandle, err := pkcs11helpers.FindObject(ctx, session, []*pkcs11.Attribute{ privateHandle, err := session.FindObject([]*pkcs11.Attribute{
pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_PRIVATE_KEY), pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_PRIVATE_KEY),
pkcs11.NewAttribute(pkcs11.CKA_LABEL, label), pkcs11.NewAttribute(pkcs11.CKA_LABEL, label),
pkcs11.NewAttribute(pkcs11.CKA_ID, id), pkcs11.NewAttribute(pkcs11.CKA_ID, id),
@ -370,7 +368,7 @@ func newSigner(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, label stri
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to retrieve private key handle: %s", err) return nil, fmt.Errorf("failed to retrieve private key handle: %s", err)
} }
attrs, err := ctx.GetAttributeValue(session, privateHandle, []*pkcs11.Attribute{ attrs, err := session.GetAttributeValue(privateHandle, []*pkcs11.Attribute{
pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, nil)}, pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, nil)},
) )
if err != nil { if err != nil {
@ -382,7 +380,7 @@ func newSigner(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, label stri
// Retrieve the public key handle with the same CKA_ID as the private key // Retrieve the public key handle with the same CKA_ID as the private key
// and construct a {rsa,ecdsa}.PublicKey for use in x509.CreateCertificate // and construct a {rsa,ecdsa}.PublicKey for use in x509.CreateCertificate
pubHandle, err := pkcs11helpers.FindObject(ctx, session, []*pkcs11.Attribute{ pubHandle, err := session.FindObject([]*pkcs11.Attribute{
pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_PUBLIC_KEY), pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_PUBLIC_KEY),
pkcs11.NewAttribute(pkcs11.CKA_LABEL, label), pkcs11.NewAttribute(pkcs11.CKA_LABEL, label),
pkcs11.NewAttribute(pkcs11.CKA_ID, id), pkcs11.NewAttribute(pkcs11.CKA_ID, id),
@ -397,14 +395,14 @@ func newSigner(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, label stri
// 0x00000000, CKK_RSA // 0x00000000, CKK_RSA
case bytes.Equal(attrs[0].Value, []byte{0, 0, 0, 0, 0, 0, 0, 0}): case bytes.Equal(attrs[0].Value, []byte{0, 0, 0, 0, 0, 0, 0, 0}):
keyType = pkcs11helpers.RSAKey keyType = pkcs11helpers.RSAKey
pub, err = pkcs11helpers.GetRSAPublicKey(ctx, session, pubHandle) pub, err = session.GetRSAPublicKey(pubHandle)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to retrieve public key: %s", err) return nil, fmt.Errorf("failed to retrieve public key: %s", err)
} }
// 0x00000003, CKK_ECDSA // 0x00000003, CKK_ECDSA
case bytes.Equal(attrs[0].Value, []byte{3, 0, 0, 0, 0, 0, 0, 0}): case bytes.Equal(attrs[0].Value, []byte{3, 0, 0, 0, 0, 0, 0, 0}):
keyType = pkcs11helpers.ECDSAKey keyType = pkcs11helpers.ECDSAKey
pub, err = pkcs11helpers.GetECDSAPublicKey(ctx, session, pubHandle) pub, err = session.GetECDSAPublicKey(pubHandle)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to retrieve public key: %s", err) return nil, fmt.Errorf("failed to retrieve public key: %s", err)
} }
@ -413,7 +411,6 @@ func newSigner(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, label stri
} }
return &x509Signer{ return &x509Signer{
ctx: ctx,
session: session, session: session,
objectHandle: privateHandle, objectHandle: privateHandle,
keyType: keyType, keyType: keyType,

View File

@ -20,7 +20,7 @@ import (
) )
func TestX509Signer(t *testing.T) { func TestX509Signer(t *testing.T) {
ctx := pkcs11helpers.MockCtx{} s, ctx := pkcs11helpers.NewSessionWithMock()
// test that x509Signer.Sign properly converts the PKCS#11 format signature to // test that x509Signer.Sign properly converts the PKCS#11 format signature to
// the RFC 5480 format signature // the RFC 5480 format signature
@ -51,7 +51,7 @@ func TestX509Signer(t *testing.T) {
return append(rBytes, sBytes...), nil return append(rBytes, sBytes...), nil
} }
digest := sha256.Sum256([]byte("hello")) digest := sha256.Sum256([]byte("hello"))
signer := &x509Signer{ctx: ctx, keyType: pkcs11helpers.ECDSAKey, pub: tk.Public()} signer := &x509Signer{session: s, keyType: pkcs11helpers.ECDSAKey, pub: tk.Public()}
signature, err := signer.Sign(nil, digest[:], crypto.SHA256) signature, err := signer.Sign(nil, digest[:], crypto.SHA256)
test.AssertNotError(t, err, "x509Signer.Sign failed") test.AssertNotError(t, err, "x509Signer.Sign failed")
@ -78,9 +78,9 @@ func TestParseOID(t *testing.T) {
} }
func TestMakeTemplate(t *testing.T) { func TestMakeTemplate(t *testing.T) {
ctx := pkcs11helpers.MockCtx{} s, ctx := pkcs11helpers.NewSessionWithMock()
profile := &certProfile{} profile := &certProfile{}
randReader := newRandReader(&ctx, 0) randReader := newRandReader(s)
pubKey, err := hex.DecodeString("3059301306072a8648ce3d020106082a8648ce3d03010703420004b06745ef0375c9c54057098f077964e18d3bed0aacd54545b16eab8c539b5768cc1cea93ba56af1e22a7a01c33048c8885ed17c9c55ede70649b707072689f5e") pubKey, err := hex.DecodeString("3059301306072a8648ce3d020106082a8648ce3d03010703420004b06745ef0375c9c54057098f077964e18d3bed0aacd54545b16eab8c539b5768cc1cea93ba56af1e22a7a01c33048c8885ed17c9c55ede70649b707072689f5e")
test.AssertNotError(t, err, "failed to decode test public key") test.AssertNotError(t, err, "failed to decode test public key")
@ -157,14 +157,13 @@ func TestMakeTemplate(t *testing.T) {
} }
func TestMakeTemplateOCSP(t *testing.T) { func TestMakeTemplateOCSP(t *testing.T) {
ctx := pkcs11helpers.MockCtx{ s, ctx := pkcs11helpers.NewSessionWithMock()
GenerateRandomFunc: func(_ pkcs11.SessionHandle, length int) ([]byte, error) { ctx.GenerateRandomFunc = func(_ pkcs11.SessionHandle, length int) ([]byte, error) {
r := make([]byte, length) r := make([]byte, length)
_, err := rand.Read(r) _, err := rand.Read(r)
return r, err return r, err
},
} }
randReader := newRandReader(&ctx, 0) randReader := newRandReader(s)
profile := &certProfile{ profile := &certProfile{
SignatureAlgorithm: "SHA256WithRSA", SignatureAlgorithm: "SHA256WithRSA",
CommonName: "common name", CommonName: "common name",
@ -206,14 +205,13 @@ func TestMakeTemplateOCSP(t *testing.T) {
} }
func TestMakeTemplateCRL(t *testing.T) { func TestMakeTemplateCRL(t *testing.T) {
ctx := pkcs11helpers.MockCtx{ s, ctx := pkcs11helpers.NewSessionWithMock()
GenerateRandomFunc: func(_ pkcs11.SessionHandle, length int) ([]byte, error) { ctx.GenerateRandomFunc = func(_ pkcs11.SessionHandle, length int) ([]byte, error) {
r := make([]byte, length) r := make([]byte, length)
_, err := rand.Read(r) _, err := rand.Read(r)
return r, err return r, err
},
} }
randReader := newRandReader(&ctx, 0) randReader := newRandReader(s)
profile := &certProfile{ profile := &certProfile{
SignatureAlgorithm: "SHA256WithRSA", SignatureAlgorithm: "SHA256WithRSA",
CommonName: "common name", CommonName: "common name",
@ -462,13 +460,13 @@ func TestVerifyProfile(t *testing.T) {
} }
func TestGetKey(t *testing.T) { func TestGetKey(t *testing.T) {
ctx := pkcs11helpers.MockCtx{} s, ctx := pkcs11helpers.NewSessionWithMock()
// test newSigner fails when pkcs11helpers.FindObject for private key handle fails // test newSigner fails when pkcs11helpers.FindObject for private key handle fails
ctx.FindObjectsInitFunc = func(pkcs11.SessionHandle, []*pkcs11.Attribute) error { ctx.FindObjectsInitFunc = func(pkcs11.SessionHandle, []*pkcs11.Attribute) error {
return errors.New("broken") return errors.New("broken")
} }
_, err := newSigner(ctx, 0, "label", []byte{255, 255}) _, err := newSigner(s, "label", []byte{255, 255})
test.AssertError(t, err, "newSigner didn't fail when pkcs11helpers.FindObject for private key handle failed") test.AssertError(t, err, "newSigner didn't fail when pkcs11helpers.FindObject for private key handle failed")
// test newSigner fails when GetAttributeValue fails // test newSigner fails when GetAttributeValue fails
@ -484,14 +482,14 @@ func TestGetKey(t *testing.T) {
ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) { ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) {
return nil, errors.New("broken") return nil, errors.New("broken")
} }
_, err = newSigner(ctx, 0, "label", []byte{255, 255}) _, err = newSigner(s, "label", []byte{255, 255})
test.AssertError(t, err, "newSigner didn't fail when GetAttributeValue for private key type failed") test.AssertError(t, err, "newSigner didn't fail when GetAttributeValue for private key type failed")
// test newSigner fails when GetAttributeValue returns no attributes // test newSigner fails when GetAttributeValue returns no attributes
ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) { ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) {
return nil, nil return nil, nil
} }
_, err = newSigner(ctx, 0, "label", []byte{255, 255}) _, err = newSigner(s, "label", []byte{255, 255})
test.AssertError(t, err, "newSigner didn't fail when GetAttributeValue for private key type returned no attributes") test.AssertError(t, err, "newSigner didn't fail when GetAttributeValue for private key type returned no attributes")
// test newSigner fails when pkcs11helpers.FindObject for public key handle fails // test newSigner fails when pkcs11helpers.FindObject for public key handle fails
@ -504,7 +502,7 @@ func TestGetKey(t *testing.T) {
} }
return nil return nil
} }
_, err = newSigner(ctx, 0, "label", []byte{255, 255}) _, err = newSigner(s, "label", []byte{255, 255})
test.AssertError(t, err, "newSigner didn't fail when pkcs11helpers.FindObject for public key handle failed") test.AssertError(t, err, "newSigner didn't fail when pkcs11helpers.FindObject for public key handle failed")
// test newSigner fails when pkcs11helpers.FindObject for private key returns unknown CKA_KEY_TYPE // test newSigner fails when pkcs11helpers.FindObject for private key returns unknown CKA_KEY_TYPE
@ -514,21 +512,21 @@ func TestGetKey(t *testing.T) {
ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) { ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) {
return []*pkcs11.Attribute{pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, []byte{2, 0, 0, 0, 0, 0, 0, 0})}, nil return []*pkcs11.Attribute{pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, []byte{2, 0, 0, 0, 0, 0, 0, 0})}, nil
} }
_, err = newSigner(ctx, 0, "label", []byte{255, 255}) _, err = newSigner(s, "label", []byte{255, 255})
test.AssertError(t, err, "newSigner didn't fail when GetAttributeValue for private key returned unknown key type") test.AssertError(t, err, "newSigner didn't fail when GetAttributeValue for private key returned unknown key type")
// test newSigner fails when GetRSAPublicKey fails // test newSigner fails when GetRSAPublicKey fails
ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) { ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) {
return []*pkcs11.Attribute{pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, []byte{0, 0, 0, 0, 0, 0, 0, 0})}, nil return []*pkcs11.Attribute{pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, []byte{0, 0, 0, 0, 0, 0, 0, 0})}, nil
} }
_, err = newSigner(ctx, 0, "label", []byte{255, 255}) _, err = newSigner(s, "label", []byte{255, 255})
test.AssertError(t, err, "newSigner didn't fail when GetRSAPublicKey fails") test.AssertError(t, err, "newSigner didn't fail when GetRSAPublicKey fails")
// test newSigner fails when GetECDSAPublicKey fails // test newSigner fails when GetECDSAPublicKey fails
ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) { ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) {
return []*pkcs11.Attribute{pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, []byte{3, 0, 0, 0, 0, 0, 0, 0})}, nil return []*pkcs11.Attribute{pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, []byte{3, 0, 0, 0, 0, 0, 0, 0})}, nil
} }
_, err = newSigner(ctx, 0, "label", []byte{255, 255}) _, err = newSigner(s, "label", []byte{255, 255})
test.AssertError(t, err, "newSigner didn't fail when GetECDSAPublicKey fails") test.AssertError(t, err, "newSigner didn't fail when GetECDSAPublicKey fails")
// test newSigner works when everything... works // test newSigner works when everything... works
@ -548,6 +546,6 @@ func TestGetKey(t *testing.T) {
} }
return returns, nil return returns, nil
} }
_, err = newSigner(ctx, 0, "label", []byte{255, 255}) _, err = newSigner(s, "label", []byte{255, 255})
test.AssertNotError(t, err, "newSigner failed when everything worked properly") test.AssertNotError(t, err, "newSigner failed when everything worked properly")
} }

View File

@ -76,12 +76,11 @@ func ecArgs(label string, curve elliptic.Curve, keyID []byte) generateArgs {
// handle, and constructs an ecdsa.PublicKey. It also checks that the key is of // handle, and constructs an ecdsa.PublicKey. It also checks that the key is of
// the correct curve type. // the correct curve type.
func ecPub( func ecPub(
ctx pkcs11helpers.PKCtx, session *pkcs11helpers.Session,
session pkcs11.SessionHandle,
object pkcs11.ObjectHandle, object pkcs11.ObjectHandle,
expectedCurve elliptic.Curve, expectedCurve elliptic.Curve,
) (*ecdsa.PublicKey, error) { ) (*ecdsa.PublicKey, error) {
pubKey, err := pkcs11helpers.GetECDSAPublicKey(ctx, session, object) pubKey, err := session.GetECDSAPublicKey(object)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -97,9 +96,9 @@ func ecPub(
// private key on the device, specified by the provided object handle, by signing // private key on the device, specified by the provided object handle, by signing
// a nonce generated on the device and verifying the returned signature using the // a nonce generated on the device and verifying the returned signature using the
// public key. // public key.
func ecVerify(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, object pkcs11.ObjectHandle, pub *ecdsa.PublicKey) error { func ecVerify(session *pkcs11helpers.Session, object pkcs11.ObjectHandle, pub *ecdsa.PublicKey) error {
nonce := make([]byte, 4) nonce := make([]byte, 4)
_, err := newRandReader(ctx, session).Read(nonce) _, err := newRandReader(session).Read(nonce)
if err != nil { if err != nil {
return fmt.Errorf("failed to construct nonce: %s", err) return fmt.Errorf("failed to construct nonce: %s", err)
} }
@ -108,7 +107,7 @@ func ecVerify(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, object pkcs
hashFunc.Write(nonce) hashFunc.Write(nonce)
digest := hashFunc.Sum(nil) digest := hashFunc.Sum(nil)
log.Printf("\tMessage %s hash: %X\n", hashToString[curveToHash[pub.Curve]], digest) log.Printf("\tMessage %s hash: %X\n", hashToString[curveToHash[pub.Curve]], digest)
signature, err := pkcs11helpers.Sign(ctx, session, object, pkcs11helpers.ECDSAKey, digest, curveToHash[pub.Curve]) signature, err := session.Sign(object, pkcs11helpers.ECDSAKey, digest, curveToHash[pub.Curve])
if err != nil { if err != nil {
return err return err
} }
@ -126,31 +125,31 @@ func ecVerify(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, object pkcs
// specified by curveStr and with the provided label. It returns the public // specified by curveStr and with the provided label. It returns the public
// part of the generated key pair as a ecdsa.PublicKey and the random key ID // part of the generated key pair as a ecdsa.PublicKey and the random key ID
// that the HSM uses to identify the key pair. // that the HSM uses to identify the key pair.
func ecGenerate(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, label, curveStr string) (*ecdsa.PublicKey, []byte, error) { func ecGenerate(session *pkcs11helpers.Session, label, curveStr string) (*ecdsa.PublicKey, []byte, error) {
curve, present := stringToCurve[curveStr] curve, present := stringToCurve[curveStr]
if !present { if !present {
return nil, nil, fmt.Errorf("curve %q not supported", curveStr) return nil, nil, fmt.Errorf("curve %q not supported", curveStr)
} }
keyID := make([]byte, 4) keyID := make([]byte, 4)
_, err := newRandReader(ctx, session).Read(keyID) _, err := newRandReader(session).Read(keyID)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
log.Printf("Generating ECDSA key with curve %s and ID %x\n", curveStr, keyID) log.Printf("Generating ECDSA key with curve %s and ID %x\n", curveStr, keyID)
args := ecArgs(label, curve, keyID) args := ecArgs(label, curve, keyID)
pub, priv, err := ctx.GenerateKeyPair(session, args.mechanism, args.publicAttrs, args.privateAttrs) pub, priv, err := session.GenerateKeyPair(args.mechanism, args.publicAttrs, args.privateAttrs)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
log.Println("Key generated") log.Println("Key generated")
log.Println("Extracting public key") log.Println("Extracting public key")
pk, err := ecPub(ctx, session, pub, curve) pk, err := ecPub(session, pub, curve)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
log.Println("Extracted public key") log.Println("Extracted public key")
log.Println("Verifying public key") log.Println("Verifying public key")
err = ecVerify(ctx, session, priv, pk) err = ecVerify(session, priv, pk)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View File

@ -13,13 +13,13 @@ import (
) )
func TestECPub(t *testing.T) { func TestECPub(t *testing.T) {
ctx := pkcs11helpers.MockCtx{} s, ctx := pkcs11helpers.NewSessionWithMock()
// test we fail when pkcs11helpers.GetECDSAPublicKey fails // test we fail when pkcs11helpers.GetECDSAPublicKey fails
ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) { ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) {
return nil, errors.New("bad!") return nil, errors.New("bad!")
} }
_, err := ecPub(ctx, 0, 0, elliptic.P256()) _, err := ecPub(s, 0, elliptic.P256())
test.AssertError(t, err, "ecPub didn't fail with non-matching curve") test.AssertError(t, err, "ecPub didn't fail with non-matching curve")
test.AssertEquals(t, err.Error(), "Failed to retrieve key attributes: bad!") test.AssertEquals(t, err.Error(), "Failed to retrieve key attributes: bad!")
@ -30,18 +30,17 @@ func TestECPub(t *testing.T) {
pkcs11.NewAttribute(pkcs11.CKA_EC_POINT, []byte{4, 217, 225, 246, 210, 153, 134, 246, 104, 95, 79, 122, 206, 135, 241, 37, 114, 199, 87, 56, 167, 83, 56, 136, 174, 6, 145, 97, 239, 221, 49, 67, 148, 13, 126, 65, 90, 208, 195, 193, 171, 105, 40, 98, 132, 124, 30, 189, 215, 197, 178, 226, 166, 238, 240, 57, 215}), pkcs11.NewAttribute(pkcs11.CKA_EC_POINT, []byte{4, 217, 225, 246, 210, 153, 134, 246, 104, 95, 79, 122, 206, 135, 241, 37, 114, 199, 87, 56, 167, 83, 56, 136, 174, 6, 145, 97, 239, 221, 49, 67, 148, 13, 126, 65, 90, 208, 195, 193, 171, 105, 40, 98, 132, 124, 30, 189, 215, 197, 178, 226, 166, 238, 240, 57, 215}),
}, nil }, nil
} }
_, err = ecPub(ctx, 0, 0, elliptic.P256()) _, err = ecPub(s, 0, elliptic.P256())
test.AssertError(t, err, "ecPub didn't fail with non-matching curve") test.AssertError(t, err, "ecPub didn't fail with non-matching curve")
} }
func TestECVerify(t *testing.T) { func TestECVerify(t *testing.T) {
ctx := pkcs11helpers.MockCtx{} s, ctx := pkcs11helpers.NewSessionWithMock()
// test GenerateRandom failing // test GenerateRandom failing
ctx.GenerateRandomFunc = func(pkcs11.SessionHandle, int) ([]byte, error) { ctx.GenerateRandomFunc = func(pkcs11.SessionHandle, int) ([]byte, error) {
return nil, errors.New("yup") return nil, errors.New("yup")
} }
err := ecVerify(ctx, 0, 0, nil) err := ecVerify(s, 0, nil)
test.AssertError(t, err, "ecVerify didn't fail on GenerateRandom error") test.AssertError(t, err, "ecVerify didn't fail on GenerateRandom error")
// test SignInit failing // test SignInit failing
@ -51,7 +50,7 @@ func TestECVerify(t *testing.T) {
ctx.SignInitFunc = func(pkcs11.SessionHandle, []*pkcs11.Mechanism, pkcs11.ObjectHandle) error { ctx.SignInitFunc = func(pkcs11.SessionHandle, []*pkcs11.Mechanism, pkcs11.ObjectHandle) error {
return errors.New("yup") return errors.New("yup")
} }
err = ecVerify(ctx, 0, 0, &ecdsa.PublicKey{Curve: elliptic.P256()}) err = ecVerify(s, 0, &ecdsa.PublicKey{Curve: elliptic.P256()})
test.AssertError(t, err, "ecVerify didn't fail on SignInit error") test.AssertError(t, err, "ecVerify didn't fail on SignInit error")
// test Sign failing // test Sign failing
@ -61,7 +60,7 @@ func TestECVerify(t *testing.T) {
ctx.SignFunc = func(pkcs11.SessionHandle, []byte) ([]byte, error) { ctx.SignFunc = func(pkcs11.SessionHandle, []byte) ([]byte, error) {
return nil, errors.New("yup") return nil, errors.New("yup")
} }
err = ecVerify(ctx, 0, 0, &ecdsa.PublicKey{Curve: elliptic.P256()}) err = ecVerify(s, 0, &ecdsa.PublicKey{Curve: elliptic.P256()})
test.AssertError(t, err, "ecVerify didn't fail on Sign error") test.AssertError(t, err, "ecVerify didn't fail on Sign error")
// test signature verification failing // test signature verification failing
@ -70,19 +69,20 @@ func TestECVerify(t *testing.T) {
} }
tk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) tk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
test.AssertNotError(t, err, "ecdsa.GenerateKey failed") test.AssertNotError(t, err, "ecdsa.GenerateKey failed")
err = ecVerify(ctx, 0, 0, &tk.PublicKey) err = ecVerify(s, 0, &tk.PublicKey)
test.AssertError(t, err, "ecVerify didn't fail on signature verification error") test.AssertError(t, err, "ecVerify didn't fail on signature verification error")
// test we don't fail with valid signature // test we don't fail with valid signature
ctx.SignFunc = func(_ pkcs11.SessionHandle, msg []byte) ([]byte, error) { ctx.SignFunc = func(_ pkcs11.SessionHandle, msg []byte) ([]byte, error) {
return ecPKCS11Sign(tk, msg) return ecPKCS11Sign(tk, msg)
} }
err = ecVerify(ctx, 0, 0, &tk.PublicKey) err = ecVerify(s, 0, &tk.PublicKey)
test.AssertNotError(t, err, "ecVerify failed with a valid signature") test.AssertNotError(t, err, "ecVerify failed with a valid signature")
} }
func TestECGenerate(t *testing.T) { func TestECGenerate(t *testing.T) {
ctx := pkcs11helpers.MockCtx{} ctx := pkcs11helpers.MockCtx{}
s := &pkcs11helpers.Session{Module: &ctx, Session: 0}
ctx.GenerateRandomFunc = func(pkcs11.SessionHandle, int) ([]byte, error) { ctx.GenerateRandomFunc = func(pkcs11.SessionHandle, int) ([]byte, error) {
return []byte{1, 2, 3}, nil return []byte{1, 2, 3}, nil
} }
@ -90,14 +90,14 @@ func TestECGenerate(t *testing.T) {
test.AssertNotError(t, err, "Failed to generate a ECDSA test key") test.AssertNotError(t, err, "Failed to generate a ECDSA test key")
// Test ecGenerate fails with unknown curve // Test ecGenerate fails with unknown curve
_, _, err = ecGenerate(ctx, 0, "", "bad-curve") _, _, err = ecGenerate(s, "", "bad-curve")
test.AssertError(t, err, "ecGenerate accepted unknown curve") test.AssertError(t, err, "ecGenerate accepted unknown curve")
// Test ecGenerate fails when GenerateKeyPair fails // Test ecGenerate fails when GenerateKeyPair fails
ctx.GenerateKeyPairFunc = func(pkcs11.SessionHandle, []*pkcs11.Mechanism, []*pkcs11.Attribute, []*pkcs11.Attribute) (pkcs11.ObjectHandle, pkcs11.ObjectHandle, error) { ctx.GenerateKeyPairFunc = func(pkcs11.SessionHandle, []*pkcs11.Mechanism, []*pkcs11.Attribute, []*pkcs11.Attribute) (pkcs11.ObjectHandle, pkcs11.ObjectHandle, error) {
return 0, 0, errors.New("bad") return 0, 0, errors.New("bad")
} }
_, _, err = ecGenerate(ctx, 0, "", "P-256") _, _, err = ecGenerate(s, "", "P-256")
test.AssertError(t, err, "ecGenerate didn't fail on GenerateKeyPair error") test.AssertError(t, err, "ecGenerate didn't fail on GenerateKeyPair error")
// Test ecGenerate fails when ecPub fails // Test ecGenerate fails when ecPub fails
@ -107,7 +107,7 @@ func TestECGenerate(t *testing.T) {
ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) { ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) {
return nil, errors.New("bad") return nil, errors.New("bad")
} }
_, _, err = ecGenerate(ctx, 0, "", "P-256") _, _, err = ecGenerate(s, "", "P-256")
test.AssertError(t, err, "ecGenerate didn't fail on ecPub error") test.AssertError(t, err, "ecGenerate didn't fail on ecPub error")
// Test ecGenerate fails when ecVerify fails // Test ecGenerate fails when ecVerify fails
@ -120,7 +120,7 @@ func TestECGenerate(t *testing.T) {
ctx.GenerateRandomFunc = func(pkcs11.SessionHandle, int) ([]byte, error) { ctx.GenerateRandomFunc = func(pkcs11.SessionHandle, int) ([]byte, error) {
return nil, errors.New("yup") return nil, errors.New("yup")
} }
_, _, err = ecGenerate(ctx, 0, "", "P-256") _, _, err = ecGenerate(s, "", "P-256")
test.AssertError(t, err, "ecGenerate didn't fail on ecVerify error") test.AssertError(t, err, "ecGenerate didn't fail on ecVerify error")
// Test ecGenerate doesn't fail when everything works // Test ecGenerate doesn't fail when everything works
@ -133,7 +133,7 @@ func TestECGenerate(t *testing.T) {
ctx.SignFunc = func(_ pkcs11.SessionHandle, msg []byte) ([]byte, error) { ctx.SignFunc = func(_ pkcs11.SessionHandle, msg []byte) ([]byte, error) {
return ecPKCS11Sign(priv, msg) return ecPKCS11Sign(priv, msg)
} }
_, _, err = ecGenerate(ctx, 0, "", "P-256") _, _, err = ecGenerate(s, "", "P-256")
test.AssertNotError(t, err, "ecGenerate didn't succeed when everything worked as expected") test.AssertNotError(t, err, "ecGenerate didn't succeed when everything worked as expected")
} }

View File

@ -13,19 +13,15 @@ import (
) )
type hsmRandReader struct { type hsmRandReader struct {
ctx pkcs11helpers.PKCtx *pkcs11helpers.Session
session pkcs11.SessionHandle
} }
func newRandReader(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle) *hsmRandReader { func newRandReader(session *pkcs11helpers.Session) *hsmRandReader {
return &hsmRandReader{ return &hsmRandReader{session}
ctx: ctx,
session: session,
}
} }
func (hrr hsmRandReader) Read(p []byte) (n int, err error) { func (hrr hsmRandReader) Read(p []byte) (n int, err error) {
r, err := hrr.ctx.GenerateRandom(hrr.session, len(p)) r, err := hrr.Module.GenerateRandom(hrr.Session.Session, len(p))
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -53,8 +49,8 @@ type keyInfo struct {
id []byte id []byte
} }
func generateKey(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, label string, outputPath string, config keyGenConfig) (*keyInfo, error) { func generateKey(session *pkcs11helpers.Session, label string, outputPath string, config keyGenConfig) (*keyInfo, error) {
_, err := pkcs11helpers.FindObject(ctx, session, []*pkcs11.Attribute{}) _, err := session.FindObject([]*pkcs11.Attribute{})
if err != pkcs11helpers.ErrNoObject { if err != pkcs11helpers.ErrNoObject {
return nil, fmt.Errorf("expected no objects in slot for key storage. got error: %s", err) return nil, fmt.Errorf("expected no objects in slot for key storage. got error: %s", err)
} }
@ -63,12 +59,12 @@ func generateKey(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, label st
var keyID []byte var keyID []byte
switch config.Type { switch config.Type {
case "rsa": case "rsa":
pubKey, keyID, err = rsaGenerate(ctx, session, label, config.RSAModLength, rsaExp) pubKey, keyID, err = rsaGenerate(session, label, config.RSAModLength, rsaExp)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to generate RSA key pair: %s", err) return nil, fmt.Errorf("failed to generate RSA key pair: %s", err)
} }
case "ecdsa": case "ecdsa":
pubKey, keyID, err = ecGenerate(ctx, session, label, config.ECDSACurve) pubKey, keyID, err = ecGenerate(session, label, config.ECDSACurve)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to generate ECDSA key pair: %s", err) return nil, fmt.Errorf("failed to generate ECDSA key pair: %s", err)
} }

View File

@ -61,8 +61,9 @@ func TestGenerateKeyRSA(t *testing.T) {
// Chop of the hash identifier and feed back into rsa.SignPKCS1v15 // Chop of the hash identifier and feed back into rsa.SignPKCS1v15
return rsa.SignPKCS1v15(rand.Reader, rsaPriv, crypto.SHA256, msg[19:]) return rsa.SignPKCS1v15(rand.Reader, rsaPriv, crypto.SHA256, msg[19:])
} }
s := &pkcs11helpers.Session{Module: &ctx, Session: 0}
keyPath := path.Join(tmp, "test-rsa-key.pem") keyPath := path.Join(tmp, "test-rsa-key.pem")
keyInfo, err := generateKey(ctx, 0, "", keyPath, keyGenConfig{ keyInfo, err := generateKey(s, "", keyPath, keyGenConfig{
Type: "rsa", Type: "rsa",
RSAModLength: 1024, RSAModLength: 1024,
}) })
@ -93,7 +94,8 @@ func TestGenerateKeyEC(t *testing.T) {
return ecPKCS11Sign(ecPriv, msg) return ecPKCS11Sign(ecPriv, msg)
} }
keyPath := path.Join(tmp, "test-ecdsa-key.pem") keyPath := path.Join(tmp, "test-ecdsa-key.pem")
keyInfo, err := generateKey(ctx, 0, "", keyPath, keyGenConfig{ s := &pkcs11helpers.Session{Module: &ctx, Session: 0}
keyInfo, err := generateKey(s, "", keyPath, keyGenConfig{
Type: "ecdsa", Type: "ecdsa",
ECDSACurve: "P-256", ECDSACurve: "P-256",
}) })
@ -116,7 +118,8 @@ func TestGenerateKeySlotHasSomething(t *testing.T) {
return []pkcs11.ObjectHandle{1}, false, nil return []pkcs11.ObjectHandle{1}, false, nil
} }
keyPath := path.Join(tmp, "should-not-exist.pem") keyPath := path.Join(tmp, "should-not-exist.pem")
_, err = generateKey(ctx, 0, "", keyPath, keyGenConfig{ s := &pkcs11helpers.Session{Module: &ctx, Session: 0}
_, err = generateKey(s, "", keyPath, keyGenConfig{
Type: "ecdsa", Type: "ecdsa",
ECDSACurve: "P-256", ECDSACurve: "P-256",
}) })

View File

@ -342,7 +342,7 @@ func equalPubKeys(a, b interface{}) bool {
} }
func openSigner(cfg PKCS11SigningConfig, issuer *x509.Certificate) (crypto.Signer, *hsmRandReader, error) { func openSigner(cfg PKCS11SigningConfig, issuer *x509.Certificate) (crypto.Signer, *hsmRandReader, error) {
ctx, session, err := pkcs11helpers.Initialize(cfg.Module, cfg.SigningSlot, cfg.PIN) session, err := pkcs11helpers.Initialize(cfg.Module, cfg.SigningSlot, cfg.PIN)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("failed to setup session and PKCS#11 context for slot %d: %s", return nil, nil, fmt.Errorf("failed to setup session and PKCS#11 context for slot %d: %s",
cfg.SigningSlot, err) cfg.SigningSlot, err)
@ -352,7 +352,7 @@ func openSigner(cfg PKCS11SigningConfig, issuer *x509.Certificate) (crypto.Signe
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("failed to decode key-id: %s", err) return nil, nil, fmt.Errorf("failed to decode key-id: %s", err)
} }
signer, err := newSigner(ctx, session, cfg.SigningLabel, keyID) signer, err := newSigner(session, cfg.SigningLabel, keyID)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("failed to retrieve private key handle: %s", err) return nil, nil, fmt.Errorf("failed to retrieve private key handle: %s", err)
} }
@ -360,7 +360,7 @@ func openSigner(cfg PKCS11SigningConfig, issuer *x509.Certificate) (crypto.Signe
return nil, nil, fmt.Errorf("signer pubkey did not match issuer pubkey") return nil, nil, fmt.Errorf("signer pubkey did not match issuer pubkey")
} }
log.Println("Retrieved private key handle") log.Println("Retrieved private key handle")
return signer, newRandReader(ctx, session), nil return signer, newRandReader(session), nil
} }
func signAndWriteCert(tbs, issuer *x509.Certificate, subjectPubKey crypto.PublicKey, signer crypto.Signer, certPath string) error { func signAndWriteCert(tbs, issuer *x509.Certificate, subjectPubKey crypto.PublicKey, signer crypto.Signer, certPath string) error {
@ -405,20 +405,20 @@ func rootCeremony(configBytes []byte) error {
if err := config.validate(); err != nil { if err := config.validate(); err != nil {
return fmt.Errorf("failed to validate config: %s", err) return fmt.Errorf("failed to validate config: %s", err)
} }
ctx, session, err := pkcs11helpers.Initialize(config.PKCS11.Module, config.PKCS11.StoreSlot, config.PKCS11.PIN) session, err := pkcs11helpers.Initialize(config.PKCS11.Module, config.PKCS11.StoreSlot, config.PKCS11.PIN)
if err != nil { if err != nil {
return fmt.Errorf("failed to setup session and PKCS#11 context for slot %d: %s", config.PKCS11.StoreSlot, err) return fmt.Errorf("failed to setup session and PKCS#11 context for slot %d: %s", config.PKCS11.StoreSlot, err)
} }
log.Printf("Opened PKCS#11 session for slot %d\n", config.PKCS11.StoreSlot) log.Printf("Opened PKCS#11 session for slot %d\n", config.PKCS11.StoreSlot)
keyInfo, err := generateKey(ctx, session, config.PKCS11.StoreLabel, config.Outputs.PublicKeyPath, config.Key) keyInfo, err := generateKey(session, config.PKCS11.StoreLabel, config.Outputs.PublicKeyPath, config.Key)
if err != nil { if err != nil {
return err return err
} }
signer, err := newSigner(ctx, session, config.PKCS11.StoreLabel, keyInfo.id) signer, err := newSigner(session, config.PKCS11.StoreLabel, keyInfo.id)
if err != nil { if err != nil {
return fmt.Errorf("failed to retrieve signer: %s", err) return fmt.Errorf("failed to retrieve signer: %s", err)
} }
template, err := makeTemplate(newRandReader(ctx, session), &config.CertProfile, keyInfo.der, rootCert) template, err := makeTemplate(newRandReader(session), &config.CertProfile, keyInfo.der, rootCert)
if err != nil { if err != nil {
return fmt.Errorf("failed to create certificate profile: %s", err) return fmt.Errorf("failed to create certificate profile: %s", err)
} }
@ -486,12 +486,12 @@ func keyCeremony(configBytes []byte) error {
if err := config.validate(); err != nil { if err := config.validate(); err != nil {
return fmt.Errorf("failed to validate config: %s", err) return fmt.Errorf("failed to validate config: %s", err)
} }
ctx, session, err := pkcs11helpers.Initialize(config.PKCS11.Module, config.PKCS11.StoreSlot, config.PKCS11.PIN) session, err := pkcs11helpers.Initialize(config.PKCS11.Module, config.PKCS11.StoreSlot, config.PKCS11.PIN)
if err != nil { if err != nil {
return fmt.Errorf("failed to setup session and PKCS#11 context for slot %d: %s", config.PKCS11.StoreSlot, err) return fmt.Errorf("failed to setup session and PKCS#11 context for slot %d: %s", config.PKCS11.StoreSlot, err)
} }
log.Printf("Opened PKCS#11 session for slot %d\n", config.PKCS11.StoreSlot) log.Printf("Opened PKCS#11 session for slot %d\n", config.PKCS11.StoreSlot)
if _, err = generateKey(ctx, session, config.PKCS11.StoreLabel, config.Outputs.PublicKeyPath, config.Key); err != nil { if _, err = generateKey(session, config.PKCS11.StoreLabel, config.Outputs.PublicKeyPath, config.Key); err != nil {
return err return err
} }

View File

@ -54,8 +54,8 @@ func rsaArgs(label string, modulusLen, exponent uint, keyID []byte) generateArgs
// handle, and constructs a rsa.PublicKey. It also checks that the key has the // handle, and constructs a rsa.PublicKey. It also checks that the key has the
// correct length modulus and that the public exponent is what was requested in // correct length modulus and that the public exponent is what was requested in
// the public key template. // the public key template.
func rsaPub(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, object pkcs11.ObjectHandle, modulusLen, exponent uint) (*rsa.PublicKey, error) { func rsaPub(session *pkcs11helpers.Session, object pkcs11.ObjectHandle, modulusLen, exponent uint) (*rsa.PublicKey, error) {
pubKey, err := pkcs11helpers.GetRSAPublicKey(ctx, session, object) pubKey, err := session.GetRSAPublicKey(object)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -74,16 +74,16 @@ func rsaPub(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, object pkcs11
// private key on the device, specified by the provided object handle, by signing // private key on the device, specified by the provided object handle, by signing
// a nonce generated on the device and verifying the returned signature using the // a nonce generated on the device and verifying the returned signature using the
// public key. // public key.
func rsaVerify(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, object pkcs11.ObjectHandle, pub *rsa.PublicKey) error { func rsaVerify(session *pkcs11helpers.Session, object pkcs11.ObjectHandle, pub *rsa.PublicKey) error {
nonce := make([]byte, 4) nonce := make([]byte, 4)
_, err := newRandReader(ctx, session).Read(nonce) _, err := newRandReader(session).Read(nonce)
if err != nil { if err != nil {
return fmt.Errorf("Failed to retrieve nonce: %s", err) return fmt.Errorf("Failed to retrieve nonce: %s", err)
} }
log.Printf("\tConstructed nonce: %d (%X)\n", big.NewInt(0).SetBytes(nonce), nonce) log.Printf("\tConstructed nonce: %d (%X)\n", big.NewInt(0).SetBytes(nonce), nonce)
digest := sha256.Sum256(nonce) digest := sha256.Sum256(nonce)
log.Printf("\tMessage SHA-256 hash: %X\n", digest) log.Printf("\tMessage SHA-256 hash: %X\n", digest)
signature, err := pkcs11helpers.Sign(ctx, session, object, pkcs11helpers.RSAKey, digest[:], crypto.SHA256) signature, err := session.Sign(object, pkcs11helpers.RSAKey, digest[:], crypto.SHA256)
if err != nil { if err != nil {
return fmt.Errorf("Failed to sign data: %s", err) return fmt.Errorf("Failed to sign data: %s", err)
} }
@ -100,27 +100,27 @@ func rsaVerify(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, object pkc
// specified by modulusLen and with the exponent specified by pubExponent. // specified by modulusLen and with the exponent specified by pubExponent.
// It returns the public part of the generated key pair as a rsa.PublicKey // It returns the public part of the generated key pair as a rsa.PublicKey
// and the random key ID that the HSM uses to identify the key pair. // and the random key ID that the HSM uses to identify the key pair.
func rsaGenerate(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, label string, modulusLen, pubExponent uint) (*rsa.PublicKey, []byte, error) { func rsaGenerate(session *pkcs11helpers.Session, label string, modulusLen, pubExponent uint) (*rsa.PublicKey, []byte, error) {
keyID := make([]byte, 4) keyID := make([]byte, 4)
_, err := newRandReader(ctx, session).Read(keyID) _, err := newRandReader(session).Read(keyID)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
log.Printf("Generating RSA key with %d bit modulus and public exponent %d and ID %x\n", modulusLen, pubExponent, keyID) log.Printf("Generating RSA key with %d bit modulus and public exponent %d and ID %x\n", modulusLen, pubExponent, keyID)
args := rsaArgs(label, modulusLen, pubExponent, keyID) args := rsaArgs(label, modulusLen, pubExponent, keyID)
pub, priv, err := ctx.GenerateKeyPair(session, args.mechanism, args.publicAttrs, args.privateAttrs) pub, priv, err := session.GenerateKeyPair(args.mechanism, args.publicAttrs, args.privateAttrs)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
log.Println("Key generated") log.Println("Key generated")
log.Println("Extracting public key") log.Println("Extracting public key")
pk, err := rsaPub(ctx, session, pub, modulusLen, pubExponent) pk, err := rsaPub(session, pub, modulusLen, pubExponent)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
log.Println("Extracted public key") log.Println("Extracted public key")
log.Println("Verifying public key") log.Println("Verifying public key")
err = rsaVerify(ctx, session, priv, pk) err = rsaVerify(session, priv, pk)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View File

@ -14,7 +14,7 @@ import (
) )
func TestRSAPub(t *testing.T) { func TestRSAPub(t *testing.T) {
ctx := pkcs11helpers.MockCtx{} s, ctx := pkcs11helpers.NewSessionWithMock()
// test we fail to construct key with non-matching exp // test we fail to construct key with non-matching exp
ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) { ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) {
@ -23,7 +23,7 @@ func TestRSAPub(t *testing.T) {
pkcs11.NewAttribute(pkcs11.CKA_MODULUS, []byte{255}), pkcs11.NewAttribute(pkcs11.CKA_MODULUS, []byte{255}),
}, nil }, nil
} }
_, err := rsaPub(ctx, 0, 0, 0, 255) _, err := rsaPub(s, 0, 0, 255)
test.AssertError(t, err, "rsaPub didn't fail with non-matching exp") test.AssertError(t, err, "rsaPub didn't fail with non-matching exp")
// test we fail to construct key with non-matching modulus // test we fail to construct key with non-matching modulus
@ -33,7 +33,7 @@ func TestRSAPub(t *testing.T) {
pkcs11.NewAttribute(pkcs11.CKA_MODULUS, []byte{255}), pkcs11.NewAttribute(pkcs11.CKA_MODULUS, []byte{255}),
}, nil }, nil
} }
_, err = rsaPub(ctx, 0, 0, 16, 65537) _, err = rsaPub(s, 0, 16, 65537)
test.AssertError(t, err, "rsaPub didn't fail with non-matching modulus size") test.AssertError(t, err, "rsaPub didn't fail with non-matching modulus size")
// test we don't fail with the correct attributes // test we don't fail with the correct attributes
@ -43,18 +43,18 @@ func TestRSAPub(t *testing.T) {
pkcs11.NewAttribute(pkcs11.CKA_MODULUS, []byte{255}), pkcs11.NewAttribute(pkcs11.CKA_MODULUS, []byte{255}),
}, nil }, nil
} }
_, err = rsaPub(ctx, 0, 0, 8, 65537) _, err = rsaPub(s, 0, 8, 65537)
test.AssertNotError(t, err, "rsaPub failed with valid attributes") test.AssertNotError(t, err, "rsaPub failed with valid attributes")
} }
func TestRSAVerify(t *testing.T) { func TestRSAVerify(t *testing.T) {
ctx := pkcs11helpers.MockCtx{} s, ctx := pkcs11helpers.NewSessionWithMock()
// test GenerateRandom failing // test GenerateRandom failing
ctx.GenerateRandomFunc = func(pkcs11.SessionHandle, int) ([]byte, error) { ctx.GenerateRandomFunc = func(pkcs11.SessionHandle, int) ([]byte, error) {
return nil, errors.New("yup") return nil, errors.New("yup")
} }
err := rsaVerify(ctx, 0, 0, nil) err := rsaVerify(s, 0, nil)
test.AssertError(t, err, "rsaVerify didn't fail on GenerateRandom error") test.AssertError(t, err, "rsaVerify didn't fail on GenerateRandom error")
// test SignInit failing // test SignInit failing
@ -64,7 +64,7 @@ func TestRSAVerify(t *testing.T) {
ctx.SignInitFunc = func(pkcs11.SessionHandle, []*pkcs11.Mechanism, pkcs11.ObjectHandle) error { ctx.SignInitFunc = func(pkcs11.SessionHandle, []*pkcs11.Mechanism, pkcs11.ObjectHandle) error {
return errors.New("yup") return errors.New("yup")
} }
err = rsaVerify(ctx, 0, 0, nil) err = rsaVerify(s, 0, nil)
test.AssertError(t, err, "rsaVerify didn't fail on SignInit error") test.AssertError(t, err, "rsaVerify didn't fail on SignInit error")
// test Sign failing // test Sign failing
@ -77,7 +77,7 @@ func TestRSAVerify(t *testing.T) {
ctx.SignFunc = func(pkcs11.SessionHandle, []byte) ([]byte, error) { ctx.SignFunc = func(pkcs11.SessionHandle, []byte) ([]byte, error) {
return nil, errors.New("yup") return nil, errors.New("yup")
} }
err = rsaVerify(ctx, 0, 0, nil) err = rsaVerify(s, 0, nil)
test.AssertError(t, err, "rsaVerify didn't fail on Sign error") test.AssertError(t, err, "rsaVerify didn't fail on Sign error")
// test signature verification failing // test signature verification failing
@ -86,7 +86,7 @@ func TestRSAVerify(t *testing.T) {
} }
tk, err := rsa.GenerateKey(rand.Reader, 1024) tk, err := rsa.GenerateKey(rand.Reader, 1024)
test.AssertNotError(t, err, "rsa.GenerateKey failed") test.AssertNotError(t, err, "rsa.GenerateKey failed")
err = rsaVerify(ctx, 0, 0, &tk.PublicKey) err = rsaVerify(s, 0, &tk.PublicKey)
test.AssertError(t, err, "rsaVerify didn't fail on signature verification error") test.AssertError(t, err, "rsaVerify didn't fail on signature verification error")
// test we don't fail with valid signature // test we don't fail with valid signature
@ -94,12 +94,12 @@ func TestRSAVerify(t *testing.T) {
// Chop of the hash identifier and feed back into rsa.SignPKCS1v15 // Chop of the hash identifier and feed back into rsa.SignPKCS1v15
return rsa.SignPKCS1v15(rand.Reader, tk, crypto.SHA256, msg[19:]) return rsa.SignPKCS1v15(rand.Reader, tk, crypto.SHA256, msg[19:])
} }
err = rsaVerify(ctx, 0, 0, &tk.PublicKey) err = rsaVerify(s, 0, &tk.PublicKey)
test.AssertNotError(t, err, "rsaVerify failed with a valid signature") test.AssertNotError(t, err, "rsaVerify failed with a valid signature")
} }
func TestRSAGenerate(t *testing.T) { func TestRSAGenerate(t *testing.T) {
ctx := pkcs11helpers.MockCtx{} s, ctx := pkcs11helpers.NewSessionWithMock()
ctx.GenerateRandomFunc = func(pkcs11.SessionHandle, int) ([]byte, error) { ctx.GenerateRandomFunc = func(pkcs11.SessionHandle, int) ([]byte, error) {
return []byte{1, 2, 3}, nil return []byte{1, 2, 3}, nil
} }
@ -111,7 +111,7 @@ func TestRSAGenerate(t *testing.T) {
ctx.GenerateKeyPairFunc = func(pkcs11.SessionHandle, []*pkcs11.Mechanism, []*pkcs11.Attribute, []*pkcs11.Attribute) (pkcs11.ObjectHandle, pkcs11.ObjectHandle, error) { ctx.GenerateKeyPairFunc = func(pkcs11.SessionHandle, []*pkcs11.Mechanism, []*pkcs11.Attribute, []*pkcs11.Attribute) (pkcs11.ObjectHandle, pkcs11.ObjectHandle, error) {
return 0, 0, errors.New("bad") return 0, 0, errors.New("bad")
} }
_, _, err = rsaGenerate(ctx, 0, "", 1024, 65537) _, _, err = rsaGenerate(s, "", 1024, 65537)
test.AssertError(t, err, "rsaGenerate didn't fail on GenerateKeyPair error") test.AssertError(t, err, "rsaGenerate didn't fail on GenerateKeyPair error")
// Test rsaGenerate fails when rsaPub fails // Test rsaGenerate fails when rsaPub fails
@ -121,7 +121,7 @@ func TestRSAGenerate(t *testing.T) {
ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) { ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) {
return nil, errors.New("bad") return nil, errors.New("bad")
} }
_, _, err = rsaGenerate(ctx, 0, "", 1024, 65537) _, _, err = rsaGenerate(s, "", 1024, 65537)
test.AssertError(t, err, "rsaGenerate didn't fail on rsaPub error") test.AssertError(t, err, "rsaGenerate didn't fail on rsaPub error")
// Test rsaGenerate fails when rsaVerify fails // Test rsaGenerate fails when rsaVerify fails
@ -134,7 +134,7 @@ func TestRSAGenerate(t *testing.T) {
ctx.GenerateRandomFunc = func(pkcs11.SessionHandle, int) ([]byte, error) { ctx.GenerateRandomFunc = func(pkcs11.SessionHandle, int) ([]byte, error) {
return nil, errors.New("yup") return nil, errors.New("yup")
} }
_, _, err = rsaGenerate(ctx, 0, "", 1024, 65537) _, _, err = rsaGenerate(s, "", 1024, 65537)
test.AssertError(t, err, "rsaGenerate didn't fail on rsaVerify error") test.AssertError(t, err, "rsaGenerate didn't fail on rsaVerify error")
// Test rsaGenerate doesn't fail when everything works // Test rsaGenerate doesn't fail when everything works
@ -148,6 +148,6 @@ func TestRSAGenerate(t *testing.T) {
// Chop of the hash identifier and feed back into rsa.SignPKCS1v15 // Chop of the hash identifier and feed back into rsa.SignPKCS1v15
return rsa.SignPKCS1v15(rand.Reader, priv, crypto.SHA256, msg[19:]) return rsa.SignPKCS1v15(rand.Reader, priv, crypto.SHA256, msg[19:])
} }
_, _, err = rsaGenerate(ctx, 0, "", 1024, 65537) _, _, err = rsaGenerate(s, "", 1024, 65537)
test.AssertNotError(t, err, "rsaGenerate didn't succeed when everything worked as expected") test.AssertNotError(t, err, "rsaGenerate didn't succeed when everything worked as expected")
} }

View File

@ -24,32 +24,47 @@ type PKCtx interface {
FindObjectsFinal(sh pkcs11.SessionHandle) error FindObjectsFinal(sh pkcs11.SessionHandle) error
} }
func Initialize(module string, slot uint, pin string) (PKCtx, pkcs11.SessionHandle, error) { // Session represents a session with a given PKCS#11 module. It is not safe for
// concurrent access.
type Session struct {
Module PKCtx
Session pkcs11.SessionHandle
}
func Initialize(module string, slot uint, pin string) (*Session, error) {
ctx := pkcs11.New(module) ctx := pkcs11.New(module)
if ctx == nil { if ctx == nil {
return nil, 0, errors.New("failed to load module") return nil, errors.New("failed to load module")
} }
err := ctx.Initialize() err := ctx.Initialize()
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("couldn't initialize context: %s", err) return nil, fmt.Errorf("couldn't initialize context: %s", err)
} }
session, err := ctx.OpenSession(slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION) session, err := ctx.OpenSession(slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("couldn't open session: %s", err) return nil, fmt.Errorf("couldn't open session: %s", err)
} }
err = ctx.Login(session, pkcs11.CKU_USER, pin) err = ctx.Login(session, pkcs11.CKU_USER, pin)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("couldn't login: %s", err) return nil, fmt.Errorf("couldn't login: %s", err)
} }
return ctx, session, nil return &Session{ctx, session}, nil
} }
func GetRSAPublicKey(ctx PKCtx, session pkcs11.SessionHandle, object pkcs11.ObjectHandle) (*rsa.PublicKey, error) { func (s *Session) GetAttributeValue(object pkcs11.ObjectHandle, attributes []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) {
return s.Module.GetAttributeValue(s.Session, object, attributes)
}
func (s *Session) GenerateKeyPair(m []*pkcs11.Mechanism, pubAttrs []*pkcs11.Attribute, privAttrs []*pkcs11.Attribute) (pkcs11.ObjectHandle, pkcs11.ObjectHandle, error) {
return s.Module.GenerateKeyPair(s.Session, m, pubAttrs, privAttrs)
}
func (s *Session) GetRSAPublicKey(object pkcs11.ObjectHandle) (*rsa.PublicKey, error) {
// Retrieve the public exponent and modulus for the public key // Retrieve the public exponent and modulus for the public key
attrs, err := ctx.GetAttributeValue(session, object, []*pkcs11.Attribute{ attrs, err := s.Module.GetAttributeValue(s.Session, object, []*pkcs11.Attribute{
pkcs11.NewAttribute(pkcs11.CKA_PUBLIC_EXPONENT, nil), pkcs11.NewAttribute(pkcs11.CKA_PUBLIC_EXPONENT, nil),
pkcs11.NewAttribute(pkcs11.CKA_MODULUS, nil), pkcs11.NewAttribute(pkcs11.CKA_MODULUS, nil),
}) })
@ -86,9 +101,9 @@ var oidDERToCurve = map[string]elliptic.Curve{
"06052B81040023": elliptic.P521(), "06052B81040023": elliptic.P521(),
} }
func GetECDSAPublicKey(ctx PKCtx, session pkcs11.SessionHandle, object pkcs11.ObjectHandle) (*ecdsa.PublicKey, error) { func (s *Session) GetECDSAPublicKey(object pkcs11.ObjectHandle) (*ecdsa.PublicKey, error) {
// Retrieve the curve and public point for the generated public key // Retrieve the curve and public point for the generated public key
attrs, err := ctx.GetAttributeValue(session, object, []*pkcs11.Attribute{ attrs, err := s.Module.GetAttributeValue(s.Session, object, []*pkcs11.Attribute{
pkcs11.NewAttribute(pkcs11.CKA_EC_PARAMS, nil), pkcs11.NewAttribute(pkcs11.CKA_EC_PARAMS, nil),
pkcs11.NewAttribute(pkcs11.CKA_EC_POINT, nil), pkcs11.NewAttribute(pkcs11.CKA_EC_POINT, nil),
}) })
@ -152,7 +167,7 @@ var hashIdentifiers = map[crypto.Hash][]byte{
crypto.SHA512: {0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05, 0x00, 0x04, 0x40}, crypto.SHA512: {0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05, 0x00, 0x04, 0x40},
} }
func Sign(ctx PKCtx, session pkcs11.SessionHandle, object pkcs11.ObjectHandle, keyType KeyType, digest []byte, hash crypto.Hash) ([]byte, error) { func (s *Session) Sign(object pkcs11.ObjectHandle, keyType KeyType, digest []byte, hash crypto.Hash) ([]byte, error) {
if len(digest) != hash.Size() { if len(digest) != hash.Size() {
return nil, errors.New("digest length doesn't match hash length") return nil, errors.New("digest length doesn't match hash length")
} }
@ -170,11 +185,11 @@ func Sign(ctx PKCtx, session pkcs11.SessionHandle, object pkcs11.ObjectHandle, k
mech[0] = pkcs11.NewMechanism(pkcs11.CKM_ECDSA, nil) mech[0] = pkcs11.NewMechanism(pkcs11.CKM_ECDSA, nil)
} }
err := ctx.SignInit(session, mech, object) err := s.Module.SignInit(s.Session, mech, object)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialize signing operation: %s", err) return nil, fmt.Errorf("failed to initialize signing operation: %s", err)
} }
signature, err := ctx.Sign(session, digest) signature, err := s.Module.Sign(s.Session, digest)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to sign data: %s", err) return nil, fmt.Errorf("failed to sign data: %s", err)
} }
@ -187,15 +202,15 @@ var ErrNoObject = errors.New("no objects found matching provided template")
// FindObject looks up a PKCS#11 object handle based on the provided template. // FindObject looks up a PKCS#11 object handle based on the provided template.
// In the case where zero or more than one objects are found to match the // In the case where zero or more than one objects are found to match the
// template an error is returned. // template an error is returned.
func FindObject(ctx PKCtx, session pkcs11.SessionHandle, tmpl []*pkcs11.Attribute) (pkcs11.ObjectHandle, error) { func (s *Session) FindObject(tmpl []*pkcs11.Attribute) (pkcs11.ObjectHandle, error) {
if err := ctx.FindObjectsInit(session, tmpl); err != nil { if err := s.Module.FindObjectsInit(s.Session, tmpl); err != nil {
return 0, err return 0, err
} }
handles, _, err := ctx.FindObjects(session, 2) handles, _, err := s.Module.FindObjects(s.Session, 2)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if err := ctx.FindObjectsFinal(session); err != nil { if err := s.Module.FindObjectsFinal(s.Session); err != nil {
return 0, err return 0, err
} }
if len(handles) == 0 { if len(handles) == 0 {
@ -207,6 +222,15 @@ func FindObject(ctx PKCtx, session pkcs11.SessionHandle, tmpl []*pkcs11.Attribut
return handles[0], nil return handles[0], nil
} }
func NewMock() *MockCtx {
return &MockCtx{}
}
func NewSessionWithMock() (*Session, *MockCtx) {
ctx := NewMock()
return &Session{ctx, 0}, ctx
}
type MockCtx struct { type MockCtx struct {
GenerateKeyPairFunc func(pkcs11.SessionHandle, []*pkcs11.Mechanism, []*pkcs11.Attribute, []*pkcs11.Attribute) (pkcs11.ObjectHandle, pkcs11.ObjectHandle, error) GenerateKeyPairFunc func(pkcs11.SessionHandle, []*pkcs11.Mechanism, []*pkcs11.Attribute, []*pkcs11.Attribute) (pkcs11.ObjectHandle, pkcs11.ObjectHandle, error)
GetAttributeValueFunc func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) GetAttributeValueFunc func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error)

View File

@ -10,20 +10,21 @@ import (
) )
func TestGetECDSAPublicKey(t *testing.T) { func TestGetECDSAPublicKey(t *testing.T) {
ctx := MockCtx{} ctx := &MockCtx{}
s := &Session{ctx, 0}
// test attribute retrieval failing // test attribute retrieval failing
ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) { ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) {
return nil, errors.New("yup") return nil, errors.New("yup")
} }
_, err := GetECDSAPublicKey(ctx, 0, 0) _, err := s.GetECDSAPublicKey(0)
test.AssertError(t, err, "ecPub didn't fail on GetAttributeValue error") test.AssertError(t, err, "ecPub didn't fail on GetAttributeValue error")
// test we fail to construct key with missing params and point // test we fail to construct key with missing params and point
ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) { ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) {
return []*pkcs11.Attribute{}, nil return []*pkcs11.Attribute{}, nil
} }
_, err = GetECDSAPublicKey(ctx, 0, 0) _, err = s.GetECDSAPublicKey(0)
test.AssertError(t, err, "ecPub didn't fail with empty attribute list") test.AssertError(t, err, "ecPub didn't fail with empty attribute list")
// test we fail to construct key with unknown curve // test we fail to construct key with unknown curve
@ -32,7 +33,7 @@ func TestGetECDSAPublicKey(t *testing.T) {
pkcs11.NewAttribute(pkcs11.CKA_EC_PARAMS, []byte{1, 2, 3}), pkcs11.NewAttribute(pkcs11.CKA_EC_PARAMS, []byte{1, 2, 3}),
}, nil }, nil
} }
_, err = GetECDSAPublicKey(ctx, 0, 0) _, err = s.GetECDSAPublicKey(0)
test.AssertError(t, err, "ecPub didn't fail with unknown curve") test.AssertError(t, err, "ecPub didn't fail with unknown curve")
// test we fail to construct key with invalid EC point (invalid encoding) // test we fail to construct key with invalid EC point (invalid encoding)
@ -42,7 +43,7 @@ func TestGetECDSAPublicKey(t *testing.T) {
pkcs11.NewAttribute(pkcs11.CKA_EC_POINT, []byte{255}), pkcs11.NewAttribute(pkcs11.CKA_EC_POINT, []byte{255}),
}, nil }, nil
} }
_, err = GetECDSAPublicKey(ctx, 0, 0) _, err = s.GetECDSAPublicKey(0)
test.AssertError(t, err, "ecPub didn't fail with invalid EC point (invalid encoding)") test.AssertError(t, err, "ecPub didn't fail with invalid EC point (invalid encoding)")
// test we fail to construct key with invalid EC point (empty octet string) // test we fail to construct key with invalid EC point (empty octet string)
@ -52,7 +53,7 @@ func TestGetECDSAPublicKey(t *testing.T) {
pkcs11.NewAttribute(pkcs11.CKA_EC_POINT, []byte{4, 0}), pkcs11.NewAttribute(pkcs11.CKA_EC_POINT, []byte{4, 0}),
}, nil }, nil
} }
_, err = GetECDSAPublicKey(ctx, 0, 0) _, err = s.GetECDSAPublicKey(0)
test.AssertError(t, err, "ecPub didn't fail with invalid EC point (empty octet string)") test.AssertError(t, err, "ecPub didn't fail with invalid EC point (empty octet string)")
// test we fail to construct key with invalid EC point (octet string, invalid contents) // test we fail to construct key with invalid EC point (octet string, invalid contents)
@ -62,7 +63,7 @@ func TestGetECDSAPublicKey(t *testing.T) {
pkcs11.NewAttribute(pkcs11.CKA_EC_POINT, []byte{4, 4, 4, 1, 2, 3}), pkcs11.NewAttribute(pkcs11.CKA_EC_POINT, []byte{4, 4, 4, 1, 2, 3}),
}, nil }, nil
} }
_, err = GetECDSAPublicKey(ctx, 0, 0) _, err = s.GetECDSAPublicKey(0)
test.AssertError(t, err, "ecPub didn't fail with invalid EC point (octet string, invalid contents)") test.AssertError(t, err, "ecPub didn't fail with invalid EC point (octet string, invalid contents)")
// test we don't fail with the correct attributes (traditional encoding) // test we don't fail with the correct attributes (traditional encoding)
@ -72,7 +73,7 @@ func TestGetECDSAPublicKey(t *testing.T) {
pkcs11.NewAttribute(pkcs11.CKA_EC_POINT, []byte{4, 217, 225, 246, 210, 153, 134, 246, 104, 95, 79, 122, 206, 135, 241, 37, 114, 199, 87, 56, 167, 83, 56, 136, 174, 6, 145, 97, 239, 221, 49, 67, 148, 13, 126, 65, 90, 208, 195, 193, 171, 105, 40, 98, 132, 124, 30, 189, 215, 197, 178, 226, 166, 238, 240, 57, 215}), pkcs11.NewAttribute(pkcs11.CKA_EC_POINT, []byte{4, 217, 225, 246, 210, 153, 134, 246, 104, 95, 79, 122, 206, 135, 241, 37, 114, 199, 87, 56, 167, 83, 56, 136, 174, 6, 145, 97, 239, 221, 49, 67, 148, 13, 126, 65, 90, 208, 195, 193, 171, 105, 40, 98, 132, 124, 30, 189, 215, 197, 178, 226, 166, 238, 240, 57, 215}),
}, nil }, nil
} }
_, err = GetECDSAPublicKey(ctx, 0, 0) _, err = s.GetECDSAPublicKey(0)
test.AssertNotError(t, err, "ecPub failed with valid attributes (traditional encoding)") test.AssertNotError(t, err, "ecPub failed with valid attributes (traditional encoding)")
// test we don't fail with the correct attributes (non-traditional encoding) // test we don't fail with the correct attributes (non-traditional encoding)
@ -82,25 +83,26 @@ func TestGetECDSAPublicKey(t *testing.T) {
pkcs11.NewAttribute(pkcs11.CKA_EC_POINT, []byte{4, 57, 4, 217, 225, 246, 210, 153, 134, 246, 104, 95, 79, 122, 206, 135, 241, 37, 114, 199, 87, 56, 167, 83, 56, 136, 174, 6, 145, 97, 239, 221, 49, 67, 148, 13, 126, 65, 90, 208, 195, 193, 171, 105, 40, 98, 132, 124, 30, 189, 215, 197, 178, 226, 166, 238, 240, 57, 215}), pkcs11.NewAttribute(pkcs11.CKA_EC_POINT, []byte{4, 57, 4, 217, 225, 246, 210, 153, 134, 246, 104, 95, 79, 122, 206, 135, 241, 37, 114, 199, 87, 56, 167, 83, 56, 136, 174, 6, 145, 97, 239, 221, 49, 67, 148, 13, 126, 65, 90, 208, 195, 193, 171, 105, 40, 98, 132, 124, 30, 189, 215, 197, 178, 226, 166, 238, 240, 57, 215}),
}, nil }, nil
} }
_, err = GetECDSAPublicKey(ctx, 0, 0) _, err = s.GetECDSAPublicKey(0)
test.AssertNotError(t, err, "ecPub failed with valid attributes (non-traditional encoding)") test.AssertNotError(t, err, "ecPub failed with valid attributes (non-traditional encoding)")
} }
func TestRSAPublicKey(t *testing.T) { func TestRSAPublicKey(t *testing.T) {
ctx := MockCtx{} ctx := &MockCtx{}
s := &Session{ctx, 0}
// test attribute retrieval failing // test attribute retrieval failing
ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) { ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) {
return nil, errors.New("yup") return nil, errors.New("yup")
} }
_, err := GetRSAPublicKey(ctx, 0, 0) _, err := s.GetRSAPublicKey(0)
test.AssertError(t, err, "rsaPub didn't fail on GetAttributeValue error") test.AssertError(t, err, "rsaPub didn't fail on GetAttributeValue error")
// test we fail to construct key with missing modulus and exp // test we fail to construct key with missing modulus and exp
ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) { ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) {
return []*pkcs11.Attribute{}, nil return []*pkcs11.Attribute{}, nil
} }
_, err = GetRSAPublicKey(ctx, 0, 0) _, err = s.GetRSAPublicKey(0)
test.AssertError(t, err, "rsaPub didn't fail with empty attribute list") test.AssertError(t, err, "rsaPub didn't fail with empty attribute list")
// test we don't fail with the correct attributes // test we don't fail with the correct attributes
@ -110,27 +112,48 @@ func TestRSAPublicKey(t *testing.T) {
pkcs11.NewAttribute(pkcs11.CKA_MODULUS, []byte{255}), pkcs11.NewAttribute(pkcs11.CKA_MODULUS, []byte{255}),
}, nil }, nil
} }
_, err = GetRSAPublicKey(ctx, 0, 0) _, err = s.GetRSAPublicKey(0)
test.AssertNotError(t, err, "rsaPub failed with valid attributes") test.AssertNotError(t, err, "rsaPub failed with valid attributes")
} }
func findObjectsFinalOK(pkcs11.SessionHandle) error {
return nil
}
func findObjectsInitOK(pkcs11.SessionHandle, []*pkcs11.Attribute) error { func findObjectsInitOK(pkcs11.SessionHandle, []*pkcs11.Attribute) error {
return nil return nil
} }
func findObjectsOK(pkcs11.SessionHandle, int) ([]pkcs11.ObjectHandle, bool, error) {
return []pkcs11.ObjectHandle{1}, false, nil
}
func findObjectsFinalOK(pkcs11.SessionHandle) error {
return nil
}
func newMock() *MockCtx {
return &MockCtx{
FindObjectsInitFunc: findObjectsInitOK,
FindObjectsFunc: findObjectsOK,
FindObjectsFinalFunc: findObjectsFinalOK,
}
}
func newSessionWithMock() (*Session, *MockCtx) {
ctx := newMock()
return &Session{ctx, 0}, ctx
}
func TestFindObjectFailsOnFailedInit(t *testing.T) { func TestFindObjectFailsOnFailedInit(t *testing.T) {
ctx := MockCtx{} ctx := MockCtx{}
ctx.FindObjectsFinalFunc = findObjectsFinalOK ctx.FindObjectsFinalFunc = findObjectsFinalOK
ctx.FindObjectsFunc = func(pkcs11.SessionHandle, int) ([]pkcs11.ObjectHandle, bool, error) {
return []pkcs11.ObjectHandle{1}, false, nil
}
// test FindObject fails when FindObjectsInit fails // test FindObject fails when FindObjectsInit fails
ctx.FindObjectsInitFunc = func(pkcs11.SessionHandle, []*pkcs11.Attribute) error { ctx.FindObjectsInitFunc = func(pkcs11.SessionHandle, []*pkcs11.Attribute) error {
return errors.New("broken") return errors.New("broken")
} }
_, err := FindObject(ctx, 0, nil) s := &Session{ctx, 0}
_, err := s.FindObject(nil)
test.AssertError(t, err, "FindObject didn't fail when FindObjectsInit failed") test.AssertError(t, err, "FindObject didn't fail when FindObjectsInit failed")
} }
@ -143,7 +166,8 @@ func TestFindObjectFailsOnFailedFindObjects(t *testing.T) {
ctx.FindObjectsFunc = func(pkcs11.SessionHandle, int) ([]pkcs11.ObjectHandle, bool, error) { ctx.FindObjectsFunc = func(pkcs11.SessionHandle, int) ([]pkcs11.ObjectHandle, bool, error) {
return nil, false, errors.New("broken") return nil, false, errors.New("broken")
} }
_, err := FindObject(ctx, 0, nil) s := &Session{ctx, 0}
_, err := s.FindObject(nil)
test.AssertError(t, err, "FindObject didn't fail when FindObjects failed") test.AssertError(t, err, "FindObject didn't fail when FindObjects failed")
} }
@ -156,7 +180,8 @@ func TestFindObjectFailsOnNoHandles(t *testing.T) {
ctx.FindObjectsFunc = func(pkcs11.SessionHandle, int) ([]pkcs11.ObjectHandle, bool, error) { ctx.FindObjectsFunc = func(pkcs11.SessionHandle, int) ([]pkcs11.ObjectHandle, bool, error) {
return []pkcs11.ObjectHandle{}, false, nil return []pkcs11.ObjectHandle{}, false, nil
} }
_, err := FindObject(ctx, 0, nil) s := &Session{ctx, 0}
_, err := s.FindObject(nil)
test.AssertEquals(t, err, ErrNoObject) test.AssertEquals(t, err, ErrNoObject)
} }
@ -169,7 +194,8 @@ func TestFindObjectFailsOnMultipleHandles(t *testing.T) {
ctx.FindObjectsFunc = func(pkcs11.SessionHandle, int) ([]pkcs11.ObjectHandle, bool, error) { ctx.FindObjectsFunc = func(pkcs11.SessionHandle, int) ([]pkcs11.ObjectHandle, bool, error) {
return []pkcs11.ObjectHandle{1, 2, 3}, false, nil return []pkcs11.ObjectHandle{1, 2, 3}, false, nil
} }
_, err := FindObject(ctx, 0, nil) s := &Session{ctx, 0}
_, err := s.FindObject(nil)
test.AssertError(t, err, "FindObject didn't fail when FindObjects returns multiple handles") test.AssertError(t, err, "FindObject didn't fail when FindObjects returns multiple handles")
test.Assert(t, strings.HasPrefix(err.Error(), "too many objects"), "FindObject failed with wrong error") test.Assert(t, strings.HasPrefix(err.Error(), "too many objects"), "FindObject failed with wrong error")
} }
@ -185,20 +211,22 @@ func TestFindObjectFailsOnFinalizeFailure(t *testing.T) {
ctx.FindObjectsFinalFunc = func(pkcs11.SessionHandle) error { ctx.FindObjectsFinalFunc = func(pkcs11.SessionHandle) error {
return errors.New("broken") return errors.New("broken")
} }
_, err := FindObject(ctx, 0, nil) s := &Session{ctx, 0}
_, err := s.FindObject(nil)
test.AssertError(t, err, "FindObject didn't fail when FindObjectsFinal fails") test.AssertError(t, err, "FindObject didn't fail when FindObjectsFinal fails")
} }
func TestFindObjectSucceeds(t *testing.T) { func TestFindObjectSucceeds(t *testing.T) {
ctx := MockCtx{} ctx := MockCtx{}
ctx.FindObjectsInitFunc = findObjectsInitOK ctx.FindObjectsInitFunc = findObjectsInitOK
ctx.FindObjectsFinalFunc = findObjectsFinalOK ctx.FindObjectsFinalFunc = findObjectsFinalOK
ctx.FindObjectsFunc = func(pkcs11.SessionHandle, int) ([]pkcs11.ObjectHandle, bool, error) { ctx.FindObjectsFunc = func(pkcs11.SessionHandle, int) ([]pkcs11.ObjectHandle, bool, error) {
return []pkcs11.ObjectHandle{1}, false, nil return []pkcs11.ObjectHandle{1}, false, nil
} }
s := &Session{ctx, 0}
// test FindObject works // test FindObject works
handle, err := FindObject(ctx, 0, nil) handle, err := s.FindObject(nil)
test.AssertNotError(t, err, "FindObject failed when everything worked as expected") test.AssertNotError(t, err, "FindObject failed when everything worked as expected")
test.AssertEquals(t, handle, pkcs11.ObjectHandle(1)) test.AssertEquals(t, handle, pkcs11.ObjectHandle(1))
} }