sqlite return ttlExpiryTime in GetResponse (#2869)
Signed-off-by: joshvanl <me@joshvanl.dev>
This commit is contained in:
parent
1ac92a886a
commit
f689322570
|
|
@ -13,6 +13,12 @@ limitations under the License.
|
|||
|
||||
package state
|
||||
|
||||
const (
|
||||
// GetRespMetaKeyTTLExpireTime is the key for the metadata value of the TTL
|
||||
// expire time. Value is RFC3339 formatted string.
|
||||
GetRespMetaKeyTTLExpireTime string = "ttlExpireTime"
|
||||
)
|
||||
|
||||
// GetResponse is the response object for getting state.
|
||||
type GetResponse struct {
|
||||
Data []byte `json:"data"`
|
||||
|
|
|
|||
|
|
@ -250,14 +250,14 @@ func (a *sqliteDBAccess) Get(parentCtx context.Context, req *state.GetRequest) (
|
|||
}
|
||||
|
||||
// Concatenation is required for table name because sql.DB does not substitute parameters for table names
|
||||
stmt := `SELECT key, value, is_binary, etag FROM ` + a.metadata.TableName + `
|
||||
stmt := `SELECT key, value, is_binary, etag, expiration_time FROM ` + a.metadata.TableName + `
|
||||
WHERE
|
||||
key = ?
|
||||
AND (expiration_time IS NULL OR expiration_time > CURRENT_TIMESTAMP)`
|
||||
ctx, cancel := context.WithTimeout(parentCtx, a.metadata.timeout)
|
||||
defer cancel()
|
||||
row := a.db.QueryRowContext(ctx, stmt, req.Key)
|
||||
_, value, etag, err := readRow(row)
|
||||
_, value, etag, expireTime, err := readRow(row)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return &state.GetResponse{}, nil
|
||||
|
|
@ -265,10 +265,17 @@ func (a *sqliteDBAccess) Get(parentCtx context.Context, req *state.GetRequest) (
|
|||
return nil, err
|
||||
}
|
||||
|
||||
var metadata map[string]string
|
||||
if expireTime != nil {
|
||||
metadata = map[string]string{
|
||||
state.GetRespMetaKeyTTLExpireTime: expireTime.UTC().Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
return &state.GetResponse{
|
||||
Data: value,
|
||||
ETag: etag,
|
||||
Metadata: req.Metadata,
|
||||
Metadata: metadata,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -286,7 +293,7 @@ func (a *sqliteDBAccess) BulkGet(parentCtx context.Context, req []state.GetReque
|
|||
}
|
||||
|
||||
// Concatenation is required for table name because sql.DB does not substitute parameters for table names
|
||||
stmt := `SELECT key, value, is_binary, etag FROM ` + a.metadata.TableName + `
|
||||
stmt := `SELECT key, value, is_binary, etag, expiration_time FROM ` + a.metadata.TableName + `
|
||||
WHERE
|
||||
key IN (` + inClause + `)
|
||||
AND (expiration_time IS NULL OR expiration_time > CURRENT_TIMESTAMP)`
|
||||
|
|
@ -307,10 +314,16 @@ func (a *sqliteDBAccess) BulkGet(parentCtx context.Context, req []state.GetReque
|
|||
}
|
||||
|
||||
r := state.BulkGetResponse{}
|
||||
r.Key, r.Data, r.ETag, err = readRow(rows)
|
||||
var expireTime *time.Time
|
||||
r.Key, r.Data, r.ETag, expireTime, err = readRow(rows)
|
||||
if err != nil {
|
||||
r.Error = err.Error()
|
||||
}
|
||||
if expireTime != nil {
|
||||
r.Metadata = map[string]string{
|
||||
state.GetRespMetaKeyTTLExpireTime: expireTime.UTC().Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
res[n] = r
|
||||
foundKeys[r.Key] = struct{}{}
|
||||
}
|
||||
|
|
@ -337,14 +350,22 @@ func (a *sqliteDBAccess) BulkGet(parentCtx context.Context, req []state.GetReque
|
|||
return res[:n], nil
|
||||
}
|
||||
|
||||
func readRow(row interface{ Scan(dest ...any) error }) (key string, value []byte, etagP *string, err error) {
|
||||
func readRow(row interface{ Scan(dest ...any) error }) (string, []byte, *string, *time.Time, error) {
|
||||
var (
|
||||
key string
|
||||
value []byte
|
||||
isBinary bool
|
||||
etag string
|
||||
expire sql.NullTime
|
||||
expireTime *time.Time
|
||||
)
|
||||
err = row.Scan(&key, &value, &isBinary, &etag)
|
||||
err := row.Scan(&key, &value, &isBinary, &etag, &expire)
|
||||
if err != nil {
|
||||
return key, nil, nil, err
|
||||
return key, nil, nil, nil, err
|
||||
}
|
||||
|
||||
if expire.Valid {
|
||||
expireTime = &expire.Time
|
||||
}
|
||||
|
||||
if isBinary {
|
||||
|
|
@ -352,12 +373,12 @@ func readRow(row interface{ Scan(dest ...any) error }) (key string, value []byte
|
|||
data := make([]byte, len(value))
|
||||
n, err = base64.StdEncoding.Decode(data, value)
|
||||
if err != nil {
|
||||
return key, nil, nil, fmt.Errorf("failed to decode binary data: %w", err)
|
||||
return key, nil, nil, nil, fmt.Errorf("failed to decode binary data: %w", err)
|
||||
}
|
||||
return key, data[:n], &etag, nil
|
||||
return key, data[:n], &etag, expireTime, nil
|
||||
}
|
||||
|
||||
return key, value, &etag, nil
|
||||
return key, value, &etag, expireTime, nil
|
||||
}
|
||||
|
||||
func (a *sqliteDBAccess) Set(ctx context.Context, req *state.SetRequest) error {
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
|
@ -139,6 +140,11 @@ func TestSqliteIntegration(t *testing.T) {
|
|||
multiWithSetOnly(t, s)
|
||||
})
|
||||
|
||||
t.Run("ttlExpireTime", func(t *testing.T) {
|
||||
getExpireTime(t, s)
|
||||
getBulkExpireTime(t, s)
|
||||
})
|
||||
|
||||
t.Run("Binary data", func(t *testing.T) {
|
||||
key := randomKey()
|
||||
|
||||
|
|
@ -488,6 +494,65 @@ func setNoTTLUpdatesExpiry(t *testing.T, s state.Store) {
|
|||
deleteItem(t, s, key, nil)
|
||||
}
|
||||
|
||||
func getExpireTime(t *testing.T, s state.Store) {
|
||||
key1 := randomKey()
|
||||
assert.NoError(t, s.Set(context.Background(), &state.SetRequest{
|
||||
Key: key1,
|
||||
Value: "123",
|
||||
Metadata: map[string]string{
|
||||
"ttlInSeconds": "1000",
|
||||
},
|
||||
}))
|
||||
|
||||
resp, err := s.Get(context.Background(), &state.GetRequest{Key: key1})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `"123"`, string(resp.Data))
|
||||
require.Len(t, resp.Metadata, 1)
|
||||
expireTime, err := time.Parse(time.RFC3339, resp.Metadata["ttlExpireTime"])
|
||||
require.NoError(t, err)
|
||||
assert.InDelta(t, time.Now().Add(time.Second*1000).Unix(), expireTime.Unix(), 10)
|
||||
}
|
||||
|
||||
func getBulkExpireTime(t *testing.T, s state.Store) {
|
||||
key1 := randomKey()
|
||||
key2 := randomKey()
|
||||
|
||||
assert.NoError(t, s.Set(context.Background(), &state.SetRequest{
|
||||
Key: key1,
|
||||
Value: "123",
|
||||
Metadata: map[string]string{
|
||||
"ttlInSeconds": "1000",
|
||||
},
|
||||
}))
|
||||
assert.NoError(t, s.Set(context.Background(), &state.SetRequest{
|
||||
Key: key2,
|
||||
Value: "456",
|
||||
Metadata: map[string]string{
|
||||
"ttlInSeconds": "2001",
|
||||
},
|
||||
}))
|
||||
|
||||
resp, err := s.BulkGet(context.Background(), []state.GetRequest{
|
||||
{Key: key1}, {Key: key2},
|
||||
}, state.BulkGetOpts{})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp, 2)
|
||||
sort.Slice(resp, func(i, j int) bool {
|
||||
return string(resp[i].Data) < string(resp[j].Data)
|
||||
})
|
||||
|
||||
assert.Equal(t, `"123"`, string(resp[0].Data))
|
||||
assert.Equal(t, `"456"`, string(resp[1].Data))
|
||||
require.Len(t, resp[0].Metadata, 1)
|
||||
require.Len(t, resp[1].Metadata, 1)
|
||||
expireTime, err := time.Parse(time.RFC3339, resp[0].Metadata["ttlExpireTime"])
|
||||
require.NoError(t, err)
|
||||
assert.InDelta(t, time.Now().Add(time.Second*1000).Unix(), expireTime.Unix(), 10)
|
||||
expireTime, err = time.Parse(time.RFC3339, resp[1].Metadata["ttlExpireTime"])
|
||||
require.NoError(t, err)
|
||||
assert.InDelta(t, time.Now().Add(time.Second*2001).Unix(), expireTime.Unix(), 10)
|
||||
}
|
||||
|
||||
// expiredStateCannotBeRead proves that an expired state element can not be read.
|
||||
func expiredStateCannotBeRead(t *testing.T, s state.Store) {
|
||||
key := randomKey()
|
||||
|
|
|
|||
Loading…
Reference in New Issue