From f6893225702e9397bde9534b1116f093b4e4d66f Mon Sep 17 00:00:00 2001 From: Josh van Leeuwen Date: Tue, 6 Jun 2023 16:03:44 +0100 Subject: [PATCH] sqlite return ttlExpiryTime in GetResponse (#2869) Signed-off-by: joshvanl --- state/responses.go | 6 +++ state/sqlite/sqlite_dbaccess.go | 47 +++++++++++++----- state/sqlite/sqlite_integration_test.go | 65 +++++++++++++++++++++++++ 3 files changed, 105 insertions(+), 13 deletions(-) diff --git a/state/responses.go b/state/responses.go index 135861501..7f2c97c48 100644 --- a/state/responses.go +++ b/state/responses.go @@ -13,6 +13,12 @@ limitations under the License. package state +const ( + // GetRespMetaKeyTTLExpireTime is the key for the metadata value of the TTL + // expire time. Value is RFC3339 formatted string. + GetRespMetaKeyTTLExpireTime string = "ttlExpireTime" +) + // GetResponse is the response object for getting state. type GetResponse struct { Data []byte `json:"data"` diff --git a/state/sqlite/sqlite_dbaccess.go b/state/sqlite/sqlite_dbaccess.go index 51b9f7c54..d1b0b852a 100644 --- a/state/sqlite/sqlite_dbaccess.go +++ b/state/sqlite/sqlite_dbaccess.go @@ -250,14 +250,14 @@ func (a *sqliteDBAccess) Get(parentCtx context.Context, req *state.GetRequest) ( } // Concatenation is required for table name because sql.DB does not substitute parameters for table names - stmt := `SELECT key, value, is_binary, etag FROM ` + a.metadata.TableName + ` + stmt := `SELECT key, value, is_binary, etag, expiration_time FROM ` + a.metadata.TableName + ` WHERE key = ? AND (expiration_time IS NULL OR expiration_time > CURRENT_TIMESTAMP)` ctx, cancel := context.WithTimeout(parentCtx, a.metadata.timeout) defer cancel() row := a.db.QueryRowContext(ctx, stmt, req.Key) - _, value, etag, err := readRow(row) + _, value, etag, expireTime, err := readRow(row) if err != nil { if errors.Is(err, sql.ErrNoRows) { return &state.GetResponse{}, nil @@ -265,10 +265,17 @@ func (a *sqliteDBAccess) Get(parentCtx context.Context, req *state.GetRequest) ( return nil, err } + var metadata map[string]string + if expireTime != nil { + metadata = map[string]string{ + state.GetRespMetaKeyTTLExpireTime: expireTime.UTC().Format(time.RFC3339), + } + } + return &state.GetResponse{ Data: value, ETag: etag, - Metadata: req.Metadata, + Metadata: metadata, }, nil } @@ -286,7 +293,7 @@ func (a *sqliteDBAccess) BulkGet(parentCtx context.Context, req []state.GetReque } // Concatenation is required for table name because sql.DB does not substitute parameters for table names - stmt := `SELECT key, value, is_binary, etag FROM ` + a.metadata.TableName + ` + stmt := `SELECT key, value, is_binary, etag, expiration_time FROM ` + a.metadata.TableName + ` WHERE key IN (` + inClause + `) AND (expiration_time IS NULL OR expiration_time > CURRENT_TIMESTAMP)` @@ -307,10 +314,16 @@ func (a *sqliteDBAccess) BulkGet(parentCtx context.Context, req []state.GetReque } r := state.BulkGetResponse{} - r.Key, r.Data, r.ETag, err = readRow(rows) + var expireTime *time.Time + r.Key, r.Data, r.ETag, expireTime, err = readRow(rows) if err != nil { r.Error = err.Error() } + if expireTime != nil { + r.Metadata = map[string]string{ + state.GetRespMetaKeyTTLExpireTime: expireTime.UTC().Format(time.RFC3339), + } + } res[n] = r foundKeys[r.Key] = struct{}{} } @@ -337,14 +350,22 @@ func (a *sqliteDBAccess) BulkGet(parentCtx context.Context, req []state.GetReque return res[:n], nil } -func readRow(row interface{ Scan(dest ...any) error }) (key string, value []byte, etagP *string, err error) { +func readRow(row interface{ Scan(dest ...any) error }) (string, []byte, *string, *time.Time, error) { var ( - isBinary bool - etag string + key string + value []byte + isBinary bool + etag string + expire sql.NullTime + expireTime *time.Time ) - err = row.Scan(&key, &value, &isBinary, &etag) + err := row.Scan(&key, &value, &isBinary, &etag, &expire) if err != nil { - return key, nil, nil, err + return key, nil, nil, nil, err + } + + if expire.Valid { + expireTime = &expire.Time } if isBinary { @@ -352,12 +373,12 @@ func readRow(row interface{ Scan(dest ...any) error }) (key string, value []byte data := make([]byte, len(value)) n, err = base64.StdEncoding.Decode(data, value) if err != nil { - return key, nil, nil, fmt.Errorf("failed to decode binary data: %w", err) + return key, nil, nil, nil, fmt.Errorf("failed to decode binary data: %w", err) } - return key, data[:n], &etag, nil + return key, data[:n], &etag, expireTime, nil } - return key, value, &etag, nil + return key, value, &etag, expireTime, nil } func (a *sqliteDBAccess) Set(ctx context.Context, req *state.SetRequest) error { diff --git a/state/sqlite/sqlite_integration_test.go b/state/sqlite/sqlite_integration_test.go index 6aea7fba9..4400ca4cf 100644 --- a/state/sqlite/sqlite_integration_test.go +++ b/state/sqlite/sqlite_integration_test.go @@ -20,6 +20,7 @@ import ( "fmt" "io" "os" + "sort" "testing" "time" @@ -139,6 +140,11 @@ func TestSqliteIntegration(t *testing.T) { multiWithSetOnly(t, s) }) + t.Run("ttlExpireTime", func(t *testing.T) { + getExpireTime(t, s) + getBulkExpireTime(t, s) + }) + t.Run("Binary data", func(t *testing.T) { key := randomKey() @@ -488,6 +494,65 @@ func setNoTTLUpdatesExpiry(t *testing.T, s state.Store) { deleteItem(t, s, key, nil) } +func getExpireTime(t *testing.T, s state.Store) { + key1 := randomKey() + assert.NoError(t, s.Set(context.Background(), &state.SetRequest{ + Key: key1, + Value: "123", + Metadata: map[string]string{ + "ttlInSeconds": "1000", + }, + })) + + resp, err := s.Get(context.Background(), &state.GetRequest{Key: key1}) + assert.NoError(t, err) + assert.Equal(t, `"123"`, string(resp.Data)) + require.Len(t, resp.Metadata, 1) + expireTime, err := time.Parse(time.RFC3339, resp.Metadata["ttlExpireTime"]) + require.NoError(t, err) + assert.InDelta(t, time.Now().Add(time.Second*1000).Unix(), expireTime.Unix(), 10) +} + +func getBulkExpireTime(t *testing.T, s state.Store) { + key1 := randomKey() + key2 := randomKey() + + assert.NoError(t, s.Set(context.Background(), &state.SetRequest{ + Key: key1, + Value: "123", + Metadata: map[string]string{ + "ttlInSeconds": "1000", + }, + })) + assert.NoError(t, s.Set(context.Background(), &state.SetRequest{ + Key: key2, + Value: "456", + Metadata: map[string]string{ + "ttlInSeconds": "2001", + }, + })) + + resp, err := s.BulkGet(context.Background(), []state.GetRequest{ + {Key: key1}, {Key: key2}, + }, state.BulkGetOpts{}) + require.NoError(t, err) + assert.Len(t, resp, 2) + sort.Slice(resp, func(i, j int) bool { + return string(resp[i].Data) < string(resp[j].Data) + }) + + assert.Equal(t, `"123"`, string(resp[0].Data)) + assert.Equal(t, `"456"`, string(resp[1].Data)) + require.Len(t, resp[0].Metadata, 1) + require.Len(t, resp[1].Metadata, 1) + expireTime, err := time.Parse(time.RFC3339, resp[0].Metadata["ttlExpireTime"]) + require.NoError(t, err) + assert.InDelta(t, time.Now().Add(time.Second*1000).Unix(), expireTime.Unix(), 10) + expireTime, err = time.Parse(time.RFC3339, resp[1].Metadata["ttlExpireTime"]) + require.NoError(t, err) + assert.InDelta(t, time.Now().Add(time.Second*2001).Unix(), expireTime.Unix(), 10) +} + // expiredStateCannotBeRead proves that an expired state element can not be read. func expiredStateCannotBeRead(t *testing.T, s state.Store) { key := randomKey()