mysql: return ttlExpiryTime in GetResponse (#2871)
Signed-off-by: joshvanl <me@joshvanl.dev> Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
parent
06defdcc22
commit
69a1c01801
|
@ -531,10 +531,10 @@ func (m *MySQL) Get(parentCtx context.Context, req *state.GetRequest) (*state.Ge
|
|||
ctx, cancel := context.WithTimeout(parentCtx, m.timeout)
|
||||
defer cancel()
|
||||
// Concatenation is required for table name because sql.DB does not substitute parameters for table names
|
||||
query := `SELECT id, value, eTag, isbinary FROM ` + m.tableName + ` WHERE id = ?
|
||||
query := `SELECT id, value, eTag, isbinary, IFNULL(expiredate, "") FROM ` + m.tableName + ` WHERE id = ?
|
||||
AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)`
|
||||
row := m.db.QueryRowContext(ctx, query, req.Key)
|
||||
_, value, etag, err := readRow(row)
|
||||
_, value, etag, expireTime, err := readRow(row)
|
||||
if err != nil {
|
||||
// If no rows exist, return an empty response, otherwise return an error.
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
|
@ -542,10 +542,18 @@ func (m *MySQL) Get(parentCtx context.Context, req *state.GetRequest) (*state.Ge
|
|||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var metadata map[string]string
|
||||
if expireTime != nil {
|
||||
metadata = map[string]string{
|
||||
state.GetRespMetaKeyTTLExpireTime: expireTime.UTC().Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
return &state.GetResponse{
|
||||
Data: value,
|
||||
ETag: etag,
|
||||
Metadata: req.Metadata,
|
||||
Metadata: metadata,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -708,7 +716,7 @@ func (m *MySQL) BulkGet(parentCtx context.Context, req []state.GetRequest, _ sta
|
|||
}
|
||||
|
||||
// Concatenation is required for table name because sql.DB does not substitute parameters for table names
|
||||
stmt := `SELECT id, value, eTag, isbinary FROM ` + m.tableName + `
|
||||
stmt := `SELECT id, value, eTag, isbinary, IFNULL(expiredate, "") FROM ` + m.tableName + `
|
||||
WHERE
|
||||
id IN (` + inClause + `)
|
||||
AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)`
|
||||
|
@ -720,6 +728,7 @@ func (m *MySQL) BulkGet(parentCtx context.Context, req []state.GetRequest, _ sta
|
|||
}
|
||||
|
||||
var n int
|
||||
var expireTime *time.Time
|
||||
res := make([]state.BulkGetResponse, len(req))
|
||||
foundKeys := make(map[string]struct{}, len(req))
|
||||
for ; rows.Next(); n++ {
|
||||
|
@ -729,10 +738,15 @@ func (m *MySQL) BulkGet(parentCtx context.Context, req []state.GetRequest, _ sta
|
|||
}
|
||||
|
||||
r := state.BulkGetResponse{}
|
||||
r.Key, r.Data, r.ETag, err = readRow(rows)
|
||||
r.Key, r.Data, r.ETag, expireTime, err = readRow(rows)
|
||||
if err != nil {
|
||||
r.Error = err.Error()
|
||||
}
|
||||
if expireTime != nil {
|
||||
r.Metadata = map[string]string{
|
||||
state.GetRespMetaKeyTTLExpireTime: expireTime.UTC().Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
res[n] = r
|
||||
foundKeys[r.Key] = struct{}{}
|
||||
}
|
||||
|
@ -759,14 +773,24 @@ func (m *MySQL) BulkGet(parentCtx context.Context, req []state.GetRequest, _ sta
|
|||
return res[:n], nil
|
||||
}
|
||||
|
||||
func readRow(row interface{ Scan(dest ...any) error }) (key string, value []byte, etagP *string, err error) {
|
||||
func readRow(row interface{ Scan(dest ...any) error }) (key string, value []byte, etagP *string, expireTime *time.Time, err error) {
|
||||
var (
|
||||
etag string
|
||||
isBinary bool
|
||||
expire string
|
||||
)
|
||||
err = row.Scan(&key, &value, &etag, &isBinary)
|
||||
err = row.Scan(&key, &value, &etag, &isBinary, &expire)
|
||||
if err != nil {
|
||||
return key, nil, nil, err
|
||||
return key, nil, nil, nil, err
|
||||
}
|
||||
|
||||
if len(expire) > 0 {
|
||||
var expireT time.Time
|
||||
expireT, err = time.Parse(time.DateTime, expire)
|
||||
if err != nil {
|
||||
return key, nil, nil, nil, fmt.Errorf("failed to parse expiration time: %w", err)
|
||||
}
|
||||
expireTime = &expireT
|
||||
}
|
||||
|
||||
if isBinary {
|
||||
|
@ -777,17 +801,17 @@ func readRow(row interface{ Scan(dest ...any) error }) (key string, value []byte
|
|||
|
||||
err = json.Unmarshal(value, &s)
|
||||
if err != nil {
|
||||
return key, nil, nil, fmt.Errorf("failed to unmarshal JSON binary data: %w", err)
|
||||
return key, nil, nil, nil, fmt.Errorf("failed to unmarshal JSON binary data: %w", err)
|
||||
}
|
||||
|
||||
data, err = base64.StdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return key, nil, nil, fmt.Errorf("failed to decode binary data: %w", err)
|
||||
return key, nil, nil, nil, fmt.Errorf("failed to decode binary data: %w", err)
|
||||
}
|
||||
return key, data, &etag, nil
|
||||
return key, data, &etag, expireTime, nil
|
||||
}
|
||||
|
||||
return key, value, &etag, nil
|
||||
return key, value, &etag, expireTime, nil
|
||||
}
|
||||
|
||||
// Multi handles multiple transactions.
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -29,6 +30,7 @@ import (
|
|||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/components-contrib/state"
|
||||
|
@ -253,6 +255,12 @@ func TestMySQLIntegration(t *testing.T) {
|
|||
testBulkSetAndBulkDelete(t, mys)
|
||||
})
|
||||
|
||||
t.Run("Get and BulkGet with ttl", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testGetExpireTime(t, mys)
|
||||
testGetBulkExpireTime(t, mys)
|
||||
})
|
||||
|
||||
t.Run("Update and delete with eTag succeeds", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -572,6 +580,65 @@ func testBulkSetAndBulkDelete(t *testing.T, mys *MySQL) {
|
|||
assert.False(t, storeItemExists(t, setReq[1].Key))
|
||||
}
|
||||
|
||||
func testGetExpireTime(t *testing.T, mys *MySQL) {
|
||||
key1 := randomKey()
|
||||
assert.NoError(t, mys.Set(context.Background(), &state.SetRequest{
|
||||
Key: key1,
|
||||
Value: "123",
|
||||
Metadata: map[string]string{
|
||||
"ttlInSeconds": "1000",
|
||||
},
|
||||
}))
|
||||
|
||||
resp, err := mys.Get(context.Background(), &state.GetRequest{Key: key1})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `"123"`, string(resp.Data))
|
||||
require.Len(t, resp.Metadata, 1)
|
||||
expireTime, err := time.Parse(time.RFC3339, resp.Metadata["ttlExpireTime"])
|
||||
require.NoError(t, err)
|
||||
assert.InDelta(t, time.Now().Add(time.Second*1000).Unix(), expireTime.Unix(), 5)
|
||||
}
|
||||
|
||||
func testGetBulkExpireTime(t *testing.T, mys *MySQL) {
|
||||
key1 := randomKey()
|
||||
key2 := randomKey()
|
||||
|
||||
assert.NoError(t, mys.Set(context.Background(), &state.SetRequest{
|
||||
Key: key1,
|
||||
Value: "123",
|
||||
Metadata: map[string]string{
|
||||
"ttlInSeconds": "1000",
|
||||
},
|
||||
}))
|
||||
assert.NoError(t, mys.Set(context.Background(), &state.SetRequest{
|
||||
Key: key2,
|
||||
Value: "456",
|
||||
Metadata: map[string]string{
|
||||
"ttlInSeconds": "2001",
|
||||
},
|
||||
}))
|
||||
|
||||
resp, err := mys.BulkGet(context.Background(), []state.GetRequest{
|
||||
{Key: key1}, {Key: key2},
|
||||
}, state.BulkGetOpts{})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp, 2)
|
||||
sort.Slice(resp, func(i, j int) bool {
|
||||
return string(resp[i].Data) < string(resp[j].Data)
|
||||
})
|
||||
|
||||
assert.Equal(t, `"123"`, string(resp[0].Data))
|
||||
assert.Equal(t, `"456"`, string(resp[1].Data))
|
||||
require.Len(t, resp[0].Metadata, 1)
|
||||
require.Len(t, resp[1].Metadata, 1)
|
||||
expireTime, err := time.Parse(time.RFC3339, resp[0].Metadata["ttlExpireTime"])
|
||||
require.NoError(t, err)
|
||||
assert.InDelta(t, time.Now().Add(time.Second*1000).Unix(), expireTime.Unix(), 5)
|
||||
expireTime, err = time.Parse(time.RFC3339, resp[1].Metadata["ttlExpireTime"])
|
||||
require.NoError(t, err)
|
||||
assert.InDelta(t, time.Now().Add(time.Second*2001).Unix(), expireTime.Unix(), 5)
|
||||
}
|
||||
|
||||
func dropTable(t *testing.T, db *sql.DB, tableName string) {
|
||||
_, err := db.Exec(fmt.Sprintf(
|
||||
`DROP TABLE %s;`,
|
||||
|
|
|
@ -452,8 +452,8 @@ func TestGetSucceeds(t *testing.T) {
|
|||
defer m.mySQL.Close()
|
||||
|
||||
t.Run("has json type", func(t *testing.T) {
|
||||
rows := sqlmock.NewRows([]string{"id", "value", "eTag", "isbinary"}).AddRow("UnitTest", "{}", "946af56e", false)
|
||||
m.mock1.ExpectQuery("SELECT id, value, eTag, isbinary FROM state WHERE id = ?").WillReturnRows(rows)
|
||||
rows := sqlmock.NewRows([]string{"id", "value", "eTag", "isbinary", "expiredate"}).AddRow("UnitTest", "{}", "946af56e", false, "")
|
||||
m.mock1.ExpectQuery(`SELECT id, value, eTag, isbinary, IFNULL\(expiredate, ""\) FROM state WHERE id = ?`).WillReturnRows(rows)
|
||||
|
||||
request := &state.GetRequest{
|
||||
Key: "UnitTest",
|
||||
|
@ -466,12 +466,15 @@ func TestGetSucceeds(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.NotNil(t, response)
|
||||
assert.Equal(t, "{}", string(response.Data))
|
||||
assert.NotContains(t, response.Metadata, state.GetRespMetaKeyTTLExpireTime)
|
||||
})
|
||||
|
||||
t.Run("has binary type", func(t *testing.T) {
|
||||
t.Run("has binary type and expiredate", func(t *testing.T) {
|
||||
now := time.UnixMilli(20001).UTC()
|
||||
|
||||
value, _ := utils.Marshal(base64.StdEncoding.EncodeToString([]byte("abcdefg")), json.Marshal)
|
||||
rows := sqlmock.NewRows([]string{"id", "value", "eTag", "isbinary"}).AddRow("UnitTest", value, "946af56e", true)
|
||||
m.mock1.ExpectQuery("SELECT id, value, eTag, isbinary FROM state WHERE id = ?").WillReturnRows(rows)
|
||||
rows := sqlmock.NewRows([]string{"id", "value", "eTag", "isbinary", "expiredate"}).AddRow("UnitTest", value, "946af56e", true, now.Format(time.DateTime))
|
||||
m.mock1.ExpectQuery(`SELECT id, value, eTag, isbinary, IFNULL\(expiredate, ""\) FROM state WHERE id = ?`).WillReturnRows(rows)
|
||||
|
||||
request := &state.GetRequest{
|
||||
Key: "UnitTest",
|
||||
|
@ -484,6 +487,8 @@ func TestGetSucceeds(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.NotNil(t, response)
|
||||
assert.Equal(t, "abcdefg", string(response.Data))
|
||||
assert.Contains(t, response.Metadata, state.GetRespMetaKeyTTLExpireTime)
|
||||
assert.Equal(t, "1970-01-01T00:00:20Z", response.Metadata[state.GetRespMetaKeyTTLExpireTime])
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -207,6 +207,11 @@ func TestMySQL(t *testing.T) {
|
|||
resp3, err := client.GetState(ctx, stateStoreName, "reqKey3", nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "reqVal103", string(resp3.Value))
|
||||
|
||||
require.Contains(t, resp3.Metadata, "ttlExpireTime")
|
||||
expireTime, err := time.Parse(time.RFC3339, resp3.Metadata["ttlExpireTime"])
|
||||
assert.InDelta(t, time.Now().Add(50*time.Second).Unix(), expireTime.Unix(), 5)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue