boulder/cmd/admin/cert_test.go

333 lines
11 KiB
Go

package main
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"encoding/pem"
"errors"
"os"
"path"
"reflect"
"slices"
"strings"
"sync"
"testing"
"time"
"github.com/jmhodges/clock"
"google.golang.org/grpc"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/letsencrypt/boulder/core"
corepb "github.com/letsencrypt/boulder/core/proto"
berrors "github.com/letsencrypt/boulder/errors"
blog "github.com/letsencrypt/boulder/log"
"github.com/letsencrypt/boulder/mocks"
rapb "github.com/letsencrypt/boulder/ra/proto"
"github.com/letsencrypt/boulder/revocation"
sapb "github.com/letsencrypt/boulder/sa/proto"
"github.com/letsencrypt/boulder/test"
)
// mockSAWithIncident is a mock which only implements the SerialsForIncident
// gRPC method. It can be initialized with a set of serials for that method
// to return.
type mockSAWithIncident struct {
sapb.StorageAuthorityReadOnlyClient
incidentSerials []string
}
// SerialsForIncident returns a fake gRPC stream client object which itself
// will return the mockSAWithIncident's serials in order.
func (msa *mockSAWithIncident) SerialsForIncident(_ context.Context, _ *sapb.SerialsForIncidentRequest, _ ...grpc.CallOption) (grpc.ServerStreamingClient[sapb.IncidentSerial], error) {
fakeResults := make([]*sapb.IncidentSerial, len(msa.incidentSerials))
for i, serial := range msa.incidentSerials {
fakeResults[i] = &sapb.IncidentSerial{Serial: serial}
}
return &mocks.ServerStreamClient[sapb.IncidentSerial]{Results: fakeResults}, nil
}
func TestSerialsFromIncidentTable(t *testing.T) {
t.Parallel()
serials := []string{"foo", "bar", "baz"}
a := admin{
saroc: &mockSAWithIncident{incidentSerials: serials},
}
res, err := a.serialsFromIncidentTable(context.Background(), "tablename")
test.AssertNotError(t, err, "getting serials from mock SA")
test.AssertDeepEquals(t, res, serials)
}
func TestSerialsFromFile(t *testing.T) {
t.Parallel()
serials := []string{"foo", "bar", "baz"}
serialsFile := path.Join(t.TempDir(), "serials.txt")
err := os.WriteFile(serialsFile, []byte(strings.Join(serials, "\n")), os.ModeAppend)
test.AssertNotError(t, err, "writing temp serials file")
a := admin{}
res, err := a.serialsFromFile(context.Background(), serialsFile)
test.AssertNotError(t, err, "getting serials from file")
test.AssertDeepEquals(t, res, serials)
}
// mockSAWithKey is a mock which only implements the GetSerialsByKey
// gRPC method. It can be initialized with a set of serials for that method
// to return.
type mockSAWithKey struct {
sapb.StorageAuthorityReadOnlyClient
keyHash []byte
serials []string
}
// GetSerialsByKey returns a fake gRPC stream client object which itself
// will return the mockSAWithKey's serials in order.
func (msa *mockSAWithKey) GetSerialsByKey(_ context.Context, req *sapb.SPKIHash, _ ...grpc.CallOption) (grpc.ServerStreamingClient[sapb.Serial], error) {
if !slices.Equal(req.KeyHash, msa.keyHash) {
return &mocks.ServerStreamClient[sapb.Serial]{}, nil
}
fakeResults := make([]*sapb.Serial, len(msa.serials))
for i, serial := range msa.serials {
fakeResults[i] = &sapb.Serial{Serial: serial}
}
return &mocks.ServerStreamClient[sapb.Serial]{Results: fakeResults}, nil
}
func TestSerialsFromPrivateKey(t *testing.T) {
serials := []string{"foo", "bar", "baz"}
fc := clock.NewFake()
fc.Set(time.Now())
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
test.AssertNotError(t, err, "creating test private key")
keyBytes, err := x509.MarshalPKCS8PrivateKey(privKey)
test.AssertNotError(t, err, "marshalling test private key bytes")
keyFile := path.Join(t.TempDir(), "key.pem")
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyBytes})
err = os.WriteFile(keyFile, keyPEM, os.ModeAppend)
test.AssertNotError(t, err, "writing test private key file")
keyHash, err := core.KeyDigest(privKey.Public())
test.AssertNotError(t, err, "computing test SPKI hash")
a := admin{saroc: &mockSAWithKey{keyHash: keyHash[:], serials: serials}}
res, err := a.serialsFromPrivateKey(context.Background(), keyFile)
test.AssertNotError(t, err, "getting serials from keyHashToSerial table")
test.AssertDeepEquals(t, res, serials)
}
// mockSAWithAccount is a mock which only implements the GetSerialsByAccount
// gRPC method. It can be initialized with a set of serials for that method
// to return.
type mockSAWithAccount struct {
sapb.StorageAuthorityReadOnlyClient
regID int64
serials []string
}
func (msa *mockSAWithAccount) GetRegistration(_ context.Context, req *sapb.RegistrationID, _ ...grpc.CallOption) (*corepb.Registration, error) {
if req.Id != msa.regID {
return nil, errors.New("no such reg")
}
return &corepb.Registration{}, nil
}
// GetSerialsByAccount returns a fake gRPC stream client object which itself
// will return the mockSAWithAccount's serials in order.
func (msa *mockSAWithAccount) GetSerialsByAccount(_ context.Context, req *sapb.RegistrationID, _ ...grpc.CallOption) (grpc.ServerStreamingClient[sapb.Serial], error) {
if req.Id != msa.regID {
return &mocks.ServerStreamClient[sapb.Serial]{}, nil
}
fakeResults := make([]*sapb.Serial, len(msa.serials))
for i, serial := range msa.serials {
fakeResults[i] = &sapb.Serial{Serial: serial}
}
return &mocks.ServerStreamClient[sapb.Serial]{Results: fakeResults}, nil
}
func TestSerialsFromRegID(t *testing.T) {
serials := []string{"foo", "bar", "baz"}
a := admin{saroc: &mockSAWithAccount{regID: 123, serials: serials}}
res, err := a.serialsFromRegID(context.Background(), 123)
test.AssertNotError(t, err, "getting serials from serials table")
test.AssertDeepEquals(t, res, serials)
}
// mockRARecordingRevocations is a mock which only implements the
// AdministrativelyRevokeCertificate gRPC method. It can be initialized with
// serials to recognize as already revoked, or to fail.
type mockRARecordingRevocations struct {
rapb.RegistrationAuthorityClient
doomedToFail []string
alreadyRevoked []string
revocationRequests []*rapb.AdministrativelyRevokeCertificateRequest
sync.Mutex
}
// AdministrativelyRevokeCertificate records the request it received on the mock
// RA struct, and succeeds if it doesn't recognize the serial as one it should
// fail for.
func (mra *mockRARecordingRevocations) AdministrativelyRevokeCertificate(_ context.Context, req *rapb.AdministrativelyRevokeCertificateRequest, _ ...grpc.CallOption) (*emptypb.Empty, error) {
mra.Lock()
defer mra.Unlock()
mra.revocationRequests = append(mra.revocationRequests, req)
if slices.Contains(mra.doomedToFail, req.Serial) {
return nil, errors.New("oops")
}
if slices.Contains(mra.alreadyRevoked, req.Serial) {
return nil, berrors.AlreadyRevokedError("too slow")
}
return &emptypb.Empty{}, nil
}
func (mra *mockRARecordingRevocations) reset() {
mra.doomedToFail = nil
mra.alreadyRevoked = nil
mra.revocationRequests = nil
}
func TestRevokeSerials(t *testing.T) {
t.Parallel()
serials := []string{
"2a18592b7f4bf596fb1a1df135567acd825a",
"038c3f6388afb7695dd4d6bbe3d264f1e4e2",
"048c3f6388afb7695dd4d6bbe3d264f1e5e5",
}
mra := mockRARecordingRevocations{}
log := blog.NewMock()
a := admin{rac: &mra, log: log}
assertRequestsContain := func(reqs []*rapb.AdministrativelyRevokeCertificateRequest, code revocation.Reason, skipBlockKey bool) {
t.Helper()
for _, req := range reqs {
test.AssertEquals(t, len(req.Cert), 0)
test.AssertEquals(t, req.Code, int64(code))
test.AssertEquals(t, req.SkipBlockKey, skipBlockKey)
}
}
// Revoking should result in 3 gRPC requests and quiet execution.
mra.reset()
log.Clear()
a.dryRun = false
err := a.revokeSerials(context.Background(), serials, 0, false, 1)
test.AssertEquals(t, len(log.GetAllMatching("invalid serial format")), 0)
test.AssertNotError(t, err, "")
test.AssertEquals(t, len(log.GetAll()), 0)
test.AssertEquals(t, len(mra.revocationRequests), 3)
assertRequestsContain(mra.revocationRequests, 0, false)
// Revoking an already-revoked serial should result in one log line.
mra.reset()
log.Clear()
mra.alreadyRevoked = []string{"048c3f6388afb7695dd4d6bbe3d264f1e5e5"}
err = a.revokeSerials(context.Background(), serials, 0, false, 1)
t.Logf("error: %s", err)
t.Logf("logs: %s", strings.Join(log.GetAll(), ""))
test.AssertError(t, err, "already-revoked should result in error")
test.AssertEquals(t, len(log.GetAllMatching("not revoking")), 1)
test.AssertEquals(t, len(mra.revocationRequests), 3)
assertRequestsContain(mra.revocationRequests, 0, false)
// Revoking a doomed-to-fail serial should also result in one log line.
mra.reset()
log.Clear()
mra.doomedToFail = []string{"048c3f6388afb7695dd4d6bbe3d264f1e5e5"}
err = a.revokeSerials(context.Background(), serials, 0, false, 1)
test.AssertError(t, err, "gRPC error should result in error")
test.AssertEquals(t, len(log.GetAllMatching("failed to revoke")), 1)
test.AssertEquals(t, len(mra.revocationRequests), 3)
assertRequestsContain(mra.revocationRequests, 0, false)
// Revoking with other parameters should get carried through.
mra.reset()
log.Clear()
err = a.revokeSerials(context.Background(), serials, 1, true, 3)
test.AssertNotError(t, err, "")
test.AssertEquals(t, len(mra.revocationRequests), 3)
assertRequestsContain(mra.revocationRequests, 1, true)
// Revoking in dry-run mode should result in no gRPC requests and three logs.
mra.reset()
log.Clear()
a.dryRun = true
a.rac = dryRunRAC{log: log}
err = a.revokeSerials(context.Background(), serials, 0, false, 1)
test.AssertNotError(t, err, "")
test.AssertEquals(t, len(log.GetAllMatching("dry-run:")), 3)
test.AssertEquals(t, len(mra.revocationRequests), 0)
assertRequestsContain(mra.revocationRequests, 0, false)
}
func TestRevokeMalformed(t *testing.T) {
t.Parallel()
mra := mockRARecordingRevocations{}
log := blog.NewMock()
a := &admin{
rac: &mra,
log: log,
dryRun: false,
}
s := subcommandRevokeCert{
crlShard: 623,
}
serial := "0379c3dfdd518be45948f2dbfa6ea3e9b209"
err := s.revokeMalformed(context.Background(), a, []string{serial}, 1)
if err != nil {
t.Errorf("revokedMalformed with crlShard 623: want success, got %s", err)
}
if len(mra.revocationRequests) != 1 {
t.Errorf("revokeMalformed: want 1 revocation request to SA, got %v", mra.revocationRequests)
}
if mra.revocationRequests[0].Serial != serial {
t.Errorf("revokeMalformed: want %s to be revoked, got %s", serial, mra.revocationRequests[0])
}
s = subcommandRevokeCert{
crlShard: 0,
}
err = s.revokeMalformed(context.Background(), a, []string{"038c3f6388afb7695dd4d6bbe3d264f1e4e2"}, 1)
if err == nil {
t.Errorf("revokedMalformed with crlShard 0: want error, got none")
}
s = subcommandRevokeCert{
crlShard: 623,
}
err = s.revokeMalformed(context.Background(), a, []string{"038c3f6388afb7695dd4d6bbe3d264f1e4e2", "28a94f966eae14e525777188512ddf5a0a3b"}, 1)
if err == nil {
t.Errorf("revokedMalformed with multiple serials: want error, got none")
}
}
func TestCleanSerials(t *testing.T) {
input := []string{
"2a:18:59:2b:7f:4b:f5:96:fb:1a:1d:f1:35:56:7a:cd:82:5a",
"03:8c:3f:63:88:af:b7:69:5d:d4:d6:bb:e3:d2:64:f1:e4:e2",
"038c3f6388afb7695dd4d6bbe3d264f1e4e2",
}
expected := []string{
"2a18592b7f4bf596fb1a1df135567acd825a",
"038c3f6388afb7695dd4d6bbe3d264f1e4e2",
"038c3f6388afb7695dd4d6bbe3d264f1e4e2",
}
output, err := cleanSerials(input)
if err != nil {
t.Errorf("cleanSerials(%s): %s, want %s", input, err, expected)
}
if !reflect.DeepEqual(output, expected) {
t.Errorf("cleanSerials(%s)=%s, want %s", input, output, expected)
}
}