diff --git a/state/mongodb/mongodb.go b/state/mongodb/mongodb.go index 893bb83c2..9e164c5b8 100644 --- a/state/mongodb/mongodb.go +++ b/state/mongodb/mongodb.go @@ -179,6 +179,8 @@ func (m *MongoDB) setInternal(ctx context.Context, req *state.SetRequest) error filter := bson.M{id: req.Key} if req.ETag != nil { filter[etag] = *req.ETag + } else if req.Options.Concurrency == state.FirstWrite { + filter[etag] = uuid.NewString() } update := bson.M{"$set": bson.M{id: req.Key, value: vStr, etag: uuid.NewString()}} diff --git a/state/redis/redis.go b/state/redis/redis.go index fcd59a75f..6d8e72c9b 100644 --- a/state/redis/redis.go +++ b/state/redis/redis.go @@ -24,7 +24,7 @@ import ( ) const ( - setQuery = "local var1 = redis.pcall(\"HGET\", KEYS[1], \"version\"); if type(var1) == \"table\" then redis.call(\"DEL\", KEYS[1]); end; if not var1 or type(var1)==\"table\" or var1 == \"\" or var1 == ARGV[1] or ARGV[1] == \"0\" then redis.call(\"HSET\", KEYS[1], \"data\", ARGV[2]) return redis.call(\"HINCRBY\", KEYS[1], \"version\", 1) else return error(\"failed to set key \" .. KEYS[1]) end" + setQuery = "local var1 = redis.pcall(\"HGET\", KEYS[1], \"version\"); if type(var1) == \"table\" then redis.call(\"DEL\", KEYS[1]); end; local var2 = redis.pcall(\"HGET\", KEYS[1], \"first-write\"); if not var1 or type(var1)==\"table\" or var1 == \"\" or var1 == ARGV[1] or (not var2 and ARGV[1] == \"0\") then redis.call(\"HSET\", KEYS[1], \"data\", ARGV[2]); if ARGV[3] == \"0\" then redis.call(\"HSET\", KEYS[1], \"first-write\", 0); end; return redis.call(\"HINCRBY\", KEYS[1], \"version\", 1) else return error(\"failed to set key \" .. KEYS[1]) end" delQuery = "local var1 = redis.pcall(\"HGET\", KEYS[1], \"version\"); if not var1 or type(var1)==\"table\" or var1 == ARGV[1] or var1 == \"\" or ARGV[1] == \"0\" then return redis.call(\"DEL\", KEYS[1]) else return error(\"failed to delete \" .. KEYS[1]) end" connectedSlavesReplicas = "connected_slaves:" infoReplicationDelimiter = "\r\n" @@ -238,7 +238,11 @@ func (r *StateStore) setValue(req *state.SetRequest) error { bt, _ := utils.Marshal(req.Value, r.json.Marshal) - _, err = r.client.Do(r.ctx, "EVAL", setQuery, 1, req.Key, ver, bt).Result() + firstWrite := 1 + if req.Options.Concurrency == state.FirstWrite { + firstWrite = 0 + } + _, err = r.client.Do(r.ctx, "EVAL", setQuery, 1, req.Key, ver, bt, firstWrite).Result() if err != nil { if req.ETag != nil { return state.NewETagError(state.ETagMismatch, err) @@ -337,7 +341,7 @@ func (r *StateStore) getKeyVersion(vals []interface{}) (data string, version *st } func (r *StateStore) parseETag(req *state.SetRequest) (int, error) { - if req.Options.Concurrency == state.LastWrite || req.ETag == nil || (req.ETag != nil && *req.ETag == "") { + if req.Options.Concurrency == state.LastWrite || req.ETag == nil || *req.ETag == "" { return 0, nil } ver, err := strconv.Atoi(*req.ETag) diff --git a/state/redis/redis_test.go b/state/redis/redis_test.go index 577906319..049fd797c 100644 --- a/state/redis/redis_test.go +++ b/state/redis/redis_test.go @@ -90,6 +90,32 @@ func TestParseEtag(t *testing.T) { assert.Equal(t, nil, err, "failed to parse ETag") assert.Equal(t, 0, ver, "version should be 0") }) + t.Run("Concurrency=FirstWrite", func(t *testing.T) { + ver, err := store.parseETag(&state.SetRequest{ + Options: state.SetStateOption{ + Concurrency: state.FirstWrite, + }, + }) + assert.Equal(t, nil, err, "failed to parse Concurrency") + assert.Equal(t, 0, ver, "version should be 0") + + // ETag is nil + req := &state.SetRequest{ + Options: state.SetStateOption{}, + } + ver, err = store.parseETag(req) + assert.Equal(t, nil, err, "failed to parse Concurrency") + assert.Equal(t, 0, ver, "version should be 0") + + // ETag is empty + emptyString := "" + req = &state.SetRequest{ + ETag: &emptyString, + } + ver, err = store.parseETag(req) + assert.Equal(t, nil, err, "failed to parse Concurrency") + assert.Equal(t, 0, ver, "version should be 0") + }) } func TestParseTTL(t *testing.T) { diff --git a/tests/conformance/state/state.go b/tests/conformance/state/state.go index 60cbbb121..0adc2008e 100644 --- a/tests/conformance/state/state.go +++ b/tests/conformance/state/state.go @@ -474,4 +474,127 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St features := statestore.Features() assert.False(t, state.FeatureETag.IsPresent(features)) } + + if config.HasOperation("first-write") { + t.Run("first-write without etag", func(t *testing.T) { + testKey := "first-writeTest" + firstValue := []byte("testValue1") + secondValue := []byte("testValue2") + emptyString := "" + + requestSets := [][2]*state.SetRequest{ + { + { + Key: testKey, + Value: firstValue, + Options: state.SetStateOption{ + Concurrency: state.FirstWrite, + Consistency: state.Strong, + }, + }, { + Key: testKey, + Value: secondValue, + Options: state.SetStateOption{ + Concurrency: state.FirstWrite, + Consistency: state.Strong, + }, + }, + }, + {{ + Key: testKey, + Value: firstValue, + Options: state.SetStateOption{ + Concurrency: state.FirstWrite, + Consistency: state.Strong, + }, + ETag: &emptyString, + }, { + Key: testKey, + Value: secondValue, + Options: state.SetStateOption{ + Concurrency: state.FirstWrite, + Consistency: state.Strong, + }, + ETag: &emptyString, + }}, + } + + for _, requestSet := range requestSets { + // Delete any potential object, it's important to start from a clean slate. + err := statestore.Delete(&state.DeleteRequest{ + Key: testKey, + }) + assert.Nil(t, err) + + err = statestore.Set(requestSet[0]) + assert.Nil(t, err) + + // Validate the set. + res, err := statestore.Get(&state.GetRequest{ + Key: testKey, + }) + assert.Nil(t, err) + assert.Equal(t, firstValue, res.Data) + + // Second write expect fail + err = statestore.Set(requestSet[1]) + assert.NotNil(t, err) + } + }) + + t.Run("first-write with etag", func(t *testing.T) { + testKey := "first-writeTest" + firstValue := []byte("testValue1") + secondValue := []byte("testValue2") + + request := &state.SetRequest{ + Key: testKey, + Value: firstValue, + } + + // Delete any potential object, it's important to start from a clean slate. + err := statestore.Delete(&state.DeleteRequest{ + Key: testKey, + }) + assert.Nil(t, err) + + err = statestore.Set(request) + assert.Nil(t, err) + + // Validate the set. + res, err := statestore.Get(&state.GetRequest{ + Key: testKey, + }) + assert.Nil(t, err) + assert.Equal(t, firstValue, res.Data) + + etag := res.ETag + + request = &state.SetRequest{ + Key: testKey, + Value: secondValue, + ETag: etag, + Options: state.SetStateOption{ + Concurrency: state.FirstWrite, + Consistency: state.Strong, + }, + } + err = statestore.Set(request) + assert.Nil(t, err) + + // Validate the set. + res, err = statestore.Get(&state.GetRequest{ + Key: testKey, + }) + assert.Nil(t, err) + assert.NotEqual(t, etag, res.ETag) + assert.Equal(t, secondValue, res.Data) + + request.ETag = etag + + // Second write expect fail + err = statestore.Set(request) + assert.NotNil(t, err) + }) + } }