sqlite return ttlExpiryTime in GetResponse (#2869)

Signed-off-by: joshvanl <me@joshvanl.dev>
This commit is contained in:
Josh van Leeuwen 2023-06-06 16:03:44 +01:00 committed by GitHub
parent 1ac92a886a
commit f689322570
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 105 additions and 13 deletions

View File

@ -13,6 +13,12 @@ limitations under the License.
package state package state
const (
// GetRespMetaKeyTTLExpireTime is the key for the metadata value of the TTL
// expire time. Value is RFC3339 formatted string.
GetRespMetaKeyTTLExpireTime string = "ttlExpireTime"
)
// GetResponse is the response object for getting state. // GetResponse is the response object for getting state.
type GetResponse struct { type GetResponse struct {
Data []byte `json:"data"` Data []byte `json:"data"`

View File

@ -250,14 +250,14 @@ func (a *sqliteDBAccess) Get(parentCtx context.Context, req *state.GetRequest) (
} }
// Concatenation is required for table name because sql.DB does not substitute parameters for table names // Concatenation is required for table name because sql.DB does not substitute parameters for table names
stmt := `SELECT key, value, is_binary, etag FROM ` + a.metadata.TableName + ` stmt := `SELECT key, value, is_binary, etag, expiration_time FROM ` + a.metadata.TableName + `
WHERE WHERE
key = ? key = ?
AND (expiration_time IS NULL OR expiration_time > CURRENT_TIMESTAMP)` AND (expiration_time IS NULL OR expiration_time > CURRENT_TIMESTAMP)`
ctx, cancel := context.WithTimeout(parentCtx, a.metadata.timeout) ctx, cancel := context.WithTimeout(parentCtx, a.metadata.timeout)
defer cancel() defer cancel()
row := a.db.QueryRowContext(ctx, stmt, req.Key) row := a.db.QueryRowContext(ctx, stmt, req.Key)
_, value, etag, err := readRow(row) _, value, etag, expireTime, err := readRow(row)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return &state.GetResponse{}, nil return &state.GetResponse{}, nil
@ -265,10 +265,17 @@ func (a *sqliteDBAccess) Get(parentCtx context.Context, req *state.GetRequest) (
return nil, err return nil, err
} }
var metadata map[string]string
if expireTime != nil {
metadata = map[string]string{
state.GetRespMetaKeyTTLExpireTime: expireTime.UTC().Format(time.RFC3339),
}
}
return &state.GetResponse{ return &state.GetResponse{
Data: value, Data: value,
ETag: etag, ETag: etag,
Metadata: req.Metadata, Metadata: metadata,
}, nil }, nil
} }
@ -286,7 +293,7 @@ func (a *sqliteDBAccess) BulkGet(parentCtx context.Context, req []state.GetReque
} }
// Concatenation is required for table name because sql.DB does not substitute parameters for table names // Concatenation is required for table name because sql.DB does not substitute parameters for table names
stmt := `SELECT key, value, is_binary, etag FROM ` + a.metadata.TableName + ` stmt := `SELECT key, value, is_binary, etag, expiration_time FROM ` + a.metadata.TableName + `
WHERE WHERE
key IN (` + inClause + `) key IN (` + inClause + `)
AND (expiration_time IS NULL OR expiration_time > CURRENT_TIMESTAMP)` AND (expiration_time IS NULL OR expiration_time > CURRENT_TIMESTAMP)`
@ -307,10 +314,16 @@ func (a *sqliteDBAccess) BulkGet(parentCtx context.Context, req []state.GetReque
} }
r := state.BulkGetResponse{} r := state.BulkGetResponse{}
r.Key, r.Data, r.ETag, err = readRow(rows) var expireTime *time.Time
r.Key, r.Data, r.ETag, expireTime, err = readRow(rows)
if err != nil { if err != nil {
r.Error = err.Error() r.Error = err.Error()
} }
if expireTime != nil {
r.Metadata = map[string]string{
state.GetRespMetaKeyTTLExpireTime: expireTime.UTC().Format(time.RFC3339),
}
}
res[n] = r res[n] = r
foundKeys[r.Key] = struct{}{} foundKeys[r.Key] = struct{}{}
} }
@ -337,14 +350,22 @@ func (a *sqliteDBAccess) BulkGet(parentCtx context.Context, req []state.GetReque
return res[:n], nil return res[:n], nil
} }
func readRow(row interface{ Scan(dest ...any) error }) (key string, value []byte, etagP *string, err error) { func readRow(row interface{ Scan(dest ...any) error }) (string, []byte, *string, *time.Time, error) {
var ( var (
key string
value []byte
isBinary bool isBinary bool
etag string etag string
expire sql.NullTime
expireTime *time.Time
) )
err = row.Scan(&key, &value, &isBinary, &etag) err := row.Scan(&key, &value, &isBinary, &etag, &expire)
if err != nil { if err != nil {
return key, nil, nil, err return key, nil, nil, nil, err
}
if expire.Valid {
expireTime = &expire.Time
} }
if isBinary { if isBinary {
@ -352,12 +373,12 @@ func readRow(row interface{ Scan(dest ...any) error }) (key string, value []byte
data := make([]byte, len(value)) data := make([]byte, len(value))
n, err = base64.StdEncoding.Decode(data, value) n, err = base64.StdEncoding.Decode(data, value)
if err != nil { if err != nil {
return key, nil, nil, fmt.Errorf("failed to decode binary data: %w", err) return key, nil, nil, nil, fmt.Errorf("failed to decode binary data: %w", err)
} }
return key, data[:n], &etag, nil return key, data[:n], &etag, expireTime, nil
} }
return key, value, &etag, nil return key, value, &etag, expireTime, nil
} }
func (a *sqliteDBAccess) Set(ctx context.Context, req *state.SetRequest) error { func (a *sqliteDBAccess) Set(ctx context.Context, req *state.SetRequest) error {

View File

@ -20,6 +20,7 @@ import (
"fmt" "fmt"
"io" "io"
"os" "os"
"sort"
"testing" "testing"
"time" "time"
@ -139,6 +140,11 @@ func TestSqliteIntegration(t *testing.T) {
multiWithSetOnly(t, s) multiWithSetOnly(t, s)
}) })
t.Run("ttlExpireTime", func(t *testing.T) {
getExpireTime(t, s)
getBulkExpireTime(t, s)
})
t.Run("Binary data", func(t *testing.T) { t.Run("Binary data", func(t *testing.T) {
key := randomKey() key := randomKey()
@ -488,6 +494,65 @@ func setNoTTLUpdatesExpiry(t *testing.T, s state.Store) {
deleteItem(t, s, key, nil) deleteItem(t, s, key, nil)
} }
func getExpireTime(t *testing.T, s state.Store) {
key1 := randomKey()
assert.NoError(t, s.Set(context.Background(), &state.SetRequest{
Key: key1,
Value: "123",
Metadata: map[string]string{
"ttlInSeconds": "1000",
},
}))
resp, err := s.Get(context.Background(), &state.GetRequest{Key: key1})
assert.NoError(t, err)
assert.Equal(t, `"123"`, string(resp.Data))
require.Len(t, resp.Metadata, 1)
expireTime, err := time.Parse(time.RFC3339, resp.Metadata["ttlExpireTime"])
require.NoError(t, err)
assert.InDelta(t, time.Now().Add(time.Second*1000).Unix(), expireTime.Unix(), 10)
}
func getBulkExpireTime(t *testing.T, s state.Store) {
key1 := randomKey()
key2 := randomKey()
assert.NoError(t, s.Set(context.Background(), &state.SetRequest{
Key: key1,
Value: "123",
Metadata: map[string]string{
"ttlInSeconds": "1000",
},
}))
assert.NoError(t, s.Set(context.Background(), &state.SetRequest{
Key: key2,
Value: "456",
Metadata: map[string]string{
"ttlInSeconds": "2001",
},
}))
resp, err := s.BulkGet(context.Background(), []state.GetRequest{
{Key: key1}, {Key: key2},
}, state.BulkGetOpts{})
require.NoError(t, err)
assert.Len(t, resp, 2)
sort.Slice(resp, func(i, j int) bool {
return string(resp[i].Data) < string(resp[j].Data)
})
assert.Equal(t, `"123"`, string(resp[0].Data))
assert.Equal(t, `"456"`, string(resp[1].Data))
require.Len(t, resp[0].Metadata, 1)
require.Len(t, resp[1].Metadata, 1)
expireTime, err := time.Parse(time.RFC3339, resp[0].Metadata["ttlExpireTime"])
require.NoError(t, err)
assert.InDelta(t, time.Now().Add(time.Second*1000).Unix(), expireTime.Unix(), 10)
expireTime, err = time.Parse(time.RFC3339, resp[1].Metadata["ttlExpireTime"])
require.NoError(t, err)
assert.InDelta(t, time.Now().Add(time.Second*2001).Unix(), expireTime.Unix(), 10)
}
// expiredStateCannotBeRead proves that an expired state element can not be read. // expiredStateCannotBeRead proves that an expired state element can not be read.
func expiredStateCannotBeRead(t *testing.T, s state.Store) { func expiredStateCannotBeRead(t *testing.T, s state.Store) {
key := randomKey() key := randomKey()