diff --git a/bindings/rabbitmq/rabbitmq_integration_test.go b/bindings/rabbitmq/rabbitmq_integration_test.go index e2714c9d3..e3939bbce 100644 --- a/bindings/rabbitmq/rabbitmq_integration_test.go +++ b/bindings/rabbitmq/rabbitmq_integration_test.go @@ -130,12 +130,14 @@ func TestPublishingWithTTL(t *testing.T) { const maxGetDuration = ttlInSeconds * time.Second metadata := bindings.Metadata{ - Name: "testQueue", - Properties: map[string]string{ - "queueName": queueName, - "host": rabbitmqHost, - "deleteWhenUnused": strconv.FormatBool(exclusive), - "durable": strconv.FormatBool(durable), + Base: contribMetadata.Base{ + Name: "testQueue", + Properties: map[string]string{ + "queueName": queueName, + "host": rabbitmqHost, + "deleteWhenUnused": strconv.FormatBool(exclusive), + "durable": strconv.FormatBool(durable), + }, }, } @@ -162,7 +164,7 @@ func TestPublishingWithTTL(t *testing.T) { }, } - _, err = rabbitMQBinding1.Invoke(context.Backgound(), &writeRequest) + _, err = rabbitMQBinding1.Invoke(context.Background(), &writeRequest) assert.Nil(t, err) time.Sleep(time.Second + (ttlInSeconds * time.Second)) @@ -183,7 +185,7 @@ func TestPublishingWithTTL(t *testing.T) { contribMetadata.TTLMetadataKey: strconv.Itoa(ttlInSeconds * 1000), }, } - _, err = rabbitMQBinding2.Invoke(context.Backgound(), &writeRequest) + _, err = rabbitMQBinding2.Invoke(context.Background(), &writeRequest) assert.Nil(t, err) msg, ok, err := getMessageWithRetries(ch, queueName, maxGetDuration) @@ -204,14 +206,16 @@ func TestExclusiveQueue(t *testing.T) { const maxGetDuration = ttlInSeconds * time.Second metadata := bindings.Metadata{ - Name: "testQueue", - Properties: map[string]string{ - "queueName": queueName, - "host": rabbitmqHost, - "deleteWhenUnused": strconv.FormatBool(exclusive), - "durable": strconv.FormatBool(durable), - "exclusive": strconv.FormatBool(exclusive), - contribMetadata.TTLMetadataKey: strconv.FormatInt(ttlInSeconds, 10), + Base: contribMetadata.Base{ + Name: "testQueue", + Properties: map[string]string{ + "queueName": queueName, + "host": rabbitmqHost, + "deleteWhenUnused": strconv.FormatBool(exclusive), + "durable": strconv.FormatBool(durable), + "exclusive": strconv.FormatBool(exclusive), + contribMetadata.TTLMetadataKey: strconv.FormatInt(ttlInSeconds, 10), + }, }, } @@ -257,13 +261,15 @@ func TestPublishWithPriority(t *testing.T) { const maxPriority = 10 metadata := bindings.Metadata{ - Name: "testQueue", - Properties: map[string]string{ - "queueName": queueName, - "host": rabbitmqHost, - "deleteWhenUnused": strconv.FormatBool(exclusive), - "durable": strconv.FormatBool(durable), - "maxPriority": strconv.FormatInt(maxPriority, 10), + Base: contribMetadata.Base{ + Name: "testQueue", + Properties: map[string]string{ + "queueName": queueName, + "host": rabbitmqHost, + "deleteWhenUnused": strconv.FormatBool(exclusive), + "durable": strconv.FormatBool(durable), + "maxPriority": strconv.FormatInt(maxPriority, 10), + }, }, } @@ -283,7 +289,7 @@ func TestPublishWithPriority(t *testing.T) { defer ch.Close() const middlePriorityMsgContent = "middle" - _, err = r.Invoke(context.Backgound(), &bindings.InvokeRequest{ + _, err = r.Invoke(context.Background(), &bindings.InvokeRequest{ Metadata: map[string]string{ contribMetadata.PriorityMetadataKey: "5", }, @@ -292,7 +298,7 @@ func TestPublishWithPriority(t *testing.T) { assert.Nil(t, err) const lowPriorityMsgContent = "low" - _, err = r.Invoke(context.Backgound(), &bindings.InvokeRequest{ + _, err = r.Invoke(context.Background(), &bindings.InvokeRequest{ Metadata: map[string]string{ contribMetadata.PriorityMetadataKey: "1", }, @@ -301,7 +307,7 @@ func TestPublishWithPriority(t *testing.T) { assert.Nil(t, err) const highPriorityMsgContent = "high" - _, err = r.Invoke(context.Backgound(), &bindings.InvokeRequest{ + _, err = r.Invoke(context.Background(), &bindings.InvokeRequest{ Metadata: map[string]string{ contribMetadata.PriorityMetadataKey: "10", }, diff --git a/pubsub/azure/eventhubs/eventhubs_integration_test.go b/pubsub/azure/eventhubs/eventhubs_integration_test.go index dddf66f5a..077e91ac9 100644 --- a/pubsub/azure/eventhubs/eventhubs_integration_test.go +++ b/pubsub/azure/eventhubs/eventhubs_integration_test.go @@ -52,11 +52,11 @@ func createIotHubPubsubMetadata() pubsub.Metadata { metadata := pubsub.Metadata{ Base: metadata.Base{ Properties: map[string]string{ - connectionString: os.Getenv(iotHubConnectionStringEnvKey), - consumerID: os.Getenv(iotHubConsumerGroupEnvKey), - storageAccountName: os.Getenv(storageAccountNameEnvKey), - storageAccountKey: os.Getenv(storageAccountKeyEnvKey), - storageContainerName: testStorageContainerName, + "connectionString": os.Getenv(iotHubConnectionStringEnvKey), + "consumerID": os.Getenv(iotHubConsumerGroupEnvKey), + "storageAccountName": os.Getenv(storageAccountNameEnvKey), + "storageAccountKey": os.Getenv(storageAccountKeyEnvKey), + "storageContainerName": testStorageContainerName, }, }, } diff --git a/state/cockroachdb/cockroachdb_access.go b/state/cockroachdb/cockroachdb_access.go index 84c88702c..6fdb454fc 100644 --- a/state/cockroachdb/cockroachdb_access.go +++ b/state/cockroachdb/cockroachdb_access.go @@ -114,11 +114,6 @@ func (p *cockroachDBAccess) Init(metadata state.Metadata) error { // Set makes an insert or update to the database. func (p *cockroachDBAccess) Set(req *state.SetRequest) error { - return state.SetWithOptions(p.setValue, req) -} - -// setValue is an internal implementation of set to enable passing the logic to state.SetWithRetries as a func. -func (p *cockroachDBAccess) setValue(req *state.SetRequest) error { p.logger.Debug("Setting state value in CockroachDB") value, isBinary, err := validateAndReturnValue(req) @@ -240,11 +235,6 @@ func (p *cockroachDBAccess) Get(req *state.GetRequest) (*state.GetResponse, erro // Delete removes an item from the state store. func (p *cockroachDBAccess) Delete(req *state.DeleteRequest) error { - return state.DeleteWithOptions(p.deleteValue, req) -} - -// deleteValue is an internal implementation of delete to enable passing the logic to state.DeleteWithRetries as a func. -func (p *cockroachDBAccess) deleteValue(req *state.DeleteRequest) error { p.logger.Debug("Deleting state value from CockroachDB") if req.Key == "" { return fmt.Errorf("missing key in delete operation") diff --git a/state/gcp/firestore/firestore.go b/state/gcp/firestore/firestore.go index 5a36173bd..17c3ef9bf 100644 --- a/state/gcp/firestore/firestore.go +++ b/state/gcp/firestore/firestore.go @@ -113,7 +113,8 @@ func (f *Firestore) Get(req *state.GetRequest) (*state.GetResponse, error) { }, nil } -func (f *Firestore) setValue(req *state.SetRequest) error { +// Set saves state into Firestore. +func (f *Firestore) Set(req *state.SetRequest) error { err := state.CheckRequestOptions(req.Options) if err != nil { return err @@ -142,12 +143,8 @@ func (f *Firestore) setValue(req *state.SetRequest) error { return nil } -// Set saves state into Firestore with retry. -func (f *Firestore) Set(req *state.SetRequest) error { - return state.SetWithOptions(f.setValue, req) -} - -func (f *Firestore) deleteValue(req *state.DeleteRequest) error { +// Delete performs a delete operation. +func (f *Firestore) Delete(req *state.DeleteRequest) error { ctx := context.Background() key := datastore.NameKey(f.entityKind, req.Key, nil) @@ -159,11 +156,6 @@ func (f *Firestore) deleteValue(req *state.DeleteRequest) error { return nil } -// Delete performs a delete operation. -func (f *Firestore) Delete(req *state.DeleteRequest) error { - return state.DeleteWithOptions(f.deleteValue, req) -} - func getFirestoreMetadata(meta state.Metadata) (*firestoreMetadata, error) { m := firestoreMetadata{ EntityKind: defaultEntityKind, diff --git a/state/memcached/memcached.go b/state/memcached/memcached.go index 486b20aa5..e04811a40 100644 --- a/state/memcached/memcached.go +++ b/state/memcached/memcached.go @@ -146,7 +146,7 @@ func (m *Memcached) parseTTL(req *state.SetRequest) (*int32, error) { return nil, nil } -func (m *Memcached) setValue(req *state.SetRequest) error { +func (m *Memcached) Set(req *state.SetRequest) error { var bt []byte ttl, err := m.parseTTL(req) if err != nil { @@ -194,10 +194,6 @@ func (m *Memcached) Get(req *state.GetRequest) (*state.GetResponse, error) { }, nil } -func (m *Memcached) Set(req *state.SetRequest) error { - return state.SetWithOptions(m.setValue, req) -} - func (m *Memcached) GetComponentMetadata() map[string]string { metadataStruct := memcachedMetadata{} metadataInfo := map[string]string{} diff --git a/state/mysql/mysql.go b/state/mysql/mysql.go index 654f42ac0..628812d19 100644 --- a/state/mysql/mysql.go +++ b/state/mysql/mysql.go @@ -30,7 +30,6 @@ import ( "github.com/dapr/components-contrib/metadata" "github.com/dapr/components-contrib/state" "github.com/dapr/kit/logger" - "github.com/dapr/kit/ptr" ) // Optimistic Concurrency is implemented using a string column that stores a UUID. @@ -337,8 +336,6 @@ func (m *MySQL) Delete(req *state.DeleteRequest) error { return m.deleteValue(m.db, req) } -// deleteValue is an internal implementation of delete to enable passing the -// logic to state.DeleteWithRetries as a func. func (m *MySQL) deleteValue(querier querier, req *state.DeleteRequest) error { m.logger.Debug("Deleting state value from MySql") @@ -458,14 +455,14 @@ func (m *MySQL) Get(req *state.GetRequest) (*state.GetResponse, error) { return &state.GetResponse{ Data: data, - ETag: ptr.Of(eTag), + ETag: &eTag, Metadata: req.Metadata, }, nil } return &state.GetResponse{ Data: value, - ETag: ptr.Of(eTag), + ETag: &eTag, Metadata: req.Metadata, }, nil } @@ -476,8 +473,6 @@ func (m *MySQL) Set(req *state.SetRequest) error { return m.setValue(m.db, req) } -// setValue is an internal implementation of set to enable passing the logic -// to state.SetWithRetries as a func. func (m *MySQL) setValue(querier querier, req *state.SetRequest) error { m.logger.Debug("Setting state value in MySql") diff --git a/state/oracledatabase/oracledatabaseaccess.go b/state/oracledatabase/oracledatabaseaccess.go index 5446959b7..2e395cb46 100644 --- a/state/oracledatabase/oracledatabaseaccess.go +++ b/state/oracledatabase/oracledatabaseaccess.go @@ -114,11 +114,6 @@ func (o *oracleDatabaseAccess) Init(metadata state.Metadata) error { return nil } -// Set makes an insert or update to the database. -func (o *oracleDatabaseAccess) Set(req *state.SetRequest) error { - return state.SetWithOptions(o.setValue, req) -} - func parseTTL(requestMetadata map[string]string) (*int, error) { if val, found := requestMetadata[metadataTTLKey]; found && val != "" { parsedVal, err := strconv.ParseInt(val, 10, 0) @@ -133,8 +128,8 @@ func parseTTL(requestMetadata map[string]string) (*int, error) { return nil, nil } -// setValue is an internal implementation of set to enable passing the logic to state.SetWithRetries as a func. -func (o *oracleDatabaseAccess) setValue(req *state.SetRequest) error { +// Set makes an insert or update to the database. +func (o *oracleDatabaseAccess) Set(req *state.SetRequest) error { o.logger.Debug("Setting state value in OracleDatabase") err := state.CheckRequestOptions(req.Options) if err != nil { @@ -204,6 +199,7 @@ func (o *oracleDatabaseAccess) setValue(req *state.SetRequest) error { result, err = tx.Exec(mergeStatement, req.Key, value, binaryYN, etag, ttlSeconds) } else { // when first write policy is indicated, an existing record has to be updated - one that has the etag provided. + // TODO: Needs to update ttl_in_seconds updateStatement := fmt.Sprintf( `UPDATE %s SET value = :value, binary_yn = :binary_yn, etag = :new_etag WHERE key = :key AND etag = :etag`, @@ -273,11 +269,6 @@ func (o *oracleDatabaseAccess) Get(req *state.GetRequest) (*state.GetResponse, e // Delete removes an item from the state store. func (o *oracleDatabaseAccess) Delete(req *state.DeleteRequest) error { - return state.DeleteWithOptions(o.deleteValue, req) -} - -// deleteValue is an internal implementation of delete to enable passing the logic to state.DeleteWithRetries as a func. -func (o *oracleDatabaseAccess) deleteValue(req *state.DeleteRequest) error { o.logger.Debug("Deleting state value from OracleDatabase") if req.Key == "" { return fmt.Errorf("missing key in delete operation") @@ -354,7 +345,7 @@ func (o *oracleDatabaseAccess) ExecuteMulti(sets []state.SetRequest, deletes []s return err } -// Close implements io.Close. +// Close implements io.Closer. func (o *oracleDatabaseAccess) Close() error { if o.db != nil { return o.db.Close() @@ -391,10 +382,3 @@ func tableExists(db *sql.DB, tableName string) (bool, error) { exists := tblCount > 0 return exists, err } - -// func handleError(msg string, err error) { -// if err != nil { -// fmt.Println(msg, err) - -// } -// } diff --git a/state/postgresql/postgresdbaccess.go b/state/postgresql/postgresdbaccess.go index 1e0267c62..53846066a 100644 --- a/state/postgresql/postgresdbaccess.go +++ b/state/postgresql/postgresdbaccess.go @@ -17,6 +17,7 @@ import ( "database/sql" "encoding/base64" "encoding/json" + "errors" "fmt" "strconv" "time" @@ -77,7 +78,7 @@ func (p *postgresDBAccess) Init(meta state.Metadata) error { if m.ConnectionString == "" { p.logger.Error("Missing postgreSQL connection string") - return fmt.Errorf(errMissingConnectionString) + return errors.New(errMissingConnectionString) } p.connectionString = m.ConnectionString @@ -111,11 +112,6 @@ func (p *postgresDBAccess) Init(meta state.Metadata) error { // Set makes an insert or update to the database. func (p *postgresDBAccess) Set(req *state.SetRequest) error { - return state.SetWithOptions(p.setValue, req) -} - -// setValue is an internal implementation of set to enable passing the logic to state.SetWithRetries as a func. -func (p *postgresDBAccess) setValue(req *state.SetRequest) error { p.logger.Debug("Setting state value in PostgreSQL") err := state.CheckRequestOptions(req.Options) @@ -124,11 +120,11 @@ func (p *postgresDBAccess) setValue(req *state.SetRequest) error { } if req.Key == "" { - return fmt.Errorf("missing key in set operation") + return errors.New("missing key in set operation") } if v, ok := req.Value.(string); ok && v == "" { - return fmt.Errorf("empty string is not allowed in set operation") + return errors.New("empty string is not allowed in set operation") } v := req.Value @@ -149,7 +145,7 @@ func (p *postgresDBAccess) setValue(req *state.SetRequest) error { result, err = p.db.Exec(fmt.Sprintf( `INSERT INTO %s (key, value, isbinary) VALUES ($1, $2, $3);`, p.tableName), req.Key, value, isBinary) - } else if req.ETag == nil { + } else if req.ETag == nil || *req.ETag == "" { result, err = p.db.Exec(fmt.Sprintf( `INSERT INTO %s (key, value, isbinary) VALUES ($1, $2, $3) ON CONFLICT (key) DO UPDATE SET value = $2, isbinary = $3, updatedate = NOW();`, @@ -184,7 +180,7 @@ func (p *postgresDBAccess) setValue(req *state.SetRequest) error { } if rows != 1 { - return fmt.Errorf("no item was updated") + return errors.New("no item was updated") } return nil @@ -218,27 +214,30 @@ func (p *postgresDBAccess) BulkSet(req []state.SetRequest) error { func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error) { p.logger.Debug("Getting state value from PostgreSQL") if req.Key == "" { - return nil, fmt.Errorf("missing key in get operation") + return nil, errors.New("missing key in get operation") } - var value string - var isBinary bool - var etag int + var ( + value []byte + isBinary bool + etag uint64 // Postgres uses uint32, but FormatUint requires uint64, so using uint64 directly to avoid re-allocations + ) err := p.db.QueryRow(fmt.Sprintf("SELECT value, isbinary, xmin as etag FROM %s WHERE key = $1", p.tableName), req.Key).Scan(&value, &isBinary, &etag) if err != nil { // If no rows exist, return an empty response, otherwise return the error. if err == sql.ErrNoRows { return &state.GetResponse{}, nil } - return nil, err } if isBinary { - var s string - var data []byte + var ( + s string + data []byte + ) - if err = json.Unmarshal([]byte(value), &s); err != nil { + if err = json.Unmarshal(value, &s); err != nil { return nil, err } @@ -248,34 +247,28 @@ func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error return &state.GetResponse{ Data: data, - ETag: ptr.Of(strconv.Itoa(etag)), + ETag: ptr.Of(strconv.FormatUint(etag, 10)), Metadata: req.Metadata, }, nil } return &state.GetResponse{ - Data: []byte(value), - ETag: ptr.Of(strconv.Itoa(etag)), + Data: value, + ETag: ptr.Of(strconv.FormatUint(etag, 10)), Metadata: req.Metadata, }, nil } // Delete removes an item from the state store. -func (p *postgresDBAccess) Delete(req *state.DeleteRequest) error { - return state.DeleteWithOptions(p.deleteValue, req) -} - -// deleteValue is an internal implementation of delete to enable passing the logic to state.DeleteWithRetries as a func. -func (p *postgresDBAccess) deleteValue(req *state.DeleteRequest) error { +func (p *postgresDBAccess) Delete(req *state.DeleteRequest) (err error) { p.logger.Debug("Deleting state value from PostgreSQL") if req.Key == "" { - return fmt.Errorf("missing key in delete operation") + return errors.New("missing key in delete operation") } var result sql.Result - var err error - if req.ETag == nil { + if req.ETag == nil || *req.ETag == "" { result, err = p.db.Exec("DELETE FROM state WHERE key = $1", req.Key) } else { // Convert req.ETag to uint32 for postgres XID compatibility @@ -313,12 +306,10 @@ func (p *postgresDBAccess) BulkDelete(req []state.DeleteRequest) error { } if len(req) > 0 { - for _, d := range req { - da := d // Fix for gosec G601: Implicit memory aliasing in for loop. - err = p.Delete(&da) + for i := range req { + err = p.Delete(&req[i]) if err != nil { tx.Rollback() - return err } } @@ -446,11 +437,11 @@ func tableExists(db *sql.DB, tableName string) (bool, error) { func getSet(req state.TransactionalStateOperation) (state.SetRequest, error) { setReq, ok := req.Request.(state.SetRequest) if !ok { - return setReq, fmt.Errorf("expecting set request") + return setReq, errors.New("expecting set request") } if setReq.Key == "" { - return setReq, fmt.Errorf("missing key in upsert operation") + return setReq, errors.New("missing key in upsert operation") } return setReq, nil @@ -460,11 +451,11 @@ func getSet(req state.TransactionalStateOperation) (state.SetRequest, error) { func getDelete(req state.TransactionalStateOperation) (state.DeleteRequest, error) { delReq, ok := req.Request.(state.DeleteRequest) if !ok { - return delReq, fmt.Errorf("expecting delete request") + return delReq, errors.New("expecting delete request") } if delReq.Key == "" { - return delReq, fmt.Errorf("missing key in upsert operation") + return delReq, errors.New("missing key in upsert operation") } return delReq, nil diff --git a/state/postgresql/postgresql_query.go b/state/postgresql/postgresql_query.go index ca113ff09..30e287c96 100644 --- a/state/postgresql/postgresql_query.go +++ b/state/postgresql/postgresql_query.go @@ -88,9 +88,9 @@ func (q *Query) visitFilters(op string, filters []query.Filter) (string, error) } } - sep := fmt.Sprintf(" %s ", op) + sep := " " + op + " " - return fmt.Sprintf("(%s)", strings.Join(arr, sep)), nil + return "(" + strings.Join(arr, sep) + ")", nil } func (q *Query) VisitAND(f *query.AND) (string, error) { @@ -102,10 +102,10 @@ func (q *Query) VisitOR(f *query.OR) (string, error) { } func (q *Query) Finalize(filters string, qq *query.Query) error { - q.query = fmt.Sprintf("SELECT key, value, xmin as etag FROM %s", q.tableName) + q.query = "SELECT key, value, xmin as etag FROM " + q.tableName if filters != "" { - q.query += fmt.Sprintf(" WHERE %s", filters) + q.query += " WHERE " + filters } if len(qq.Sort) > 0 { @@ -117,13 +117,13 @@ func (q *Query) Finalize(filters string, qq *query.Query) error { } q.query += translateFieldToFilter(sortItem.Key) if sortItem.Order != "" { - q.query += fmt.Sprintf(" %s", sortItem.Order) + q.query += " " + sortItem.Order } } } if qq.Page.Limit > 0 { - q.query += fmt.Sprintf(" LIMIT %d", qq.Page.Limit) + q.query += " LIMIT " + strconv.Itoa(qq.Page.Limit) q.limit = qq.Page.Limit } @@ -132,7 +132,7 @@ func (q *Query) Finalize(filters string, qq *query.Query) error { if err != nil { return err } - q.query += fmt.Sprintf(" OFFSET %d", skip) + q.query += " OFFSET " + strconv.FormatInt(skip, 10) q.skip = &skip } @@ -151,7 +151,7 @@ func (q *Query) execute(logger logger.Logger, db *sql.DB) ([]state.QueryItem, st var ( key string data []byte - etag int + etag uint64 // Postgres uses uint32, but FormatUint requires uint64, so using uint64 directly to avoid re-allocations ) if err = rows.Scan(&key, &data, &etag); err != nil { return nil, "", err @@ -159,7 +159,7 @@ func (q *Query) execute(logger logger.Logger, db *sql.DB) ([]state.QueryItem, st result := state.QueryItem{ Key: key, Data: data, - ETag: ptr.Of(strconv.Itoa(etag)), + ETag: ptr.Of(strconv.FormatUint(etag, 10)), } ret = append(ret, result) } @@ -200,7 +200,7 @@ func translateFieldToFilter(key string) string { filterField += ">" } - filterField += fmt.Sprintf("'%s'", fieldPart) + filterField += "'" + fieldPart + "'" } return filterField @@ -209,6 +209,6 @@ func translateFieldToFilter(key string) string { func (q *Query) whereFieldEqual(key string, value interface{}) string { position := q.addParamValueAndReturnPosition(value) filterField := translateFieldToFilter(key) - query := fmt.Sprintf("%s=$%v", filterField, position) + query := filterField + "=$" + strconv.Itoa(position) return query } diff --git a/state/redis/redis.go b/state/redis/redis.go index ae676737c..a3120a1fd 100644 --- a/state/redis/redis.go +++ b/state/redis/redis.go @@ -196,7 +196,13 @@ func (r *StateStore) parseConnectedSlaves(res string) int { return 0 } -func (r *StateStore) deleteValue(req *state.DeleteRequest) error { +// Delete performs a delete operation. +func (r *StateStore) Delete(req *state.DeleteRequest) error { + err := state.CheckRequestOptions(req.Options) + if err != nil { + return err + } + if req.ETag == nil { etag := "0" req.ETag = &etag @@ -208,7 +214,7 @@ func (r *StateStore) deleteValue(req *state.DeleteRequest) error { } else { delQuery = delDefaultQuery } - _, err := r.client.Do(r.ctx, "EVAL", delQuery, 1, req.Key, *req.ETag).Result() + _, err = r.client.Do(r.ctx, "EVAL", delQuery, 1, req.Key, *req.ETag).Result() if err != nil { return state.NewETagError(state.ETagMismatch, err) } @@ -216,16 +222,6 @@ func (r *StateStore) deleteValue(req *state.DeleteRequest) error { return nil } -// Delete performs a delete operation. -func (r *StateStore) Delete(req *state.DeleteRequest) error { - err := state.CheckRequestOptions(req.Options) - if err != nil { - return err - } - - return state.DeleteWithOptions(r.deleteValue, req) -} - func (r *StateStore) directGet(req *state.GetRequest) (*state.GetResponse, error) { res, err := r.client.Do(r.ctx, "GET", req.Key).Result() if err != nil { @@ -318,7 +314,8 @@ type jsonEntry struct { Version *int `json:"version,omitempty"` } -func (r *StateStore) setValue(req *state.SetRequest) error { +// Set saves state into redis. +func (r *StateStore) Set(req *state.SetRequest) error { err := state.CheckRequestOptions(req.Options) if err != nil { return err @@ -384,11 +381,6 @@ func (r *StateStore) setValue(req *state.SetRequest) error { return nil } -// Set saves state into redis. -func (r *StateStore) Set(req *state.SetRequest) error { - return state.SetWithOptions(r.setValue, req) -} - // Multi performs a transactional operation. succeeds only if all operations succeed, and fails if one or more operations fail. func (r *StateStore) Multi(request *state.TransactionalStateRequest) error { var setQuery, delQuery string diff --git a/state/request_options.go b/state/request_options.go index 23ed85cb3..140783244 100644 --- a/state/request_options.go +++ b/state/request_options.go @@ -66,13 +66,3 @@ func validateConsistencyOption(c string) error { return nil } - -// SetWithOptions handles SetRequest with request options. -func SetWithOptions(method func(req *SetRequest) error, req *SetRequest) error { - return method(req) -} - -// DeleteWithOptions handles DeleteRequest with options. -func DeleteWithOptions(method func(req *DeleteRequest) error, req *DeleteRequest) error { - return method(req) -} diff --git a/state/request_options_test.go b/state/request_options_test.go index 2bf8e72f3..2853fea6b 100644 --- a/state/request_options_test.go +++ b/state/request_options_test.go @@ -19,31 +19,6 @@ import ( "github.com/stretchr/testify/assert" ) -// TestSetRequestWithOptions is used to test request options. -func TestSetRequestWithOptions(t *testing.T) { - t.Run("set with default options", func(t *testing.T) { - counter := 0 - SetWithOptions(func(req *SetRequest) error { - counter++ - - return nil - }, &SetRequest{}) - assert.Equal(t, 1, counter, "should execute only once") - }) - - t.Run("set with no explicit options", func(t *testing.T) { - counter := 0 - SetWithOptions(func(req *SetRequest) error { - counter++ - - return nil - }, &SetRequest{ - Options: SetStateOption{}, - }) - assert.Equal(t, 1, counter, "should execute only once") - }) -} - // TestCheckRequestOptions is used to validate request options. func TestCheckRequestOptions(t *testing.T) { t.Run("set state options", func(t *testing.T) { diff --git a/state/zookeeper/zk.go b/state/zookeeper/zk.go index 8b2935148..280dbf09b 100644 --- a/state/zookeeper/zk.go +++ b/state/zookeeper/zk.go @@ -187,22 +187,20 @@ func (s *StateStore) Delete(req *state.DeleteRequest) error { return err } - return state.DeleteWithOptions(func(req *state.DeleteRequest) error { - err := s.conn.Delete(r.Path, r.Version) - if errors.Is(err, zk.ErrNoNode) { - return nil - } - - if err != nil { - if req.ETag != nil { - return state.NewETagError(state.ETagMismatch, err) - } - - return err - } - + err = s.conn.Delete(r.Path, r.Version) + if errors.Is(err, zk.ErrNoNode) { return nil - }, req) + } + + if err != nil { + if req.ETag != nil { + return state.NewETagError(state.ETagMismatch, err) + } + + return err + } + + return nil } // BulkDelete performs a bulk delete operation. @@ -239,23 +237,20 @@ func (s *StateStore) Set(req *state.SetRequest) error { return err } - return state.SetWithOptions(func(req *state.SetRequest) error { - _, err = s.conn.Set(r.Path, r.Data, r.Version) + _, err = s.conn.Set(r.Path, r.Data, r.Version) + if errors.Is(err, zk.ErrNoNode) { + _, err = s.conn.Create(r.Path, r.Data, 0, nil) + } - if errors.Is(err, zk.ErrNoNode) { - _, err = s.conn.Create(r.Path, r.Data, 0, nil) + if err != nil { + if req.ETag != nil { + return state.NewETagError(state.ETagMismatch, err) } - if err != nil { - if req.ETag != nil { - return state.NewETagError(state.ETagMismatch, err) - } + return err + } - return err - } - - return nil - }, req) + return nil } // BulkSet performs a bulks save operation.