Add GetValidAuthorizations to batch authz checks
By performing only one query to MySQL, we should be able to avoid blowing the timeouts. Fixes #1567
This commit is contained in:
parent
a220ee0ad8
commit
91bc75b0e3
|
|
@ -97,6 +97,7 @@ type StorageGetter interface {
|
||||||
GetRegistrationByKey(jose.JsonWebKey) (Registration, error)
|
GetRegistrationByKey(jose.JsonWebKey) (Registration, error)
|
||||||
GetAuthorization(string) (Authorization, error)
|
GetAuthorization(string) (Authorization, error)
|
||||||
GetLatestValidAuthorization(int64, AcmeIdentifier) (Authorization, error)
|
GetLatestValidAuthorization(int64, AcmeIdentifier) (Authorization, error)
|
||||||
|
GetValidAuthorizations(int64, []string, time.Time) (map[string]*Authorization, error)
|
||||||
GetCertificate(string) (Certificate, error)
|
GetCertificate(string) (Certificate, error)
|
||||||
GetCertificateStatus(string) (CertificateStatus, error)
|
GetCertificateStatus(string) (CertificateStatus, error)
|
||||||
AlreadyDeniedCSR([]string) (bool, error)
|
AlreadyDeniedCSR([]string) (bool, error)
|
||||||
|
|
|
||||||
|
|
@ -268,6 +268,29 @@ func (sa *StorageAuthority) GetLatestValidAuthorization(registrationID int64, id
|
||||||
return core.Authorization{}, errors.New("no authz")
|
return core.Authorization{}, errors.New("no authz")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetValidAuthorizations is a mock
|
||||||
|
func (sa *StorageAuthority) GetValidAuthorizations(regID int64, names []string, now time.Time) (map[string]*core.Authorization, error) {
|
||||||
|
if regID == 1 {
|
||||||
|
auths := make(map[string]*core.Authorization)
|
||||||
|
for _, name := range names {
|
||||||
|
if sa.authorizedDomains[name] || name == "not-an-example.com" {
|
||||||
|
exp := now.AddDate(100, 0, 0)
|
||||||
|
auths[name] = &core.Authorization{
|
||||||
|
Status: core.StatusValid,
|
||||||
|
RegistrationID: 1,
|
||||||
|
Expires: &exp,
|
||||||
|
Identifier: core.AcmeIdentifier{
|
||||||
|
Type: "dns",
|
||||||
|
Value: name,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return auths, nil
|
||||||
|
}
|
||||||
|
return nil, errors.New("no authz")
|
||||||
|
}
|
||||||
|
|
||||||
// CountCertificatesRange is a mock
|
// CountCertificatesRange is a mock
|
||||||
func (sa *StorageAuthority) CountCertificatesRange(_, _ time.Time) (int64, error) {
|
func (sa *StorageAuthority) CountCertificatesRange(_, _ time.Time) (int64, error) {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
|
|
|
||||||
|
|
@ -452,9 +452,16 @@ func (ra *RegistrationAuthorityImpl) MatchesCSR(cert core.Certificate, csr *x509
|
||||||
func (ra *RegistrationAuthorityImpl) checkAuthorizations(names []string, registration *core.Registration) error {
|
func (ra *RegistrationAuthorityImpl) checkAuthorizations(names []string, registration *core.Registration) error {
|
||||||
now := ra.clk.Now()
|
now := ra.clk.Now()
|
||||||
var badNames []string
|
var badNames []string
|
||||||
|
for i := range names {
|
||||||
|
names[i] = strings.ToLower(names[i])
|
||||||
|
}
|
||||||
|
auths, err := ra.SA.GetValidAuthorizations(registration.ID, names, now)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
for _, name := range names {
|
for _, name := range names {
|
||||||
authz, err := ra.SA.GetLatestValidAuthorization(registration.ID, core.AcmeIdentifier{Type: core.IdentifierDNS, Value: name})
|
authz := auths[name]
|
||||||
if err != nil || authz.Expires.Before(now) {
|
if authz == nil || authz.Expires.Before(now) {
|
||||||
badNames = append(badNames, name)
|
badNames = append(badNames, name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -53,6 +53,7 @@ const (
|
||||||
MethodGetRegistrationByKey = "GetRegistrationByKey" // RA, SA
|
MethodGetRegistrationByKey = "GetRegistrationByKey" // RA, SA
|
||||||
MethodGetAuthorization = "GetAuthorization" // SA
|
MethodGetAuthorization = "GetAuthorization" // SA
|
||||||
MethodGetLatestValidAuthorization = "GetLatestValidAuthorization" // SA
|
MethodGetLatestValidAuthorization = "GetLatestValidAuthorization" // SA
|
||||||
|
MethodGetValidAuthorizations = "GetValidAuthorizations" // SA
|
||||||
MethodGetCertificate = "GetCertificate" // SA
|
MethodGetCertificate = "GetCertificate" // SA
|
||||||
MethodGetCertificateStatus = "GetCertificateStatus" // SA
|
MethodGetCertificateStatus = "GetCertificateStatus" // SA
|
||||||
MethodMarkCertificateRevoked = "MarkCertificateRevoked" // SA
|
MethodMarkCertificateRevoked = "MarkCertificateRevoked" // SA
|
||||||
|
|
@ -103,6 +104,12 @@ type latestValidAuthorizationRequest struct {
|
||||||
Identifier core.AcmeIdentifier
|
Identifier core.AcmeIdentifier
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type getValidAuthorizationsRequest struct {
|
||||||
|
RegID int64
|
||||||
|
Names []string
|
||||||
|
Now time.Time
|
||||||
|
}
|
||||||
|
|
||||||
type certificateRequest struct {
|
type certificateRequest struct {
|
||||||
Req core.CertificateRequest
|
Req core.CertificateRequest
|
||||||
RegID int64
|
RegID int64
|
||||||
|
|
@ -846,6 +853,28 @@ func NewStorageAuthorityServer(rpc Server, impl core.StorageAuthority) error {
|
||||||
return
|
return
|
||||||
})
|
})
|
||||||
|
|
||||||
|
rpc.Handle(MethodGetValidAuthorizations, func(req []byte) (response []byte, err error) {
|
||||||
|
var mreq getValidAuthorizationsRequest
|
||||||
|
if err = json.Unmarshal(req, &mreq); err != nil {
|
||||||
|
// AUDIT[ Improper Messages ] 0786b6f2-91ca-4f48-9883-842a19084c64
|
||||||
|
improperMessage(MethodGetValidAuthorizations, err, req)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
auths, err := impl.GetValidAuthorizations(mreq.RegID, mreq.Names, mreq.Now)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err = json.Marshal(auths)
|
||||||
|
if err != nil {
|
||||||
|
// AUDIT[ Error Conditions ] 9cc4d537-8534-4970-8665-4b382abe82f3
|
||||||
|
errorCondition(MethodGetValidAuthorizations, err, req)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
})
|
||||||
|
|
||||||
rpc.Handle(MethodAddCertificate, func(req []byte) (response []byte, err error) {
|
rpc.Handle(MethodAddCertificate, func(req []byte) (response []byte, err error) {
|
||||||
var acReq addCertificateRequest
|
var acReq addCertificateRequest
|
||||||
err = json.Unmarshal(req, &acReq)
|
err = json.Unmarshal(req, &acReq)
|
||||||
|
|
@ -1256,6 +1285,28 @@ func (cac StorageAuthorityClient) GetLatestValidAuthorization(registrationID int
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetValidAuthorizations sends a request to get a batch of Authorizations by
|
||||||
|
// RegID and dnsName. The current time is also included in the request to
|
||||||
|
// assist filtering.
|
||||||
|
func (cac StorageAuthorityClient) GetValidAuthorizations(registrationID int64, names []string, now time.Time) (auths map[string]*core.Authorization, err error) {
|
||||||
|
data, err := json.Marshal(getValidAuthorizationsRequest{
|
||||||
|
RegID: registrationID,
|
||||||
|
Names: names,
|
||||||
|
Now: now,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonAuths, err := cac.rpc.DispatchSync(MethodGetValidAuthorizations, data)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.Unmarshal(jsonAuths, &auths)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// GetCertificate sends a request to get a Certificate by ID
|
// GetCertificate sends a request to get a Certificate by ID
|
||||||
func (cac StorageAuthorityClient) GetCertificate(id string) (cert core.Certificate, err error) {
|
func (cac StorageAuthorityClient) GetCertificate(id string) (cert core.Certificate, err error) {
|
||||||
jsonCert, err := cac.rpc.DispatchSync(MethodGetCertificate, []byte(id))
|
jsonCert, err := cac.rpc.DispatchSync(MethodGetCertificate, []byte(id))
|
||||||
|
|
|
||||||
|
|
@ -232,6 +232,51 @@ func (ssa *SQLStorageAuthority) GetLatestValidAuthorization(registrationID int64
|
||||||
return ssa.GetAuthorization(auth.ID)
|
return ssa.GetAuthorization(auth.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetValidAuthorizations returns the latest authorization object for all
|
||||||
|
// domain names from the parameters that the account has authorizations for.
|
||||||
|
func (ssa *SQLStorageAuthority) GetValidAuthorizations(registrationID int64, names []string, now time.Time) (latest map[string]*core.Authorization, err error) {
|
||||||
|
if len(names) == 0 {
|
||||||
|
return nil, errors.New("GetValidAuthorizations: no names received")
|
||||||
|
}
|
||||||
|
|
||||||
|
params := make([]interface{}, len(names))
|
||||||
|
qmarks := make([]string, len(names))
|
||||||
|
for i, name := range names {
|
||||||
|
id := core.AcmeIdentifier{Type: core.IdentifierDNS, Value: name}
|
||||||
|
idJSON, err := json.Marshal(id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
params[i] = string(idJSON)
|
||||||
|
qmarks[i] = "?"
|
||||||
|
}
|
||||||
|
|
||||||
|
var auths []*core.Authorization
|
||||||
|
_, err = ssa.dbMap.Select(&auths, `
|
||||||
|
SELECT * FROM authz
|
||||||
|
WHERE registrationID = ?
|
||||||
|
AND expires > ?
|
||||||
|
AND identifier IN (`+strings.Join(qmarks, ",")+`)
|
||||||
|
AND status = 'valid'
|
||||||
|
ORDER BY expires ASC
|
||||||
|
`, append([]interface{}{registrationID, now}, params...)...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
byName := make(map[string]*core.Authorization)
|
||||||
|
for _, auth := range auths {
|
||||||
|
if auth.Identifier.Type != core.IdentifierDNS {
|
||||||
|
return nil, fmt.Errorf("unknown identifier type: %q on authz id %q", auth.Identifier.Type, auth.ID)
|
||||||
|
}
|
||||||
|
// Due to ORDER BY expires, this results in the latest value
|
||||||
|
// for each name being used.
|
||||||
|
byName[auth.Identifier.Value] = auth
|
||||||
|
}
|
||||||
|
|
||||||
|
return byName, nil
|
||||||
|
}
|
||||||
|
|
||||||
// incrementIP returns a copy of `ip` incremented at a bit index `index`,
|
// incrementIP returns a copy of `ip` incremented at a bit index `index`,
|
||||||
// or in other words the first IP of the next highest subnet given a mask of
|
// or in other words the first IP of the next highest subnet given a mask of
|
||||||
// length `index`.
|
// length `index`.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue