diff --git a/core/interfaces.go b/core/interfaces.go index b3e26241d..0a6aa5cee 100644 --- a/core/interfaces.go +++ b/core/interfaces.go @@ -97,6 +97,7 @@ type StorageGetter interface { GetRegistrationByKey(jose.JsonWebKey) (Registration, error) GetAuthorization(string) (Authorization, error) GetLatestValidAuthorization(int64, AcmeIdentifier) (Authorization, error) + GetValidAuthorizations(int64, []string, time.Time) (map[string]*Authorization, error) GetCertificate(string) (Certificate, error) GetCertificateStatus(string) (CertificateStatus, error) AlreadyDeniedCSR([]string) (bool, error) diff --git a/mocks/mocks.go b/mocks/mocks.go index 16b89052e..7d3c51bc6 100644 --- a/mocks/mocks.go +++ b/mocks/mocks.go @@ -268,6 +268,29 @@ func (sa *StorageAuthority) GetLatestValidAuthorization(registrationID int64, id return core.Authorization{}, errors.New("no authz") } +// GetValidAuthorizations is a mock +func (sa *StorageAuthority) GetValidAuthorizations(regID int64, names []string, now time.Time) (map[string]*core.Authorization, error) { + if regID == 1 { + auths := make(map[string]*core.Authorization) + for _, name := range names { + if sa.authorizedDomains[name] || name == "not-an-example.com" { + exp := now.AddDate(100, 0, 0) + auths[name] = &core.Authorization{ + Status: core.StatusValid, + RegistrationID: 1, + Expires: &exp, + Identifier: core.AcmeIdentifier{ + Type: "dns", + Value: name, + }, + } + } + } + return auths, nil + } + return nil, errors.New("no authz") +} + // CountCertificatesRange is a mock func (sa *StorageAuthority) CountCertificatesRange(_, _ time.Time) (int64, error) { return 0, nil diff --git a/ra/registration-authority.go b/ra/registration-authority.go index 9e977d29d..541fabf46 100644 --- a/ra/registration-authority.go +++ b/ra/registration-authority.go @@ -452,9 +452,16 @@ func (ra *RegistrationAuthorityImpl) MatchesCSR(cert core.Certificate, csr *x509 func (ra *RegistrationAuthorityImpl) checkAuthorizations(names []string, registration *core.Registration) error { now := ra.clk.Now() var badNames []string + for i := range names { + names[i] = strings.ToLower(names[i]) + } + auths, err := ra.SA.GetValidAuthorizations(registration.ID, names, now) + if err != nil { + return err + } for _, name := range names { - authz, err := ra.SA.GetLatestValidAuthorization(registration.ID, core.AcmeIdentifier{Type: core.IdentifierDNS, Value: name}) - if err != nil || authz.Expires.Before(now) { + authz := auths[name] + if authz == nil || authz.Expires.Before(now) { badNames = append(badNames, name) } } diff --git a/rpc/rpc-wrappers.go b/rpc/rpc-wrappers.go index 048ec3a0b..320d2272c 100644 --- a/rpc/rpc-wrappers.go +++ b/rpc/rpc-wrappers.go @@ -53,6 +53,7 @@ const ( MethodGetRegistrationByKey = "GetRegistrationByKey" // RA, SA MethodGetAuthorization = "GetAuthorization" // SA MethodGetLatestValidAuthorization = "GetLatestValidAuthorization" // SA + MethodGetValidAuthorizations = "GetValidAuthorizations" // SA MethodGetCertificate = "GetCertificate" // SA MethodGetCertificateStatus = "GetCertificateStatus" // SA MethodMarkCertificateRevoked = "MarkCertificateRevoked" // SA @@ -103,6 +104,12 @@ type latestValidAuthorizationRequest struct { Identifier core.AcmeIdentifier } +type getValidAuthorizationsRequest struct { + RegID int64 + Names []string + Now time.Time +} + type certificateRequest struct { Req core.CertificateRequest RegID int64 @@ -846,6 +853,28 @@ func NewStorageAuthorityServer(rpc Server, impl core.StorageAuthority) error { return }) + rpc.Handle(MethodGetValidAuthorizations, func(req []byte) (response []byte, err error) { + var mreq getValidAuthorizationsRequest + if err = json.Unmarshal(req, &mreq); err != nil { + // AUDIT[ Improper Messages ] 0786b6f2-91ca-4f48-9883-842a19084c64 + improperMessage(MethodGetValidAuthorizations, err, req) + return + } + + auths, err := impl.GetValidAuthorizations(mreq.RegID, mreq.Names, mreq.Now) + if err != nil { + return + } + + response, err = json.Marshal(auths) + if err != nil { + // AUDIT[ Error Conditions ] 9cc4d537-8534-4970-8665-4b382abe82f3 + errorCondition(MethodGetValidAuthorizations, err, req) + return + } + return + }) + rpc.Handle(MethodAddCertificate, func(req []byte) (response []byte, err error) { var acReq addCertificateRequest err = json.Unmarshal(req, &acReq) @@ -1256,6 +1285,28 @@ func (cac StorageAuthorityClient) GetLatestValidAuthorization(registrationID int return } +// GetValidAuthorizations sends a request to get a batch of Authorizations by +// RegID and dnsName. The current time is also included in the request to +// assist filtering. +func (cac StorageAuthorityClient) GetValidAuthorizations(registrationID int64, names []string, now time.Time) (auths map[string]*core.Authorization, err error) { + data, err := json.Marshal(getValidAuthorizationsRequest{ + RegID: registrationID, + Names: names, + Now: now, + }) + if err != nil { + return + } + + jsonAuths, err := cac.rpc.DispatchSync(MethodGetValidAuthorizations, data) + if err != nil { + return + } + + err = json.Unmarshal(jsonAuths, &auths) + return +} + // GetCertificate sends a request to get a Certificate by ID func (cac StorageAuthorityClient) GetCertificate(id string) (cert core.Certificate, err error) { jsonCert, err := cac.rpc.DispatchSync(MethodGetCertificate, []byte(id)) diff --git a/sa/storage-authority.go b/sa/storage-authority.go index baa3dccf6..d11cb4ee7 100644 --- a/sa/storage-authority.go +++ b/sa/storage-authority.go @@ -232,6 +232,51 @@ func (ssa *SQLStorageAuthority) GetLatestValidAuthorization(registrationID int64 return ssa.GetAuthorization(auth.ID) } +// GetValidAuthorizations returns the latest authorization object for all +// domain names from the parameters that the account has authorizations for. +func (ssa *SQLStorageAuthority) GetValidAuthorizations(registrationID int64, names []string, now time.Time) (latest map[string]*core.Authorization, err error) { + if len(names) == 0 { + return nil, errors.New("GetValidAuthorizations: no names received") + } + + params := make([]interface{}, len(names)) + qmarks := make([]string, len(names)) + for i, name := range names { + id := core.AcmeIdentifier{Type: core.IdentifierDNS, Value: name} + idJSON, err := json.Marshal(id) + if err != nil { + return nil, err + } + params[i] = string(idJSON) + qmarks[i] = "?" + } + + var auths []*core.Authorization + _, err = ssa.dbMap.Select(&auths, ` + SELECT * FROM authz + WHERE registrationID = ? + AND expires > ? + AND identifier IN (`+strings.Join(qmarks, ",")+`) + AND status = 'valid' + ORDER BY expires ASC + `, append([]interface{}{registrationID, now}, params...)...) + if err != nil { + return nil, err + } + + byName := make(map[string]*core.Authorization) + for _, auth := range auths { + if auth.Identifier.Type != core.IdentifierDNS { + return nil, fmt.Errorf("unknown identifier type: %q on authz id %q", auth.Identifier.Type, auth.ID) + } + // Due to ORDER BY expires, this results in the latest value + // for each name being used. + byName[auth.Identifier.Value] = auth + } + + return byName, nil +} + // incrementIP returns a copy of `ip` incremented at a bit index `index`, // or in other words the first IP of the next highest subnet given a mask of // length `index`.