diff --git a/state/mysql/mysql.go b/state/mysql/mysql.go index b78f2be83..4c94a41ba 100644 --- a/state/mysql/mysql.go +++ b/state/mysql/mysql.go @@ -531,10 +531,10 @@ func (m *MySQL) Get(parentCtx context.Context, req *state.GetRequest) (*state.Ge ctx, cancel := context.WithTimeout(parentCtx, m.timeout) defer cancel() // Concatenation is required for table name because sql.DB does not substitute parameters for table names - query := `SELECT id, value, eTag, isbinary FROM ` + m.tableName + ` WHERE id = ? + query := `SELECT id, value, eTag, isbinary, IFNULL(expiredate, "") FROM ` + m.tableName + ` WHERE id = ? AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)` row := m.db.QueryRowContext(ctx, query, req.Key) - _, value, etag, err := readRow(row) + _, value, etag, expireTime, err := readRow(row) if err != nil { // If no rows exist, return an empty response, otherwise return an error. if errors.Is(err, sql.ErrNoRows) { @@ -542,10 +542,18 @@ func (m *MySQL) Get(parentCtx context.Context, req *state.GetRequest) (*state.Ge } return nil, err } + + var metadata map[string]string + if expireTime != nil { + metadata = map[string]string{ + state.GetRespMetaKeyTTLExpireTime: expireTime.UTC().Format(time.RFC3339), + } + } + return &state.GetResponse{ Data: value, ETag: etag, - Metadata: req.Metadata, + Metadata: metadata, }, nil } @@ -708,7 +716,7 @@ func (m *MySQL) BulkGet(parentCtx context.Context, req []state.GetRequest, _ sta } // Concatenation is required for table name because sql.DB does not substitute parameters for table names - stmt := `SELECT id, value, eTag, isbinary FROM ` + m.tableName + ` + stmt := `SELECT id, value, eTag, isbinary, IFNULL(expiredate, "") FROM ` + m.tableName + ` WHERE id IN (` + inClause + `) AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)` @@ -720,6 +728,7 @@ func (m *MySQL) BulkGet(parentCtx context.Context, req []state.GetRequest, _ sta } var n int + var expireTime *time.Time res := make([]state.BulkGetResponse, len(req)) foundKeys := make(map[string]struct{}, len(req)) for ; rows.Next(); n++ { @@ -729,10 +738,15 @@ func (m *MySQL) BulkGet(parentCtx context.Context, req []state.GetRequest, _ sta } r := state.BulkGetResponse{} - r.Key, r.Data, r.ETag, err = readRow(rows) + r.Key, r.Data, r.ETag, expireTime, err = readRow(rows) if err != nil { r.Error = err.Error() } + if expireTime != nil { + r.Metadata = map[string]string{ + state.GetRespMetaKeyTTLExpireTime: expireTime.UTC().Format(time.RFC3339), + } + } res[n] = r foundKeys[r.Key] = struct{}{} } @@ -759,14 +773,24 @@ func (m *MySQL) BulkGet(parentCtx context.Context, req []state.GetRequest, _ sta return res[:n], nil } -func readRow(row interface{ Scan(dest ...any) error }) (key string, value []byte, etagP *string, err error) { +func readRow(row interface{ Scan(dest ...any) error }) (key string, value []byte, etagP *string, expireTime *time.Time, err error) { var ( etag string isBinary bool + expire string ) - err = row.Scan(&key, &value, &etag, &isBinary) + err = row.Scan(&key, &value, &etag, &isBinary, &expire) if err != nil { - return key, nil, nil, err + return key, nil, nil, nil, err + } + + if len(expire) > 0 { + var expireT time.Time + expireT, err = time.Parse(time.DateTime, expire) + if err != nil { + return key, nil, nil, nil, fmt.Errorf("failed to parse expiration time: %w", err) + } + expireTime = &expireT } if isBinary { @@ -777,17 +801,17 @@ func readRow(row interface{ Scan(dest ...any) error }) (key string, value []byte err = json.Unmarshal(value, &s) if err != nil { - return key, nil, nil, fmt.Errorf("failed to unmarshal JSON binary data: %w", err) + return key, nil, nil, nil, fmt.Errorf("failed to unmarshal JSON binary data: %w", err) } data, err = base64.StdEncoding.DecodeString(s) if err != nil { - return key, nil, nil, fmt.Errorf("failed to decode binary data: %w", err) + return key, nil, nil, nil, fmt.Errorf("failed to decode binary data: %w", err) } - return key, data, &etag, nil + return key, data, &etag, expireTime, nil } - return key, value, &etag, nil + return key, value, &etag, expireTime, nil } // Multi handles multiple transactions. diff --git a/state/mysql/mysql_integration_test.go b/state/mysql/mysql_integration_test.go index 34ec04327..31beb61fd 100644 --- a/state/mysql/mysql_integration_test.go +++ b/state/mysql/mysql_integration_test.go @@ -22,6 +22,7 @@ import ( "encoding/json" "fmt" "os" + "sort" "strings" "testing" "time" @@ -29,6 +30,7 @@ import ( "github.com/go-sql-driver/mysql" "github.com/google/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/dapr/components-contrib/metadata" "github.com/dapr/components-contrib/state" @@ -253,6 +255,12 @@ func TestMySQLIntegration(t *testing.T) { testBulkSetAndBulkDelete(t, mys) }) + t.Run("Get and BulkGet with ttl", func(t *testing.T) { + t.Parallel() + testGetExpireTime(t, mys) + testGetBulkExpireTime(t, mys) + }) + t.Run("Update and delete with eTag succeeds", func(t *testing.T) { t.Parallel() @@ -572,6 +580,65 @@ func testBulkSetAndBulkDelete(t *testing.T, mys *MySQL) { assert.False(t, storeItemExists(t, setReq[1].Key)) } +func testGetExpireTime(t *testing.T, mys *MySQL) { + key1 := randomKey() + assert.NoError(t, mys.Set(context.Background(), &state.SetRequest{ + Key: key1, + Value: "123", + Metadata: map[string]string{ + "ttlInSeconds": "1000", + }, + })) + + resp, err := mys.Get(context.Background(), &state.GetRequest{Key: key1}) + assert.NoError(t, err) + assert.Equal(t, `"123"`, string(resp.Data)) + require.Len(t, resp.Metadata, 1) + expireTime, err := time.Parse(time.RFC3339, resp.Metadata["ttlExpireTime"]) + require.NoError(t, err) + assert.InDelta(t, time.Now().Add(time.Second*1000).Unix(), expireTime.Unix(), 5) +} + +func testGetBulkExpireTime(t *testing.T, mys *MySQL) { + key1 := randomKey() + key2 := randomKey() + + assert.NoError(t, mys.Set(context.Background(), &state.SetRequest{ + Key: key1, + Value: "123", + Metadata: map[string]string{ + "ttlInSeconds": "1000", + }, + })) + assert.NoError(t, mys.Set(context.Background(), &state.SetRequest{ + Key: key2, + Value: "456", + Metadata: map[string]string{ + "ttlInSeconds": "2001", + }, + })) + + resp, err := mys.BulkGet(context.Background(), []state.GetRequest{ + {Key: key1}, {Key: key2}, + }, state.BulkGetOpts{}) + require.NoError(t, err) + assert.Len(t, resp, 2) + sort.Slice(resp, func(i, j int) bool { + return string(resp[i].Data) < string(resp[j].Data) + }) + + assert.Equal(t, `"123"`, string(resp[0].Data)) + assert.Equal(t, `"456"`, string(resp[1].Data)) + require.Len(t, resp[0].Metadata, 1) + require.Len(t, resp[1].Metadata, 1) + expireTime, err := time.Parse(time.RFC3339, resp[0].Metadata["ttlExpireTime"]) + require.NoError(t, err) + assert.InDelta(t, time.Now().Add(time.Second*1000).Unix(), expireTime.Unix(), 5) + expireTime, err = time.Parse(time.RFC3339, resp[1].Metadata["ttlExpireTime"]) + require.NoError(t, err) + assert.InDelta(t, time.Now().Add(time.Second*2001).Unix(), expireTime.Unix(), 5) +} + func dropTable(t *testing.T, db *sql.DB, tableName string) { _, err := db.Exec(fmt.Sprintf( `DROP TABLE %s;`, diff --git a/state/mysql/mysql_test.go b/state/mysql/mysql_test.go index b88328dbc..af5cbd31e 100644 --- a/state/mysql/mysql_test.go +++ b/state/mysql/mysql_test.go @@ -452,8 +452,8 @@ func TestGetSucceeds(t *testing.T) { defer m.mySQL.Close() t.Run("has json type", func(t *testing.T) { - rows := sqlmock.NewRows([]string{"id", "value", "eTag", "isbinary"}).AddRow("UnitTest", "{}", "946af56e", false) - m.mock1.ExpectQuery("SELECT id, value, eTag, isbinary FROM state WHERE id = ?").WillReturnRows(rows) + rows := sqlmock.NewRows([]string{"id", "value", "eTag", "isbinary", "expiredate"}).AddRow("UnitTest", "{}", "946af56e", false, "") + m.mock1.ExpectQuery(`SELECT id, value, eTag, isbinary, IFNULL\(expiredate, ""\) FROM state WHERE id = ?`).WillReturnRows(rows) request := &state.GetRequest{ Key: "UnitTest", @@ -466,12 +466,15 @@ func TestGetSucceeds(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, response) assert.Equal(t, "{}", string(response.Data)) + assert.NotContains(t, response.Metadata, state.GetRespMetaKeyTTLExpireTime) }) - t.Run("has binary type", func(t *testing.T) { + t.Run("has binary type and expiredate", func(t *testing.T) { + now := time.UnixMilli(20001).UTC() + value, _ := utils.Marshal(base64.StdEncoding.EncodeToString([]byte("abcdefg")), json.Marshal) - rows := sqlmock.NewRows([]string{"id", "value", "eTag", "isbinary"}).AddRow("UnitTest", value, "946af56e", true) - m.mock1.ExpectQuery("SELECT id, value, eTag, isbinary FROM state WHERE id = ?").WillReturnRows(rows) + rows := sqlmock.NewRows([]string{"id", "value", "eTag", "isbinary", "expiredate"}).AddRow("UnitTest", value, "946af56e", true, now.Format(time.DateTime)) + m.mock1.ExpectQuery(`SELECT id, value, eTag, isbinary, IFNULL\(expiredate, ""\) FROM state WHERE id = ?`).WillReturnRows(rows) request := &state.GetRequest{ Key: "UnitTest", @@ -484,6 +487,8 @@ func TestGetSucceeds(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, response) assert.Equal(t, "abcdefg", string(response.Data)) + assert.Contains(t, response.Metadata, state.GetRespMetaKeyTTLExpireTime) + assert.Equal(t, "1970-01-01T00:00:20Z", response.Metadata[state.GetRespMetaKeyTTLExpireTime]) }) } diff --git a/tests/certification/state/mysql/mysql_test.go b/tests/certification/state/mysql/mysql_test.go index d34215280..184088a2d 100644 --- a/tests/certification/state/mysql/mysql_test.go +++ b/tests/certification/state/mysql/mysql_test.go @@ -207,6 +207,11 @@ func TestMySQL(t *testing.T) { resp3, err := client.GetState(ctx, stateStoreName, "reqKey3", nil) require.NoError(t, err) assert.Equal(t, "reqVal103", string(resp3.Value)) + + require.Contains(t, resp3.Metadata, "ttlExpireTime") + expireTime, err := time.Parse(time.RFC3339, resp3.Metadata["ttlExpireTime"]) + assert.InDelta(t, time.Now().Add(50*time.Second).Unix(), expireTime.Unix(), 5) + return nil } }