Merge pull request #44 from endophage/atomic_update

Atomic updates of metadata.
This commit is contained in:
David Lawrence 2015-07-13 20:23:01 -07:00
commit b8674162f8
29 changed files with 2670 additions and 31 deletions

5
Godeps/Godeps.json generated
View File

@ -9,6 +9,11 @@
"ImportPath": "github.com/BurntSushi/toml",
"Rev": "bd2bdf7f18f849530ef7a1c29a4290217cab32a1"
},
{
"ImportPath": "github.com/DATA-DOG/go-sqlmock",
"Comment": "0.1.0-8-ged4836e",
"Rev": "ed4836e31d3e9e77420e442ed9b864df55370ee0"
},
{
"ImportPath": "github.com/Sirupsen/logrus",
"Comment": "v0.7.3",

View File

@ -0,0 +1 @@
/*.test

View File

@ -0,0 +1,16 @@
language: go
go:
- 1.2
- 1.3
- 1.4
- release
- tip
script:
- go get github.com/kisielk/errcheck
- go get ./...
- go test -v ./...
- go test -race ./...
- errcheck github.com/DATA-DOG/go-sqlmock

View File

@ -0,0 +1,28 @@
The three clause BSD license (http://en.wikipedia.org/wiki/BSD_licenses)
Copyright (c) 2013, DataDog.lt team
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* The name DataDog.lt may not be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL MICHAEL BOSTOCK BE LIABLE FOR ANY DIRECT,
INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -0,0 +1,360 @@
[![Build Status](https://travis-ci.org/DATA-DOG/go-sqlmock.png)](https://travis-ci.org/DATA-DOG/go-sqlmock)
[![GoDoc](https://godoc.org/github.com/DATA-DOG/go-sqlmock?status.png)](https://godoc.org/github.com/DATA-DOG/go-sqlmock)
# Sql driver mock for Golang
This is a **mock** driver as **database/sql/driver** which is very flexible and pragmatic to
manage and mock expected queries. All the expectations should be met and all queries and actions
triggered should be mocked in order to pass a test.
## Install
go get github.com/DATA-DOG/go-sqlmock
## Use it with pleasure
An example of some database interaction which you may want to test:
``` go
package main
import (
"database/sql"
_ "github.com/go-sql-driver/mysql"
"github.com/kisielk/sqlstruct"
"fmt"
"log"
)
const ORDER_PENDING = 0
const ORDER_CANCELLED = 1
type User struct {
Id int `sql:"id"`
Username string `sql:"username"`
Balance float64 `sql:"balance"`
}
type Order struct {
Id int `sql:"id"`
Value float64 `sql:"value"`
ReservedFee float64 `sql:"reserved_fee"`
Status int `sql:"status"`
}
func cancelOrder(id int, db *sql.DB) (err error) {
tx, err := db.Begin()
if err != nil {
return
}
var order Order
var user User
sql := fmt.Sprintf(`
SELECT %s, %s
FROM orders AS o
INNER JOIN users AS u ON o.buyer_id = u.id
WHERE o.id = ?
FOR UPDATE`,
sqlstruct.ColumnsAliased(order, "o"),
sqlstruct.ColumnsAliased(user, "u"))
// fetch order to cancel
rows, err := tx.Query(sql, id)
if err != nil {
tx.Rollback()
return
}
defer rows.Close()
// no rows, nothing to do
if !rows.Next() {
tx.Rollback()
return
}
// read order
err = sqlstruct.ScanAliased(&order, rows, "o")
if err != nil {
tx.Rollback()
return
}
// ensure order status
if order.Status != ORDER_PENDING {
tx.Rollback()
return
}
// read user
err = sqlstruct.ScanAliased(&user, rows, "u")
if err != nil {
tx.Rollback()
return
}
rows.Close() // manually close before other prepared statements
// refund order value
sql = "UPDATE users SET balance = balance + ? WHERE id = ?"
refundStmt, err := tx.Prepare(sql)
if err != nil {
tx.Rollback()
return
}
defer refundStmt.Close()
_, err = refundStmt.Exec(order.Value + order.ReservedFee, user.Id)
if err != nil {
tx.Rollback()
return
}
// update order status
order.Status = ORDER_CANCELLED
sql = "UPDATE orders SET status = ?, updated = NOW() WHERE id = ?"
orderUpdStmt, err := tx.Prepare(sql)
if err != nil {
tx.Rollback()
return
}
defer orderUpdStmt.Close()
_, err = orderUpdStmt.Exec(order.Status, order.Id)
if err != nil {
tx.Rollback()
return
}
return tx.Commit()
}
func main() {
db, err := sql.Open("mysql", "root:nimda@/test")
if err != nil {
log.Fatal(err)
}
defer db.Close()
err = cancelOrder(1, db)
if err != nil {
log.Fatal(err)
}
}
```
And the clean nice test:
``` go
package main
import (
"database/sql"
"github.com/DATA-DOG/go-sqlmock"
"testing"
"fmt"
)
// will test that order with a different status, cannot be cancelled
func TestShouldNotCancelOrderWithNonPendingStatus(t *testing.T) {
// open database stub
db, err := sqlmock.New()
if err != nil {
t.Errorf("An error '%s' was not expected when opening a stub database connection", err)
}
// columns are prefixed with "o" since we used sqlstruct to generate them
columns := []string{"o_id", "o_status"}
// expect transaction begin
sqlmock.ExpectBegin()
// expect query to fetch order and user, match it with regexp
sqlmock.ExpectQuery("SELECT (.+) FROM orders AS o INNER JOIN users AS u (.+) FOR UPDATE").
WithArgs(1).
WillReturnRows(sqlmock.NewRows(columns).FromCSVString("1,1"))
// expect transaction rollback, since order status is "cancelled"
sqlmock.ExpectRollback()
// run the cancel order function
err = cancelOrder(1, db)
if err != nil {
t.Errorf("Expected no error, but got %s instead", err)
}
// db.Close() ensures that all expectations have been met
if err = db.Close(); err != nil {
t.Errorf("Error '%s' was not expected while closing the database", err)
}
}
// will test order cancellation
func TestShouldRefundUserWhenOrderIsCancelled(t *testing.T) {
// open database stub
db, err := sqlmock.New()
if err != nil {
t.Errorf("An error '%s' was not expected when opening a stub database connection", err)
}
// columns are prefixed with "o" since we used sqlstruct to generate them
columns := []string{"o_id", "o_status", "o_value", "o_reserved_fee", "u_id", "u_balance"}
// expect transaction begin
sqlmock.ExpectBegin()
// expect query to fetch order and user, match it with regexp
sqlmock.ExpectQuery("SELECT (.+) FROM orders AS o INNER JOIN users AS u (.+) FOR UPDATE").
WithArgs(1).
WillReturnRows(sqlmock.NewRows(columns).AddRow(1, 0, 25.75, 3.25, 2, 10.00))
// expect user balance update
sqlmock.ExpectExec("UPDATE users SET balance").
WithArgs(25.75 + 3.25, 2). // refund amount, user id
WillReturnResult(sqlmock.NewResult(0, 1)) // no insert id, 1 affected row
// expect order status update
sqlmock.ExpectExec("UPDATE orders SET status").
WithArgs(ORDER_CANCELLED, 1). // status, id
WillReturnResult(sqlmock.NewResult(0, 1)) // no insert id, 1 affected row
// expect a transaction commit
sqlmock.ExpectCommit()
// run the cancel order function
err = cancelOrder(1, db)
if err != nil {
t.Errorf("Expected no error, but got %s instead", err)
}
// db.Close() ensures that all expectations have been met
if err = db.Close(); err != nil {
t.Errorf("Error '%s' was not expected while closing the database", err)
}
}
// will test order cancellation
func TestShouldRollbackOnError(t *testing.T) {
// open database stub
db, err := sqlmock.New()
if err != nil {
t.Errorf("An error '%s' was not expected when opening a stub database connection", err)
}
// expect transaction begin
sqlmock.ExpectBegin()
// expect query to fetch order and user, match it with regexp
sqlmock.ExpectQuery("SELECT (.+) FROM orders AS o INNER JOIN users AS u (.+) FOR UPDATE").
WithArgs(1).
WillReturnError(fmt.Errorf("Some error"))
// should rollback since error was returned from query execution
sqlmock.ExpectRollback()
// run the cancel order function
err = cancelOrder(1, db)
// error should return back
if err == nil {
t.Error("Expected error, but got none")
}
// db.Close() ensures that all expectations have been met
if err = db.Close(); err != nil {
t.Errorf("Error '%s' was not expected while closing the database", err)
}
}
```
## Expectations
All **Expect** methods return a **Mock** interface which allow you to describe
expectations in more details: return an error, expect specific arguments, return rows and so on.
**NOTE:** that if you call **WithArgs** on a non query based expectation, it will panic
A **Mock** interface:
``` go
type Mock interface {
WithArgs(...driver.Value) Mock
WillReturnError(error) Mock
WillReturnRows(driver.Rows) Mock
WillReturnResult(driver.Result) Mock
}
```
As an example we can expect a transaction commit and simulate an error for it:
``` go
sqlmock.ExpectCommit().WillReturnError(fmt.Errorf("Deadlock occured"))
```
In same fashion, we can expect queries to match arguments. If there are any, it must be matched.
Instead of result we can return error.
``` go
sqlmock.ExpectQuery("SELECT (.*) FROM orders").
WithArgs("string value").
WillReturnRows(sqlmock.NewRows([]string{"col"}).AddRow("val"))
```
**NOTE:** it matches a regular expression. Some regex special characters must be escaped if you want to match them.
For example if we want to match a subselect:
``` go
sqlmock.ExpectQuery("SELECT (.*) FROM orders WHERE id IN \\(SELECT id FROM finished WHERE status = 1\\)").
WithArgs("string value").
WillReturnRows(sqlmock.NewRows([]string{"col"}).AddRow("val"))
```
**WithArgs** expectation, compares values based on their type, for usual values like **string, float, int**
it matches the actual value. Types like **time** are compared only by type. Other types might require different ways
to compare them correctly, this may be improved.
You can build rows either from CSV string or from interface values:
**Rows** interface, which satisfies sql driver.Rows:
``` go
type Rows interface {
AddRow(...driver.Value) Rows
FromCSVString(s string) Rows
Next([]driver.Value) error
Columns() []string
Close() error
}
```
Example for to build rows:
``` go
rs := sqlmock.NewRows([]string{"column1", "column2"}).
FromCSVString("one,1\ntwo,2").
AddRow("three", 3)
```
**Prepare** will ignore other expectations if ExpectPrepare not set. When set, can expect normal result or simulate an error:
``` go
rs := sqlmock.ExpectPrepare().
WillReturnError(fmt.Errorf("Query prepare failed"))
```
## Run tests
go test
## Documentation
Visit [godoc](http://godoc.org/github.com/DATA-DOG/go-sqlmock)
See **.travis.yml** for supported **go** versions
Different use case, is to functionally test with a real database - [go-txdb](https://github.com/DATA-DOG/go-txdb)
all database related actions are isolated within a single transaction so the database can remain in the same state.
## Changes
- **2014-08-16** instead of **panic** during reflect type mismatch when comparing query arguments - now return error
- **2014-08-14** added **sqlmock.NewErrorResult** which gives an option to return driver.Result with errors for
interface methods, see [issue](https://github.com/DATA-DOG/go-sqlmock/issues/5)
- **2014-05-29** allow to match arguments in more sophisticated ways, by providing an **sqlmock.Argument** interface
- **2014-04-21** introduce **sqlmock.New()** to open a mock database connection for tests. This method
calls sql.DB.Ping to ensure that connection is open, see [issue](https://github.com/DATA-DOG/go-sqlmock/issues/4).
This way on Close it will surely assert if all expectations are met, even if database was not triggered at all.
The old way is still available, but it is advisable to call db.Ping manually before asserting with db.Close.
- **2014-02-14** RowsFromCSVString is now a part of Rows interface named as FromCSVString.
It has changed to allow more ways to construct rows and to easily extend this API in future.
See [issue 1](https://github.com/DATA-DOG/go-sqlmock/issues/1)
**RowsFromCSVString** is deprecated and will be removed in future
## Contributions
Feel free to open a pull request. Note, if you wish to contribute an extension to public (exported methods or types) -
please open an issue before, to discuss whether these changes can be accepted. All backward incompatible changes are
and will be treated cautiously
## License
The [three clause BSD license](http://en.wikipedia.org/wiki/BSD_licenses)

View File

@ -0,0 +1,151 @@
package sqlmock
import (
"database/sql/driver"
"fmt"
"reflect"
)
type conn struct {
expectations []expectation
active expectation
}
// Close a mock database driver connection. It should
// be always called to ensure that all expectations
// were met successfully. Returns error if there is any
func (c *conn) Close() (err error) {
for _, e := range mock.conn.expectations {
if !e.fulfilled() {
err = fmt.Errorf("there is a remaining expectation %T which was not matched yet", e)
break
}
}
mock.conn.expectations = []expectation{}
mock.conn.active = nil
return err
}
func (c *conn) Begin() (driver.Tx, error) {
e := c.next()
if e == nil {
return nil, fmt.Errorf("all expectations were already fulfilled, call to begin transaction was not expected")
}
etb, ok := e.(*expectedBegin)
if !ok {
return nil, fmt.Errorf("call to begin transaction, was not expected, next expectation is %T as %+v", e, e)
}
etb.triggered = true
return &transaction{c}, etb.err
}
// get next unfulfilled expectation
func (c *conn) next() (e expectation) {
for _, e = range c.expectations {
if !e.fulfilled() {
return
}
}
return nil // all expectations were fulfilled
}
func (c *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) {
e := c.next()
query = stripQuery(query)
if e == nil {
return nil, fmt.Errorf("all expectations were already fulfilled, call to exec '%s' query with args %+v was not expected", query, args)
}
eq, ok := e.(*expectedExec)
if !ok {
return nil, fmt.Errorf("call to exec query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, e, e)
}
eq.triggered = true
defer argMatcherErrorHandler(&err) // converts panic to error in case of reflect value type mismatch
if !eq.queryMatches(query) {
return nil, fmt.Errorf("exec query '%s', does not match regex '%s'", query, eq.sqlRegex.String())
}
if !eq.argsMatches(args) {
return nil, fmt.Errorf("exec query '%s', args %+v does not match expected %+v", query, args, eq.args)
}
if eq.err != nil {
return nil, eq.err // mocked to return error
}
if eq.result == nil {
return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, eq, eq)
}
return eq.result, err
}
func (c *conn) Prepare(query string) (driver.Stmt, error) {
e := c.next()
// for backwards compatibility, ignore when Prepare not expected
if e == nil {
return &statement{mock.conn, stripQuery(query)}, nil
}
eq, ok := e.(*expectedPrepare)
if !ok {
return &statement{mock.conn, stripQuery(query)}, nil
}
eq.triggered = true
if eq.err != nil {
return nil, eq.err // mocked to return error
}
return &statement{mock.conn, stripQuery(query)}, nil
}
func (c *conn) Query(query string, args []driver.Value) (rw driver.Rows, err error) {
e := c.next()
query = stripQuery(query)
if e == nil {
return nil, fmt.Errorf("all expectations were already fulfilled, call to query '%s' with args %+v was not expected", query, args)
}
eq, ok := e.(*expectedQuery)
if !ok {
return nil, fmt.Errorf("call to query '%s' with args %+v, was not expected, next expectation is %T as %+v", query, args, e, e)
}
eq.triggered = true
defer argMatcherErrorHandler(&err) // converts panic to error in case of reflect value type mismatch
if !eq.queryMatches(query) {
return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, eq.sqlRegex.String())
}
if !eq.argsMatches(args) {
return nil, fmt.Errorf("query '%s', args %+v does not match expected %+v", query, args, eq.args)
}
if eq.err != nil {
return nil, eq.err // mocked to return error
}
if eq.rows == nil {
return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, eq, eq)
}
return eq.rows, err
}
func argMatcherErrorHandler(errp *error) {
if e := recover(); e != nil {
if se, ok := e.(*reflect.ValueError); ok { // catch reflect error, failed type conversion
*errp = fmt.Errorf("Failed to compare query arguments: %s", se)
} else {
panic(e) // overwise panic
}
}
}

View File

@ -0,0 +1,378 @@
package sqlmock
import (
"database/sql/driver"
"errors"
"regexp"
"testing"
)
func TestExecNoExpectations(t *testing.T) {
c := &conn{
expectations: []expectation{
&expectedExec{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{
triggered: true,
err: errors.New("WillReturnError"),
},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("otherquery")),
args: []driver.Value{456},
},
},
},
}
res, err := c.Exec("query", []driver.Value{123})
if res != nil {
t.Error("Result should be nil")
}
if err == nil {
t.Error("error should not be nil")
}
pattern := regexp.MustCompile(regexp.QuoteMeta("all expectations were already fulfilled, call to exec"))
if !pattern.MatchString(err.Error()) {
t.Errorf("error should match expected error message (actual: %s)", err.Error())
}
}
func TestExecExpectationMismatch(t *testing.T) {
c := &conn{
expectations: []expectation{
&expectedQuery{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{
err: errors.New("WillReturnError"),
},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("otherquery")),
args: []driver.Value{456},
},
},
},
}
res, err := c.Exec("query", []driver.Value{123})
if res != nil {
t.Error("Result should be nil")
}
if err == nil {
t.Error("error should not be nil")
}
pattern := regexp.MustCompile(regexp.QuoteMeta("was not expected, next expectation is"))
if !pattern.MatchString(err.Error()) {
t.Errorf("error should match expected error message (actual: %s)", err.Error())
}
}
func TestExecQueryMismatch(t *testing.T) {
c := &conn{
expectations: []expectation{
&expectedExec{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{
err: errors.New("WillReturnError"),
},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("otherquery")),
args: []driver.Value{456},
},
},
},
}
res, err := c.Exec("query", []driver.Value{123})
if res != nil {
t.Error("Result should be nil")
}
if err == nil {
t.Error("error should not be nil")
}
pattern := regexp.MustCompile(regexp.QuoteMeta("does not match regex"))
if !pattern.MatchString(err.Error()) {
t.Errorf("error should match expected error message (actual: %s)", err.Error())
}
}
func TestExecArgsMismatch(t *testing.T) {
c := &conn{
expectations: []expectation{
&expectedExec{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{
err: errors.New("WillReturnError"),
},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")),
args: []driver.Value{456},
},
},
},
}
res, err := c.Exec("query", []driver.Value{123})
if res != nil {
t.Error("Result should be nil")
}
if err == nil {
t.Error("error should not be nil")
}
pattern := regexp.MustCompile(regexp.QuoteMeta("does not match expected"))
if !pattern.MatchString(err.Error()) {
t.Errorf("error should match expected error message (actual: %s)", err.Error())
}
}
func TestExecWillReturnError(t *testing.T) {
c := &conn{
expectations: []expectation{
&expectedExec{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{
err: errors.New("WillReturnError"),
},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")),
},
},
},
}
res, err := c.Exec("query", []driver.Value{123})
if res != nil {
t.Error("Result should be nil")
}
if err == nil {
t.Error("error should not be nil")
}
if err.Error() != "WillReturnError" {
t.Errorf("error should match expected error message (actual: %s)", err.Error())
}
}
func TestExecMissingResult(t *testing.T) {
c := &conn{
expectations: []expectation{
&expectedExec{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")),
args: []driver.Value{123},
},
},
},
}
res, err := c.Exec("query", []driver.Value{123})
if res != nil {
t.Error("Result should be nil")
}
if err == nil {
t.Error("error should not be nil")
}
pattern := regexp.MustCompile(regexp.QuoteMeta("must return a database/sql/driver.result, but it was not set for expectation"))
if !pattern.MatchString(err.Error()) {
t.Errorf("error should match expected error message (actual: %s)", err.Error())
}
}
func TestExec(t *testing.T) {
expectedResult := driver.Result(&result{})
c := &conn{
expectations: []expectation{
&expectedExec{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")),
args: []driver.Value{123},
},
result: expectedResult,
},
},
}
res, err := c.Exec("query", []driver.Value{123})
if res == nil {
t.Error("Result should not be nil")
}
if res != expectedResult {
t.Errorf("Result should match expected Result (actual %+v)", res)
}
if err != nil {
t.Errorf("error should be nil (actual %s)", err.Error())
}
}
func TestQueryNoExpectations(t *testing.T) {
c := &conn{
expectations: []expectation{
&expectedQuery{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{
triggered: true,
err: errors.New("WillReturnError"),
},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("otherquery")),
args: []driver.Value{456},
},
},
},
}
res, err := c.Query("query", []driver.Value{123})
if res != nil {
t.Error("Rows should be nil")
}
if err == nil {
t.Error("error should not be nil")
}
pattern := regexp.MustCompile(regexp.QuoteMeta("all expectations were already fulfilled, call to query"))
if !pattern.MatchString(err.Error()) {
t.Errorf("error should match expected error message (actual: %s)", err.Error())
}
}
func TestQueryExpectationMismatch(t *testing.T) {
c := &conn{
expectations: []expectation{
&expectedExec{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{
err: errors.New("WillReturnError"),
},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("otherquery")),
args: []driver.Value{456},
},
},
},
}
res, err := c.Query("query", []driver.Value{123})
if res != nil {
t.Error("Rows should be nil")
}
if err == nil {
t.Error("error should not be nil")
}
pattern := regexp.MustCompile(regexp.QuoteMeta("was not expected, next expectation is"))
if !pattern.MatchString(err.Error()) {
t.Errorf("error should match expected error message (actual: %s)", err.Error())
}
}
func TestQueryQueryMismatch(t *testing.T) {
c := &conn{
expectations: []expectation{
&expectedQuery{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{
err: errors.New("WillReturnError"),
},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("otherquery")),
args: []driver.Value{456},
},
},
},
}
res, err := c.Query("query", []driver.Value{123})
if res != nil {
t.Error("Rows should be nil")
}
if err == nil {
t.Error("error should not be nil")
}
pattern := regexp.MustCompile(regexp.QuoteMeta("does not match regex"))
if !pattern.MatchString(err.Error()) {
t.Errorf("error should match expected error message (actual: %s)", err.Error())
}
}
func TestQueryArgsMismatch(t *testing.T) {
c := &conn{
expectations: []expectation{
&expectedQuery{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{
err: errors.New("WillReturnError"),
},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")),
args: []driver.Value{456},
},
},
},
}
res, err := c.Query("query", []driver.Value{123})
if res != nil {
t.Error("Rows should be nil")
}
if err == nil {
t.Error("error should not be nil")
}
pattern := regexp.MustCompile(regexp.QuoteMeta("does not match expected"))
if !pattern.MatchString(err.Error()) {
t.Errorf("error should match expected error message (actual: %s)", err.Error())
}
}
func TestQueryWillReturnError(t *testing.T) {
c := &conn{
expectations: []expectation{
&expectedQuery{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{
err: errors.New("WillReturnError"),
},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")),
},
},
},
}
res, err := c.Query("query", []driver.Value{123})
if res != nil {
t.Error("Rows should be nil")
}
if err == nil {
t.Error("error should not be nil")
}
if err.Error() != "WillReturnError" {
t.Errorf("error should match expected error message (actual: %s)", err.Error())
}
}
func TestQueryMissingRows(t *testing.T) {
c := &conn{
expectations: []expectation{
&expectedQuery{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")),
args: []driver.Value{123},
},
},
},
}
res, err := c.Query("query", []driver.Value{123})
if res != nil {
t.Error("Rows should be nil")
}
if err == nil {
t.Error("error should not be nil")
}
pattern := regexp.MustCompile(regexp.QuoteMeta("must return a database/sql/driver.rows, but it was not set for expectation"))
if !pattern.MatchString(err.Error()) {
t.Errorf("error should match expected error message (actual: %s)", err.Error())
}
}
func TestQuery(t *testing.T) {
expectedRows := driver.Rows(&rows{})
c := &conn{
expectations: []expectation{
&expectedQuery{
queryBasedExpectation: queryBasedExpectation{
commonExpectation: commonExpectation{},
sqlRegex: regexp.MustCompile(regexp.QuoteMeta("query")),
args: []driver.Value{123},
},
rows: expectedRows,
},
},
}
rows, err := c.Query("query", []driver.Value{123})
if rows == nil {
t.Error("Rows should not be nil")
}
if rows != expectedRows {
t.Errorf("Rows should match expected Rows (actual %+v)", rows)
}
if err != nil {
t.Errorf("error should be nil (actual %s)", err.Error())
}
}

View File

@ -0,0 +1,126 @@
package sqlmock
import (
"database/sql/driver"
"reflect"
"regexp"
)
// Argument interface allows to match
// any argument in specific way
type Argument interface {
Match(driver.Value) bool
}
// an expectation interface
type expectation interface {
fulfilled() bool
setError(err error)
}
// common expectation struct
// satisfies the expectation interface
type commonExpectation struct {
triggered bool
err error
}
func (e *commonExpectation) fulfilled() bool {
return e.triggered
}
func (e *commonExpectation) setError(err error) {
e.err = err
}
// query based expectation
// adds a query matching logic
type queryBasedExpectation struct {
commonExpectation
sqlRegex *regexp.Regexp
args []driver.Value
}
func (e *queryBasedExpectation) queryMatches(sql string) bool {
return e.sqlRegex.MatchString(sql)
}
func (e *queryBasedExpectation) argsMatches(args []driver.Value) bool {
if nil == e.args {
return true
}
if len(args) != len(e.args) {
return false
}
for k, v := range args {
matcher, ok := e.args[k].(Argument)
if ok {
if !matcher.Match(v) {
return false
}
continue
}
vi := reflect.ValueOf(v)
ai := reflect.ValueOf(e.args[k])
switch vi.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if vi.Int() != ai.Int() {
return false
}
case reflect.Float32, reflect.Float64:
if vi.Float() != ai.Float() {
return false
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if vi.Uint() != ai.Uint() {
return false
}
case reflect.String:
if vi.String() != ai.String() {
return false
}
default:
// compare types like time.Time based on type only
if vi.Kind() != ai.Kind() {
return false
}
}
}
return true
}
// begin transaction
type expectedBegin struct {
commonExpectation
}
// tx commit
type expectedCommit struct {
commonExpectation
}
// tx rollback
type expectedRollback struct {
commonExpectation
}
// query expectation
type expectedQuery struct {
queryBasedExpectation
rows driver.Rows
}
// exec query expectation
type expectedExec struct {
queryBasedExpectation
result driver.Result
}
// Prepare expectation
type expectedPrepare struct {
commonExpectation
statement driver.Stmt
}

View File

@ -0,0 +1,73 @@
package sqlmock
import (
"database/sql/driver"
"regexp"
"testing"
"time"
)
type matcher struct {
}
func (m matcher) Match(driver.Value) bool {
return true
}
func TestQueryExpectationArgComparison(t *testing.T) {
e := &queryBasedExpectation{}
against := []driver.Value{5}
if !e.argsMatches(against) {
t.Error("arguments should match, since the no expectation was set")
}
e.args = []driver.Value{5, "str"}
against = []driver.Value{5}
if e.argsMatches(against) {
t.Error("arguments should not match, since the size is not the same")
}
against = []driver.Value{3, "str"}
if e.argsMatches(against) {
t.Error("arguments should not match, since the first argument (int value) is different")
}
against = []driver.Value{5, "st"}
if e.argsMatches(against) {
t.Error("arguments should not match, since the second argument (string value) is different")
}
against = []driver.Value{5, "str"}
if !e.argsMatches(against) {
t.Error("arguments should match, but it did not")
}
e.args = []driver.Value{5, time.Now()}
const longForm = "Jan 2, 2006 at 3:04pm (MST)"
tm, _ := time.Parse(longForm, "Feb 3, 2013 at 7:54pm (PST)")
against = []driver.Value{5, tm}
if !e.argsMatches(against) {
t.Error("arguments should match (time will be compared only by type), but it did not")
}
against = []driver.Value{5, matcher{}}
if !e.argsMatches(against) {
t.Error("arguments should match, but it did not")
}
}
func TestQueryExpectationSqlMatch(t *testing.T) {
e := &expectedExec{}
e.sqlRegex = regexp.MustCompile("SELECT x FROM")
if !e.queryMatches("SELECT x FROM someting") {
t.Errorf("Sql must have matched the query")
}
e.sqlRegex = regexp.MustCompile("SELECT COUNT\\(x\\) FROM")
if !e.queryMatches("SELECT COUNT(x) FROM someting") {
t.Errorf("Sql must have matched the query")
}
}

View File

@ -0,0 +1,39 @@
package sqlmock
import (
"database/sql/driver"
)
// Result satisfies sql driver Result, which
// holds last insert id and rows affected
// by Exec queries
type result struct {
insertID int64
rowsAffected int64
err error
}
// NewResult creates a new sql driver Result
// for Exec based query mocks.
func NewResult(lastInsertID int64, rowsAffected int64) driver.Result {
return &result{
insertID: lastInsertID,
rowsAffected: rowsAffected,
}
}
// NewErrorResult creates a new sql driver Result
// which returns an error given for both interface methods
func NewErrorResult(err error) driver.Result {
return &result{
err: err,
}
}
func (r *result) LastInsertId() (int64, error) {
return r.insertID, r.err
}
func (r *result) RowsAffected() (int64, error) {
return r.rowsAffected, r.err
}

View File

@ -0,0 +1,36 @@
package sqlmock
import (
"fmt"
"testing"
)
func TestShouldReturnValidSqlDriverResult(t *testing.T) {
result := NewResult(1, 2)
id, err := result.LastInsertId()
if 1 != id {
t.Errorf("Expected last insert id to be 1, but got: %d", id)
}
if err != nil {
t.Errorf("expected no error, but got: %s", err)
}
affected, err := result.RowsAffected()
if 2 != affected {
t.Errorf("Expected affected rows to be 2, but got: %d", affected)
}
if err != nil {
t.Errorf("expected no error, but got: %s", err)
}
}
func TestShouldReturnErroeSqlDriverResult(t *testing.T) {
result := NewErrorResult(fmt.Errorf("some error"))
_, err := result.LastInsertId()
if err == nil {
t.Error("expected error, but got none")
}
_, err = result.RowsAffected()
if err == nil {
t.Error("expected error, but got none")
}
}

View File

@ -0,0 +1,120 @@
package sqlmock
import (
"database/sql/driver"
"encoding/csv"
"io"
"strings"
)
// Rows interface allows to construct rows
// which also satisfies database/sql/driver.Rows interface
type Rows interface {
driver.Rows // composed interface, supports sql driver.Rows
AddRow(...driver.Value) Rows
FromCSVString(s string) Rows
}
// a struct which implements database/sql/driver.Rows
type rows struct {
cols []string
rows [][]driver.Value
pos int
}
func (r *rows) Columns() []string {
return r.cols
}
func (r *rows) Close() error {
return nil
}
func (r *rows) Err() error {
return nil
}
// advances to next row
func (r *rows) Next(dest []driver.Value) error {
r.pos++
if r.pos > len(r.rows) {
return io.EOF // per interface spec
}
for i, col := range r.rows[r.pos-1] {
dest[i] = col
}
return nil
}
// NewRows allows Rows to be created from a group of
// sql driver.Value or from the CSV string and
// to be used as sql driver.Rows
func NewRows(columns []string) Rows {
return &rows{cols: columns}
}
// AddRow adds a row which is built from arguments
// in the same column order, returns sql driver.Rows
// compatible interface
func (r *rows) AddRow(values ...driver.Value) Rows {
if len(values) != len(r.cols) {
panic("Expected number of values to match number of columns")
}
row := make([]driver.Value, len(r.cols))
for i, v := range values {
row[i] = v
}
r.rows = append(r.rows, row)
return r
}
// FromCSVString adds rows from CSV string.
// Returns sql driver.Rows compatible interface
func (r *rows) FromCSVString(s string) Rows {
res := strings.NewReader(strings.TrimSpace(s))
csvReader := csv.NewReader(res)
for {
res, err := csvReader.Read()
if err != nil || res == nil {
break
}
row := make([]driver.Value, len(r.cols))
for i, v := range res {
row[i] = []byte(strings.TrimSpace(v))
}
r.rows = append(r.rows, row)
}
return r
}
// RowsFromCSVString creates Rows from CSV string
// to be used for mocked queries. Returns sql driver Rows interface
// ** DEPRECATED ** will be removed in the future, use Rows.FromCSVString
func RowsFromCSVString(columns []string, s string) driver.Rows {
rs := &rows{}
rs.cols = columns
r := strings.NewReader(strings.TrimSpace(s))
csvReader := csv.NewReader(r)
for {
r, err := csvReader.Read()
if err != nil || r == nil {
break
}
row := make([]driver.Value, len(columns))
for i, v := range r {
v := strings.TrimSpace(v)
row[i] = []byte(v)
}
rs.rows = append(rs.rows, row)
}
return rs
}

View File

@ -0,0 +1,195 @@
/*
Package sqlmock provides sql driver mock connecection, which allows to test database,
create expectations and ensure the correct execution flow of any database operations.
It hooks into Go standard library's database/sql package.
The package provides convenient methods to mock database queries, transactions and
expect the right execution flow, compare query arguments or even return error instead
to simulate failures. See the example bellow, which illustrates how convenient it is
to work with:
package main
import (
"database/sql"
"github.com/DATA-DOG/go-sqlmock"
"testing"
"fmt"
)
// will test that order with a different status, cannot be cancelled
func TestShouldNotCancelOrderWithNonPendingStatus(t *testing.T) {
// open database stub
db, err := sql.Open("mock", "")
if err != nil {
t.Errorf("An error '%s' was not expected when opening a stub database connection", err)
}
// columns to be used for result
columns := []string{"id", "status"}
// expect transaction begin
sqlmock.ExpectBegin()
// expect query to fetch order, match it with regexp
sqlmock.ExpectQuery("SELECT (.+) FROM orders (.+) FOR UPDATE").
WithArgs(1).
WillReturnRows(sqlmock.NewRows(columns).FromCSVString("1,1"))
// expect transaction rollback, since order status is "cancelled"
sqlmock.ExpectRollback()
// run the cancel order function
someOrderId := 1
// call a function which executes expected database operations
err = cancelOrder(someOrderId, db)
if err != nil {
t.Errorf("Expected no error, but got %s instead", err)
}
// db.Close() ensures that all expectations have been met
if err = db.Close(); err != nil {
t.Errorf("Error '%s' was not expected while closing the database", err)
}
}
*/
package sqlmock
import (
"database/sql"
"database/sql/driver"
"fmt"
"regexp"
)
var mock *mockDriver
// Mock interface defines a mock which is returned
// by any expectation and can be detailed further
// with the methods this interface provides
type Mock interface {
WithArgs(...driver.Value) Mock
WillReturnError(error) Mock
WillReturnRows(driver.Rows) Mock
WillReturnResult(driver.Result) Mock
}
type mockDriver struct {
conn *conn
}
func (d *mockDriver) Open(dsn string) (driver.Conn, error) {
return mock.conn, nil
}
func init() {
mock = &mockDriver{&conn{}}
sql.Register("mock", mock)
}
// New creates sqlmock database connection
// and pings it so that all expectations could be
// asserted on Close.
func New() (db *sql.DB, err error) {
db, err = sql.Open("mock", "")
if err != nil {
return
}
// ensure open connection, otherwise Close does not assert expectations
return db, db.Ping()
}
// ExpectBegin expects transaction to be started
func ExpectBegin() Mock {
e := &expectedBegin{}
mock.conn.expectations = append(mock.conn.expectations, e)
mock.conn.active = e
return mock.conn
}
// ExpectCommit expects transaction to be commited
func ExpectCommit() Mock {
e := &expectedCommit{}
mock.conn.expectations = append(mock.conn.expectations, e)
mock.conn.active = e
return mock.conn
}
// ExpectRollback expects transaction to be rolled back
func ExpectRollback() Mock {
e := &expectedRollback{}
mock.conn.expectations = append(mock.conn.expectations, e)
mock.conn.active = e
return mock.conn
}
// ExpectPrepare expects Query to be prepared
func ExpectPrepare() Mock {
e := &expectedPrepare{}
mock.conn.expectations = append(mock.conn.expectations, e)
mock.conn.active = e
return mock.conn
}
// WillReturnError the expectation will return an error
func (c *conn) WillReturnError(err error) Mock {
c.active.setError(err)
return c
}
// ExpectExec expects database Exec to be triggered, which will match
// the given query string as a regular expression
func ExpectExec(sqlRegexStr string) Mock {
e := &expectedExec{}
e.sqlRegex = regexp.MustCompile(sqlRegexStr)
mock.conn.expectations = append(mock.conn.expectations, e)
mock.conn.active = e
return mock.conn
}
// ExpectQuery database Query to be triggered, which will match
// the given query string as a regular expression
func ExpectQuery(sqlRegexStr string) Mock {
e := &expectedQuery{}
e.sqlRegex = regexp.MustCompile(sqlRegexStr)
mock.conn.expectations = append(mock.conn.expectations, e)
mock.conn.active = e
return mock.conn
}
// WithArgs expectation should be called with given arguments.
// Works with Exec and Query expectations
func (c *conn) WithArgs(args ...driver.Value) Mock {
eq, ok := c.active.(*expectedQuery)
if !ok {
ee, ok := c.active.(*expectedExec)
if !ok {
panic(fmt.Sprintf("arguments may be expected only with query based expectations, current is %T", c.active))
}
ee.args = args
} else {
eq.args = args
}
return c
}
// WillReturnResult expectation will return a Result.
// Works only with Exec expectations
func (c *conn) WillReturnResult(result driver.Result) Mock {
eq, ok := c.active.(*expectedExec)
if !ok {
panic(fmt.Sprintf("driver.result may be returned only by exec expectations, current is %T", c.active))
}
eq.result = result
return c
}
// WillReturnRows expectation will return Rows.
// Works only with Query expectations
func (c *conn) WillReturnRows(rows driver.Rows) Mock {
eq, ok := c.active.(*expectedQuery)
if !ok {
panic(fmt.Sprintf("driver.rows may be returned only by query expectations, current is %T", c.active))
}
eq.rows = rows
return c
}

View File

@ -0,0 +1,532 @@
package sqlmock
import (
"database/sql"
"fmt"
"testing"
"time"
)
func TestIssue14EscapeSQL(t *testing.T) {
db, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
ExpectExec("INSERT INTO mytable\\(a, b\\)").
WithArgs("A", "B").
WillReturnResult(NewResult(1, 1))
_, err = db.Exec("INSERT INTO mytable(a, b) VALUES (?, ?)", "A", "B")
if err != nil {
t.Errorf("error '%s' was not expected, while inserting a row", err)
}
err = db.Close()
if err != nil {
t.Errorf("error '%s' was not expected while closing the database", err)
}
}
// test the case when db is not triggered and expectations
// are not asserted on close
func TestIssue4(t *testing.T) {
db, err := New()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
ExpectQuery("some sql query which will not be called").
WillReturnRows(NewRows([]string{"id"}))
err = db.Close()
if err == nil {
t.Errorf("Was expecting an error, since expected query was not matched")
}
}
func TestMockQuery(t *testing.T) {
db, err := sql.Open("mock", "")
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
rs := NewRows([]string{"id", "title"}).FromCSVString("5,hello world")
ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
WithArgs(5).
WillReturnRows(rs)
rows, err := db.Query("SELECT (.+) FROM articles WHERE id = ?", 5)
if err != nil {
t.Errorf("error '%s' was not expected while retrieving mock rows", err)
}
defer func() {
if er := rows.Close(); er != nil {
t.Error("Unexpected error while trying to close rows")
}
}()
if !rows.Next() {
t.Error("it must have had one row as result, but got empty result set instead")
}
var id int
var title string
err = rows.Scan(&id, &title)
if err != nil {
t.Errorf("error '%s' was not expected while trying to scan row", err)
}
if id != 5 {
t.Errorf("expected mocked id to be 5, but got %d instead", id)
}
if title != "hello world" {
t.Errorf("expected mocked title to be 'hello world', but got '%s' instead", title)
}
if err = db.Close(); err != nil {
t.Errorf("error '%s' was not expected while closing the database", err)
}
}
func TestMockQueryTypes(t *testing.T) {
db, err := sql.Open("mock", "")
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
columns := []string{"id", "timestamp", "sold"}
timestamp := time.Now()
rs := NewRows(columns)
rs.AddRow(5, timestamp, true)
ExpectQuery("SELECT (.+) FROM sales WHERE id = ?").
WithArgs(5).
WillReturnRows(rs)
rows, err := db.Query("SELECT (.+) FROM sales WHERE id = ?", 5)
if err != nil {
t.Errorf("error '%s' was not expected while retrieving mock rows", err)
}
defer func() {
if er := rows.Close(); er != nil {
t.Error("Unexpected error while trying to close rows")
}
}()
if !rows.Next() {
t.Error("it must have had one row as result, but got empty result set instead")
}
var id int
var time time.Time
var sold bool
err = rows.Scan(&id, &time, &sold)
if err != nil {
t.Errorf("error '%s' was not expected while trying to scan row", err)
}
if id != 5 {
t.Errorf("expected mocked id to be 5, but got %d instead", id)
}
if time != timestamp {
t.Errorf("expected mocked time to be %s, but got '%s' instead", timestamp, time)
}
if sold != true {
t.Errorf("expected mocked boolean to be true, but got %v instead", sold)
}
if err = db.Close(); err != nil {
t.Errorf("error '%s' was not expected while closing the database", err)
}
}
func TestTransactionExpectations(t *testing.T) {
db, err := sql.Open("mock", "")
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
// begin and commit
ExpectBegin()
ExpectCommit()
tx, err := db.Begin()
if err != nil {
t.Errorf("an error '%s' was not expected when beginning a transaction", err)
}
err = tx.Commit()
if err != nil {
t.Errorf("an error '%s' was not expected when commiting a transaction", err)
}
// begin and rollback
ExpectBegin()
ExpectRollback()
tx, err = db.Begin()
if err != nil {
t.Errorf("an error '%s' was not expected when beginning a transaction", err)
}
err = tx.Rollback()
if err != nil {
t.Errorf("an error '%s' was not expected when rolling back a transaction", err)
}
// begin with an error
ExpectBegin().WillReturnError(fmt.Errorf("some err"))
tx, err = db.Begin()
if err == nil {
t.Error("an error was expected when beginning a transaction, but got none")
}
if err = db.Close(); err != nil {
t.Errorf("error '%s' was not expected while closing the database", err)
}
}
func TestPrepareExpectations(t *testing.T) {
db, err := sql.Open("mock", "")
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
// no expectations, w/o ExpectPrepare()
stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ?")
if err != nil {
t.Errorf("error '%s' was not expected while creating a prepared statement", err)
}
if stmt == nil {
t.Errorf("stmt was expected while creating a prepared statement")
}
// expect something else, w/o ExpectPrepare()
var id int
var title string
rs := NewRows([]string{"id", "title"}).FromCSVString("5,hello world")
ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
WithArgs(5).
WillReturnRows(rs)
stmt, err = db.Prepare("SELECT (.+) FROM articles WHERE id = ?")
if err != nil {
t.Errorf("error '%s' was not expected while creating a prepared statement", err)
}
if stmt == nil {
t.Errorf("stmt was expected while creating a prepared statement")
}
err = stmt.QueryRow(5).Scan(&id, &title)
if err != nil {
t.Errorf("error '%s' was not expected while retrieving mock rows", err)
}
// expect normal result
ExpectPrepare()
stmt, err = db.Prepare("SELECT (.+) FROM articles WHERE id = ?")
if err != nil {
t.Errorf("error '%s' was not expected while creating a prepared statement", err)
}
if stmt == nil {
t.Errorf("stmt was expected while creating a prepared statement")
}
// expect error result
ExpectPrepare().WillReturnError(fmt.Errorf("Some DB error occurred"))
stmt, err = db.Prepare("SELECT (.+) FROM articles WHERE id = ?")
if err == nil {
t.Error("error was expected while creating a prepared statement")
}
if stmt != nil {
t.Errorf("stmt was not expected while creating a prepared statement returning error")
}
if err = db.Close(); err != nil {
t.Errorf("error '%s' was not expected while closing the database", err)
}
}
func TestPreparedQueryExecutions(t *testing.T) {
db, err := sql.Open("mock", "")
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
rs1 := NewRows([]string{"id", "title"}).FromCSVString("5,hello world")
ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
WithArgs(5).
WillReturnRows(rs1)
rs2 := NewRows([]string{"id", "title"}).FromCSVString("2,whoop")
ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
WithArgs(2).
WillReturnRows(rs2)
stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ?")
if err != nil {
t.Errorf("error '%s' was not expected while creating a prepared statement", err)
}
var id int
var title string
err = stmt.QueryRow(5).Scan(&id, &title)
if err != nil {
t.Errorf("error '%s' was not expected querying row from statement and scanning", err)
}
if id != 5 {
t.Errorf("expected mocked id to be 5, but got %d instead", id)
}
if title != "hello world" {
t.Errorf("expected mocked title to be 'hello world', but got '%s' instead", title)
}
err = stmt.QueryRow(2).Scan(&id, &title)
if err != nil {
t.Errorf("error '%s' was not expected querying row from statement and scanning", err)
}
if id != 2 {
t.Errorf("expected mocked id to be 2, but got %d instead", id)
}
if title != "whoop" {
t.Errorf("expected mocked title to be 'whoop', but got '%s' instead", title)
}
if err = db.Close(); err != nil {
t.Errorf("error '%s' was not expected while closing the database", err)
}
}
func TestUnexpectedOperations(t *testing.T) {
db, err := sql.Open("mock", "")
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ?")
if err != nil {
t.Errorf("error '%s' was not expected while creating a prepared statement", err)
}
var id int
var title string
err = stmt.QueryRow(5).Scan(&id, &title)
if err == nil {
t.Error("error was expected querying row, since there was no such expectation")
}
ExpectRollback()
err = db.Close()
if err == nil {
t.Error("error was expected while closing the database, expectation was not fulfilled", err)
}
}
func TestWrongExpectations(t *testing.T) {
db, err := sql.Open("mock", "")
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
ExpectBegin()
rs1 := NewRows([]string{"id", "title"}).FromCSVString("5,hello world")
ExpectQuery("SELECT (.+) FROM articles WHERE id = ?").
WithArgs(5).
WillReturnRows(rs1)
ExpectCommit().WillReturnError(fmt.Errorf("deadlock occured"))
ExpectRollback() // won't be triggered
stmt, err := db.Prepare("SELECT (.+) FROM articles WHERE id = ? FOR UPDATE")
if err != nil {
t.Errorf("error '%s' was not expected while creating a prepared statement", err)
}
var id int
var title string
err = stmt.QueryRow(5).Scan(&id, &title)
if err == nil {
t.Error("error was expected while querying row, since there begin transaction expectation is not fulfilled")
}
// lets go around and start transaction
tx, err := db.Begin()
if err != nil {
t.Errorf("an error '%s' was not expected when beginning a transaction", err)
}
err = stmt.QueryRow(5).Scan(&id, &title)
if err != nil {
t.Errorf("error '%s' was not expected while querying row, since transaction was started", err)
}
err = tx.Commit()
if err == nil {
t.Error("a deadlock error was expected when commiting a transaction", err)
}
err = db.Close()
if err == nil {
t.Error("error was expected while closing the database, expectation was not fulfilled", err)
}
}
func TestExecExpectations(t *testing.T) {
db, err := sql.Open("mock", "")
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
result := NewResult(1, 1)
ExpectExec("^INSERT INTO articles").
WithArgs("hello").
WillReturnResult(result)
res, err := db.Exec("INSERT INTO articles (title) VALUES (?)", "hello")
if err != nil {
t.Errorf("error '%s' was not expected, while inserting a row", err)
}
id, err := res.LastInsertId()
if err != nil {
t.Errorf("error '%s' was not expected, while getting a last insert id", err)
}
affected, err := res.RowsAffected()
if err != nil {
t.Errorf("error '%s' was not expected, while getting affected rows", err)
}
if id != 1 {
t.Errorf("expected last insert id to be 1, but got %d instead", id)
}
if affected != 1 {
t.Errorf("expected affected rows to be 1, but got %d instead", affected)
}
if err = db.Close(); err != nil {
t.Errorf("error '%s' was not expected while closing the database", err)
}
}
func TestRowBuilderAndNilTypes(t *testing.T) {
db, err := sql.Open("mock", "")
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
rs := NewRows([]string{"id", "active", "created", "status"}).
AddRow(1, true, time.Now(), 5).
AddRow(2, false, nil, nil)
ExpectQuery("SELECT (.+) FROM sales").WillReturnRows(rs)
rows, err := db.Query("SELECT * FROM sales")
if err != nil {
t.Errorf("error '%s' was not expected while retrieving mock rows", err)
}
defer func() {
if er := rows.Close(); er != nil {
t.Error("Unexpected error while trying to close rows")
}
}()
// NullTime and NullInt are used from stubs_test.go
var (
id int
active bool
created NullTime
status NullInt
)
if !rows.Next() {
t.Error("it must have had row in rows, but got empty result set instead")
}
err = rows.Scan(&id, &active, &created, &status)
if err != nil {
t.Errorf("error '%s' was not expected while trying to scan row", err)
}
if id != 1 {
t.Errorf("expected mocked id to be 1, but got %d instead", id)
}
if !active {
t.Errorf("expected 'active' to be 'true', but got '%v' instead", active)
}
if !created.Valid {
t.Errorf("expected 'created' to be valid, but it %+v is not", created)
}
if !status.Valid {
t.Errorf("expected 'status' to be valid, but it %+v is not", status)
}
if status.Integer != 5 {
t.Errorf("expected 'status' to be '5', but got '%d'", status.Integer)
}
// test second row
if !rows.Next() {
t.Error("it must have had row in rows, but got empty result set instead")
}
err = rows.Scan(&id, &active, &created, &status)
if err != nil {
t.Errorf("error '%s' was not expected while trying to scan row", err)
}
if id != 2 {
t.Errorf("expected mocked id to be 2, but got %d instead", id)
}
if active {
t.Errorf("expected 'active' to be 'false', but got '%v' instead", active)
}
if created.Valid {
t.Errorf("expected 'created' to be invalid, but it %+v is not", created)
}
if status.Valid {
t.Errorf("expected 'status' to be invalid, but it %+v is not", status)
}
if err = db.Close(); err != nil {
t.Errorf("error '%s' was not expected while closing the database", err)
}
}
func TestArgumentReflectValueTypeError(t *testing.T) {
db, err := sql.Open("mock", "")
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}
rs := NewRows([]string{"id"}).AddRow(1)
ExpectQuery("SELECT (.+) FROM sales").WithArgs(5.5).WillReturnRows(rs)
_, err = db.Query("SELECT * FROM sales WHERE x = ?", 5)
if err == nil {
t.Error("Expected error, but got none")
}
}

View File

@ -0,0 +1,26 @@
package sqlmock
import (
"database/sql/driver"
)
type statement struct {
conn *conn
query string
}
func (stmt *statement) Close() error {
return nil
}
func (stmt *statement) NumInput() int {
return -1
}
func (stmt *statement) Exec(args []driver.Value) (driver.Result, error) {
return stmt.conn.Exec(stmt.query, args)
}
func (stmt *statement) Query(args []driver.Value) (driver.Rows, error) {
return stmt.conn.Query(stmt.query, args)
}

View File

@ -0,0 +1,76 @@
package sqlmock
import (
"database/sql/driver"
"fmt"
"strconv"
"time"
)
type NullTime struct {
Time time.Time
Valid bool // Valid is true if Time is not NULL
}
type NullInt struct {
Integer int
Valid bool
}
// Satisfy sql.Scanner interface
func (ni *NullInt) Scan(value interface{}) (err error) {
if value == nil {
ni.Integer, ni.Valid = 0, false
return
}
switch v := value.(type) {
case int, int8, int16, int32, int64:
ni.Integer, ni.Valid = v.(int), true
return
case []byte:
ni.Integer, err = strconv.Atoi(string(v))
ni.Valid = (err == nil)
return
case string:
ni.Integer, err = strconv.Atoi(v)
ni.Valid = (err == nil)
return
}
ni.Valid = false
return fmt.Errorf("Can't convert %T to integer", value)
}
// Satisfy sql.Valuer interface.
func (ni NullInt) Value() (driver.Value, error) {
if !ni.Valid {
return nil, nil
}
return ni.Integer, nil
}
// Satisfy sql.Scanner interface
func (nt *NullTime) Scan(value interface{}) (err error) {
if value == nil {
nt.Time, nt.Valid = time.Time{}, false
return
}
switch v := value.(type) {
case time.Time:
nt.Time, nt.Valid = v, true
return
}
nt.Valid = false
return fmt.Errorf("Can't convert %T to time.Time", value)
}
// Satisfy sql.Valuer interface.
func (nt NullTime) Value() (driver.Value, error) {
if !nt.Valid {
return nil, nil
}
return nt.Time, nil
}

