Add GetValidAuthorizations to batch authz checks
By performing only one query to MySQL, we should be able to avoid blowing the timeouts. Fixes #1567
This commit is contained in:
parent
a220ee0ad8
commit
91bc75b0e3
|
|
@ -97,6 +97,7 @@ type StorageGetter interface {
|
|||
GetRegistrationByKey(jose.JsonWebKey) (Registration, error)
|
||||
GetAuthorization(string) (Authorization, error)
|
||||
GetLatestValidAuthorization(int64, AcmeIdentifier) (Authorization, error)
|
||||
GetValidAuthorizations(int64, []string, time.Time) (map[string]*Authorization, error)
|
||||
GetCertificate(string) (Certificate, error)
|
||||
GetCertificateStatus(string) (CertificateStatus, error)
|
||||
AlreadyDeniedCSR([]string) (bool, error)
|
||||
|
|
|
|||
|
|
@ -268,6 +268,29 @@ func (sa *StorageAuthority) GetLatestValidAuthorization(registrationID int64, id
|
|||
return core.Authorization{}, errors.New("no authz")
|
||||
}
|
||||
|
||||
// GetValidAuthorizations is a mock
|
||||
func (sa *StorageAuthority) GetValidAuthorizations(regID int64, names []string, now time.Time) (map[string]*core.Authorization, error) {
|
||||
if regID == 1 {
|
||||
auths := make(map[string]*core.Authorization)
|
||||
for _, name := range names {
|
||||
if sa.authorizedDomains[name] || name == "not-an-example.com" {
|
||||
exp := now.AddDate(100, 0, 0)
|
||||
auths[name] = &core.Authorization{
|
||||
Status: core.StatusValid,
|
||||
RegistrationID: 1,
|
||||
Expires: &exp,
|
||||
Identifier: core.AcmeIdentifier{
|
||||
Type: "dns",
|
||||
Value: name,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
return auths, nil
|
||||
}
|
||||
return nil, errors.New("no authz")
|
||||
}
|
||||
|
||||
// CountCertificatesRange is a mock
|
||||
func (sa *StorageAuthority) CountCertificatesRange(_, _ time.Time) (int64, error) {
|
||||
return 0, nil
|
||||
|
|
|
|||
|
|
@ -452,9 +452,16 @@ func (ra *RegistrationAuthorityImpl) MatchesCSR(cert core.Certificate, csr *x509
|
|||
func (ra *RegistrationAuthorityImpl) checkAuthorizations(names []string, registration *core.Registration) error {
|
||||
now := ra.clk.Now()
|
||||
var badNames []string
|
||||
for i := range names {
|
||||
names[i] = strings.ToLower(names[i])
|
||||
}
|
||||
auths, err := ra.SA.GetValidAuthorizations(registration.ID, names, now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, name := range names {
|
||||
authz, err := ra.SA.GetLatestValidAuthorization(registration.ID, core.AcmeIdentifier{Type: core.IdentifierDNS, Value: name})
|
||||
if err != nil || authz.Expires.Before(now) {
|
||||
authz := auths[name]
|
||||
if authz == nil || authz.Expires.Before(now) {
|
||||
badNames = append(badNames, name)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ const (
|
|||
MethodGetRegistrationByKey = "GetRegistrationByKey" // RA, SA
|
||||
MethodGetAuthorization = "GetAuthorization" // SA
|
||||
MethodGetLatestValidAuthorization = "GetLatestValidAuthorization" // SA
|
||||
MethodGetValidAuthorizations = "GetValidAuthorizations" // SA
|
||||
MethodGetCertificate = "GetCertificate" // SA
|
||||
MethodGetCertificateStatus = "GetCertificateStatus" // SA
|
||||
MethodMarkCertificateRevoked = "MarkCertificateRevoked" // SA
|
||||
|
|
@ -103,6 +104,12 @@ type latestValidAuthorizationRequest struct {
|
|||
Identifier core.AcmeIdentifier
|
||||
}
|
||||
|
||||
type getValidAuthorizationsRequest struct {
|
||||
RegID int64
|
||||
Names []string
|
||||
Now time.Time
|
||||
}
|
||||
|
||||
type certificateRequest struct {
|
||||
Req core.CertificateRequest
|
||||
RegID int64
|
||||
|
|
@ -846,6 +853,28 @@ func NewStorageAuthorityServer(rpc Server, impl core.StorageAuthority) error {
|
|||
return
|
||||
})
|
||||
|
||||
rpc.Handle(MethodGetValidAuthorizations, func(req []byte) (response []byte, err error) {
|
||||
var mreq getValidAuthorizationsRequest
|
||||
if err = json.Unmarshal(req, &mreq); err != nil {
|
||||
// AUDIT[ Improper Messages ] 0786b6f2-91ca-4f48-9883-842a19084c64
|
||||
improperMessage(MethodGetValidAuthorizations, err, req)
|
||||
return
|
||||
}
|
||||
|
||||
auths, err := impl.GetValidAuthorizations(mreq.RegID, mreq.Names, mreq.Now)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
response, err = json.Marshal(auths)
|
||||
if err != nil {
|
||||
// AUDIT[ Error Conditions ] 9cc4d537-8534-4970-8665-4b382abe82f3
|
||||
errorCondition(MethodGetValidAuthorizations, err, req)
|
||||
return
|
||||
}
|
||||
return
|
||||
})
|
||||
|
||||
rpc.Handle(MethodAddCertificate, func(req []byte) (response []byte, err error) {
|
||||
var acReq addCertificateRequest
|
||||
err = json.Unmarshal(req, &acReq)
|
||||
|
|
@ -1256,6 +1285,28 @@ func (cac StorageAuthorityClient) GetLatestValidAuthorization(registrationID int
|
|||
return
|
||||
}
|
||||
|
||||
// GetValidAuthorizations sends a request to get a batch of Authorizations by
|
||||
// RegID and dnsName. The current time is also included in the request to
|
||||
// assist filtering.
|
||||
func (cac StorageAuthorityClient) GetValidAuthorizations(registrationID int64, names []string, now time.Time) (auths map[string]*core.Authorization, err error) {
|
||||
data, err := json.Marshal(getValidAuthorizationsRequest{
|
||||
RegID: registrationID,
|
||||
Names: names,
|
||||
Now: now,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
jsonAuths, err := cac.rpc.DispatchSync(MethodGetValidAuthorizations, data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = json.Unmarshal(jsonAuths, &auths)
|
||||
return
|
||||
}
|
||||
|
||||
// GetCertificate sends a request to get a Certificate by ID
|
||||
func (cac StorageAuthorityClient) GetCertificate(id string) (cert core.Certificate, err error) {
|
||||
jsonCert, err := cac.rpc.DispatchSync(MethodGetCertificate, []byte(id))
|
||||
|
|
|
|||
|
|
@ -232,6 +232,51 @@ func (ssa *SQLStorageAuthority) GetLatestValidAuthorization(registrationID int64
|
|||
return ssa.GetAuthorization(auth.ID)
|
||||
}
|
||||
|
||||
// GetValidAuthorizations returns the latest authorization object for all
|
||||
// domain names from the parameters that the account has authorizations for.
|
||||
func (ssa *SQLStorageAuthority) GetValidAuthorizations(registrationID int64, names []string, now time.Time) (latest map[string]*core.Authorization, err error) {
|
||||
if len(names) == 0 {
|
||||
return nil, errors.New("GetValidAuthorizations: no names received")
|
||||
}
|
||||
|
||||
params := make([]interface{}, len(names))
|
||||
qmarks := make([]string, len(names))
|
||||
for i, name := range names {
|
||||
id := core.AcmeIdentifier{Type: core.IdentifierDNS, Value: name}
|
||||
idJSON, err := json.Marshal(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
params[i] = string(idJSON)
|
||||
qmarks[i] = "?"
|
||||
}
|
||||
|
||||
var auths []*core.Authorization
|
||||
_, err = ssa.dbMap.Select(&auths, `
|
||||
SELECT * FROM authz
|
||||
WHERE registrationID = ?
|
||||
AND expires > ?
|
||||
AND identifier IN (`+strings.Join(qmarks, ",")+`)
|
||||
AND status = 'valid'
|
||||
ORDER BY expires ASC
|
||||
`, append([]interface{}{registrationID, now}, params...)...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
byName := make(map[string]*core.Authorization)
|
||||
for _, auth := range auths {
|
||||
if auth.Identifier.Type != core.IdentifierDNS {
|
||||
return nil, fmt.Errorf("unknown identifier type: %q on authz id %q", auth.Identifier.Type, auth.ID)
|
||||
}
|
||||
// Due to ORDER BY expires, this results in the latest value
|
||||
// for each name being used.
|
||||
byName[auth.Identifier.Value] = auth
|
||||
}
|
||||
|
||||
return byName, nil
|
||||
}
|
||||
|
||||
// incrementIP returns a copy of `ip` incremented at a bit index `index`,
|
||||
// or in other words the first IP of the next highest subnet given a mask of
|
||||
// length `index`.
|
||||
|
|
|
|||
Loading…
Reference in New Issue