333 lines
11 KiB
Go
333 lines
11 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"crypto/x509"
|
|
"encoding/pem"
|
|
"errors"
|
|
"os"
|
|
"path"
|
|
"reflect"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/jmhodges/clock"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/protobuf/types/known/emptypb"
|
|
|
|
"github.com/letsencrypt/boulder/core"
|
|
corepb "github.com/letsencrypt/boulder/core/proto"
|
|
berrors "github.com/letsencrypt/boulder/errors"
|
|
blog "github.com/letsencrypt/boulder/log"
|
|
"github.com/letsencrypt/boulder/mocks"
|
|
rapb "github.com/letsencrypt/boulder/ra/proto"
|
|
"github.com/letsencrypt/boulder/revocation"
|
|
sapb "github.com/letsencrypt/boulder/sa/proto"
|
|
"github.com/letsencrypt/boulder/test"
|
|
)
|
|
|
|
// mockSAWithIncident is a mock which only implements the SerialsForIncident
|
|
// gRPC method. It can be initialized with a set of serials for that method
|
|
// to return.
|
|
type mockSAWithIncident struct {
|
|
sapb.StorageAuthorityReadOnlyClient
|
|
incidentSerials []string
|
|
}
|
|
|
|
// SerialsForIncident returns a fake gRPC stream client object which itself
|
|
// will return the mockSAWithIncident's serials in order.
|
|
func (msa *mockSAWithIncident) SerialsForIncident(_ context.Context, _ *sapb.SerialsForIncidentRequest, _ ...grpc.CallOption) (grpc.ServerStreamingClient[sapb.IncidentSerial], error) {
|
|
fakeResults := make([]*sapb.IncidentSerial, len(msa.incidentSerials))
|
|
for i, serial := range msa.incidentSerials {
|
|
fakeResults[i] = &sapb.IncidentSerial{Serial: serial}
|
|
}
|
|
return &mocks.ServerStreamClient[sapb.IncidentSerial]{Results: fakeResults}, nil
|
|
}
|
|
|
|
func TestSerialsFromIncidentTable(t *testing.T) {
|
|
t.Parallel()
|
|
serials := []string{"foo", "bar", "baz"}
|
|
|
|
a := admin{
|
|
saroc: &mockSAWithIncident{incidentSerials: serials},
|
|
}
|
|
|
|
res, err := a.serialsFromIncidentTable(context.Background(), "tablename")
|
|
test.AssertNotError(t, err, "getting serials from mock SA")
|
|
test.AssertDeepEquals(t, res, serials)
|
|
}
|
|
|
|
func TestSerialsFromFile(t *testing.T) {
|
|
t.Parallel()
|
|
serials := []string{"foo", "bar", "baz"}
|
|
|
|
serialsFile := path.Join(t.TempDir(), "serials.txt")
|
|
err := os.WriteFile(serialsFile, []byte(strings.Join(serials, "\n")), os.ModeAppend)
|
|
test.AssertNotError(t, err, "writing temp serials file")
|
|
|
|
a := admin{}
|
|
|
|
res, err := a.serialsFromFile(context.Background(), serialsFile)
|
|
test.AssertNotError(t, err, "getting serials from file")
|
|
test.AssertDeepEquals(t, res, serials)
|
|
}
|
|
|
|
// mockSAWithKey is a mock which only implements the GetSerialsByKey
|
|
// gRPC method. It can be initialized with a set of serials for that method
|
|
// to return.
|
|
type mockSAWithKey struct {
|
|
sapb.StorageAuthorityReadOnlyClient
|
|
keyHash []byte
|
|
serials []string
|
|
}
|
|
|
|
// GetSerialsByKey returns a fake gRPC stream client object which itself
|
|
// will return the mockSAWithKey's serials in order.
|
|
func (msa *mockSAWithKey) GetSerialsByKey(_ context.Context, req *sapb.SPKIHash, _ ...grpc.CallOption) (grpc.ServerStreamingClient[sapb.Serial], error) {
|
|
if !slices.Equal(req.KeyHash, msa.keyHash) {
|
|
return &mocks.ServerStreamClient[sapb.Serial]{}, nil
|
|
}
|
|
fakeResults := make([]*sapb.Serial, len(msa.serials))
|
|
for i, serial := range msa.serials {
|
|
fakeResults[i] = &sapb.Serial{Serial: serial}
|
|
}
|
|
return &mocks.ServerStreamClient[sapb.Serial]{Results: fakeResults}, nil
|
|
}
|
|
|
|
func TestSerialsFromPrivateKey(t *testing.T) {
|
|
serials := []string{"foo", "bar", "baz"}
|
|
fc := clock.NewFake()
|
|
fc.Set(time.Now())
|
|
|
|
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
test.AssertNotError(t, err, "creating test private key")
|
|
keyBytes, err := x509.MarshalPKCS8PrivateKey(privKey)
|
|
test.AssertNotError(t, err, "marshalling test private key bytes")
|
|
|
|
keyFile := path.Join(t.TempDir(), "key.pem")
|
|
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyBytes})
|
|
err = os.WriteFile(keyFile, keyPEM, os.ModeAppend)
|
|
test.AssertNotError(t, err, "writing test private key file")
|
|
|
|
keyHash, err := core.KeyDigest(privKey.Public())
|
|
test.AssertNotError(t, err, "computing test SPKI hash")
|
|
|
|
a := admin{saroc: &mockSAWithKey{keyHash: keyHash[:], serials: serials}}
|
|
|
|
res, err := a.serialsFromPrivateKey(context.Background(), keyFile)
|
|
test.AssertNotError(t, err, "getting serials from keyHashToSerial table")
|
|
test.AssertDeepEquals(t, res, serials)
|
|
}
|
|
|
|
// mockSAWithAccount is a mock which only implements the GetSerialsByAccount
|
|
// gRPC method. It can be initialized with a set of serials for that method
|
|
// to return.
|
|
type mockSAWithAccount struct {
|
|
sapb.StorageAuthorityReadOnlyClient
|
|
regID int64
|
|
serials []string
|
|
}
|
|
|
|
func (msa *mockSAWithAccount) GetRegistration(_ context.Context, req *sapb.RegistrationID, _ ...grpc.CallOption) (*corepb.Registration, error) {
|
|
if req.Id != msa.regID {
|
|
return nil, errors.New("no such reg")
|
|
}
|
|
return &corepb.Registration{}, nil
|
|
}
|
|
|
|
// GetSerialsByAccount returns a fake gRPC stream client object which itself
|
|
// will return the mockSAWithAccount's serials in order.
|
|
func (msa *mockSAWithAccount) GetSerialsByAccount(_ context.Context, req *sapb.RegistrationID, _ ...grpc.CallOption) (grpc.ServerStreamingClient[sapb.Serial], error) {
|
|
if req.Id != msa.regID {
|
|
return &mocks.ServerStreamClient[sapb.Serial]{}, nil
|
|
}
|
|
fakeResults := make([]*sapb.Serial, len(msa.serials))
|
|
for i, serial := range msa.serials {
|
|
fakeResults[i] = &sapb.Serial{Serial: serial}
|
|
}
|
|
return &mocks.ServerStreamClient[sapb.Serial]{Results: fakeResults}, nil
|
|
}
|
|
|
|
func TestSerialsFromRegID(t *testing.T) {
|
|
serials := []string{"foo", "bar", "baz"}
|
|
a := admin{saroc: &mockSAWithAccount{regID: 123, serials: serials}}
|
|
|
|
res, err := a.serialsFromRegID(context.Background(), 123)
|
|
test.AssertNotError(t, err, "getting serials from serials table")
|
|
test.AssertDeepEquals(t, res, serials)
|
|
}
|
|
|
|
// mockRARecordingRevocations is a mock which only implements the
|
|
// AdministrativelyRevokeCertificate gRPC method. It can be initialized with
|
|
// serials to recognize as already revoked, or to fail.
|
|
type mockRARecordingRevocations struct {
|
|
rapb.RegistrationAuthorityClient
|
|
doomedToFail []string
|
|
alreadyRevoked []string
|
|
revocationRequests []*rapb.AdministrativelyRevokeCertificateRequest
|
|
sync.Mutex
|
|
}
|
|
|
|
// AdministrativelyRevokeCertificate records the request it received on the mock
|
|
// RA struct, and succeeds if it doesn't recognize the serial as one it should
|
|
// fail for.
|
|
func (mra *mockRARecordingRevocations) AdministrativelyRevokeCertificate(_ context.Context, req *rapb.AdministrativelyRevokeCertificateRequest, _ ...grpc.CallOption) (*emptypb.Empty, error) {
|
|
mra.Lock()
|
|
defer mra.Unlock()
|
|
mra.revocationRequests = append(mra.revocationRequests, req)
|
|
if slices.Contains(mra.doomedToFail, req.Serial) {
|
|
return nil, errors.New("oops")
|
|
}
|
|
if slices.Contains(mra.alreadyRevoked, req.Serial) {
|
|
return nil, berrors.AlreadyRevokedError("too slow")
|
|
}
|
|
return &emptypb.Empty{}, nil
|
|
}
|
|
|
|
func (mra *mockRARecordingRevocations) reset() {
|
|
mra.doomedToFail = nil
|
|
mra.alreadyRevoked = nil
|
|
mra.revocationRequests = nil
|
|
}
|
|
|
|
func TestRevokeSerials(t *testing.T) {
|
|
t.Parallel()
|
|
serials := []string{
|
|
"2a18592b7f4bf596fb1a1df135567acd825a",
|
|
"038c3f6388afb7695dd4d6bbe3d264f1e4e2",
|
|
"048c3f6388afb7695dd4d6bbe3d264f1e5e5",
|
|
}
|
|
mra := mockRARecordingRevocations{}
|
|
log := blog.NewMock()
|
|
a := admin{rac: &mra, log: log}
|
|
|
|
assertRequestsContain := func(reqs []*rapb.AdministrativelyRevokeCertificateRequest, code revocation.Reason, skipBlockKey bool) {
|
|
t.Helper()
|
|
for _, req := range reqs {
|
|
test.AssertEquals(t, len(req.Cert), 0)
|
|
test.AssertEquals(t, req.Code, int64(code))
|
|
test.AssertEquals(t, req.SkipBlockKey, skipBlockKey)
|
|
}
|
|
}
|
|
|
|
// Revoking should result in 3 gRPC requests and quiet execution.
|
|
mra.reset()
|
|
log.Clear()
|
|
a.dryRun = false
|
|
err := a.revokeSerials(context.Background(), serials, 0, false, 1)
|
|
test.AssertEquals(t, len(log.GetAllMatching("invalid serial format")), 0)
|
|
test.AssertNotError(t, err, "")
|
|
test.AssertEquals(t, len(log.GetAll()), 0)
|
|
test.AssertEquals(t, len(mra.revocationRequests), 3)
|
|
assertRequestsContain(mra.revocationRequests, 0, false)
|
|
|
|
// Revoking an already-revoked serial should result in one log line.
|
|
mra.reset()
|
|
log.Clear()
|
|
mra.alreadyRevoked = []string{"048c3f6388afb7695dd4d6bbe3d264f1e5e5"}
|
|
err = a.revokeSerials(context.Background(), serials, 0, false, 1)
|
|
t.Logf("error: %s", err)
|
|
t.Logf("logs: %s", strings.Join(log.GetAll(), ""))
|
|
test.AssertError(t, err, "already-revoked should result in error")
|
|
test.AssertEquals(t, len(log.GetAllMatching("not revoking")), 1)
|
|
test.AssertEquals(t, len(mra.revocationRequests), 3)
|
|
assertRequestsContain(mra.revocationRequests, 0, false)
|
|
|
|
// Revoking a doomed-to-fail serial should also result in one log line.
|
|
mra.reset()
|
|
log.Clear()
|
|
mra.doomedToFail = []string{"048c3f6388afb7695dd4d6bbe3d264f1e5e5"}
|
|
err = a.revokeSerials(context.Background(), serials, 0, false, 1)
|
|
test.AssertError(t, err, "gRPC error should result in error")
|
|
test.AssertEquals(t, len(log.GetAllMatching("failed to revoke")), 1)
|
|
test.AssertEquals(t, len(mra.revocationRequests), 3)
|
|
assertRequestsContain(mra.revocationRequests, 0, false)
|
|
|
|
// Revoking with other parameters should get carried through.
|
|
mra.reset()
|
|
log.Clear()
|
|
err = a.revokeSerials(context.Background(), serials, 1, true, 3)
|
|
test.AssertNotError(t, err, "")
|
|
test.AssertEquals(t, len(mra.revocationRequests), 3)
|
|
assertRequestsContain(mra.revocationRequests, 1, true)
|
|
|
|
// Revoking in dry-run mode should result in no gRPC requests and three logs.
|
|
mra.reset()
|
|
log.Clear()
|
|
a.dryRun = true
|
|
a.rac = dryRunRAC{log: log}
|
|
err = a.revokeSerials(context.Background(), serials, 0, false, 1)
|
|
test.AssertNotError(t, err, "")
|
|
test.AssertEquals(t, len(log.GetAllMatching("dry-run:")), 3)
|
|
test.AssertEquals(t, len(mra.revocationRequests), 0)
|
|
assertRequestsContain(mra.revocationRequests, 0, false)
|
|
}
|
|
|
|
func TestRevokeMalformed(t *testing.T) {
|
|
t.Parallel()
|
|
mra := mockRARecordingRevocations{}
|
|
log := blog.NewMock()
|
|
a := &admin{
|
|
rac: &mra,
|
|
log: log,
|
|
dryRun: false,
|
|
}
|
|
|
|
s := subcommandRevokeCert{
|
|
crlShard: 623,
|
|
}
|
|
serial := "0379c3dfdd518be45948f2dbfa6ea3e9b209"
|
|
err := s.revokeMalformed(context.Background(), a, []string{serial}, 1)
|
|
if err != nil {
|
|
t.Errorf("revokedMalformed with crlShard 623: want success, got %s", err)
|
|
}
|
|
if len(mra.revocationRequests) != 1 {
|
|
t.Errorf("revokeMalformed: want 1 revocation request to SA, got %v", mra.revocationRequests)
|
|
}
|
|
if mra.revocationRequests[0].Serial != serial {
|
|
t.Errorf("revokeMalformed: want %s to be revoked, got %s", serial, mra.revocationRequests[0])
|
|
}
|
|
|
|
s = subcommandRevokeCert{
|
|
crlShard: 0,
|
|
}
|
|
err = s.revokeMalformed(context.Background(), a, []string{"038c3f6388afb7695dd4d6bbe3d264f1e4e2"}, 1)
|
|
if err == nil {
|
|
t.Errorf("revokedMalformed with crlShard 0: want error, got none")
|
|
}
|
|
|
|
s = subcommandRevokeCert{
|
|
crlShard: 623,
|
|
}
|
|
err = s.revokeMalformed(context.Background(), a, []string{"038c3f6388afb7695dd4d6bbe3d264f1e4e2", "28a94f966eae14e525777188512ddf5a0a3b"}, 1)
|
|
if err == nil {
|
|
t.Errorf("revokedMalformed with multiple serials: want error, got none")
|
|
}
|
|
}
|
|
|
|
func TestCleanSerials(t *testing.T) {
|
|
input := []string{
|
|
"2a:18:59:2b:7f:4b:f5:96:fb:1a:1d:f1:35:56:7a:cd:82:5a",
|
|
"03:8c:3f:63:88:af:b7:69:5d:d4:d6:bb:e3:d2:64:f1:e4:e2",
|
|
"038c3f6388afb7695dd4d6bbe3d264f1e4e2",
|
|
}
|
|
expected := []string{
|
|
"2a18592b7f4bf596fb1a1df135567acd825a",
|
|
"038c3f6388afb7695dd4d6bbe3d264f1e4e2",
|
|
"038c3f6388afb7695dd4d6bbe3d264f1e4e2",
|
|
}
|
|
output, err := cleanSerials(input)
|
|
if err != nil {
|
|
t.Errorf("cleanSerials(%s): %s, want %s", input, err, expected)
|
|
}
|
|
if !reflect.DeepEqual(output, expected) {
|
|
t.Errorf("cleanSerials(%s)=%s, want %s", input, output, expected)
|
|
}
|
|
}
|