Add GetValidAuthorizations to batch authz checks

By performing only one query to MySQL, we should be able to avoid
blowing the timeouts.

Fixes #1567
This commit is contained in:
Kane York 2016-03-09 15:38:07 -08:00
parent a220ee0ad8
commit 91bc75b0e3
5 changed files with 129 additions and 2 deletions

View File

@ -97,6 +97,7 @@ type StorageGetter interface {
GetRegistrationByKey(jose.JsonWebKey) (Registration, error)
GetAuthorization(string) (Authorization, error)
GetLatestValidAuthorization(int64, AcmeIdentifier) (Authorization, error)
GetValidAuthorizations(int64, []string, time.Time) (map[string]*Authorization, error)
GetCertificate(string) (Certificate, error)
GetCertificateStatus(string) (CertificateStatus, error)
AlreadyDeniedCSR([]string) (bool, error)

View File

@ -268,6 +268,29 @@ func (sa *StorageAuthority) GetLatestValidAuthorization(registrationID int64, id
return core.Authorization{}, errors.New("no authz")
}
// GetValidAuthorizations is a mock
func (sa *StorageAuthority) GetValidAuthorizations(regID int64, names []string, now time.Time) (map[string]*core.Authorization, error) {
if regID == 1 {
auths := make(map[string]*core.Authorization)
for _, name := range names {
if sa.authorizedDomains[name] || name == "not-an-example.com" {
exp := now.AddDate(100, 0, 0)
auths[name] = &core.Authorization{
Status: core.StatusValid,
RegistrationID: 1,
Expires: &exp,
Identifier: core.AcmeIdentifier{
Type: "dns",
Value: name,
},
}
}
}
return auths, nil
}
return nil, errors.New("no authz")
}
// CountCertificatesRange is a mock
func (sa *StorageAuthority) CountCertificatesRange(_, _ time.Time) (int64, error) {
return 0, nil

View File

@ -452,9 +452,16 @@ func (ra *RegistrationAuthorityImpl) MatchesCSR(cert core.Certificate, csr *x509
func (ra *RegistrationAuthorityImpl) checkAuthorizations(names []string, registration *core.Registration) error {
now := ra.clk.Now()
var badNames []string
for i := range names {
names[i] = strings.ToLower(names[i])
}
auths, err := ra.SA.GetValidAuthorizations(registration.ID, names, now)
if err != nil {
return err
}
for _, name := range names {
authz, err := ra.SA.GetLatestValidAuthorization(registration.ID, core.AcmeIdentifier{Type: core.IdentifierDNS, Value: name})
if err != nil || authz.Expires.Before(now) {
authz := auths[name]
if authz == nil || authz.Expires.Before(now) {
badNames = append(badNames, name)
}
}

View File

@ -53,6 +53,7 @@ const (
MethodGetRegistrationByKey = "GetRegistrationByKey" // RA, SA
MethodGetAuthorization = "GetAuthorization" // SA
MethodGetLatestValidAuthorization = "GetLatestValidAuthorization" // SA
MethodGetValidAuthorizations = "GetValidAuthorizations" // SA
MethodGetCertificate = "GetCertificate" // SA
MethodGetCertificateStatus = "GetCertificateStatus" // SA
MethodMarkCertificateRevoked = "MarkCertificateRevoked" // SA
@ -103,6 +104,12 @@ type latestValidAuthorizationRequest struct {
Identifier core.AcmeIdentifier
}
type getValidAuthorizationsRequest struct {
RegID int64
Names []string
Now time.Time
}
type certificateRequest struct {
Req core.CertificateRequest
RegID int64
@ -846,6 +853,28 @@ func NewStorageAuthorityServer(rpc Server, impl core.StorageAuthority) error {
return
})
rpc.Handle(MethodGetValidAuthorizations, func(req []byte) (response []byte, err error) {
var mreq getValidAuthorizationsRequest
if err = json.Unmarshal(req, &mreq); err != nil {
// AUDIT[ Improper Messages ] 0786b6f2-91ca-4f48-9883-842a19084c64
improperMessage(MethodGetValidAuthorizations, err, req)
return
}
auths, err := impl.GetValidAuthorizations(mreq.RegID, mreq.Names, mreq.Now)
if err != nil {
return
}
response, err = json.Marshal(auths)
if err != nil {
// AUDIT[ Error Conditions ] 9cc4d537-8534-4970-8665-4b382abe82f3
errorCondition(MethodGetValidAuthorizations, err, req)
return
}
return
})
rpc.Handle(MethodAddCertificate, func(req []byte) (response []byte, err error) {
var acReq addCertificateRequest
err = json.Unmarshal(req, &acReq)
@ -1256,6 +1285,28 @@ func (cac StorageAuthorityClient) GetLatestValidAuthorization(registrationID int
return
}
// GetValidAuthorizations sends a request to get a batch of Authorizations by
// RegID and dnsName. The current time is also included in the request to
// assist filtering.
func (cac StorageAuthorityClient) GetValidAuthorizations(registrationID int64, names []string, now time.Time) (auths map[string]*core.Authorization, err error) {
data, err := json.Marshal(getValidAuthorizationsRequest{
RegID: registrationID,
Names: names,
Now: now,
})
if err != nil {
return
}
jsonAuths, err := cac.rpc.DispatchSync(MethodGetValidAuthorizations, data)
if err != nil {
return
}
err = json.Unmarshal(jsonAuths, &auths)
return
}
// GetCertificate sends a request to get a Certificate by ID
func (cac StorageAuthorityClient) GetCertificate(id string) (cert core.Certificate, err error) {
jsonCert, err := cac.rpc.DispatchSync(MethodGetCertificate, []byte(id))

View File

@ -232,6 +232,51 @@ func (ssa *SQLStorageAuthority) GetLatestValidAuthorization(registrationID int64
return ssa.GetAuthorization(auth.ID)
}
// GetValidAuthorizations returns the latest authorization object for all
// domain names from the parameters that the account has authorizations for.
func (ssa *SQLStorageAuthority) GetValidAuthorizations(registrationID int64, names []string, now time.Time) (latest map[string]*core.Authorization, err error) {
if len(names) == 0 {
return nil, errors.New("GetValidAuthorizations: no names received")
}
params := make([]interface{}, len(names))
qmarks := make([]string, len(names))
for i, name := range names {
id := core.AcmeIdentifier{Type: core.IdentifierDNS, Value: name}
idJSON, err := json.Marshal(id)
if err != nil {
return nil, err
}
params[i] = string(idJSON)
qmarks[i] = "?"
}
var auths []*core.Authorization
_, err = ssa.dbMap.Select(&auths, `
SELECT * FROM authz
WHERE registrationID = ?
AND expires > ?
AND identifier IN (`+strings.Join(qmarks, ",")+`)
AND status = 'valid'
ORDER BY expires ASC
`, append([]interface{}{registrationID, now}, params...)...)
if err != nil {
return nil, err
}
byName := make(map[string]*core.Authorization)
for _, auth := range auths {
if auth.Identifier.Type != core.IdentifierDNS {
return nil, fmt.Errorf("unknown identifier type: %q on authz id %q", auth.Identifier.Type, auth.ID)
}
// Due to ORDER BY expires, this results in the latest value
// for each name being used.
byName[auth.Identifier.Value] = auth
}
return byName, nil
}
// incrementIP returns a copy of `ip` incremented at a bit index `index`,
// or in other words the first IP of the next highest subnet given a mask of
// length `index`.