revert to original transactional semantics

This commit is contained in:
Roland Shoemaker 2015-05-07 01:55:38 -07:00
parent 3041423361
commit df1ff86acd
1 changed files with 132 additions and 64 deletions

View File

@ -152,15 +152,21 @@ func (ssa *SQLStorageAuthority) InitTables() (err error) {
return return
} }
func (ssa *SQLStorageAuthority) DumpTables() { func (ssa *SQLStorageAuthority) DumpTables() error {
tx, err := ssa.dbMap.Begin()
if err != nil {
tx.Rollback()
return err
}
fmt.Printf("===== TABLE DUMP =====\n") fmt.Printf("===== TABLE DUMP =====\n")
fmt.Printf("\n----- registrations -----\n") fmt.Printf("\n----- registrations -----\n")
var registrations []core.Registration var registrations []core.Registration
_, err := ssa.dbMap.Select(&registrations, "SELECT * FROM registrations ") _, err = tx.Select(&registrations, "SELECT * FROM registrations ")
if err != nil { if err != nil {
fmt.Println(err) tx.Rollback()
return return err
} }
for _, r := range registrations { for _, r := range registrations {
fmt.Printf("%+v\n", r) fmt.Printf("%+v\n", r)
@ -168,10 +174,10 @@ func (ssa *SQLStorageAuthority) DumpTables() {
fmt.Printf("\n----- pending_authz -----\n") fmt.Printf("\n----- pending_authz -----\n")
var pending_authz []pendingauthzModel var pending_authz []pendingauthzModel
_, err = ssa.dbMap.Select(&pending_authz, "SELECT * FROM pending_authz") _, err = tx.Select(&pending_authz, "SELECT * FROM pending_authz")
if err != nil { if err != nil {
fmt.Println(err) tx.Rollback()
return return err
} }
for _, pa := range pending_authz { for _, pa := range pending_authz {
fmt.Printf("%+v\n", pa) fmt.Printf("%+v\n", pa)
@ -179,10 +185,10 @@ func (ssa *SQLStorageAuthority) DumpTables() {
fmt.Printf("\n----- authz -----\n") fmt.Printf("\n----- authz -----\n")
var authz []authzModel var authz []authzModel
_, err = ssa.dbMap.Select(&authz, "SELECT * FROM authz") _, err = tx.Select(&authz, "SELECT * FROM authz")
if err != nil { if err != nil {
fmt.Println(err) tx.Rollback()
return return err
} }
for _, a := range authz { for _, a := range authz {
fmt.Printf("%+v\n", a) fmt.Printf("%+v\n", a)
@ -190,10 +196,10 @@ func (ssa *SQLStorageAuthority) DumpTables() {
fmt.Printf("\n----- certificates -----\n") fmt.Printf("\n----- certificates -----\n")
var certificates []core.Certificate var certificates []core.Certificate
_, err = ssa.dbMap.Select(&certificates, "SELECT * FROM certificates") _, err = tx.Select(&certificates, "SELECT * FROM certificates")
if err != nil { if err != nil {
fmt.Println(err) tx.Rollback()
return return err
} }
for _, c := range certificates { for _, c := range certificates {
fmt.Printf("%+v\n", c) fmt.Printf("%+v\n", c)
@ -201,10 +207,10 @@ func (ssa *SQLStorageAuthority) DumpTables() {
fmt.Printf("\n----- certificateStatus -----\n") fmt.Printf("\n----- certificateStatus -----\n")
var certificateStatuses []core.CertificateStatus var certificateStatuses []core.CertificateStatus
_, err = ssa.dbMap.Select(&certificateStatuses, "SELECT * FROM certificateStatus") _, err = tx.Select(&certificateStatuses, "SELECT * FROM certificateStatus")
if err != nil { if err != nil {
fmt.Println(err) tx.Rollback()
return return err
} }
for _, cS := range certificateStatuses { for _, cS := range certificateStatuses {
fmt.Printf("%+v\n", cS) fmt.Printf("%+v\n", cS)
@ -212,10 +218,10 @@ func (ssa *SQLStorageAuthority) DumpTables() {
fmt.Printf("\n----- ocspResponses -----\n") fmt.Printf("\n----- ocspResponses -----\n")
var ocspResponses []core.OcspResponse var ocspResponses []core.OcspResponse
_, err = ssa.dbMap.Select(&ocspResponses, "SELECT * FROM ocspResponses") _, err = tx.Select(&ocspResponses, "SELECT * FROM ocspResponses")
if err != nil { if err != nil {
fmt.Println(err) tx.Rollback()
return return err
} }
for _, oR := range ocspResponses { for _, oR := range ocspResponses {
fmt.Printf("%+v\n", oR) fmt.Printf("%+v\n", oR)
@ -223,35 +229,38 @@ func (ssa *SQLStorageAuthority) DumpTables() {
fmt.Printf("\n----- crls -----\n") fmt.Printf("\n----- crls -----\n")
var crls []core.Crl var crls []core.Crl
_, err = ssa.dbMap.Select(&crls, "SELECT * FROM crls") _, err = tx.Select(&crls, "SELECT * FROM crls")
if err != nil { if err != nil {
fmt.Println(err) tx.Rollback()
return return err
} }
for _, c := range crls { for _, c := range crls {
fmt.Printf("%+v\n", c) fmt.Printf("%+v\n", c)
} }
err = tx.Commit()
return err
} }
func statusIsPending(status core.AcmeStatus) bool { func statusIsPending(status core.AcmeStatus) bool {
return status == core.StatusPending || status == core.StatusProcessing || status == core.StatusUnknown return status == core.StatusPending || status == core.StatusProcessing || status == core.StatusUnknown
} }
func (ssa *SQLStorageAuthority) existingPending(id string) (bool) { func existingPending(tx *gorp.Transaction, id string) (bool) {
var count int64 var count int64
_ = ssa.dbMap.SelectOne(&count, "SELECT count(*) FROM pending_authz WHERE id = :id", map[string]interface{} {"id": id}) _ = tx.SelectOne(&count, "SELECT count(*) FROM pending_authz WHERE id = :id", map[string]interface{} {"id": id})
return count > 0 return count > 0
} }
func (ssa *SQLStorageAuthority) existingFinal(id string) (bool) { func existingFinal(tx *gorp.Transaction, id string) (bool) {
var count int64 var count int64
_ = ssa.dbMap.SelectOne(&count, "SELECT count(*) FROM authz WHERE id = :id", map[string]interface{} {"id": id}) _ = tx.SelectOne(&count, "SELECT count(*) FROM authz WHERE id = :id", map[string]interface{} {"id": id})
return count > 0 return count > 0
} }
func (ssa *SQLStorageAuthority) existingRegistration(id string) (bool) { func existingRegistration(tx *gorp.Transaction, id string) (bool) {
var count int64 var count int64
_ = ssa.dbMap.SelectOne(&count, "SELECT count(*) FROM registrations WHERE id = :id", map[string]interface{} {"id": id}) _ = tx.SelectOne(&count, "SELECT count(*) FROM registrations WHERE id = :id", map[string]interface{} {"id": id})
return count > 0 return count > 0
} }
@ -269,25 +278,37 @@ func (ssa *SQLStorageAuthority) GetRegistration(id string) (reg core.Registratio
} }
func (ssa *SQLStorageAuthority) GetAuthorization(id string) (authz core.Authorization, err error) { func (ssa *SQLStorageAuthority) GetAuthorization(id string) (authz core.Authorization, err error) {
authObj, err := ssa.dbMap.Get(pendingauthzModel{}, id) tx, err := ssa.dbMap.Begin()
if err != nil { if err != nil {
return return
} }
authObj, err := tx.Get(pendingauthzModel{}, id)
if err != nil {
tx.Rollback()
return
}
if authObj == nil { if authObj == nil {
authObj, err = ssa.dbMap.Get(authzModel{}, id) authObj, err = tx.Get(authzModel{}, id)
if err != nil { if err != nil {
tx.Rollback()
return return
} }
if authObj == nil { if authObj == nil {
err = fmt.Errorf("No pending_authz or authz with ID %s", id) err = fmt.Errorf("No pending_authz or authz with ID %s", id)
tx.Rollback()
return return
} }
authD := authObj.(*authzModel) authD := authObj.(*authzModel)
authz = authD.Authorization authz = authD.Authorization
err = tx.Commit()
return return
} }
authD := *authObj.(*pendingauthzModel) authD := *authObj.(*pendingauthzModel)
authz = authD.Authorization authz = authD.Authorization
err = tx.Commit()
return return
} }
@ -350,20 +371,27 @@ func (ssa *SQLStorageAuthority) GetCertificateStatus(serial string) (status core
} }
func (ssa *SQLStorageAuthority) NewRegistration() (id string, err error) { func (ssa *SQLStorageAuthority) NewRegistration() (id string, err error) {
tx, err := ssa.dbMap.Begin()
if err != nil {
return
}
// Check that it doesn't exist already // Check that it doesn't exist already
id = core.NewToken() id = core.NewToken()
for ssa.existingRegistration(id) { for existingRegistration(tx, id) {
id = core.NewToken() id = core.NewToken()
} }
reg := &core.Registration{} reg := &core.Registration{}
reg.ID = id reg.ID = id
err = ssa.dbMap.Insert(reg) err = tx.Insert(reg)
if err != nil { if err != nil {
tx.Rollback()
return return
} }
err = tx.Commit()
return return
} }
@ -414,74 +442,118 @@ func (ssa *SQLStorageAuthority) MarkCertificateRevoked(serial string, ocspRespon
return return
} }
tx.Commit() err = tx.Commit()
return return
} }
func (ssa *SQLStorageAuthority) UpdateRegistration(reg core.Registration) (err error) { func (ssa *SQLStorageAuthority) UpdateRegistration(reg core.Registration) (err error) {
tx, err := ssa.dbMap.Begin()
if !ssa.existingRegistration(reg.ID) { if err != nil {
err = errors.New("Requested registration not found " + reg.ID)
return return
} }
_, err = ssa.dbMap.Update(&reg) if !existingRegistration(tx, reg.ID) {
err = errors.New("Requested registration not found " + reg.ID)
tx.Rollback()
return
}
_, err = tx.Update(&reg)
if err != nil {
tx.Rollback()
return
}
err = tx.Commit()
return return
} }
func (ssa *SQLStorageAuthority) NewPendingAuthorization() (id string, err error) { func (ssa *SQLStorageAuthority) NewPendingAuthorization() (id string, err error) {
tx, err := ssa.dbMap.Begin()
if err != nil {
return
}
// Check that it doesn't exist already // Check that it doesn't exist already
id = core.NewToken() id = core.NewToken()
for ssa.existingPending(id) || ssa.existingFinal(id) { for existingPending(tx, id) || existingFinal(tx, id) {
id = core.NewToken() id = core.NewToken()
} }
// Insert a stub row in pending // Insert a stub row in pending
pending_authz := &pendingauthzModel{Authorization: core.Authorization{ID: id}} pending_authz := &pendingauthzModel{Authorization: core.Authorization{ID: id}}
err = ssa.dbMap.Insert(pending_authz) err = tx.Insert(pending_authz)
if err != nil {
tx.Rollback()
return
}
err = tx.Commit()
return return
} }
func (ssa *SQLStorageAuthority) UpdatePendingAuthorization(authz core.Authorization) (err error) { func (ssa *SQLStorageAuthority) UpdatePendingAuthorization(authz core.Authorization) (err error) {
if !statusIsPending(authz.Status) { tx, err := ssa.dbMap.Begin()
err = errors.New("Use Finalize() to update to a final status")
return
}
if ssa.existingFinal(authz.ID) {
err = errors.New("Cannot update a final authorization")
return
}
if !ssa.existingPending(authz.ID) {
err = errors.New("Requested authorization not found " + authz.ID)
return
}
authObj, err := ssa.dbMap.Get(pendingauthzModel{}, authz.ID)
if err != nil { if err != nil {
return return
} }
if !statusIsPending(authz.Status) {
err = errors.New("Use FinalizeAuthorization() to update to a final status")
tx.Rollback()
return
}
if existingFinal(tx, authz.ID) {
err = errors.New("Cannot update a final authorization")
tx.Rollback()
return
}
if !existingPending(tx, authz.ID) {
err = errors.New("Requested authorization not found " + authz.ID)
tx.Rollback()
return
}
authObj, err := tx.Get(pendingauthzModel{}, authz.ID)
if err != nil {
tx.Rollback()
return
}
auth := authObj.(*pendingauthzModel) auth := authObj.(*pendingauthzModel)
auth.Authorization = authz auth.Authorization = authz
_, err = ssa.dbMap.Update(auth) _, err = tx.Update(auth)
if err != nil {
tx.Rollback()
return
}
err = tx.Commit()
return return
} }
func (ssa *SQLStorageAuthority) FinalizeAuthorization(authz core.Authorization) (err error) { func (ssa *SQLStorageAuthority) FinalizeAuthorization(authz core.Authorization) (err error) {
tx, err := ssa.dbMap.Begin()
if err != nil {
return
}
// Check that a pending authz exists // Check that a pending authz exists
if !ssa.existingPending(authz.ID) { if !existingPending(tx, authz.ID) {
err = errors.New("Cannot finalize a authorization that is not pending") err = errors.New("Cannot finalize a authorization that is not pending")
tx.Rollback()
return return
} }
if statusIsPending(authz.Status) { if statusIsPending(authz.Status) {
err = errors.New("Cannot finalize to a non-final status") err = errors.New("Cannot finalize to a non-final status")
tx.Rollback()
return return
} }
// Manually set the index, to avoid AUTOINCREMENT issues // Manually set the index, to avoid AUTOINCREMENT issues
var sequence int64 var sequence int64
sequenceObj, err := ssa.dbMap.SelectNullInt("SELECT max(sequence) FROM authz") sequenceObj, err := tx.SelectNullInt("SELECT max(sequence) FROM authz")
switch { switch {
case !sequenceObj.Valid: case !sequenceObj.Valid:
sequence = 0 sequence = 0
@ -492,17 +564,13 @@ func (ssa *SQLStorageAuthority) FinalizeAuthorization(authz core.Authorization)
} }
auth := &authzModel{authz, sequence} auth := &authzModel{authz, sequence}
authObj, err := ssa.dbMap.Get(pendingauthzModel{}, authz.ID) authObj, err := tx.Get(pendingauthzModel{}, authz.ID)
if err != nil { if err != nil {
tx.Rollback()
return return
} }
oldAuth := authObj.(*pendingauthzModel) oldAuth := authObj.(*pendingauthzModel)
tx, err := ssa.dbMap.Begin()
if err != nil {
return
}
err = tx.Insert(auth) err = tx.Insert(auth)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()