add overriding of ARI response (#501)

Fixes #486

This moves the GetCertificateBySerial call earlier, which means that
call needs to succeed even for revoked certificates. So this also
follows up on #252 by keeping revoked certs in the primary
certificatesByID map (while still adding them to the
revokedCertificatesByID map).
This commit is contained in:
Jacob Hoffman-Andrews 2025-06-05 16:15:11 -07:00 committed by GitHub
parent 39dbb64e14
commit d52948ce25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 79 additions and 9 deletions

View File

@ -149,6 +149,8 @@ type Certificate struct {
DER []byte DER []byte
IssuerChains [][]*Certificate IssuerChains [][]*Certificate
AccountID string AccountID string
// When non-empty, this is the ARI response sent for this certificate.
ARIResponse string
} }
func (c Certificate) PEM() []byte { func (c Certificate) PEM() []byte {

View File

@ -413,7 +413,6 @@ func (m *MemoryStore) RevokeCertificate(cert *core.RevokedCertificate) {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
m.revokedCertificatesByID[cert.Certificate.ID] = cert m.revokedCertificatesByID[cert.Certificate.ID] = cert
delete(m.certificatesByID, cert.Certificate.ID)
} }
/* /*
@ -549,3 +548,19 @@ func (m *MemoryStore) IsDomainBlocked(name string) bool {
return false return false
} }
// SetARIResponse looks up a certificate by serial number and sets its ARI response field
func (m *MemoryStore) SetARIResponse(serial *big.Int, ariResponse string) error {
m.Lock()
defer m.Unlock()
for _, cert := range m.certificatesByID {
if cert.Cert.SerialNumber.Cmp(serial) == 0 {
cert.ARIResponse = ariResponse
return nil
}
}
// Certificate not found
return fmt.Errorf("certificate with serial number %s not found", serial.String())
}

View File

@ -57,7 +57,7 @@ const (
// Draft or likely-to-change paths // Draft or likely-to-change paths
renewalInfoPath = "/draft-ietf-acme-ari-03/renewalInfo/" renewalInfoPath = "/draft-ietf-acme-ari-03/renewalInfo/"
// Theses entrypoints are not a part of the standard ACME endpoints, // These entrypoints are not a part of the standard ACME endpoints,
// and are exposed by Pebble as an integration test tool. We export // and are exposed by Pebble as an integration test tool. We export
// RootCertPath so that the pebble binary can reference it. // RootCertPath so that the pebble binary can reference it.
RootCertPath = "/roots/" RootCertPath = "/roots/"
@ -65,6 +65,10 @@ const (
intermediateCertPath = "/intermediates/" intermediateCertPath = "/intermediates/"
intermediateKeyPath = "/intermediate-keys/" intermediateKeyPath = "/intermediate-keys/"
certStatusBySerial = "/cert-status-by-serial/" certStatusBySerial = "/cert-status-by-serial/"
// Post certificate PEM and desired literal response for renewal info
// (the renewal info response is not validated so may be intentionally
// malformed).
setRenewalInfoPath = "/set-renewal-info/"
// How long do pending authorizations last before expiring? // How long do pending authorizations last before expiring?
pendingAuthzExpire = time.Hour pendingAuthzExpire = time.Hour
@ -542,6 +546,9 @@ func (wfe *WebFrontEndImpl) ManagementHandler() http.Handler {
wfe.HandleManagementFunc(m, intermediateCertPath, wfe.handleCert(wfe.ca.GetIntermediateCert, intermediateCertPath)) wfe.HandleManagementFunc(m, intermediateCertPath, wfe.handleCert(wfe.ca.GetIntermediateCert, intermediateCertPath))
wfe.HandleManagementFunc(m, intermediateKeyPath, wfe.handleKey(wfe.ca.GetIntermediateKey, intermediateKeyPath)) wfe.HandleManagementFunc(m, intermediateKeyPath, wfe.handleKey(wfe.ca.GetIntermediateKey, intermediateKeyPath))
wfe.HandleManagementFunc(m, certStatusBySerial, wfe.handleCertStatusBySerial) wfe.HandleManagementFunc(m, certStatusBySerial, wfe.handleCertStatusBySerial)
// POST only handlers
wfe.HandleFunc(m, setRenewalInfoPath, wfe.SetRenewalInfo, http.MethodPost)
return m return m
} }
@ -1903,7 +1910,18 @@ func (wfe *WebFrontEndImpl) RenewalInfo(_ context.Context, response http.Respons
return return
} }
renewalInfo, err := wfe.determineARIWindow(certID) cert := wfe.db.GetCertificateBySerial(certID.SerialNumber)
if cert == nil {
wfe.sendError(acme.NotFoundProblem("failed to retrieve existing certificate serial"), response)
return
}
if cert.ARIResponse != "" {
_, _ = response.Write([]byte(cert.ARIResponse))
return
}
renewalInfo, err := wfe.determineARIWindow(certID, cert)
if err != nil { if err != nil {
wfe.sendError(acme.InternalErrorProblem(fmt.Sprintf("Error determining renewal window: %s", err)), response) wfe.sendError(acme.InternalErrorProblem(fmt.Sprintf("Error determining renewal window: %s", err)), response)
return return
@ -1917,7 +1935,47 @@ func (wfe *WebFrontEndImpl) RenewalInfo(_ context.Context, response http.Respons
} }
} }
func (wfe *WebFrontEndImpl) determineARIWindow(id *core.CertID) (*core.RenewalInfo, error) { // SetRenewalInfo overrides the default ARI response for a certificate.
func (wfe *WebFrontEndImpl) SetRenewalInfo(_ context.Context, response http.ResponseWriter, request *http.Request) {
body, err := io.ReadAll(request.Body)
if err != nil {
wfe.sendError(acme.InternalErrorProblem("Error reading body"), response)
}
var reqJSON struct {
Certificate string // in PEM form
ARIResponse string // can be anything, even malformed JSON, so that users can test client response to malformed data
}
err = json.Unmarshal(body, &reqJSON)
if err != nil {
wfe.sendError(acme.MalformedProblem("Error unmarshaling request body"), response)
return
}
// Decode and parse the PEM certificate
block, _ := pem.Decode([]byte(reqJSON.Certificate))
if block == nil || block.Type != "CERTIFICATE" {
wfe.sendError(acme.MalformedProblem("Error decoding certificate PEM"), response)
return
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
wfe.sendError(acme.MalformedProblem("Error parsing certificate"), response)
return
}
// Look up certificate by serial number and update response
err = wfe.db.SetARIResponse(cert.SerialNumber, reqJSON.ARIResponse)
if err != nil {
wfe.sendError(acme.NotFoundProblem(err.Error()), response)
return
}
response.WriteHeader(http.StatusOK)
}
func (wfe *WebFrontEndImpl) determineARIWindow(id *core.CertID, cert *core.Certificate) (*core.RenewalInfo, error) {
if id == nil { if id == nil {
return nil, errors.New("CertID was nil") return nil, errors.New("CertID was nil")
} }
@ -1928,11 +1986,6 @@ func (wfe *WebFrontEndImpl) determineARIWindow(id *core.CertID) (*core.RenewalIn
return core.RenewalInfoImmediate(time.Now().In(time.UTC)), nil return core.RenewalInfoImmediate(time.Now().In(time.UTC)), nil
} }
cert := wfe.db.GetCertificateBySerial(id.SerialNumber)
if cert == nil {
return nil, errors.New("failed to retrieve existing certificate serial")
}
return core.RenewalInfoSimple(cert.Cert.NotBefore, cert.Cert.NotAfter), nil return core.RenewalInfoSimple(cert.Cert.NotBefore, cert.Cert.NotAfter), nil
} }