pkcs11helper: add a Session abstraction (#4989)
This commit is contained in:
		
							parent
							
								
									09c060f3de
								
							
						
					
					
						commit
						0834ca4a19
					
				|  | @ -317,9 +317,7 @@ func (fr *failReader) Read([]byte) (int, error) { | |||
| // PKCS#11 ECDSA signature format and the RFC 5480 one which is required
 | ||||
| // for X.509 certificates
 | ||||
| type x509Signer struct { | ||||
| 	ctx pkcs11helpers.PKCtx | ||||
| 
 | ||||
| 	session      pkcs11.SessionHandle | ||||
| 	session      *pkcs11helpers.Session | ||||
| 	objectHandle pkcs11.ObjectHandle | ||||
| 	keyType      pkcs11helpers.KeyType | ||||
| 
 | ||||
|  | @ -330,7 +328,7 @@ type x509Signer struct { | |||
| // is converted from the PKCS#11 format to the RFC 5480 format. For RSA keys a
 | ||||
| // conversion step is not needed.
 | ||||
| func (p *x509Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { | ||||
| 	signature, err := pkcs11helpers.Sign(p.ctx, p.session, p.objectHandle, p.keyType, digest, opts.HashFunc()) | ||||
| 	signature, err := p.session.Sign(p.objectHandle, p.keyType, digest, opts.HashFunc()) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | @ -359,10 +357,10 @@ func (p *x509Signer) Public() crypto.PublicKey { | |||
| // having the actual public key object in order to retrieve the private key
 | ||||
| // handle. This is because we already have the key pair object ID, and as such
 | ||||
| // do not need to query the HSM to retrieve it.
 | ||||
| func newSigner(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, label string, id []byte) (crypto.Signer, error) { | ||||
| func newSigner(session *pkcs11helpers.Session, label string, id []byte) (crypto.Signer, error) { | ||||
| 	// Retrieve the private key handle that will later be used for the certificate
 | ||||
| 	// signing operation
 | ||||
| 	privateHandle, err := pkcs11helpers.FindObject(ctx, session, []*pkcs11.Attribute{ | ||||
| 	privateHandle, err := session.FindObject([]*pkcs11.Attribute{ | ||||
| 		pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_PRIVATE_KEY), | ||||
| 		pkcs11.NewAttribute(pkcs11.CKA_LABEL, label), | ||||
| 		pkcs11.NewAttribute(pkcs11.CKA_ID, id), | ||||
|  | @ -370,7 +368,7 @@ func newSigner(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, label stri | |||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to retrieve private key handle: %s", err) | ||||
| 	} | ||||
| 	attrs, err := ctx.GetAttributeValue(session, privateHandle, []*pkcs11.Attribute{ | ||||
| 	attrs, err := session.GetAttributeValue(privateHandle, []*pkcs11.Attribute{ | ||||
| 		pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, nil)}, | ||||
| 	) | ||||
| 	if err != nil { | ||||
|  | @ -382,7 +380,7 @@ func newSigner(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, label stri | |||
| 
 | ||||
| 	// Retrieve the public key handle with the same CKA_ID as the private key
 | ||||
| 	// and construct a {rsa,ecdsa}.PublicKey for use in x509.CreateCertificate
 | ||||
| 	pubHandle, err := pkcs11helpers.FindObject(ctx, session, []*pkcs11.Attribute{ | ||||
| 	pubHandle, err := session.FindObject([]*pkcs11.Attribute{ | ||||
| 		pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_PUBLIC_KEY), | ||||
| 		pkcs11.NewAttribute(pkcs11.CKA_LABEL, label), | ||||
| 		pkcs11.NewAttribute(pkcs11.CKA_ID, id), | ||||
|  | @ -397,14 +395,14 @@ func newSigner(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, label stri | |||
| 	// 0x00000000, CKK_RSA
 | ||||
| 	case bytes.Equal(attrs[0].Value, []byte{0, 0, 0, 0, 0, 0, 0, 0}): | ||||
| 		keyType = pkcs11helpers.RSAKey | ||||
| 		pub, err = pkcs11helpers.GetRSAPublicKey(ctx, session, pubHandle) | ||||
| 		pub, err = session.GetRSAPublicKey(pubHandle) | ||||
| 		if err != nil { | ||||
| 			return nil, fmt.Errorf("failed to retrieve public key: %s", err) | ||||
| 		} | ||||
| 	// 0x00000003, CKK_ECDSA
 | ||||
| 	case bytes.Equal(attrs[0].Value, []byte{3, 0, 0, 0, 0, 0, 0, 0}): | ||||
| 		keyType = pkcs11helpers.ECDSAKey | ||||
| 		pub, err = pkcs11helpers.GetECDSAPublicKey(ctx, session, pubHandle) | ||||
| 		pub, err = session.GetECDSAPublicKey(pubHandle) | ||||
| 		if err != nil { | ||||
| 			return nil, fmt.Errorf("failed to retrieve public key: %s", err) | ||||
| 		} | ||||
|  | @ -413,7 +411,6 @@ func newSigner(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, label stri | |||
| 	} | ||||
| 
 | ||||
| 	return &x509Signer{ | ||||
| 		ctx:          ctx, | ||||
| 		session:      session, | ||||
| 		objectHandle: privateHandle, | ||||
| 		keyType:      keyType, | ||||
|  |  | |||
|  | @ -20,7 +20,7 @@ import ( | |||
| ) | ||||
| 
 | ||||
| func TestX509Signer(t *testing.T) { | ||||
| 	ctx := pkcs11helpers.MockCtx{} | ||||
| 	s, ctx := pkcs11helpers.NewSessionWithMock() | ||||
| 
 | ||||
| 	// test that x509Signer.Sign properly converts the PKCS#11 format signature to
 | ||||
| 	// the RFC 5480 format signature
 | ||||
|  | @ -51,7 +51,7 @@ func TestX509Signer(t *testing.T) { | |||
| 		return append(rBytes, sBytes...), nil | ||||
| 	} | ||||
| 	digest := sha256.Sum256([]byte("hello")) | ||||
| 	signer := &x509Signer{ctx: ctx, keyType: pkcs11helpers.ECDSAKey, pub: tk.Public()} | ||||
| 	signer := &x509Signer{session: s, keyType: pkcs11helpers.ECDSAKey, pub: tk.Public()} | ||||
| 	signature, err := signer.Sign(nil, digest[:], crypto.SHA256) | ||||
| 	test.AssertNotError(t, err, "x509Signer.Sign failed") | ||||
| 
 | ||||
|  | @ -78,9 +78,9 @@ func TestParseOID(t *testing.T) { | |||
| } | ||||
| 
 | ||||
| func TestMakeTemplate(t *testing.T) { | ||||
| 	ctx := pkcs11helpers.MockCtx{} | ||||
| 	s, ctx := pkcs11helpers.NewSessionWithMock() | ||||
| 	profile := &certProfile{} | ||||
| 	randReader := newRandReader(&ctx, 0) | ||||
| 	randReader := newRandReader(s) | ||||
| 
 | ||||
| 	pubKey, err := hex.DecodeString("3059301306072a8648ce3d020106082a8648ce3d03010703420004b06745ef0375c9c54057098f077964e18d3bed0aacd54545b16eab8c539b5768cc1cea93ba56af1e22a7a01c33048c8885ed17c9c55ede70649b707072689f5e") | ||||
| 	test.AssertNotError(t, err, "failed to decode test public key") | ||||
|  | @ -157,14 +157,13 @@ func TestMakeTemplate(t *testing.T) { | |||
| } | ||||
| 
 | ||||
| func TestMakeTemplateOCSP(t *testing.T) { | ||||
| 	ctx := pkcs11helpers.MockCtx{ | ||||
| 		GenerateRandomFunc: func(_ pkcs11.SessionHandle, length int) ([]byte, error) { | ||||
| 	s, ctx := pkcs11helpers.NewSessionWithMock() | ||||
| 	ctx.GenerateRandomFunc = func(_ pkcs11.SessionHandle, length int) ([]byte, error) { | ||||
| 		r := make([]byte, length) | ||||
| 		_, err := rand.Read(r) | ||||
| 		return r, err | ||||
| 		}, | ||||
| 	} | ||||
| 	randReader := newRandReader(&ctx, 0) | ||||
| 	randReader := newRandReader(s) | ||||
| 	profile := &certProfile{ | ||||
| 		SignatureAlgorithm: "SHA256WithRSA", | ||||
| 		CommonName:         "common name", | ||||
|  | @ -206,14 +205,13 @@ func TestMakeTemplateOCSP(t *testing.T) { | |||
| } | ||||
| 
 | ||||
| func TestMakeTemplateCRL(t *testing.T) { | ||||
| 	ctx := pkcs11helpers.MockCtx{ | ||||
| 		GenerateRandomFunc: func(_ pkcs11.SessionHandle, length int) ([]byte, error) { | ||||
| 	s, ctx := pkcs11helpers.NewSessionWithMock() | ||||
| 	ctx.GenerateRandomFunc = func(_ pkcs11.SessionHandle, length int) ([]byte, error) { | ||||
| 		r := make([]byte, length) | ||||
| 		_, err := rand.Read(r) | ||||
| 		return r, err | ||||
| 		}, | ||||
| 	} | ||||
| 	randReader := newRandReader(&ctx, 0) | ||||
| 	randReader := newRandReader(s) | ||||
| 	profile := &certProfile{ | ||||
| 		SignatureAlgorithm: "SHA256WithRSA", | ||||
| 		CommonName:         "common name", | ||||
|  | @ -462,13 +460,13 @@ func TestVerifyProfile(t *testing.T) { | |||
| } | ||||
| 
 | ||||
| func TestGetKey(t *testing.T) { | ||||
| 	ctx := pkcs11helpers.MockCtx{} | ||||
| 	s, ctx := pkcs11helpers.NewSessionWithMock() | ||||
| 
 | ||||
| 	// test newSigner fails when pkcs11helpers.FindObject for private key handle fails
 | ||||
| 	ctx.FindObjectsInitFunc = func(pkcs11.SessionHandle, []*pkcs11.Attribute) error { | ||||
| 		return errors.New("broken") | ||||
| 	} | ||||
| 	_, err := newSigner(ctx, 0, "label", []byte{255, 255}) | ||||
| 	_, err := newSigner(s, "label", []byte{255, 255}) | ||||
| 	test.AssertError(t, err, "newSigner didn't fail when pkcs11helpers.FindObject for private key handle failed") | ||||
| 
 | ||||
| 	// test newSigner fails when GetAttributeValue fails
 | ||||
|  | @ -484,14 +482,14 @@ func TestGetKey(t *testing.T) { | |||
| 	ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) { | ||||
| 		return nil, errors.New("broken") | ||||
| 	} | ||||
| 	_, err = newSigner(ctx, 0, "label", []byte{255, 255}) | ||||
| 	_, err = newSigner(s, "label", []byte{255, 255}) | ||||
| 	test.AssertError(t, err, "newSigner didn't fail when GetAttributeValue for private key type failed") | ||||
| 
 | ||||
| 	// test newSigner fails when GetAttributeValue returns no attributes
 | ||||
| 	ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 	_, err = newSigner(ctx, 0, "label", []byte{255, 255}) | ||||
| 	_, err = newSigner(s, "label", []byte{255, 255}) | ||||
| 	test.AssertError(t, err, "newSigner didn't fail when GetAttributeValue for private key type returned no attributes") | ||||
| 
 | ||||
| 	// test newSigner fails when pkcs11helpers.FindObject for public key handle fails
 | ||||
|  | @ -504,7 +502,7 @@ func TestGetKey(t *testing.T) { | |||
| 		} | ||||
| 		return nil | ||||
| 	} | ||||
| 	_, err = newSigner(ctx, 0, "label", []byte{255, 255}) | ||||
| 	_, err = newSigner(s, "label", []byte{255, 255}) | ||||
| 	test.AssertError(t, err, "newSigner didn't fail when pkcs11helpers.FindObject for public key handle failed") | ||||
| 
 | ||||
| 	// test newSigner fails when pkcs11helpers.FindObject for private key returns unknown CKA_KEY_TYPE
 | ||||
|  | @ -514,21 +512,21 @@ func TestGetKey(t *testing.T) { | |||
| 	ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) { | ||||
| 		return []*pkcs11.Attribute{pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, []byte{2, 0, 0, 0, 0, 0, 0, 0})}, nil | ||||
| 	} | ||||
| 	_, err = newSigner(ctx, 0, "label", []byte{255, 255}) | ||||
| 	_, err = newSigner(s, "label", []byte{255, 255}) | ||||
| 	test.AssertError(t, err, "newSigner didn't fail when GetAttributeValue for private key returned unknown key type") | ||||
| 
 | ||||
| 	// test newSigner fails when GetRSAPublicKey fails
 | ||||
| 	ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) { | ||||
| 		return []*pkcs11.Attribute{pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, []byte{0, 0, 0, 0, 0, 0, 0, 0})}, nil | ||||
| 	} | ||||
| 	_, err = newSigner(ctx, 0, "label", []byte{255, 255}) | ||||
| 	_, err = newSigner(s, "label", []byte{255, 255}) | ||||
| 	test.AssertError(t, err, "newSigner didn't fail when GetRSAPublicKey fails") | ||||
| 
 | ||||
| 	// test newSigner fails when GetECDSAPublicKey fails
 | ||||
| 	ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) { | ||||
| 		return []*pkcs11.Attribute{pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, []byte{3, 0, 0, 0, 0, 0, 0, 0})}, nil | ||||
| 	} | ||||
| 	_, err = newSigner(ctx, 0, "label", []byte{255, 255}) | ||||
| 	_, err = newSigner(s, "label", []byte{255, 255}) | ||||
| 	test.AssertError(t, err, "newSigner didn't fail when GetECDSAPublicKey fails") | ||||
| 
 | ||||
| 	// test newSigner works when everything... works
 | ||||
|  | @ -548,6 +546,6 @@ func TestGetKey(t *testing.T) { | |||
| 		} | ||||
| 		return returns, nil | ||||
| 	} | ||||
| 	_, err = newSigner(ctx, 0, "label", []byte{255, 255}) | ||||
| 	_, err = newSigner(s, "label", []byte{255, 255}) | ||||
| 	test.AssertNotError(t, err, "newSigner failed when everything worked properly") | ||||
| } | ||||
|  |  | |||
|  | @ -76,12 +76,11 @@ func ecArgs(label string, curve elliptic.Curve, keyID []byte) generateArgs { | |||
| // handle, and constructs an ecdsa.PublicKey. It also checks that the key is of
 | ||||
| // the correct curve type.
 | ||||
| func ecPub( | ||||
| 	ctx pkcs11helpers.PKCtx, | ||||
| 	session pkcs11.SessionHandle, | ||||
| 	session *pkcs11helpers.Session, | ||||
| 	object pkcs11.ObjectHandle, | ||||
| 	expectedCurve elliptic.Curve, | ||||
| ) (*ecdsa.PublicKey, error) { | ||||
| 	pubKey, err := pkcs11helpers.GetECDSAPublicKey(ctx, session, object) | ||||
| 	pubKey, err := session.GetECDSAPublicKey(object) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | @ -97,9 +96,9 @@ func ecPub( | |||
| // private key on the device, specified by the provided object handle, by signing
 | ||||
| // a nonce generated on the device and verifying the returned signature using the
 | ||||
| // public key.
 | ||||
| func ecVerify(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, object pkcs11.ObjectHandle, pub *ecdsa.PublicKey) error { | ||||
| func ecVerify(session *pkcs11helpers.Session, object pkcs11.ObjectHandle, pub *ecdsa.PublicKey) error { | ||||
| 	nonce := make([]byte, 4) | ||||
| 	_, err := newRandReader(ctx, session).Read(nonce) | ||||
| 	_, err := newRandReader(session).Read(nonce) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to construct nonce: %s", err) | ||||
| 	} | ||||
|  | @ -108,7 +107,7 @@ func ecVerify(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, object pkcs | |||
| 	hashFunc.Write(nonce) | ||||
| 	digest := hashFunc.Sum(nil) | ||||
| 	log.Printf("\tMessage %s hash: %X\n", hashToString[curveToHash[pub.Curve]], digest) | ||||
| 	signature, err := pkcs11helpers.Sign(ctx, session, object, pkcs11helpers.ECDSAKey, digest, curveToHash[pub.Curve]) | ||||
| 	signature, err := session.Sign(object, pkcs11helpers.ECDSAKey, digest, curveToHash[pub.Curve]) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | @ -126,31 +125,31 @@ func ecVerify(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, object pkcs | |||
| // specified by curveStr and with the provided label. It returns the public
 | ||||
| // part of the generated key pair as a ecdsa.PublicKey and the random key ID
 | ||||
| // that the HSM uses to identify the key pair.
 | ||||
| func ecGenerate(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, label, curveStr string) (*ecdsa.PublicKey, []byte, error) { | ||||
| func ecGenerate(session *pkcs11helpers.Session, label, curveStr string) (*ecdsa.PublicKey, []byte, error) { | ||||
| 	curve, present := stringToCurve[curveStr] | ||||
| 	if !present { | ||||
| 		return nil, nil, fmt.Errorf("curve %q not supported", curveStr) | ||||
| 	} | ||||
| 	keyID := make([]byte, 4) | ||||
| 	_, err := newRandReader(ctx, session).Read(keyID) | ||||
| 	_, err := newRandReader(session).Read(keyID) | ||||
| 	if err != nil { | ||||
| 		return nil, nil, err | ||||
| 	} | ||||
| 	log.Printf("Generating ECDSA key with curve %s and ID %x\n", curveStr, keyID) | ||||
| 	args := ecArgs(label, curve, keyID) | ||||
| 	pub, priv, err := ctx.GenerateKeyPair(session, args.mechanism, args.publicAttrs, args.privateAttrs) | ||||
| 	pub, priv, err := session.GenerateKeyPair(args.mechanism, args.publicAttrs, args.privateAttrs) | ||||
| 	if err != nil { | ||||
| 		return nil, nil, err | ||||
| 	} | ||||
| 	log.Println("Key generated") | ||||
| 	log.Println("Extracting public key") | ||||
| 	pk, err := ecPub(ctx, session, pub, curve) | ||||
| 	pk, err := ecPub(session, pub, curve) | ||||
| 	if err != nil { | ||||
| 		return nil, nil, err | ||||
| 	} | ||||
| 	log.Println("Extracted public key") | ||||
| 	log.Println("Verifying public key") | ||||
| 	err = ecVerify(ctx, session, priv, pk) | ||||
| 	err = ecVerify(session, priv, pk) | ||||
| 	if err != nil { | ||||
| 		return nil, nil, err | ||||
| 	} | ||||
|  |  | |||
|  | @ -13,13 +13,13 @@ import ( | |||
| ) | ||||
| 
 | ||||
| func TestECPub(t *testing.T) { | ||||
| 	ctx := pkcs11helpers.MockCtx{} | ||||
| 	s, ctx := pkcs11helpers.NewSessionWithMock() | ||||
| 
 | ||||
| 	// test we fail when pkcs11helpers.GetECDSAPublicKey fails
 | ||||
| 	ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) { | ||||
| 		return nil, errors.New("bad!") | ||||
| 	} | ||||
| 	_, err := ecPub(ctx, 0, 0, elliptic.P256()) | ||||
| 	_, err := ecPub(s, 0, elliptic.P256()) | ||||
| 	test.AssertError(t, err, "ecPub didn't fail with non-matching curve") | ||||
| 	test.AssertEquals(t, err.Error(), "Failed to retrieve key attributes: bad!") | ||||
| 
 | ||||
|  | @ -30,18 +30,17 @@ func TestECPub(t *testing.T) { | |||
| 			pkcs11.NewAttribute(pkcs11.CKA_EC_POINT, []byte{4, 217, 225, 246, 210, 153, 134, 246, 104, 95, 79, 122, 206, 135, 241, 37, 114, 199, 87, 56, 167, 83, 56, 136, 174, 6, 145, 97, 239, 221, 49, 67, 148, 13, 126, 65, 90, 208, 195, 193, 171, 105, 40, 98, 132, 124, 30, 189, 215, 197, 178, 226, 166, 238, 240, 57, 215}), | ||||
| 		}, nil | ||||
| 	} | ||||
| 	_, err = ecPub(ctx, 0, 0, elliptic.P256()) | ||||
| 	_, err = ecPub(s, 0, elliptic.P256()) | ||||
| 	test.AssertError(t, err, "ecPub didn't fail with non-matching curve") | ||||
| } | ||||
| 
 | ||||
| func TestECVerify(t *testing.T) { | ||||
| 	ctx := pkcs11helpers.MockCtx{} | ||||
| 
 | ||||
| 	s, ctx := pkcs11helpers.NewSessionWithMock() | ||||
| 	// test GenerateRandom failing
 | ||||
| 	ctx.GenerateRandomFunc = func(pkcs11.SessionHandle, int) ([]byte, error) { | ||||
| 		return nil, errors.New("yup") | ||||
| 	} | ||||
| 	err := ecVerify(ctx, 0, 0, nil) | ||||
| 	err := ecVerify(s, 0, nil) | ||||
| 	test.AssertError(t, err, "ecVerify didn't fail on GenerateRandom error") | ||||
| 
 | ||||
| 	// test SignInit failing
 | ||||
|  | @ -51,7 +50,7 @@ func TestECVerify(t *testing.T) { | |||
| 	ctx.SignInitFunc = func(pkcs11.SessionHandle, []*pkcs11.Mechanism, pkcs11.ObjectHandle) error { | ||||
| 		return errors.New("yup") | ||||
| 	} | ||||
| 	err = ecVerify(ctx, 0, 0, &ecdsa.PublicKey{Curve: elliptic.P256()}) | ||||
| 	err = ecVerify(s, 0, &ecdsa.PublicKey{Curve: elliptic.P256()}) | ||||
| 	test.AssertError(t, err, "ecVerify didn't fail on SignInit error") | ||||
| 
 | ||||
| 	// test Sign failing
 | ||||
|  | @ -61,7 +60,7 @@ func TestECVerify(t *testing.T) { | |||
| 	ctx.SignFunc = func(pkcs11.SessionHandle, []byte) ([]byte, error) { | ||||
| 		return nil, errors.New("yup") | ||||
| 	} | ||||
| 	err = ecVerify(ctx, 0, 0, &ecdsa.PublicKey{Curve: elliptic.P256()}) | ||||
| 	err = ecVerify(s, 0, &ecdsa.PublicKey{Curve: elliptic.P256()}) | ||||
| 	test.AssertError(t, err, "ecVerify didn't fail on Sign error") | ||||
| 
 | ||||
| 	// test signature verification failing
 | ||||
|  | @ -70,19 +69,20 @@ func TestECVerify(t *testing.T) { | |||
| 	} | ||||
| 	tk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) | ||||
| 	test.AssertNotError(t, err, "ecdsa.GenerateKey failed") | ||||
| 	err = ecVerify(ctx, 0, 0, &tk.PublicKey) | ||||
| 	err = ecVerify(s, 0, &tk.PublicKey) | ||||
| 	test.AssertError(t, err, "ecVerify didn't fail on signature verification error") | ||||
| 
 | ||||
| 	// test we don't fail with valid signature
 | ||||
| 	ctx.SignFunc = func(_ pkcs11.SessionHandle, msg []byte) ([]byte, error) { | ||||
| 		return ecPKCS11Sign(tk, msg) | ||||
| 	} | ||||
| 	err = ecVerify(ctx, 0, 0, &tk.PublicKey) | ||||
| 	err = ecVerify(s, 0, &tk.PublicKey) | ||||
| 	test.AssertNotError(t, err, "ecVerify failed with a valid signature") | ||||
| } | ||||
| 
 | ||||
| func TestECGenerate(t *testing.T) { | ||||
| 	ctx := pkcs11helpers.MockCtx{} | ||||
| 	s := &pkcs11helpers.Session{Module: &ctx, Session: 0} | ||||
| 	ctx.GenerateRandomFunc = func(pkcs11.SessionHandle, int) ([]byte, error) { | ||||
| 		return []byte{1, 2, 3}, nil | ||||
| 	} | ||||
|  | @ -90,14 +90,14 @@ func TestECGenerate(t *testing.T) { | |||
| 	test.AssertNotError(t, err, "Failed to generate a ECDSA test key") | ||||
| 
 | ||||
| 	// Test ecGenerate fails with unknown curve
 | ||||
| 	_, _, err = ecGenerate(ctx, 0, "", "bad-curve") | ||||
| 	_, _, err = ecGenerate(s, "", "bad-curve") | ||||
| 	test.AssertError(t, err, "ecGenerate accepted unknown curve") | ||||
| 
 | ||||
| 	// Test ecGenerate fails when GenerateKeyPair fails
 | ||||
| 	ctx.GenerateKeyPairFunc = func(pkcs11.SessionHandle, []*pkcs11.Mechanism, []*pkcs11.Attribute, []*pkcs11.Attribute) (pkcs11.ObjectHandle, pkcs11.ObjectHandle, error) { | ||||
| 		return 0, 0, errors.New("bad") | ||||
| 	} | ||||
| 	_, _, err = ecGenerate(ctx, 0, "", "P-256") | ||||
| 	_, _, err = ecGenerate(s, "", "P-256") | ||||
| 	test.AssertError(t, err, "ecGenerate didn't fail on GenerateKeyPair error") | ||||
| 
 | ||||
| 	// Test ecGenerate fails when ecPub fails
 | ||||
|  | @ -107,7 +107,7 @@ func TestECGenerate(t *testing.T) { | |||
| 	ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) { | ||||
| 		return nil, errors.New("bad") | ||||
| 	} | ||||
| 	_, _, err = ecGenerate(ctx, 0, "", "P-256") | ||||
| 	_, _, err = ecGenerate(s, "", "P-256") | ||||
| 	test.AssertError(t, err, "ecGenerate didn't fail on ecPub error") | ||||
| 
 | ||||
| 	// Test ecGenerate fails when ecVerify fails
 | ||||
|  | @ -120,7 +120,7 @@ func TestECGenerate(t *testing.T) { | |||
| 	ctx.GenerateRandomFunc = func(pkcs11.SessionHandle, int) ([]byte, error) { | ||||
| 		return nil, errors.New("yup") | ||||
| 	} | ||||
| 	_, _, err = ecGenerate(ctx, 0, "", "P-256") | ||||
| 	_, _, err = ecGenerate(s, "", "P-256") | ||||
| 	test.AssertError(t, err, "ecGenerate didn't fail on ecVerify error") | ||||
| 
 | ||||
| 	// Test ecGenerate doesn't fail when everything works
 | ||||
|  | @ -133,7 +133,7 @@ func TestECGenerate(t *testing.T) { | |||
| 	ctx.SignFunc = func(_ pkcs11.SessionHandle, msg []byte) ([]byte, error) { | ||||
| 		return ecPKCS11Sign(priv, msg) | ||||
| 	} | ||||
| 	_, _, err = ecGenerate(ctx, 0, "", "P-256") | ||||
| 	_, _, err = ecGenerate(s, "", "P-256") | ||||
| 	test.AssertNotError(t, err, "ecGenerate didn't succeed when everything worked as expected") | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -13,19 +13,15 @@ import ( | |||
| ) | ||||
| 
 | ||||
| type hsmRandReader struct { | ||||
| 	ctx     pkcs11helpers.PKCtx | ||||
| 	session pkcs11.SessionHandle | ||||
| 	*pkcs11helpers.Session | ||||
| } | ||||
| 
 | ||||
| func newRandReader(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle) *hsmRandReader { | ||||
| 	return &hsmRandReader{ | ||||
| 		ctx:     ctx, | ||||
| 		session: session, | ||||
| 	} | ||||
| func newRandReader(session *pkcs11helpers.Session) *hsmRandReader { | ||||
| 	return &hsmRandReader{session} | ||||
| } | ||||
| 
 | ||||
| func (hrr hsmRandReader) Read(p []byte) (n int, err error) { | ||||
| 	r, err := hrr.ctx.GenerateRandom(hrr.session, len(p)) | ||||
| 	r, err := hrr.Module.GenerateRandom(hrr.Session.Session, len(p)) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | @ -53,8 +49,8 @@ type keyInfo struct { | |||
| 	id  []byte | ||||
| } | ||||
| 
 | ||||
| func generateKey(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, label string, outputPath string, config keyGenConfig) (*keyInfo, error) { | ||||
| 	_, err := pkcs11helpers.FindObject(ctx, session, []*pkcs11.Attribute{}) | ||||
| func generateKey(session *pkcs11helpers.Session, label string, outputPath string, config keyGenConfig) (*keyInfo, error) { | ||||
| 	_, err := session.FindObject([]*pkcs11.Attribute{}) | ||||
| 	if err != pkcs11helpers.ErrNoObject { | ||||
| 		return nil, fmt.Errorf("expected no objects in slot for key storage. got error: %s", err) | ||||
| 	} | ||||
|  | @ -63,12 +59,12 @@ func generateKey(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, label st | |||
| 	var keyID []byte | ||||
| 	switch config.Type { | ||||
| 	case "rsa": | ||||
| 		pubKey, keyID, err = rsaGenerate(ctx, session, label, config.RSAModLength, rsaExp) | ||||
| 		pubKey, keyID, err = rsaGenerate(session, label, config.RSAModLength, rsaExp) | ||||
| 		if err != nil { | ||||
| 			return nil, fmt.Errorf("failed to generate RSA key pair: %s", err) | ||||
| 		} | ||||
| 	case "ecdsa": | ||||
| 		pubKey, keyID, err = ecGenerate(ctx, session, label, config.ECDSACurve) | ||||
| 		pubKey, keyID, err = ecGenerate(session, label, config.ECDSACurve) | ||||
| 		if err != nil { | ||||
| 			return nil, fmt.Errorf("failed to generate ECDSA key pair: %s", err) | ||||
| 		} | ||||
|  |  | |||
|  | @ -61,8 +61,9 @@ func TestGenerateKeyRSA(t *testing.T) { | |||
| 		// Chop of the hash identifier and feed back into rsa.SignPKCS1v15
 | ||||
| 		return rsa.SignPKCS1v15(rand.Reader, rsaPriv, crypto.SHA256, msg[19:]) | ||||
| 	} | ||||
| 	s := &pkcs11helpers.Session{Module: &ctx, Session: 0} | ||||
| 	keyPath := path.Join(tmp, "test-rsa-key.pem") | ||||
| 	keyInfo, err := generateKey(ctx, 0, "", keyPath, keyGenConfig{ | ||||
| 	keyInfo, err := generateKey(s, "", keyPath, keyGenConfig{ | ||||
| 		Type:         "rsa", | ||||
| 		RSAModLength: 1024, | ||||
| 	}) | ||||
|  | @ -93,7 +94,8 @@ func TestGenerateKeyEC(t *testing.T) { | |||
| 		return ecPKCS11Sign(ecPriv, msg) | ||||
| 	} | ||||
| 	keyPath := path.Join(tmp, "test-ecdsa-key.pem") | ||||
| 	keyInfo, err := generateKey(ctx, 0, "", keyPath, keyGenConfig{ | ||||
| 	s := &pkcs11helpers.Session{Module: &ctx, Session: 0} | ||||
| 	keyInfo, err := generateKey(s, "", keyPath, keyGenConfig{ | ||||
| 		Type:       "ecdsa", | ||||
| 		ECDSACurve: "P-256", | ||||
| 	}) | ||||
|  | @ -116,7 +118,8 @@ func TestGenerateKeySlotHasSomething(t *testing.T) { | |||
| 		return []pkcs11.ObjectHandle{1}, false, nil | ||||
| 	} | ||||
| 	keyPath := path.Join(tmp, "should-not-exist.pem") | ||||
| 	_, err = generateKey(ctx, 0, "", keyPath, keyGenConfig{ | ||||
| 	s := &pkcs11helpers.Session{Module: &ctx, Session: 0} | ||||
| 	_, err = generateKey(s, "", keyPath, keyGenConfig{ | ||||
| 		Type:       "ecdsa", | ||||
| 		ECDSACurve: "P-256", | ||||
| 	}) | ||||
|  |  | |||
|  | @ -342,7 +342,7 @@ func equalPubKeys(a, b interface{}) bool { | |||
| } | ||||
| 
 | ||||
| func openSigner(cfg PKCS11SigningConfig, issuer *x509.Certificate) (crypto.Signer, *hsmRandReader, error) { | ||||
| 	ctx, session, err := pkcs11helpers.Initialize(cfg.Module, cfg.SigningSlot, cfg.PIN) | ||||
| 	session, err := pkcs11helpers.Initialize(cfg.Module, cfg.SigningSlot, cfg.PIN) | ||||
| 	if err != nil { | ||||
| 		return nil, nil, fmt.Errorf("failed to setup session and PKCS#11 context for slot %d: %s", | ||||
| 			cfg.SigningSlot, err) | ||||
|  | @ -352,7 +352,7 @@ func openSigner(cfg PKCS11SigningConfig, issuer *x509.Certificate) (crypto.Signe | |||
| 	if err != nil { | ||||
| 		return nil, nil, fmt.Errorf("failed to decode key-id: %s", err) | ||||
| 	} | ||||
| 	signer, err := newSigner(ctx, session, cfg.SigningLabel, keyID) | ||||
| 	signer, err := newSigner(session, cfg.SigningLabel, keyID) | ||||
| 	if err != nil { | ||||
| 		return nil, nil, fmt.Errorf("failed to retrieve private key handle: %s", err) | ||||
| 	} | ||||
|  | @ -360,7 +360,7 @@ func openSigner(cfg PKCS11SigningConfig, issuer *x509.Certificate) (crypto.Signe | |||
| 		return nil, nil, fmt.Errorf("signer pubkey did not match issuer pubkey") | ||||
| 	} | ||||
| 	log.Println("Retrieved private key handle") | ||||
| 	return signer, newRandReader(ctx, session), nil | ||||
| 	return signer, newRandReader(session), nil | ||||
| } | ||||
| 
 | ||||
| func signAndWriteCert(tbs, issuer *x509.Certificate, subjectPubKey crypto.PublicKey, signer crypto.Signer, certPath string) error { | ||||
|  | @ -405,20 +405,20 @@ func rootCeremony(configBytes []byte) error { | |||
| 	if err := config.validate(); err != nil { | ||||
| 		return fmt.Errorf("failed to validate config: %s", err) | ||||
| 	} | ||||
| 	ctx, session, err := pkcs11helpers.Initialize(config.PKCS11.Module, config.PKCS11.StoreSlot, config.PKCS11.PIN) | ||||
| 	session, err := pkcs11helpers.Initialize(config.PKCS11.Module, config.PKCS11.StoreSlot, config.PKCS11.PIN) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to setup session and PKCS#11 context for slot %d: %s", config.PKCS11.StoreSlot, err) | ||||
| 	} | ||||
| 	log.Printf("Opened PKCS#11 session for slot %d\n", config.PKCS11.StoreSlot) | ||||
| 	keyInfo, err := generateKey(ctx, session, config.PKCS11.StoreLabel, config.Outputs.PublicKeyPath, config.Key) | ||||
| 	keyInfo, err := generateKey(session, config.PKCS11.StoreLabel, config.Outputs.PublicKeyPath, config.Key) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	signer, err := newSigner(ctx, session, config.PKCS11.StoreLabel, keyInfo.id) | ||||
| 	signer, err := newSigner(session, config.PKCS11.StoreLabel, keyInfo.id) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to retrieve signer: %s", err) | ||||
| 	} | ||||
| 	template, err := makeTemplate(newRandReader(ctx, session), &config.CertProfile, keyInfo.der, rootCert) | ||||
| 	template, err := makeTemplate(newRandReader(session), &config.CertProfile, keyInfo.der, rootCert) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to create certificate profile: %s", err) | ||||
| 	} | ||||
|  | @ -486,12 +486,12 @@ func keyCeremony(configBytes []byte) error { | |||
| 	if err := config.validate(); err != nil { | ||||
| 		return fmt.Errorf("failed to validate config: %s", err) | ||||
| 	} | ||||
| 	ctx, session, err := pkcs11helpers.Initialize(config.PKCS11.Module, config.PKCS11.StoreSlot, config.PKCS11.PIN) | ||||
| 	session, err := pkcs11helpers.Initialize(config.PKCS11.Module, config.PKCS11.StoreSlot, config.PKCS11.PIN) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to setup session and PKCS#11 context for slot %d: %s", config.PKCS11.StoreSlot, err) | ||||
| 	} | ||||
| 	log.Printf("Opened PKCS#11 session for slot %d\n", config.PKCS11.StoreSlot) | ||||
| 	if _, err = generateKey(ctx, session, config.PKCS11.StoreLabel, config.Outputs.PublicKeyPath, config.Key); err != nil { | ||||
| 	if _, err = generateKey(session, config.PKCS11.StoreLabel, config.Outputs.PublicKeyPath, config.Key); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
|  |  | |||
|  | @ -54,8 +54,8 @@ func rsaArgs(label string, modulusLen, exponent uint, keyID []byte) generateArgs | |||
| // handle, and constructs a rsa.PublicKey. It also checks that the key has the
 | ||||
| // correct length modulus and that the public exponent is what was requested in
 | ||||
| // the public key template.
 | ||||
| func rsaPub(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, object pkcs11.ObjectHandle, modulusLen, exponent uint) (*rsa.PublicKey, error) { | ||||
| 	pubKey, err := pkcs11helpers.GetRSAPublicKey(ctx, session, object) | ||||
| func rsaPub(session *pkcs11helpers.Session, object pkcs11.ObjectHandle, modulusLen, exponent uint) (*rsa.PublicKey, error) { | ||||
| 	pubKey, err := session.GetRSAPublicKey(object) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | @ -74,16 +74,16 @@ func rsaPub(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, object pkcs11 | |||
| // private key on the device, specified by the provided object handle, by signing
 | ||||
| // a nonce generated on the device and verifying the returned signature using the
 | ||||
| // public key.
 | ||||
| func rsaVerify(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, object pkcs11.ObjectHandle, pub *rsa.PublicKey) error { | ||||
| func rsaVerify(session *pkcs11helpers.Session, object pkcs11.ObjectHandle, pub *rsa.PublicKey) error { | ||||
| 	nonce := make([]byte, 4) | ||||
| 	_, err := newRandReader(ctx, session).Read(nonce) | ||||
| 	_, err := newRandReader(session).Read(nonce) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("Failed to retrieve nonce: %s", err) | ||||
| 	} | ||||
| 	log.Printf("\tConstructed nonce: %d (%X)\n", big.NewInt(0).SetBytes(nonce), nonce) | ||||
| 	digest := sha256.Sum256(nonce) | ||||
| 	log.Printf("\tMessage SHA-256 hash: %X\n", digest) | ||||
| 	signature, err := pkcs11helpers.Sign(ctx, session, object, pkcs11helpers.RSAKey, digest[:], crypto.SHA256) | ||||
| 	signature, err := session.Sign(object, pkcs11helpers.RSAKey, digest[:], crypto.SHA256) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("Failed to sign data: %s", err) | ||||
| 	} | ||||
|  | @ -100,27 +100,27 @@ func rsaVerify(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, object pkc | |||
| // specified by modulusLen and with the exponent specified by pubExponent.
 | ||||
| // It returns the public part of the generated key pair as a rsa.PublicKey
 | ||||
| // and the random key ID that the HSM uses to identify the key pair.
 | ||||
| func rsaGenerate(ctx pkcs11helpers.PKCtx, session pkcs11.SessionHandle, label string, modulusLen, pubExponent uint) (*rsa.PublicKey, []byte, error) { | ||||
| func rsaGenerate(session *pkcs11helpers.Session, label string, modulusLen, pubExponent uint) (*rsa.PublicKey, []byte, error) { | ||||
| 	keyID := make([]byte, 4) | ||||
| 	_, err := newRandReader(ctx, session).Read(keyID) | ||||
| 	_, err := newRandReader(session).Read(keyID) | ||||
| 	if err != nil { | ||||
| 		return nil, nil, err | ||||
| 	} | ||||
| 	log.Printf("Generating RSA key with %d bit modulus and public exponent %d and ID %x\n", modulusLen, pubExponent, keyID) | ||||
| 	args := rsaArgs(label, modulusLen, pubExponent, keyID) | ||||
| 	pub, priv, err := ctx.GenerateKeyPair(session, args.mechanism, args.publicAttrs, args.privateAttrs) | ||||
| 	pub, priv, err := session.GenerateKeyPair(args.mechanism, args.publicAttrs, args.privateAttrs) | ||||
| 	if err != nil { | ||||
| 		return nil, nil, err | ||||
| 	} | ||||
| 	log.Println("Key generated") | ||||
| 	log.Println("Extracting public key") | ||||
| 	pk, err := rsaPub(ctx, session, pub, modulusLen, pubExponent) | ||||
| 	pk, err := rsaPub(session, pub, modulusLen, pubExponent) | ||||
| 	if err != nil { | ||||
| 		return nil, nil, err | ||||
| 	} | ||||
| 	log.Println("Extracted public key") | ||||
| 	log.Println("Verifying public key") | ||||
| 	err = rsaVerify(ctx, session, priv, pk) | ||||
| 	err = rsaVerify(session, priv, pk) | ||||
| 	if err != nil { | ||||
| 		return nil, nil, err | ||||
| 	} | ||||
|  |  | |||
|  | @ -14,7 +14,7 @@ import ( | |||
| ) | ||||
| 
 | ||||
| func TestRSAPub(t *testing.T) { | ||||
| 	ctx := pkcs11helpers.MockCtx{} | ||||
| 	s, ctx := pkcs11helpers.NewSessionWithMock() | ||||
| 
 | ||||
| 	// test we fail to construct key with non-matching exp
 | ||||
| 	ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) { | ||||
|  | @ -23,7 +23,7 @@ func TestRSAPub(t *testing.T) { | |||
| 			pkcs11.NewAttribute(pkcs11.CKA_MODULUS, []byte{255}), | ||||
| 		}, nil | ||||
| 	} | ||||
| 	_, err := rsaPub(ctx, 0, 0, 0, 255) | ||||
| 	_, err := rsaPub(s, 0, 0, 255) | ||||
| 	test.AssertError(t, err, "rsaPub didn't fail with non-matching exp") | ||||
| 
 | ||||
| 	// test we fail to construct key with non-matching modulus
 | ||||
|  | @ -33,7 +33,7 @@ func TestRSAPub(t *testing.T) { | |||
| 			pkcs11.NewAttribute(pkcs11.CKA_MODULUS, []byte{255}), | ||||
| 		}, nil | ||||
| 	} | ||||
| 	_, err = rsaPub(ctx, 0, 0, 16, 65537) | ||||
| 	_, err = rsaPub(s, 0, 16, 65537) | ||||
| 	test.AssertError(t, err, "rsaPub didn't fail with non-matching modulus size") | ||||
| 
 | ||||
| 	// test we don't fail with the correct attributes
 | ||||
|  | @ -43,18 +43,18 @@ func TestRSAPub(t *testing.T) { | |||
| 			pkcs11.NewAttribute(pkcs11.CKA_MODULUS, []byte{255}), | ||||
| 		}, nil | ||||
| 	} | ||||
| 	_, err = rsaPub(ctx, 0, 0, 8, 65537) | ||||
| 	_, err = rsaPub(s, 0, 8, 65537) | ||||
| 	test.AssertNotError(t, err, "rsaPub failed with valid attributes") | ||||
| } | ||||
| 
 | ||||
| func TestRSAVerify(t *testing.T) { | ||||
| 	ctx := pkcs11helpers.MockCtx{} | ||||
| 	s, ctx := pkcs11helpers.NewSessionWithMock() | ||||
| 
 | ||||
| 	// test GenerateRandom failing
 | ||||
| 	ctx.GenerateRandomFunc = func(pkcs11.SessionHandle, int) ([]byte, error) { | ||||
| 		return nil, errors.New("yup") | ||||
| 	} | ||||
| 	err := rsaVerify(ctx, 0, 0, nil) | ||||
| 	err := rsaVerify(s, 0, nil) | ||||
| 	test.AssertError(t, err, "rsaVerify didn't fail on GenerateRandom error") | ||||
| 
 | ||||
| 	// test SignInit failing
 | ||||
|  | @ -64,7 +64,7 @@ func TestRSAVerify(t *testing.T) { | |||
| 	ctx.SignInitFunc = func(pkcs11.SessionHandle, []*pkcs11.Mechanism, pkcs11.ObjectHandle) error { | ||||
| 		return errors.New("yup") | ||||
| 	} | ||||
| 	err = rsaVerify(ctx, 0, 0, nil) | ||||
| 	err = rsaVerify(s, 0, nil) | ||||
| 	test.AssertError(t, err, "rsaVerify didn't fail on SignInit error") | ||||
| 
 | ||||
| 	// test Sign failing
 | ||||
|  | @ -77,7 +77,7 @@ func TestRSAVerify(t *testing.T) { | |||
| 	ctx.SignFunc = func(pkcs11.SessionHandle, []byte) ([]byte, error) { | ||||
| 		return nil, errors.New("yup") | ||||
| 	} | ||||
| 	err = rsaVerify(ctx, 0, 0, nil) | ||||
| 	err = rsaVerify(s, 0, nil) | ||||
| 	test.AssertError(t, err, "rsaVerify didn't fail on Sign error") | ||||
| 
 | ||||
| 	// test signature verification failing
 | ||||
|  | @ -86,7 +86,7 @@ func TestRSAVerify(t *testing.T) { | |||
| 	} | ||||
| 	tk, err := rsa.GenerateKey(rand.Reader, 1024) | ||||
| 	test.AssertNotError(t, err, "rsa.GenerateKey failed") | ||||
| 	err = rsaVerify(ctx, 0, 0, &tk.PublicKey) | ||||
| 	err = rsaVerify(s, 0, &tk.PublicKey) | ||||
| 	test.AssertError(t, err, "rsaVerify didn't fail on signature verification error") | ||||
| 
 | ||||
| 	// test we don't fail with valid signature
 | ||||
|  | @ -94,12 +94,12 @@ func TestRSAVerify(t *testing.T) { | |||
| 		// Chop of the hash identifier and feed back into rsa.SignPKCS1v15
 | ||||
| 		return rsa.SignPKCS1v15(rand.Reader, tk, crypto.SHA256, msg[19:]) | ||||
| 	} | ||||
| 	err = rsaVerify(ctx, 0, 0, &tk.PublicKey) | ||||
| 	err = rsaVerify(s, 0, &tk.PublicKey) | ||||
| 	test.AssertNotError(t, err, "rsaVerify failed with a valid signature") | ||||
| } | ||||
| 
 | ||||
| func TestRSAGenerate(t *testing.T) { | ||||
| 	ctx := pkcs11helpers.MockCtx{} | ||||
| 	s, ctx := pkcs11helpers.NewSessionWithMock() | ||||
| 	ctx.GenerateRandomFunc = func(pkcs11.SessionHandle, int) ([]byte, error) { | ||||
| 		return []byte{1, 2, 3}, nil | ||||
| 	} | ||||
|  | @ -111,7 +111,7 @@ func TestRSAGenerate(t *testing.T) { | |||
| 	ctx.GenerateKeyPairFunc = func(pkcs11.SessionHandle, []*pkcs11.Mechanism, []*pkcs11.Attribute, []*pkcs11.Attribute) (pkcs11.ObjectHandle, pkcs11.ObjectHandle, error) { | ||||
| 		return 0, 0, errors.New("bad") | ||||
| 	} | ||||
| 	_, _, err = rsaGenerate(ctx, 0, "", 1024, 65537) | ||||
| 	_, _, err = rsaGenerate(s, "", 1024, 65537) | ||||
| 	test.AssertError(t, err, "rsaGenerate didn't fail on GenerateKeyPair error") | ||||
| 
 | ||||
| 	// Test rsaGenerate fails when rsaPub fails
 | ||||
|  | @ -121,7 +121,7 @@ func TestRSAGenerate(t *testing.T) { | |||
| 	ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) { | ||||
| 		return nil, errors.New("bad") | ||||
| 	} | ||||
| 	_, _, err = rsaGenerate(ctx, 0, "", 1024, 65537) | ||||
| 	_, _, err = rsaGenerate(s, "", 1024, 65537) | ||||
| 	test.AssertError(t, err, "rsaGenerate didn't fail on rsaPub error") | ||||
| 
 | ||||
| 	// Test rsaGenerate fails when rsaVerify fails
 | ||||
|  | @ -134,7 +134,7 @@ func TestRSAGenerate(t *testing.T) { | |||
| 	ctx.GenerateRandomFunc = func(pkcs11.SessionHandle, int) ([]byte, error) { | ||||
| 		return nil, errors.New("yup") | ||||
| 	} | ||||
| 	_, _, err = rsaGenerate(ctx, 0, "", 1024, 65537) | ||||
| 	_, _, err = rsaGenerate(s, "", 1024, 65537) | ||||
| 	test.AssertError(t, err, "rsaGenerate didn't fail on rsaVerify error") | ||||
| 
 | ||||
| 	// Test rsaGenerate doesn't fail when everything works
 | ||||
|  | @ -148,6 +148,6 @@ func TestRSAGenerate(t *testing.T) { | |||
| 		// Chop of the hash identifier and feed back into rsa.SignPKCS1v15
 | ||||
| 		return rsa.SignPKCS1v15(rand.Reader, priv, crypto.SHA256, msg[19:]) | ||||
| 	} | ||||
| 	_, _, err = rsaGenerate(ctx, 0, "", 1024, 65537) | ||||
| 	_, _, err = rsaGenerate(s, "", 1024, 65537) | ||||
| 	test.AssertNotError(t, err, "rsaGenerate didn't succeed when everything worked as expected") | ||||
| } | ||||
|  |  | |||
|  | @ -24,32 +24,47 @@ type PKCtx interface { | |||
| 	FindObjectsFinal(sh pkcs11.SessionHandle) error | ||||
| } | ||||
| 
 | ||||
| func Initialize(module string, slot uint, pin string) (PKCtx, pkcs11.SessionHandle, error) { | ||||
| // Session represents a session with a given PKCS#11 module. It is not safe for
 | ||||
| // concurrent access.
 | ||||
| type Session struct { | ||||
| 	Module  PKCtx | ||||
| 	Session pkcs11.SessionHandle | ||||
| } | ||||
| 
 | ||||
| func Initialize(module string, slot uint, pin string) (*Session, error) { | ||||
| 	ctx := pkcs11.New(module) | ||||
| 	if ctx == nil { | ||||
| 		return nil, 0, errors.New("failed to load module") | ||||
| 		return nil, errors.New("failed to load module") | ||||
| 	} | ||||
| 	err := ctx.Initialize() | ||||
| 	if err != nil { | ||||
| 		return nil, 0, fmt.Errorf("couldn't initialize context: %s", err) | ||||
| 		return nil, fmt.Errorf("couldn't initialize context: %s", err) | ||||
| 	} | ||||
| 
 | ||||
| 	session, err := ctx.OpenSession(slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION) | ||||
| 	if err != nil { | ||||
| 		return nil, 0, fmt.Errorf("couldn't open session: %s", err) | ||||
| 		return nil, fmt.Errorf("couldn't open session: %s", err) | ||||
| 	} | ||||
| 
 | ||||
| 	err = ctx.Login(session, pkcs11.CKU_USER, pin) | ||||
| 	if err != nil { | ||||
| 		return nil, 0, fmt.Errorf("couldn't login: %s", err) | ||||
| 		return nil, fmt.Errorf("couldn't login: %s", err) | ||||
| 	} | ||||
| 
 | ||||
| 	return ctx, session, nil | ||||
| 	return &Session{ctx, session}, nil | ||||
| } | ||||
| 
 | ||||
| func GetRSAPublicKey(ctx PKCtx, session pkcs11.SessionHandle, object pkcs11.ObjectHandle) (*rsa.PublicKey, error) { | ||||
| func (s *Session) GetAttributeValue(object pkcs11.ObjectHandle, attributes []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) { | ||||
| 	return s.Module.GetAttributeValue(s.Session, object, attributes) | ||||
| } | ||||
| 
 | ||||
| func (s *Session) GenerateKeyPair(m []*pkcs11.Mechanism, pubAttrs []*pkcs11.Attribute, privAttrs []*pkcs11.Attribute) (pkcs11.ObjectHandle, pkcs11.ObjectHandle, error) { | ||||
| 	return s.Module.GenerateKeyPair(s.Session, m, pubAttrs, privAttrs) | ||||
| } | ||||
| 
 | ||||
| func (s *Session) GetRSAPublicKey(object pkcs11.ObjectHandle) (*rsa.PublicKey, error) { | ||||
| 	// Retrieve the public exponent and modulus for the public key
 | ||||
| 	attrs, err := ctx.GetAttributeValue(session, object, []*pkcs11.Attribute{ | ||||
| 	attrs, err := s.Module.GetAttributeValue(s.Session, object, []*pkcs11.Attribute{ | ||||
| 		pkcs11.NewAttribute(pkcs11.CKA_PUBLIC_EXPONENT, nil), | ||||
| 		pkcs11.NewAttribute(pkcs11.CKA_MODULUS, nil), | ||||
| 	}) | ||||
|  | @ -86,9 +101,9 @@ var oidDERToCurve = map[string]elliptic.Curve{ | |||
| 	"06052B81040023":       elliptic.P521(), | ||||
| } | ||||
| 
 | ||||
| func GetECDSAPublicKey(ctx PKCtx, session pkcs11.SessionHandle, object pkcs11.ObjectHandle) (*ecdsa.PublicKey, error) { | ||||
| func (s *Session) GetECDSAPublicKey(object pkcs11.ObjectHandle) (*ecdsa.PublicKey, error) { | ||||
| 	// Retrieve the curve and public point for the generated public key
 | ||||
| 	attrs, err := ctx.GetAttributeValue(session, object, []*pkcs11.Attribute{ | ||||
| 	attrs, err := s.Module.GetAttributeValue(s.Session, object, []*pkcs11.Attribute{ | ||||
| 		pkcs11.NewAttribute(pkcs11.CKA_EC_PARAMS, nil), | ||||
| 		pkcs11.NewAttribute(pkcs11.CKA_EC_POINT, nil), | ||||
| 	}) | ||||
|  | @ -152,7 +167,7 @@ var hashIdentifiers = map[crypto.Hash][]byte{ | |||
| 	crypto.SHA512: {0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05, 0x00, 0x04, 0x40}, | ||||
| } | ||||
| 
 | ||||
| func Sign(ctx PKCtx, session pkcs11.SessionHandle, object pkcs11.ObjectHandle, keyType KeyType, digest []byte, hash crypto.Hash) ([]byte, error) { | ||||
| func (s *Session) Sign(object pkcs11.ObjectHandle, keyType KeyType, digest []byte, hash crypto.Hash) ([]byte, error) { | ||||
| 	if len(digest) != hash.Size() { | ||||
| 		return nil, errors.New("digest length doesn't match hash length") | ||||
| 	} | ||||
|  | @ -170,11 +185,11 @@ func Sign(ctx PKCtx, session pkcs11.SessionHandle, object pkcs11.ObjectHandle, k | |||
| 		mech[0] = pkcs11.NewMechanism(pkcs11.CKM_ECDSA, nil) | ||||
| 	} | ||||
| 
 | ||||
| 	err := ctx.SignInit(session, mech, object) | ||||
| 	err := s.Module.SignInit(s.Session, mech, object) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to initialize signing operation: %s", err) | ||||
| 	} | ||||
| 	signature, err := ctx.Sign(session, digest) | ||||
| 	signature, err := s.Module.Sign(s.Session, digest) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to sign data: %s", err) | ||||
| 	} | ||||
|  | @ -187,15 +202,15 @@ var ErrNoObject = errors.New("no objects found matching provided template") | |||
| // FindObject looks up a PKCS#11 object handle based on the provided template.
 | ||||
| // In the case where zero or more than one objects are found to match the
 | ||||
| // template an error is returned.
 | ||||
| func FindObject(ctx PKCtx, session pkcs11.SessionHandle, tmpl []*pkcs11.Attribute) (pkcs11.ObjectHandle, error) { | ||||
| 	if err := ctx.FindObjectsInit(session, tmpl); err != nil { | ||||
| func (s *Session) FindObject(tmpl []*pkcs11.Attribute) (pkcs11.ObjectHandle, error) { | ||||
| 	if err := s.Module.FindObjectsInit(s.Session, tmpl); err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	handles, _, err := ctx.FindObjects(session, 2) | ||||
| 	handles, _, err := s.Module.FindObjects(s.Session, 2) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	if err := ctx.FindObjectsFinal(session); err != nil { | ||||
| 	if err := s.Module.FindObjectsFinal(s.Session); err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	if len(handles) == 0 { | ||||
|  | @ -207,6 +222,15 @@ func FindObject(ctx PKCtx, session pkcs11.SessionHandle, tmpl []*pkcs11.Attribut | |||
| 	return handles[0], nil | ||||
| } | ||||
| 
 | ||||
| func NewMock() *MockCtx { | ||||
| 	return &MockCtx{} | ||||
| } | ||||
| 
 | ||||
| func NewSessionWithMock() (*Session, *MockCtx) { | ||||
| 	ctx := NewMock() | ||||
| 	return &Session{ctx, 0}, ctx | ||||
| } | ||||
| 
 | ||||
| type MockCtx struct { | ||||
| 	GenerateKeyPairFunc   func(pkcs11.SessionHandle, []*pkcs11.Mechanism, []*pkcs11.Attribute, []*pkcs11.Attribute) (pkcs11.ObjectHandle, pkcs11.ObjectHandle, error) | ||||
| 	GetAttributeValueFunc func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) | ||||
|  |  | |||
|  | @ -10,20 +10,21 @@ import ( | |||
| ) | ||||
| 
 | ||||
| func TestGetECDSAPublicKey(t *testing.T) { | ||||
| 	ctx := MockCtx{} | ||||
| 	ctx := &MockCtx{} | ||||
| 	s := &Session{ctx, 0} | ||||
| 
 | ||||
| 	// test attribute retrieval failing
 | ||||
| 	ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) { | ||||
| 		return nil, errors.New("yup") | ||||
| 	} | ||||
| 	_, err := GetECDSAPublicKey(ctx, 0, 0) | ||||
| 	_, err := s.GetECDSAPublicKey(0) | ||||
| 	test.AssertError(t, err, "ecPub didn't fail on GetAttributeValue error") | ||||
| 
 | ||||
| 	// test we fail to construct key with missing params and point
 | ||||
| 	ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) { | ||||
| 		return []*pkcs11.Attribute{}, nil | ||||
| 	} | ||||
| 	_, err = GetECDSAPublicKey(ctx, 0, 0) | ||||
| 	_, err = s.GetECDSAPublicKey(0) | ||||
| 	test.AssertError(t, err, "ecPub didn't fail with empty attribute list") | ||||
| 
 | ||||
| 	// test we fail to construct key with unknown curve
 | ||||
|  | @ -32,7 +33,7 @@ func TestGetECDSAPublicKey(t *testing.T) { | |||
| 			pkcs11.NewAttribute(pkcs11.CKA_EC_PARAMS, []byte{1, 2, 3}), | ||||
| 		}, nil | ||||
| 	} | ||||
| 	_, err = GetECDSAPublicKey(ctx, 0, 0) | ||||
| 	_, err = s.GetECDSAPublicKey(0) | ||||
| 	test.AssertError(t, err, "ecPub didn't fail with unknown curve") | ||||
| 
 | ||||
| 	// test we fail to construct key with invalid EC point (invalid encoding)
 | ||||
|  | @ -42,7 +43,7 @@ func TestGetECDSAPublicKey(t *testing.T) { | |||
| 			pkcs11.NewAttribute(pkcs11.CKA_EC_POINT, []byte{255}), | ||||
| 		}, nil | ||||
| 	} | ||||
| 	_, err = GetECDSAPublicKey(ctx, 0, 0) | ||||
| 	_, err = s.GetECDSAPublicKey(0) | ||||
| 	test.AssertError(t, err, "ecPub didn't fail with invalid EC point (invalid encoding)") | ||||
| 
 | ||||
| 	// test we fail to construct key with invalid EC point (empty octet string)
 | ||||
|  | @ -52,7 +53,7 @@ func TestGetECDSAPublicKey(t *testing.T) { | |||
| 			pkcs11.NewAttribute(pkcs11.CKA_EC_POINT, []byte{4, 0}), | ||||
| 		}, nil | ||||
| 	} | ||||
| 	_, err = GetECDSAPublicKey(ctx, 0, 0) | ||||
| 	_, err = s.GetECDSAPublicKey(0) | ||||
| 	test.AssertError(t, err, "ecPub didn't fail with invalid EC point (empty octet string)") | ||||
| 
 | ||||
| 	// test we fail to construct key with invalid EC point (octet string, invalid contents)
 | ||||
|  | @ -62,7 +63,7 @@ func TestGetECDSAPublicKey(t *testing.T) { | |||
| 			pkcs11.NewAttribute(pkcs11.CKA_EC_POINT, []byte{4, 4, 4, 1, 2, 3}), | ||||
| 		}, nil | ||||
| 	} | ||||
| 	_, err = GetECDSAPublicKey(ctx, 0, 0) | ||||
| 	_, err = s.GetECDSAPublicKey(0) | ||||
| 	test.AssertError(t, err, "ecPub didn't fail with invalid EC point (octet string, invalid contents)") | ||||
| 
 | ||||
| 	// test we don't fail with the correct attributes (traditional encoding)
 | ||||
|  | @ -72,7 +73,7 @@ func TestGetECDSAPublicKey(t *testing.T) { | |||
| 			pkcs11.NewAttribute(pkcs11.CKA_EC_POINT, []byte{4, 217, 225, 246, 210, 153, 134, 246, 104, 95, 79, 122, 206, 135, 241, 37, 114, 199, 87, 56, 167, 83, 56, 136, 174, 6, 145, 97, 239, 221, 49, 67, 148, 13, 126, 65, 90, 208, 195, 193, 171, 105, 40, 98, 132, 124, 30, 189, 215, 197, 178, 226, 166, 238, 240, 57, 215}), | ||||
| 		}, nil | ||||
| 	} | ||||
| 	_, err = GetECDSAPublicKey(ctx, 0, 0) | ||||
| 	_, err = s.GetECDSAPublicKey(0) | ||||
| 	test.AssertNotError(t, err, "ecPub failed with valid attributes (traditional encoding)") | ||||
| 
 | ||||
| 	// test we don't fail with the correct attributes (non-traditional encoding)
 | ||||
|  | @ -82,25 +83,26 @@ func TestGetECDSAPublicKey(t *testing.T) { | |||
| 			pkcs11.NewAttribute(pkcs11.CKA_EC_POINT, []byte{4, 57, 4, 217, 225, 246, 210, 153, 134, 246, 104, 95, 79, 122, 206, 135, 241, 37, 114, 199, 87, 56, 167, 83, 56, 136, 174, 6, 145, 97, 239, 221, 49, 67, 148, 13, 126, 65, 90, 208, 195, 193, 171, 105, 40, 98, 132, 124, 30, 189, 215, 197, 178, 226, 166, 238, 240, 57, 215}), | ||||
| 		}, nil | ||||
| 	} | ||||
| 	_, err = GetECDSAPublicKey(ctx, 0, 0) | ||||
| 	_, err = s.GetECDSAPublicKey(0) | ||||
| 	test.AssertNotError(t, err, "ecPub failed with valid attributes (non-traditional encoding)") | ||||
| } | ||||
| 
 | ||||
| func TestRSAPublicKey(t *testing.T) { | ||||
| 	ctx := MockCtx{} | ||||
| 	ctx := &MockCtx{} | ||||
| 	s := &Session{ctx, 0} | ||||
| 
 | ||||
| 	// test attribute retrieval failing
 | ||||
| 	ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) { | ||||
| 		return nil, errors.New("yup") | ||||
| 	} | ||||
| 	_, err := GetRSAPublicKey(ctx, 0, 0) | ||||
| 	_, err := s.GetRSAPublicKey(0) | ||||
| 	test.AssertError(t, err, "rsaPub didn't fail on GetAttributeValue error") | ||||
| 
 | ||||
| 	// test we fail to construct key with missing modulus and exp
 | ||||
| 	ctx.GetAttributeValueFunc = func(pkcs11.SessionHandle, pkcs11.ObjectHandle, []*pkcs11.Attribute) ([]*pkcs11.Attribute, error) { | ||||
| 		return []*pkcs11.Attribute{}, nil | ||||
| 	} | ||||
| 	_, err = GetRSAPublicKey(ctx, 0, 0) | ||||
| 	_, err = s.GetRSAPublicKey(0) | ||||
| 	test.AssertError(t, err, "rsaPub didn't fail with empty attribute list") | ||||
| 
 | ||||
| 	// test we don't fail with the correct attributes
 | ||||
|  | @ -110,27 +112,48 @@ func TestRSAPublicKey(t *testing.T) { | |||
| 			pkcs11.NewAttribute(pkcs11.CKA_MODULUS, []byte{255}), | ||||
| 		}, nil | ||||
| 	} | ||||
| 	_, err = GetRSAPublicKey(ctx, 0, 0) | ||||
| 	_, err = s.GetRSAPublicKey(0) | ||||
| 	test.AssertNotError(t, err, "rsaPub failed with valid attributes") | ||||
| } | ||||
| 
 | ||||
| func findObjectsFinalOK(pkcs11.SessionHandle) error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func findObjectsInitOK(pkcs11.SessionHandle, []*pkcs11.Attribute) error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func findObjectsOK(pkcs11.SessionHandle, int) ([]pkcs11.ObjectHandle, bool, error) { | ||||
| 	return []pkcs11.ObjectHandle{1}, false, nil | ||||
| } | ||||
| 
 | ||||
| func findObjectsFinalOK(pkcs11.SessionHandle) error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func newMock() *MockCtx { | ||||
| 	return &MockCtx{ | ||||
| 		FindObjectsInitFunc:  findObjectsInitOK, | ||||
| 		FindObjectsFunc:      findObjectsOK, | ||||
| 		FindObjectsFinalFunc: findObjectsFinalOK, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func newSessionWithMock() (*Session, *MockCtx) { | ||||
| 	ctx := newMock() | ||||
| 	return &Session{ctx, 0}, ctx | ||||
| } | ||||
| 
 | ||||
| func TestFindObjectFailsOnFailedInit(t *testing.T) { | ||||
| 	ctx := MockCtx{} | ||||
| 	ctx.FindObjectsFinalFunc = findObjectsFinalOK | ||||
| 	ctx.FindObjectsFunc = func(pkcs11.SessionHandle, int) ([]pkcs11.ObjectHandle, bool, error) { | ||||
| 		return []pkcs11.ObjectHandle{1}, false, nil | ||||
| 	} | ||||
| 
 | ||||
| 	// test FindObject fails when FindObjectsInit fails
 | ||||
| 	ctx.FindObjectsInitFunc = func(pkcs11.SessionHandle, []*pkcs11.Attribute) error { | ||||
| 		return errors.New("broken") | ||||
| 	} | ||||
| 	_, err := FindObject(ctx, 0, nil) | ||||
| 	s := &Session{ctx, 0} | ||||
| 	_, err := s.FindObject(nil) | ||||
| 	test.AssertError(t, err, "FindObject didn't fail when FindObjectsInit failed") | ||||
| } | ||||
| 
 | ||||
|  | @ -143,7 +166,8 @@ func TestFindObjectFailsOnFailedFindObjects(t *testing.T) { | |||
| 	ctx.FindObjectsFunc = func(pkcs11.SessionHandle, int) ([]pkcs11.ObjectHandle, bool, error) { | ||||
| 		return nil, false, errors.New("broken") | ||||
| 	} | ||||
| 	_, err := FindObject(ctx, 0, nil) | ||||
| 	s := &Session{ctx, 0} | ||||
| 	_, err := s.FindObject(nil) | ||||
| 	test.AssertError(t, err, "FindObject didn't fail when FindObjects failed") | ||||
| } | ||||
| 
 | ||||
|  | @ -156,7 +180,8 @@ func TestFindObjectFailsOnNoHandles(t *testing.T) { | |||
| 	ctx.FindObjectsFunc = func(pkcs11.SessionHandle, int) ([]pkcs11.ObjectHandle, bool, error) { | ||||
| 		return []pkcs11.ObjectHandle{}, false, nil | ||||
| 	} | ||||
| 	_, err := FindObject(ctx, 0, nil) | ||||
| 	s := &Session{ctx, 0} | ||||
| 	_, err := s.FindObject(nil) | ||||
| 	test.AssertEquals(t, err, ErrNoObject) | ||||
| } | ||||
| 
 | ||||
|  | @ -169,7 +194,8 @@ func TestFindObjectFailsOnMultipleHandles(t *testing.T) { | |||
| 	ctx.FindObjectsFunc = func(pkcs11.SessionHandle, int) ([]pkcs11.ObjectHandle, bool, error) { | ||||
| 		return []pkcs11.ObjectHandle{1, 2, 3}, false, nil | ||||
| 	} | ||||
| 	_, err := FindObject(ctx, 0, nil) | ||||
| 	s := &Session{ctx, 0} | ||||
| 	_, err := s.FindObject(nil) | ||||
| 	test.AssertError(t, err, "FindObject didn't fail when FindObjects returns multiple handles") | ||||
| 	test.Assert(t, strings.HasPrefix(err.Error(), "too many objects"), "FindObject failed with wrong error") | ||||
| } | ||||
|  | @ -185,20 +211,22 @@ func TestFindObjectFailsOnFinalizeFailure(t *testing.T) { | |||
| 	ctx.FindObjectsFinalFunc = func(pkcs11.SessionHandle) error { | ||||
| 		return errors.New("broken") | ||||
| 	} | ||||
| 	_, err := FindObject(ctx, 0, nil) | ||||
| 	s := &Session{ctx, 0} | ||||
| 	_, err := s.FindObject(nil) | ||||
| 	test.AssertError(t, err, "FindObject didn't fail when FindObjectsFinal fails") | ||||
| } | ||||
| 
 | ||||
| func TestFindObjectSucceeds(t *testing.T) { | ||||
| 	ctx := MockCtx{} | ||||
| 
 | ||||
| 	ctx.FindObjectsInitFunc = findObjectsInitOK | ||||
| 	ctx.FindObjectsFinalFunc = findObjectsFinalOK | ||||
| 	ctx.FindObjectsFunc = func(pkcs11.SessionHandle, int) ([]pkcs11.ObjectHandle, bool, error) { | ||||
| 		return []pkcs11.ObjectHandle{1}, false, nil | ||||
| 	} | ||||
| 	s := &Session{ctx, 0} | ||||
| 
 | ||||
| 	// test FindObject works
 | ||||
| 	handle, err := FindObject(ctx, 0, nil) | ||||
| 	handle, err := s.FindObject(nil) | ||||
| 	test.AssertNotError(t, err, "FindObject failed when everything worked as expected") | ||||
| 	test.AssertEquals(t, handle, pkcs11.ObjectHandle(1)) | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue