Misc refactorings

- Removed `state.DeleteWithOptions` and `state.SetWithOptions` which were useless at this point (likely a leftover from when there were retries)
- Some improvements in etag handling in postgres
- Other minor refactorings

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
ItalyPaleAle 2022-11-17 16:29:29 +00:00
parent d366211810
commit 1e728597b4
13 changed files with 121 additions and 215 deletions

View File

@ -130,6 +130,7 @@ func TestPublishingWithTTL(t *testing.T) {
const maxGetDuration = ttlInSeconds * time.Second const maxGetDuration = ttlInSeconds * time.Second
metadata := bindings.Metadata{ metadata := bindings.Metadata{
Base: contribMetadata.Base{
Name: "testQueue", Name: "testQueue",
Properties: map[string]string{ Properties: map[string]string{
"queueName": queueName, "queueName": queueName,
@ -137,6 +138,7 @@ func TestPublishingWithTTL(t *testing.T) {
"deleteWhenUnused": strconv.FormatBool(exclusive), "deleteWhenUnused": strconv.FormatBool(exclusive),
"durable": strconv.FormatBool(durable), "durable": strconv.FormatBool(durable),
}, },
},
} }
logger := logger.NewLogger("test") logger := logger.NewLogger("test")
@ -162,7 +164,7 @@ func TestPublishingWithTTL(t *testing.T) {
}, },
} }
_, err = rabbitMQBinding1.Invoke(context.Backgound(), &writeRequest) _, err = rabbitMQBinding1.Invoke(context.Background(), &writeRequest)
assert.Nil(t, err) assert.Nil(t, err)
time.Sleep(time.Second + (ttlInSeconds * time.Second)) time.Sleep(time.Second + (ttlInSeconds * time.Second))
@ -183,7 +185,7 @@ func TestPublishingWithTTL(t *testing.T) {
contribMetadata.TTLMetadataKey: strconv.Itoa(ttlInSeconds * 1000), contribMetadata.TTLMetadataKey: strconv.Itoa(ttlInSeconds * 1000),
}, },
} }
_, err = rabbitMQBinding2.Invoke(context.Backgound(), &writeRequest) _, err = rabbitMQBinding2.Invoke(context.Background(), &writeRequest)
assert.Nil(t, err) assert.Nil(t, err)
msg, ok, err := getMessageWithRetries(ch, queueName, maxGetDuration) msg, ok, err := getMessageWithRetries(ch, queueName, maxGetDuration)
@ -204,6 +206,7 @@ func TestExclusiveQueue(t *testing.T) {
const maxGetDuration = ttlInSeconds * time.Second const maxGetDuration = ttlInSeconds * time.Second
metadata := bindings.Metadata{ metadata := bindings.Metadata{
Base: contribMetadata.Base{
Name: "testQueue", Name: "testQueue",
Properties: map[string]string{ Properties: map[string]string{
"queueName": queueName, "queueName": queueName,
@ -213,6 +216,7 @@ func TestExclusiveQueue(t *testing.T) {
"exclusive": strconv.FormatBool(exclusive), "exclusive": strconv.FormatBool(exclusive),
contribMetadata.TTLMetadataKey: strconv.FormatInt(ttlInSeconds, 10), contribMetadata.TTLMetadataKey: strconv.FormatInt(ttlInSeconds, 10),
}, },
},
} }
logger := logger.NewLogger("test") logger := logger.NewLogger("test")
@ -257,6 +261,7 @@ func TestPublishWithPriority(t *testing.T) {
const maxPriority = 10 const maxPriority = 10
metadata := bindings.Metadata{ metadata := bindings.Metadata{
Base: contribMetadata.Base{
Name: "testQueue", Name: "testQueue",
Properties: map[string]string{ Properties: map[string]string{
"queueName": queueName, "queueName": queueName,
@ -265,6 +270,7 @@ func TestPublishWithPriority(t *testing.T) {
"durable": strconv.FormatBool(durable), "durable": strconv.FormatBool(durable),
"maxPriority": strconv.FormatInt(maxPriority, 10), "maxPriority": strconv.FormatInt(maxPriority, 10),
}, },
},
} }
logger := logger.NewLogger("test") logger := logger.NewLogger("test")
@ -283,7 +289,7 @@ func TestPublishWithPriority(t *testing.T) {
defer ch.Close() defer ch.Close()
const middlePriorityMsgContent = "middle" const middlePriorityMsgContent = "middle"
_, err = r.Invoke(context.Backgound(), &bindings.InvokeRequest{ _, err = r.Invoke(context.Background(), &bindings.InvokeRequest{
Metadata: map[string]string{ Metadata: map[string]string{
contribMetadata.PriorityMetadataKey: "5", contribMetadata.PriorityMetadataKey: "5",
}, },
@ -292,7 +298,7 @@ func TestPublishWithPriority(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
const lowPriorityMsgContent = "low" const lowPriorityMsgContent = "low"
_, err = r.Invoke(context.Backgound(), &bindings.InvokeRequest{ _, err = r.Invoke(context.Background(), &bindings.InvokeRequest{
Metadata: map[string]string{ Metadata: map[string]string{
contribMetadata.PriorityMetadataKey: "1", contribMetadata.PriorityMetadataKey: "1",
}, },
@ -301,7 +307,7 @@ func TestPublishWithPriority(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
const highPriorityMsgContent = "high" const highPriorityMsgContent = "high"
_, err = r.Invoke(context.Backgound(), &bindings.InvokeRequest{ _, err = r.Invoke(context.Background(), &bindings.InvokeRequest{
Metadata: map[string]string{ Metadata: map[string]string{
contribMetadata.PriorityMetadataKey: "10", contribMetadata.PriorityMetadataKey: "10",
}, },

View File

@ -52,11 +52,11 @@ func createIotHubPubsubMetadata() pubsub.Metadata {
metadata := pubsub.Metadata{ metadata := pubsub.Metadata{
Base: metadata.Base{ Base: metadata.Base{
Properties: map[string]string{ Properties: map[string]string{
connectionString: os.Getenv(iotHubConnectionStringEnvKey), "connectionString": os.Getenv(iotHubConnectionStringEnvKey),
consumerID: os.Getenv(iotHubConsumerGroupEnvKey), "consumerID": os.Getenv(iotHubConsumerGroupEnvKey),
storageAccountName: os.Getenv(storageAccountNameEnvKey), "storageAccountName": os.Getenv(storageAccountNameEnvKey),
storageAccountKey: os.Getenv(storageAccountKeyEnvKey), "storageAccountKey": os.Getenv(storageAccountKeyEnvKey),
storageContainerName: testStorageContainerName, "storageContainerName": testStorageContainerName,
}, },
}, },
} }

View File

@ -114,11 +114,6 @@ func (p *cockroachDBAccess) Init(metadata state.Metadata) error {
// Set makes an insert or update to the database. // Set makes an insert or update to the database.
func (p *cockroachDBAccess) Set(req *state.SetRequest) error { func (p *cockroachDBAccess) Set(req *state.SetRequest) error {
return state.SetWithOptions(p.setValue, req)
}
// setValue is an internal implementation of set to enable passing the logic to state.SetWithRetries as a func.
func (p *cockroachDBAccess) setValue(req *state.SetRequest) error {
p.logger.Debug("Setting state value in CockroachDB") p.logger.Debug("Setting state value in CockroachDB")
value, isBinary, err := validateAndReturnValue(req) value, isBinary, err := validateAndReturnValue(req)
@ -240,11 +235,6 @@ func (p *cockroachDBAccess) Get(req *state.GetRequest) (*state.GetResponse, erro
// Delete removes an item from the state store. // Delete removes an item from the state store.
func (p *cockroachDBAccess) Delete(req *state.DeleteRequest) error { func (p *cockroachDBAccess) Delete(req *state.DeleteRequest) error {
return state.DeleteWithOptions(p.deleteValue, req)
}
// deleteValue is an internal implementation of delete to enable passing the logic to state.DeleteWithRetries as a func.
func (p *cockroachDBAccess) deleteValue(req *state.DeleteRequest) error {
p.logger.Debug("Deleting state value from CockroachDB") p.logger.Debug("Deleting state value from CockroachDB")
if req.Key == "" { if req.Key == "" {
return fmt.Errorf("missing key in delete operation") return fmt.Errorf("missing key in delete operation")

View File

@ -113,7 +113,8 @@ func (f *Firestore) Get(req *state.GetRequest) (*state.GetResponse, error) {
}, nil }, nil
} }
func (f *Firestore) setValue(req *state.SetRequest) error { // Set saves state into Firestore.
func (f *Firestore) Set(req *state.SetRequest) error {
err := state.CheckRequestOptions(req.Options) err := state.CheckRequestOptions(req.Options)
if err != nil { if err != nil {
return err return err
@ -142,12 +143,8 @@ func (f *Firestore) setValue(req *state.SetRequest) error {
return nil return nil
} }
// Set saves state into Firestore with retry. // Delete performs a delete operation.
func (f *Firestore) Set(req *state.SetRequest) error { func (f *Firestore) Delete(req *state.DeleteRequest) error {
return state.SetWithOptions(f.setValue, req)
}
func (f *Firestore) deleteValue(req *state.DeleteRequest) error {
ctx := context.Background() ctx := context.Background()
key := datastore.NameKey(f.entityKind, req.Key, nil) key := datastore.NameKey(f.entityKind, req.Key, nil)
@ -159,11 +156,6 @@ func (f *Firestore) deleteValue(req *state.DeleteRequest) error {
return nil return nil
} }
// Delete performs a delete operation.
func (f *Firestore) Delete(req *state.DeleteRequest) error {
return state.DeleteWithOptions(f.deleteValue, req)
}
func getFirestoreMetadata(meta state.Metadata) (*firestoreMetadata, error) { func getFirestoreMetadata(meta state.Metadata) (*firestoreMetadata, error) {
m := firestoreMetadata{ m := firestoreMetadata{
EntityKind: defaultEntityKind, EntityKind: defaultEntityKind,

View File

@ -146,7 +146,7 @@ func (m *Memcached) parseTTL(req *state.SetRequest) (*int32, error) {
return nil, nil return nil, nil
} }
func (m *Memcached) setValue(req *state.SetRequest) error { func (m *Memcached) Set(req *state.SetRequest) error {
var bt []byte var bt []byte
ttl, err := m.parseTTL(req) ttl, err := m.parseTTL(req)
if err != nil { if err != nil {
@ -194,10 +194,6 @@ func (m *Memcached) Get(req *state.GetRequest) (*state.GetResponse, error) {
}, nil }, nil
} }
func (m *Memcached) Set(req *state.SetRequest) error {
return state.SetWithOptions(m.setValue, req)
}
func (m *Memcached) GetComponentMetadata() map[string]string { func (m *Memcached) GetComponentMetadata() map[string]string {
metadataStruct := memcachedMetadata{} metadataStruct := memcachedMetadata{}
metadataInfo := map[string]string{} metadataInfo := map[string]string{}

View File

@ -30,7 +30,6 @@ import (
"github.com/dapr/components-contrib/metadata" "github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/state" "github.com/dapr/components-contrib/state"
"github.com/dapr/kit/logger" "github.com/dapr/kit/logger"
"github.com/dapr/kit/ptr"
) )
// Optimistic Concurrency is implemented using a string column that stores a UUID. // Optimistic Concurrency is implemented using a string column that stores a UUID.
@ -337,8 +336,6 @@ func (m *MySQL) Delete(req *state.DeleteRequest) error {
return m.deleteValue(m.db, req) return m.deleteValue(m.db, req)
} }
// deleteValue is an internal implementation of delete to enable passing the
// logic to state.DeleteWithRetries as a func.
func (m *MySQL) deleteValue(querier querier, req *state.DeleteRequest) error { func (m *MySQL) deleteValue(querier querier, req *state.DeleteRequest) error {
m.logger.Debug("Deleting state value from MySql") m.logger.Debug("Deleting state value from MySql")
@ -458,14 +455,14 @@ func (m *MySQL) Get(req *state.GetRequest) (*state.GetResponse, error) {
return &state.GetResponse{ return &state.GetResponse{
Data: data, Data: data,
ETag: ptr.Of(eTag), ETag: &eTag,
Metadata: req.Metadata, Metadata: req.Metadata,
}, nil }, nil
} }
return &state.GetResponse{ return &state.GetResponse{
Data: value, Data: value,
ETag: ptr.Of(eTag), ETag: &eTag,
Metadata: req.Metadata, Metadata: req.Metadata,
}, nil }, nil
} }
@ -476,8 +473,6 @@ func (m *MySQL) Set(req *state.SetRequest) error {
return m.setValue(m.db, req) return m.setValue(m.db, req)
} }
// setValue is an internal implementation of set to enable passing the logic
// to state.SetWithRetries as a func.
func (m *MySQL) setValue(querier querier, req *state.SetRequest) error { func (m *MySQL) setValue(querier querier, req *state.SetRequest) error {
m.logger.Debug("Setting state value in MySql") m.logger.Debug("Setting state value in MySql")

View File

@ -114,11 +114,6 @@ func (o *oracleDatabaseAccess) Init(metadata state.Metadata) error {
return nil return nil
} }
// Set makes an insert or update to the database.
func (o *oracleDatabaseAccess) Set(req *state.SetRequest) error {
return state.SetWithOptions(o.setValue, req)
}
func parseTTL(requestMetadata map[string]string) (*int, error) { func parseTTL(requestMetadata map[string]string) (*int, error) {
if val, found := requestMetadata[metadataTTLKey]; found && val != "" { if val, found := requestMetadata[metadataTTLKey]; found && val != "" {
parsedVal, err := strconv.ParseInt(val, 10, 0) parsedVal, err := strconv.ParseInt(val, 10, 0)
@ -133,8 +128,8 @@ func parseTTL(requestMetadata map[string]string) (*int, error) {
return nil, nil return nil, nil
} }
// setValue is an internal implementation of set to enable passing the logic to state.SetWithRetries as a func. // Set makes an insert or update to the database.
func (o *oracleDatabaseAccess) setValue(req *state.SetRequest) error { func (o *oracleDatabaseAccess) Set(req *state.SetRequest) error {
o.logger.Debug("Setting state value in OracleDatabase") o.logger.Debug("Setting state value in OracleDatabase")
err := state.CheckRequestOptions(req.Options) err := state.CheckRequestOptions(req.Options)
if err != nil { if err != nil {
@ -204,6 +199,7 @@ func (o *oracleDatabaseAccess) setValue(req *state.SetRequest) error {
result, err = tx.Exec(mergeStatement, req.Key, value, binaryYN, etag, ttlSeconds) result, err = tx.Exec(mergeStatement, req.Key, value, binaryYN, etag, ttlSeconds)
} else { } else {
// when first write policy is indicated, an existing record has to be updated - one that has the etag provided. // when first write policy is indicated, an existing record has to be updated - one that has the etag provided.
// TODO: Needs to update ttl_in_seconds
updateStatement := fmt.Sprintf( updateStatement := fmt.Sprintf(
`UPDATE %s SET value = :value, binary_yn = :binary_yn, etag = :new_etag `UPDATE %s SET value = :value, binary_yn = :binary_yn, etag = :new_etag
WHERE key = :key AND etag = :etag`, WHERE key = :key AND etag = :etag`,
@ -273,11 +269,6 @@ func (o *oracleDatabaseAccess) Get(req *state.GetRequest) (*state.GetResponse, e
// Delete removes an item from the state store. // Delete removes an item from the state store.
func (o *oracleDatabaseAccess) Delete(req *state.DeleteRequest) error { func (o *oracleDatabaseAccess) Delete(req *state.DeleteRequest) error {
return state.DeleteWithOptions(o.deleteValue, req)
}
// deleteValue is an internal implementation of delete to enable passing the logic to state.DeleteWithRetries as a func.
func (o *oracleDatabaseAccess) deleteValue(req *state.DeleteRequest) error {
o.logger.Debug("Deleting state value from OracleDatabase") o.logger.Debug("Deleting state value from OracleDatabase")
if req.Key == "" { if req.Key == "" {
return fmt.Errorf("missing key in delete operation") return fmt.Errorf("missing key in delete operation")
@ -354,7 +345,7 @@ func (o *oracleDatabaseAccess) ExecuteMulti(sets []state.SetRequest, deletes []s
return err return err
} }
// Close implements io.Close. // Close implements io.Closer.
func (o *oracleDatabaseAccess) Close() error { func (o *oracleDatabaseAccess) Close() error {
if o.db != nil { if o.db != nil {
return o.db.Close() return o.db.Close()
@ -391,10 +382,3 @@ func tableExists(db *sql.DB, tableName string) (bool, error) {
exists := tblCount > 0 exists := tblCount > 0
return exists, err return exists, err
} }
// func handleError(msg string, err error) {
// if err != nil {
// fmt.Println(msg, err)
// }
// }

View File

@ -17,6 +17,7 @@ import (
"database/sql" "database/sql"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"strconv" "strconv"
"time" "time"
@ -77,7 +78,7 @@ func (p *postgresDBAccess) Init(meta state.Metadata) error {
if m.ConnectionString == "" { if m.ConnectionString == "" {
p.logger.Error("Missing postgreSQL connection string") p.logger.Error("Missing postgreSQL connection string")
return fmt.Errorf(errMissingConnectionString) return errors.New(errMissingConnectionString)
} }
p.connectionString = m.ConnectionString p.connectionString = m.ConnectionString
@ -111,11 +112,6 @@ func (p *postgresDBAccess) Init(meta state.Metadata) error {
// Set makes an insert or update to the database. // Set makes an insert or update to the database.
func (p *postgresDBAccess) Set(req *state.SetRequest) error { func (p *postgresDBAccess) Set(req *state.SetRequest) error {
return state.SetWithOptions(p.setValue, req)
}
// setValue is an internal implementation of set to enable passing the logic to state.SetWithRetries as a func.
func (p *postgresDBAccess) setValue(req *state.SetRequest) error {
p.logger.Debug("Setting state value in PostgreSQL") p.logger.Debug("Setting state value in PostgreSQL")
err := state.CheckRequestOptions(req.Options) err := state.CheckRequestOptions(req.Options)
@ -124,11 +120,11 @@ func (p *postgresDBAccess) setValue(req *state.SetRequest) error {
} }
if req.Key == "" { if req.Key == "" {
return fmt.Errorf("missing key in set operation") return errors.New("missing key in set operation")
} }
if v, ok := req.Value.(string); ok && v == "" { if v, ok := req.Value.(string); ok && v == "" {
return fmt.Errorf("empty string is not allowed in set operation") return errors.New("empty string is not allowed in set operation")
} }
v := req.Value v := req.Value
@ -149,7 +145,7 @@ func (p *postgresDBAccess) setValue(req *state.SetRequest) error {
result, err = p.db.Exec(fmt.Sprintf( result, err = p.db.Exec(fmt.Sprintf(
`INSERT INTO %s (key, value, isbinary) VALUES ($1, $2, $3);`, `INSERT INTO %s (key, value, isbinary) VALUES ($1, $2, $3);`,
p.tableName), req.Key, value, isBinary) p.tableName), req.Key, value, isBinary)
} else if req.ETag == nil { } else if req.ETag == nil || *req.ETag == "" {
result, err = p.db.Exec(fmt.Sprintf( result, err = p.db.Exec(fmt.Sprintf(
`INSERT INTO %s (key, value, isbinary) VALUES ($1, $2, $3) `INSERT INTO %s (key, value, isbinary) VALUES ($1, $2, $3)
ON CONFLICT (key) DO UPDATE SET value = $2, isbinary = $3, updatedate = NOW();`, ON CONFLICT (key) DO UPDATE SET value = $2, isbinary = $3, updatedate = NOW();`,
@ -184,7 +180,7 @@ func (p *postgresDBAccess) setValue(req *state.SetRequest) error {
} }
if rows != 1 { if rows != 1 {
return fmt.Errorf("no item was updated") return errors.New("no item was updated")
} }
return nil return nil
@ -218,27 +214,30 @@ func (p *postgresDBAccess) BulkSet(req []state.SetRequest) error {
func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error) { func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error) {
p.logger.Debug("Getting state value from PostgreSQL") p.logger.Debug("Getting state value from PostgreSQL")
if req.Key == "" { if req.Key == "" {
return nil, fmt.Errorf("missing key in get operation") return nil, errors.New("missing key in get operation")
} }
var value string var (
var isBinary bool value []byte
var etag int isBinary bool
etag uint64 // Postgres uses uint32, but FormatUint requires uint64, so using uint64 directly to avoid re-allocations
)
err := p.db.QueryRow(fmt.Sprintf("SELECT value, isbinary, xmin as etag FROM %s WHERE key = $1", p.tableName), req.Key).Scan(&value, &isBinary, &etag) err := p.db.QueryRow(fmt.Sprintf("SELECT value, isbinary, xmin as etag FROM %s WHERE key = $1", p.tableName), req.Key).Scan(&value, &isBinary, &etag)
if err != nil { if err != nil {
// If no rows exist, return an empty response, otherwise return the error. // If no rows exist, return an empty response, otherwise return the error.
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return &state.GetResponse{}, nil return &state.GetResponse{}, nil
} }
return nil, err return nil, err
} }
if isBinary { if isBinary {
var s string var (
var data []byte s string
data []byte
)
if err = json.Unmarshal([]byte(value), &s); err != nil { if err = json.Unmarshal(value, &s); err != nil {
return nil, err return nil, err
} }
@ -248,34 +247,28 @@ func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error
return &state.GetResponse{ return &state.GetResponse{
Data: data, Data: data,
ETag: ptr.Of(strconv.Itoa(etag)), ETag: ptr.Of(strconv.FormatUint(etag, 10)),
Metadata: req.Metadata, Metadata: req.Metadata,
}, nil }, nil
} }
return &state.GetResponse{ return &state.GetResponse{
Data: []byte(value), Data: value,
ETag: ptr.Of(strconv.Itoa(etag)), ETag: ptr.Of(strconv.FormatUint(etag, 10)),
Metadata: req.Metadata, Metadata: req.Metadata,
}, nil }, nil
} }
// Delete removes an item from the state store. // Delete removes an item from the state store.
func (p *postgresDBAccess) Delete(req *state.DeleteRequest) error { func (p *postgresDBAccess) Delete(req *state.DeleteRequest) (err error) {
return state.DeleteWithOptions(p.deleteValue, req)
}
// deleteValue is an internal implementation of delete to enable passing the logic to state.DeleteWithRetries as a func.
func (p *postgresDBAccess) deleteValue(req *state.DeleteRequest) error {
p.logger.Debug("Deleting state value from PostgreSQL") p.logger.Debug("Deleting state value from PostgreSQL")
if req.Key == "" { if req.Key == "" {
return fmt.Errorf("missing key in delete operation") return errors.New("missing key in delete operation")
} }
var result sql.Result var result sql.Result
var err error
if req.ETag == nil { if req.ETag == nil || *req.ETag == "" {
result, err = p.db.Exec("DELETE FROM state WHERE key = $1", req.Key) result, err = p.db.Exec("DELETE FROM state WHERE key = $1", req.Key)
} else { } else {
// Convert req.ETag to uint32 for postgres XID compatibility // Convert req.ETag to uint32 for postgres XID compatibility
@ -313,12 +306,10 @@ func (p *postgresDBAccess) BulkDelete(req []state.DeleteRequest) error {
} }
if len(req) > 0 { if len(req) > 0 {
for _, d := range req { for i := range req {
da := d // Fix for gosec G601: Implicit memory aliasing in for loop. err = p.Delete(&req[i])
err = p.Delete(&da)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
return err return err
} }
} }
@ -446,11 +437,11 @@ func tableExists(db *sql.DB, tableName string) (bool, error) {
func getSet(req state.TransactionalStateOperation) (state.SetRequest, error) { func getSet(req state.TransactionalStateOperation) (state.SetRequest, error) {
setReq, ok := req.Request.(state.SetRequest) setReq, ok := req.Request.(state.SetRequest)
if !ok { if !ok {
return setReq, fmt.Errorf("expecting set request") return setReq, errors.New("expecting set request")
} }
if setReq.Key == "" { if setReq.Key == "" {
return setReq, fmt.Errorf("missing key in upsert operation") return setReq, errors.New("missing key in upsert operation")
} }
return setReq, nil return setReq, nil
@ -460,11 +451,11 @@ func getSet(req state.TransactionalStateOperation) (state.SetRequest, error) {
func getDelete(req state.TransactionalStateOperation) (state.DeleteRequest, error) { func getDelete(req state.TransactionalStateOperation) (state.DeleteRequest, error) {
delReq, ok := req.Request.(state.DeleteRequest) delReq, ok := req.Request.(state.DeleteRequest)
if !ok { if !ok {
return delReq, fmt.Errorf("expecting delete request") return delReq, errors.New("expecting delete request")
} }
if delReq.Key == "" { if delReq.Key == "" {
return delReq, fmt.Errorf("missing key in upsert operation") return delReq, errors.New("missing key in upsert operation")
} }
return delReq, nil return delReq, nil

View File

@ -88,9 +88,9 @@ func (q *Query) visitFilters(op string, filters []query.Filter) (string, error)
} }
} }
sep := fmt.Sprintf(" %s ", op) sep := " " + op + " "
return fmt.Sprintf("(%s)", strings.Join(arr, sep)), nil return "(" + strings.Join(arr, sep) + ")", nil
} }
func (q *Query) VisitAND(f *query.AND) (string, error) { func (q *Query) VisitAND(f *query.AND) (string, error) {
@ -102,10 +102,10 @@ func (q *Query) VisitOR(f *query.OR) (string, error) {
} }
func (q *Query) Finalize(filters string, qq *query.Query) error { func (q *Query) Finalize(filters string, qq *query.Query) error {
q.query = fmt.Sprintf("SELECT key, value, xmin as etag FROM %s", q.tableName) q.query = "SELECT key, value, xmin as etag FROM " + q.tableName
if filters != "" { if filters != "" {
q.query += fmt.Sprintf(" WHERE %s", filters) q.query += " WHERE " + filters
} }
if len(qq.Sort) > 0 { if len(qq.Sort) > 0 {
@ -117,13 +117,13 @@ func (q *Query) Finalize(filters string, qq *query.Query) error {
} }
q.query += translateFieldToFilter(sortItem.Key) q.query += translateFieldToFilter(sortItem.Key)
if sortItem.Order != "" { if sortItem.Order != "" {
q.query += fmt.Sprintf(" %s", sortItem.Order) q.query += " " + sortItem.Order
} }
} }
} }
if qq.Page.Limit > 0 { if qq.Page.Limit > 0 {
q.query += fmt.Sprintf(" LIMIT %d", qq.Page.Limit) q.query += " LIMIT " + strconv.Itoa(qq.Page.Limit)
q.limit = qq.Page.Limit q.limit = qq.Page.Limit
} }
@ -132,7 +132,7 @@ func (q *Query) Finalize(filters string, qq *query.Query) error {
if err != nil { if err != nil {
return err return err
} }
q.query += fmt.Sprintf(" OFFSET %d", skip) q.query += " OFFSET " + strconv.FormatInt(skip, 10)
q.skip = &skip q.skip = &skip
} }
@ -151,7 +151,7 @@ func (q *Query) execute(logger logger.Logger, db *sql.DB) ([]state.QueryItem, st
var ( var (
key string key string
data []byte data []byte
etag int etag uint64 // Postgres uses uint32, but FormatUint requires uint64, so using uint64 directly to avoid re-allocations
) )
if err = rows.Scan(&key, &data, &etag); err != nil { if err = rows.Scan(&key, &data, &etag); err != nil {
return nil, "", err return nil, "", err
@ -159,7 +159,7 @@ func (q *Query) execute(logger logger.Logger, db *sql.DB) ([]state.QueryItem, st
result := state.QueryItem{ result := state.QueryItem{
Key: key, Key: key,
Data: data, Data: data,
ETag: ptr.Of(strconv.Itoa(etag)), ETag: ptr.Of(strconv.FormatUint(etag, 10)),
} }
ret = append(ret, result) ret = append(ret, result)
} }
@ -200,7 +200,7 @@ func translateFieldToFilter(key string) string {
filterField += ">" filterField += ">"
} }
filterField += fmt.Sprintf("'%s'", fieldPart) filterField += "'" + fieldPart + "'"
} }
return filterField return filterField
@ -209,6 +209,6 @@ func translateFieldToFilter(key string) string {
func (q *Query) whereFieldEqual(key string, value interface{}) string { func (q *Query) whereFieldEqual(key string, value interface{}) string {
position := q.addParamValueAndReturnPosition(value) position := q.addParamValueAndReturnPosition(value)
filterField := translateFieldToFilter(key) filterField := translateFieldToFilter(key)
query := fmt.Sprintf("%s=$%v", filterField, position) query := filterField + "=$" + strconv.Itoa(position)
return query return query
} }

View File

@ -196,7 +196,13 @@ func (r *StateStore) parseConnectedSlaves(res string) int {
return 0 return 0
} }
func (r *StateStore) deleteValue(req *state.DeleteRequest) error { // Delete performs a delete operation.
func (r *StateStore) Delete(req *state.DeleteRequest) error {
err := state.CheckRequestOptions(req.Options)
if err != nil {
return err
}
if req.ETag == nil { if req.ETag == nil {
etag := "0" etag := "0"
req.ETag = &etag req.ETag = &etag
@ -208,7 +214,7 @@ func (r *StateStore) deleteValue(req *state.DeleteRequest) error {
} else { } else {
delQuery = delDefaultQuery delQuery = delDefaultQuery
} }
_, err := r.client.Do(r.ctx, "EVAL", delQuery, 1, req.Key, *req.ETag).Result() _, err = r.client.Do(r.ctx, "EVAL", delQuery, 1, req.Key, *req.ETag).Result()
if err != nil { if err != nil {
return state.NewETagError(state.ETagMismatch, err) return state.NewETagError(state.ETagMismatch, err)
} }
@ -216,16 +222,6 @@ func (r *StateStore) deleteValue(req *state.DeleteRequest) error {
return nil return nil
} }
// Delete performs a delete operation.
func (r *StateStore) Delete(req *state.DeleteRequest) error {
err := state.CheckRequestOptions(req.Options)
if err != nil {
return err
}
return state.DeleteWithOptions(r.deleteValue, req)
}
func (r *StateStore) directGet(req *state.GetRequest) (*state.GetResponse, error) { func (r *StateStore) directGet(req *state.GetRequest) (*state.GetResponse, error) {
res, err := r.client.Do(r.ctx, "GET", req.Key).Result() res, err := r.client.Do(r.ctx, "GET", req.Key).Result()
if err != nil { if err != nil {
@ -318,7 +314,8 @@ type jsonEntry struct {
Version *int `json:"version,omitempty"` Version *int `json:"version,omitempty"`
} }
func (r *StateStore) setValue(req *state.SetRequest) error { // Set saves state into redis.
func (r *StateStore) Set(req *state.SetRequest) error {
err := state.CheckRequestOptions(req.Options) err := state.CheckRequestOptions(req.Options)
if err != nil { if err != nil {
return err return err
@ -384,11 +381,6 @@ func (r *StateStore) setValue(req *state.SetRequest) error {
return nil return nil
} }
// Set saves state into redis.
func (r *StateStore) Set(req *state.SetRequest) error {
return state.SetWithOptions(r.setValue, req)
}
// Multi performs a transactional operation. succeeds only if all operations succeed, and fails if one or more operations fail. // Multi performs a transactional operation. succeeds only if all operations succeed, and fails if one or more operations fail.
func (r *StateStore) Multi(request *state.TransactionalStateRequest) error { func (r *StateStore) Multi(request *state.TransactionalStateRequest) error {
var setQuery, delQuery string var setQuery, delQuery string

View File

@ -66,13 +66,3 @@ func validateConsistencyOption(c string) error {
return nil return nil
} }
// SetWithOptions handles SetRequest with request options.
func SetWithOptions(method func(req *SetRequest) error, req *SetRequest) error {
return method(req)
}
// DeleteWithOptions handles DeleteRequest with options.
func DeleteWithOptions(method func(req *DeleteRequest) error, req *DeleteRequest) error {
return method(req)
}

View File

@ -19,31 +19,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
// TestSetRequestWithOptions is used to test request options.
func TestSetRequestWithOptions(t *testing.T) {
t.Run("set with default options", func(t *testing.T) {
counter := 0
SetWithOptions(func(req *SetRequest) error {
counter++
return nil
}, &SetRequest{})
assert.Equal(t, 1, counter, "should execute only once")
})
t.Run("set with no explicit options", func(t *testing.T) {
counter := 0
SetWithOptions(func(req *SetRequest) error {
counter++
return nil
}, &SetRequest{
Options: SetStateOption{},
})
assert.Equal(t, 1, counter, "should execute only once")
})
}
// TestCheckRequestOptions is used to validate request options. // TestCheckRequestOptions is used to validate request options.
func TestCheckRequestOptions(t *testing.T) { func TestCheckRequestOptions(t *testing.T) {
t.Run("set state options", func(t *testing.T) { t.Run("set state options", func(t *testing.T) {

View File

@ -187,8 +187,7 @@ func (s *StateStore) Delete(req *state.DeleteRequest) error {
return err return err
} }
return state.DeleteWithOptions(func(req *state.DeleteRequest) error { err = s.conn.Delete(r.Path, r.Version)
err := s.conn.Delete(r.Path, r.Version)
if errors.Is(err, zk.ErrNoNode) { if errors.Is(err, zk.ErrNoNode) {
return nil return nil
} }
@ -202,7 +201,6 @@ func (s *StateStore) Delete(req *state.DeleteRequest) error {
} }
return nil return nil
}, req)
} }
// BulkDelete performs a bulk delete operation. // BulkDelete performs a bulk delete operation.
@ -239,9 +237,7 @@ func (s *StateStore) Set(req *state.SetRequest) error {
return err return err
} }
return state.SetWithOptions(func(req *state.SetRequest) error {
_, err = s.conn.Set(r.Path, r.Data, r.Version) _, err = s.conn.Set(r.Path, r.Data, r.Version)
if errors.Is(err, zk.ErrNoNode) { if errors.Is(err, zk.ErrNoNode) {
_, err = s.conn.Create(r.Path, r.Data, 0, nil) _, err = s.conn.Create(r.Path, r.Data, 0, nil)
} }
@ -255,7 +251,6 @@ func (s *StateStore) Set(req *state.SetRequest) error {
} }
return nil return nil
}, req)
} }
// BulkSet performs a bulks save operation. // BulkSet performs a bulks save operation.