View File

@ -0,0 +1,37 @@
package sqlmock
import (
"fmt"
)
type transaction struct {
conn *conn
}
func (tx *transaction) Commit() error {
e := tx.conn.next()
if e == nil {
return fmt.Errorf("all expectations were already fulfilled, call to commit transaction was not expected")
}
etc, ok := e.(*expectedCommit)
if !ok {
return fmt.Errorf("call to commit transaction, was not expected, next expectation was %v", e)
}
etc.triggered = true
return etc.err
}
func (tx *transaction) Rollback() error {
e := tx.conn.next()
if e == nil {
return fmt.Errorf("all expectations were already fulfilled, call to rollback transaction was not expected")
}
etr, ok := e.(*expectedRollback)
if !ok {
return fmt.Errorf("call to rollback transaction, was not expected, next expectation was %v", e)
}
etr.triggered = true
return etr.err
}

View File

@ -0,0 +1,17 @@
package sqlmock
import (
"regexp"
"strings"
)
var re *regexp.Regexp
func init() {
re = regexp.MustCompile("\\s+")
}
// strip out new lines and trim spaces
func stripQuery(q string) (s string) {
return strings.TrimSpace(re.ReplaceAllString(q, " "))
}

View File

@ -0,0 +1,21 @@
package sqlmock
import (
"testing"
)
func TestQueryStringStripping(t *testing.T) {
assert := func(actual, expected string) {
if res := stripQuery(actual); res != expected {
t.Errorf("Expected '%s' to be '%s', but got '%s'", actual, expected, res)
}
}
assert(" SELECT 1", "SELECT 1")
assert("SELECT 1 FROM d", "SELECT 1 FROM d")
assert(`
SELECT c
FROM D
`, "SELECT c FROM D")
assert("UPDATE (.+) SET ", "UPDATE (.+) SET")
}

View File

@ -1,10 +1,13 @@
package handlers
import (
"bytes"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"strings"
"github.com/Sirupsen/logrus"
"github.com/endophage/gotuf/data"
@ -35,6 +38,86 @@ func MainHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) *e
return nil
}
// AtomicUpdateHandler will accept multiple TUF files and ensure that the storage
// backend is atomically updated with all the new records.
func AtomicUpdateHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) *errors.HTTPError {
defer r.Body.Close()
s := ctx.Value("metaStore")
if s == nil {
return &errors.HTTPError{
HTTPStatus: http.StatusInternalServerError,
Code: 9999,
Err: fmt.Errorf("Version store is nil"),
}
}
store, ok := s.(storage.MetaStore)
if !ok {
return &errors.HTTPError{
HTTPStatus: http.StatusInternalServerError,
Code: 9999,
Err: fmt.Errorf("Version store not configured"),
}
}
vars := mux.Vars(r)
gun := vars["imageName"]
reader, err := r.MultipartReader()
if err != nil {
return &errors.HTTPError{
HTTPStatus: http.StatusBadRequest,
Code: 9999,
Err: err,
}
}
var updates []storage.MetaUpdate
for {
part, err := reader.NextPart()
if err == io.EOF {
break
}
role := strings.TrimSuffix(part.FileName(), ".json")
if role == "" {
return &errors.HTTPError{
HTTPStatus: http.StatusBadRequest,
Code: 9999,
Err: fmt.Errorf("Empty filename provided. No updates performed"),
}
} else if !data.ValidRole(role) {
return &errors.HTTPError{
HTTPStatus: http.StatusBadRequest,
Code: 9999,
Err: fmt.Errorf("Invalid role: %s. No updates performed", role),
}
}
meta := &data.SignedTargets{}
var input []byte
inBuf := bytes.NewBuffer(input)
dec := json.NewDecoder(io.TeeReader(part, inBuf))
err = dec.Decode(meta)
if err != nil {
return &errors.HTTPError{
HTTPStatus: http.StatusBadRequest,
Code: 9999,
Err: err,
}
}
version := meta.Signed.Version
updates = append(updates, storage.MetaUpdate{
Role: role,
Version: version,
Data: inBuf.Bytes(),
})
}
err = store.UpdateMany(gun, updates)
if err != nil {
return &errors.HTTPError{
HTTPStatus: http.StatusInternalServerError,
Code: 9999,
Err: err,
}
}
return nil
}
// UpdateHandler adds the provided json data for the role and GUN specified in the URL
func UpdateHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) *errors.HTTPError {
defer r.Body.Close()
@ -74,8 +157,12 @@ func UpdateHandler(ctx context.Context, w http.ResponseWriter, r *http.Request)
Err: err,
}
}
version := meta.Signed.Version
err = store.UpdateCurrent(gun, tufRole, version, input)
update := storage.MetaUpdate{
Role: tufRole,
Version: meta.Signed.Version,
Data: input,
}
err = store.UpdateCurrent(gun, update)
if err != nil {
return &errors.HTTPError{
HTTPStatus: http.StatusInternalServerError,

View File

@ -76,6 +76,7 @@ func Run(ctx context.Context, addr, tlsCertFile, tlsKeyFile string, trust signed
r := mux.NewRouter()
// TODO (endophage): use correct regexes for image and tag names
r.Methods("POST").Path("/v2/{imageName:.*}/_trust/tuf/").Handler(hand(handlers.AtomicUpdateHandler, "push", "pull"))
r.Methods("GET").Path("/v2/{imageName:.*}/_trust/tuf/{tufRole:(root|targets|snapshot)}.json").Handler(hand(handlers.GetHandler, "pull"))
r.Methods("GET").Path("/v2/{imageName:.*}/_trust/tuf/timestamp.json").Handler(hand(handlers.GetTimestampHandler, "pull"))
r.Methods("GET").Path("/v2/{imageName:.*}/_trust/tuf/timestamp.key").Handler(hand(handlers.GetTimestampKeyHandler, "push", "pull"))

View File

@ -3,6 +3,7 @@ package storage
import (
"database/sql"
"github.com/Sirupsen/logrus"
"github.com/endophage/gotuf/data"
"github.com/go-sql-driver/mysql"
)
@ -37,26 +38,19 @@ func NewMySQLStorage(db *sql.DB) *MySQLStorage {
// UpdateCurrent updates multiple TUF records in a single transaction.
// Always insert a new row. The unique constraint will ensure there is only ever
func (db *MySQLStorage) UpdateCurrent(gun, role string, version int, data []byte) error {
checkStmt := "SELECT count(*) FROM `tuf_files` WHERE `gun`=? AND `role`=? AND `version`>=?;"
insertStmt := "INSERT INTO `tuf_files` (`gun`, `role`, `version`, `data`) VALUES (?,?,?,?) ;"
// ensure immediately previous version exists
row := db.QueryRow(checkStmt, gun, role, version)
var exists int
err := row.Scan(&exists)
if err != nil {
return err
}
if exists != 0 {
return &ErrOldVersion{}
}
func (db *MySQLStorage) UpdateCurrent(gun string, update MetaUpdate) error {
insertStmt := "INSERT INTO `tuf_files` (`gun`, `role`, `version`, `data`) VALUES (?,?,?,?) WHERE (SELECT count(*) FROM `tuf_files` WHERE `gun`=? AND `role`=? AND `version`>=?) = 0"
// attempt to insert. Due to race conditions with the check this could fail.
// That's OK, we're doing first write wins. The client will be messaged it
// needs to rebase.
_, err = db.Exec(insertStmt, gun, role, version, data)
_, err := db.Exec(insertStmt, gun, update.Role, update.Version, update.Data, gun, update.Role, update.Version)
if err != nil {
if err, ok := err.(*mysql.MySQLError); ok {
if err.Number == 1022 { // duplicate key error
return &ErrOldVersion{}
}
}
// need to check error type for duplicate key exception
// and return ErrOldVersion if duplicate
return err
@ -64,6 +58,32 @@ func (db *MySQLStorage) UpdateCurrent(gun, role string, version int, data []byte
return nil
}
// UpdateMany atomically updates many TUF records in a single transaction
func (db *MySQLStorage) UpdateMany(gun string, updates []MetaUpdate) error {
insertStmt := "INSERT INTO `tuf_files` (`gun`, `role`, `version`, `data`) VALUES (?,?,?,?) WHERE (SELECT count(*) FROM `tuf_files` WHERE `gun`=? AND `role`=? AND `version`>=?) = 0;"
tx, err := db.Begin()
for _, u := range updates {
// attempt to insert. Due to race conditions with the check this could fail.
// That's OK, we're doing first write wins. The client will be messaged it
// needs to rebase.
_, err = tx.Exec(insertStmt, gun, u.Role, u.Version, u.Data, gun, u.Role, u.Version)
if err != nil {
// need to check error type for duplicate key exception
// and return ErrOldVersion if duplicate
rbErr := tx.Rollback()
if rbErr != nil {
logrus.Panic("Failed on Tx rollback with error: ", err.Error())
}
if err, ok := err.(*mysql.MySQLError); ok && err.Number == 1022 { // duplicate key error
return &ErrOldVersion{}
}
return err
}
}
return tx.Commit()
}
// GetCurrent gets a specific TUF record
func (db *MySQLStorage) GetCurrent(gun, tufRole string) (data []byte, err error) {
stmt := "SELECT `data` FROM `tuf_files` WHERE `gun`=? AND `role`=? ORDER BY `version` DESC LIMIT 1;"
@ -99,7 +119,7 @@ func (db *MySQLStorage) GetTimestampKey(gun string) (algorithm data.KeyAlgorithm
var cipher string
err = row.Scan(&cipher, &public)
if err == sql.ErrNoRows {
return "", nil, ErrNoKey{gun: gun}
return "", nil, &ErrNoKey{gun: gun}
} else if err != nil {
return "", nil, err
}
@ -111,11 +131,10 @@ func (db *MySQLStorage) GetTimestampKey(gun string) (algorithm data.KeyAlgorithm
func (db *MySQLStorage) SetTimestampKey(gun string, algorithm data.KeyAlgorithm, public []byte) error {
stmt := "INSERT INTO `timestamp_keys` (`gun`, `cipher`, `public`) VALUES (?,?,?);"
_, err := db.Exec(stmt, gun, string(algorithm), public)
if err, ok := err.(*mysql.MySQLError); ok {
if err.Number == 1022 { // duplicate key error
if err != nil {
if err, ok := err.(*mysql.MySQLError); ok && err.Number == 1022 {
return &ErrTimestampKeyExists{gun: gun}
}
} else if err != nil {
return err
}
return nil

View File

@ -0,0 +1,278 @@
package storage
import (
"database/sql"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/go-sql-driver/mysql"
"github.com/stretchr/testify/assert"
)
func TestMySQLUpdateCurrent(t *testing.T) {
db, err := sqlmock.New()
assert.Nil(t, err, "Could not initialize mock DB")
s := NewMySQLStorage(db)
update := MetaUpdate{
Role: "root",
Version: 0,
Data: []byte("1"),
}
sqlmock.ExpectExec("INSERT INTO `tuf_files` \\(`gun`, `role`, `version`, `data`\\) VALUES \\(\\?,\\?,\\?,\\?\\) WHERE \\(SELECT count\\(\\*\\) FROM `tuf_files` WHERE `gun`=\\? AND `role`=\\? AND `version`>=\\?\\) = 0").WithArgs(
"testGUN",
update.Role,
update.Version,
update.Data,
"testGUN",
update.Role,
update.Version,
).WillReturnResult(sqlmock.NewResult(0, 1))
err = s.UpdateCurrent(
"testGUN",
update,
)
assert.Nil(t, err, "UpdateCurrent errored unexpectedly: %v", err)
err = db.Close()
assert.Nil(t, err, "Expectation not met: %v", err)
}
func TestMySQLUpdateCurrentError(t *testing.T) {
db, err := sqlmock.New()
assert.Nil(t, err, "Could not initialize mock DB")
s := NewMySQLStorage(db)
update := MetaUpdate{
Role: "root",
Version: 0,
Data: []byte("1"),
}
sqlmock.ExpectExec("INSERT INTO `tuf_files` \\(`gun`, `role`, `version`, `data`\\) VALUES \\(\\?,\\?,\\?,\\?\\) WHERE \\(SELECT count\\(\\*\\) FROM `tuf_files` WHERE `gun`=\\? AND `role`=\\? AND `version`>=\\?\\) = 0").WithArgs(
"testGUN",
update.Role,
update.Version,
update.Data,
"testGUN",
update.Role,
update.Version,
).WillReturnError(
&mysql.MySQLError{
Number: 1022,
Message: "Duplicate key error",
},
)
err = s.UpdateCurrent(
"testGUN",
update,
)
assert.NotNil(t, err, "Error should not be nil")
assert.IsType(t, &ErrOldVersion{}, err, "Expected ErrOldVersion error type")
err = db.Close()
assert.Nil(t, err, "Expectation not met: %v", err)
}
func TestMySQLUpdateMany(t *testing.T) {
db, err := sqlmock.New()
assert.Nil(t, err, "Could not initialize mock DB")
s := NewMySQLStorage(db)
update1 := MetaUpdate{
Role: "root",
Version: 0,
Data: []byte("1"),
}
update2 := MetaUpdate{
Role: "targets",
Version: 1,
Data: []byte("2"),
}
// start transation
sqlmock.ExpectBegin()
// insert first update
sqlmock.ExpectExec("INSERT INTO `tuf_files` \\(`gun`, `role`, `version`, `data`\\) VALUES \\(\\?,\\?,\\?,\\?\\) WHERE \\(SELECT count\\(\\*\\) FROM `tuf_files` WHERE `gun`=\\? AND `role`=\\? AND `version`>=\\?\\) = 0").WithArgs(
"testGUN",
update1.Role,
update1.Version,
update1.Data,
"testGUN",
update1.Role,
update1.Version,
).WillReturnResult(sqlmock.NewResult(0, 1))
// insert second update
sqlmock.ExpectExec("INSERT INTO `tuf_files` \\(`gun`, `role`, `version`, `data`\\) VALUES \\(\\?,\\?,\\?,\\?\\) WHERE \\(SELECT count\\(\\*\\) FROM `tuf_files` WHERE `gun`=\\? AND `role`=\\? AND `version`>=\\?\\) = 0").WithArgs(
"testGUN",
update2.Role,
update2.Version,
update2.Data,
"testGUN",
update2.Role,
update2.Version,
).WillReturnResult(sqlmock.NewResult(1, 1))
// expect commit
sqlmock.ExpectCommit()
err = s.UpdateMany(
"testGUN",
[]MetaUpdate{update1, update2},
)
assert.Nil(t, err, "UpdateMany errored unexpectedly: %v", err)
err = db.Close()
assert.Nil(t, err, "Expectation not met: %v", err)
}
func TestMySQLUpdateManyRollback(t *testing.T) {
db, err := sqlmock.New()
assert.Nil(t, err, "Could not initialize mock DB")
s := NewMySQLStorage(db)
update1 := MetaUpdate{
Role: "root",
Version: 0,
Data: []byte("1"),
}
execError := mysql.MySQLError{}
// start transation
sqlmock.ExpectBegin()
// insert first update
sqlmock.ExpectExec("INSERT INTO `tuf_files` \\(`gun`, `role`, `version`, `data`\\) VALUES \\(\\?,\\?,\\?,\\?\\) WHERE \\(SELECT count\\(\\*\\) FROM `tuf_files` WHERE `gun`=\\? AND `role`=\\? AND `version`>=\\?\\) = 0").WithArgs(
"testGUN",
update1.Role,
update1.Version,
update1.Data,
"testGUN",
update1.Role,
update1.Version,
).WillReturnError(&execError)
// expect commit
sqlmock.ExpectRollback()
err = s.UpdateMany(
"testGUN",
[]MetaUpdate{update1},
)
assert.IsType(t, &execError, err, "UpdateMany returned wrong error type")
err = db.Close()
assert.Nil(t, err, "Expectation not met: %v", err)
}
func TestMySQLUpdateManyDuplicate(t *testing.T) {
db, err := sqlmock.New()
assert.Nil(t, err, "Could not initialize mock DB")
s := NewMySQLStorage(db)
update1 := MetaUpdate{
Role: "root",
Version: 0,
Data: []byte("1"),
}
execError := mysql.MySQLError{Number: 1022}
// start transation
sqlmock.ExpectBegin()
// insert first update
sqlmock.ExpectExec("INSERT INTO `tuf_files` \\(`gun`, `role`, `version`, `data`\\) VALUES \\(\\?,\\?,\\?,\\?\\) WHERE \\(SELECT count\\(\\*\\) FROM `tuf_files` WHERE `gun`=\\? AND `role`=\\? AND `version`>=\\?\\) = 0").WithArgs(
"testGUN",
update1.Role,
update1.Version,
update1.Data,
"testGUN",
update1.Role,
update1.Version,
).WillReturnError(&execError)
// expect commit
sqlmock.ExpectRollback()
err = s.UpdateMany(
"testGUN",
[]MetaUpdate{update1},
)
assert.IsType(t, &ErrOldVersion{}, err, "UpdateMany returned wrong error type")
err = db.Close()
assert.Nil(t, err, "Expectation not met: %v", err)
}
func TestMySQLGetCurrent(t *testing.T) {
db, err := sqlmock.New()
assert.Nil(t, err, "Could not initialize mock DB")
s := NewMySQLStorage(db)
sqlmock.ExpectQuery(
"SELECT `data` FROM `tuf_files` WHERE `gun`=\\? AND `role`=\\? ORDER BY `version` DESC LIMIT 1;",
).WithArgs("testGUN", "root").WillReturnRows(
sqlmock.RowsFromCSVString(
[]string{"data"},
"1",
),
)
byt, err := s.GetCurrent("testGUN", "root")
assert.Nil(t, err, "Expected nil error from GetCurrent")
assert.Equal(t, []byte("1"), byt, "Returned data was no correct")
// TODO(endophage): these two lines are breaking because there
// seems to be some problem with go-sqlmock
//err = db.Close()
//assert.Nil(t, err, "Expectation not met: %v", err)
}
func TestMySQLDelete(t *testing.T) {
db, err := sqlmock.New()
assert.Nil(t, err, "Could not initialize mock DB")
s := NewMySQLStorage(db)
sqlmock.ExpectExec(
"DELETE FROM `tuf_files` WHERE `gun`=\\?;",
).WithArgs("testGUN").WillReturnResult(sqlmock.NewResult(0, 1))
err = s.Delete("testGUN")
assert.Nil(t, err, "Expected nil error from Delete")
err = db.Close()
assert.Nil(t, err, "Expectation not met: %v", err)
}
func TestMySQLGetTimestampKeyNoKey(t *testing.T) {
db, err := sqlmock.New()
assert.Nil(t, err, "Could not initialize mock DB")
s := NewMySQLStorage(db)
sqlmock.ExpectQuery(
"SELECT `cipher`, `public` FROM `timestamp_keys` WHERE `gun`=\\?;",
).WithArgs("testGUN").WillReturnError(sql.ErrNoRows)
_, _, err = s.GetTimestampKey("testGUN")
assert.IsType(t, &ErrNoKey{}, err, "Expected ErrNoKey from GetTimestampKey")
//err = db.Close()
//assert.Nil(t, err, "Expectation not met: %v", err)
}
func TestMySQLSetTimestampKeyExists(t *testing.T) {
db, err := sqlmock.New()
assert.Nil(t, err, "Could not initialize mock DB")
s := NewMySQLStorage(db)
sqlmock.ExpectExec(
"INSERT INTO `timestamp_keys` \\(`gun`, `cipher`, `public`\\) VALUES \\(\\?,\\?,\\?\\);",
).WithArgs(
"testGUN",
"testCipher",
[]byte("1"),
).WillReturnError(
&mysql.MySQLError{Number: 1022},
)
err = s.SetTimestampKey("testGUN", "testCipher", []byte("1"))
assert.IsType(t, &ErrTimestampKeyExists{}, err, "Expected ErrTimestampKeyExists from SetTimestampKey")
err = db.Close()
assert.Nil(t, err, "Expectation not met: %v", err)
}

View File

@ -4,7 +4,8 @@ import "github.com/endophage/gotuf/data"
// MetaStore holds the methods that are used for a Metadata Store
type MetaStore interface {
UpdateCurrent(gun, role string, version int, data []byte) error
UpdateCurrent(gun string, update MetaUpdate) error
UpdateMany(gun string, updates []MetaUpdate) error
GetCurrent(gun, tufRole string) (data []byte, err error)
Delete(gun string) error
GetTimestampKey(gun string) (algorithm data.KeyAlgorithm, public []byte, err error)

View File

@ -35,18 +35,26 @@ func NewMemStorage() *MemStorage {
}
// UpdateCurrent updates the meta data for a specific role
func (st *MemStorage) UpdateCurrent(gun, role string, version int, data []byte) error {
id := entryKey(gun, role)
func (st *MemStorage) UpdateCurrent(gun string, update MetaUpdate) error {
id := entryKey(gun, update.Role)
st.lock.Lock()
defer st.lock.Unlock()
if space, ok := st.tufMeta[id]; ok {
for _, v := range space {
if v.version >= version {
if v.version >= update.Version {
return &ErrOldVersion{}
}
}
}
st.tufMeta[id] = append(st.tufMeta[id], &ver{version: version, data: data})
st.tufMeta[id] = append(st.tufMeta[id], &ver{version: update.Version, data: update.Data})
return nil
}
// UpdateMany updates multiple TUF records
func (st *MemStorage) UpdateMany(gun string, updates []MetaUpdate) error {
for _, u := range updates {
st.UpdateCurrent(gun, u)
}
return nil
}

View File

@ -9,7 +9,7 @@ import (
func TestUpdateCurrent(t *testing.T) {
s := NewMemStorage()
s.UpdateCurrent("gun", "role", 1, []byte("test"))
s.UpdateCurrent("gun", MetaUpdate{"role", 1, []byte("test")})
k := entryKey("gun", "role")
gun, ok := s.tufMeta[k]
@ -25,7 +25,7 @@ func TestGetCurrent(t *testing.T) {
_, err := s.GetCurrent("gun", "role")
assert.IsType(t, &ErrNotFound{}, err, "Expected error to be ErrNotFound")
s.UpdateCurrent("gun", "role", 1, []byte("test"))
s.UpdateCurrent("gun", MetaUpdate{"role", 1, []byte("test")})
d, err := s.GetCurrent("gun", "role")
assert.Nil(t, err, "Expected error to be nil")
assert.Equal(t, []byte("test"), d, "Data was incorrect")
@ -33,7 +33,7 @@ func TestGetCurrent(t *testing.T) {
func TestDelete(t *testing.T) {
s := NewMemStorage()
s.UpdateCurrent("gun", "role", 1, []byte("test"))
s.UpdateCurrent("gun", MetaUpdate{"role", 1, []byte("test")})
s.Delete("gun")
k := entryKey("gun", "role")

8
server/storage/types.go Normal file
View File

@ -0,0 +1,8 @@
package storage
// MetaUpdate packages up the fields required to update a TUF record
type MetaUpdate struct {
Role string
Version int
Data []byte
}

View File

@ -79,7 +79,7 @@ func GetOrCreateTimestamp(gun string, store storage.MetaStore, cryptoService sig
logrus.Error("Failed to marshal new timestamp")
return nil, err
}
err = store.UpdateCurrent(gun, "timestamp", version, out)
err = store.UpdateCurrent(gun, storage.MetaUpdate{Role: "timestamp", Version: version, Data: out})
if err != nil {
return nil, err
}

View File

@ -52,7 +52,7 @@ func TestGetTimestamp(t *testing.T) {
snapshot := &data.SignedSnapshot{}
snapJSON, _ := json.Marshal(snapshot)
store.UpdateCurrent("gun", "snapshot", 0, snapJSON)
store.UpdateCurrent("gun", storage.MetaUpdate{Role: "snapshot", Version: 0, Data: snapJSON})
// create a key to be used by GetTimestamp
_, err := GetOrCreateTimestampKey("gun", store, crypto, data.ED25519Key)
assert.Nil(t, err, "GetTimestampKey errored")