mysql: return ttlExpiryTime in GetResponse (#2871)

Signed-off-by: joshvanl <me@joshvanl.dev>
Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
Josh van Leeuwen 2023-06-07 21:45:24 +01:00 committed by GitHub
parent 06defdcc22
commit 69a1c01801
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 118 additions and 17 deletions

View File

@ -531,10 +531,10 @@ func (m *MySQL) Get(parentCtx context.Context, req *state.GetRequest) (*state.Ge
ctx, cancel := context.WithTimeout(parentCtx, m.timeout)
defer cancel()
// Concatenation is required for table name because sql.DB does not substitute parameters for table names
query := `SELECT id, value, eTag, isbinary FROM ` + m.tableName + ` WHERE id = ?
query := `SELECT id, value, eTag, isbinary, IFNULL(expiredate, "") FROM ` + m.tableName + ` WHERE id = ?
AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)`
row := m.db.QueryRowContext(ctx, query, req.Key)
_, value, etag, err := readRow(row)
_, value, etag, expireTime, err := readRow(row)
if err != nil {
// If no rows exist, return an empty response, otherwise return an error.
if errors.Is(err, sql.ErrNoRows) {
@ -542,10 +542,18 @@ func (m *MySQL) Get(parentCtx context.Context, req *state.GetRequest) (*state.Ge
}
return nil, err
}
var metadata map[string]string
if expireTime != nil {
metadata = map[string]string{
state.GetRespMetaKeyTTLExpireTime: expireTime.UTC().Format(time.RFC3339),
}
}
return &state.GetResponse{
Data: value,
ETag: etag,
Metadata: req.Metadata,
Metadata: metadata,
}, nil
}
@ -708,7 +716,7 @@ func (m *MySQL) BulkGet(parentCtx context.Context, req []state.GetRequest, _ sta
}
// Concatenation is required for table name because sql.DB does not substitute parameters for table names
stmt := `SELECT id, value, eTag, isbinary FROM ` + m.tableName + `
stmt := `SELECT id, value, eTag, isbinary, IFNULL(expiredate, "") FROM ` + m.tableName + `
WHERE
id IN (` + inClause + `)
AND (expiredate IS NULL OR expiredate > CURRENT_TIMESTAMP)`
@ -720,6 +728,7 @@ func (m *MySQL) BulkGet(parentCtx context.Context, req []state.GetRequest, _ sta
}
var n int
var expireTime *time.Time
res := make([]state.BulkGetResponse, len(req))
foundKeys := make(map[string]struct{}, len(req))
for ; rows.Next(); n++ {
@ -729,10 +738,15 @@ func (m *MySQL) BulkGet(parentCtx context.Context, req []state.GetRequest, _ sta
}
r := state.BulkGetResponse{}
r.Key, r.Data, r.ETag, err = readRow(rows)
r.Key, r.Data, r.ETag, expireTime, err = readRow(rows)
if err != nil {
r.Error = err.Error()
}
if expireTime != nil {
r.Metadata = map[string]string{
state.GetRespMetaKeyTTLExpireTime: expireTime.UTC().Format(time.RFC3339),
}
}
res[n] = r
foundKeys[r.Key] = struct{}{}
}
@ -759,14 +773,24 @@ func (m *MySQL) BulkGet(parentCtx context.Context, req []state.GetRequest, _ sta
return res[:n], nil
}
func readRow(row interface{ Scan(dest ...any) error }) (key string, value []byte, etagP *string, err error) {
func readRow(row interface{ Scan(dest ...any) error }) (key string, value []byte, etagP *string, expireTime *time.Time, err error) {
var (
etag string
isBinary bool
expire string
)
err = row.Scan(&key, &value, &etag, &isBinary)
err = row.Scan(&key, &value, &etag, &isBinary, &expire)
if err != nil {
return key, nil, nil, err
return key, nil, nil, nil, err
}
if len(expire) > 0 {
var expireT time.Time
expireT, err = time.Parse(time.DateTime, expire)
if err != nil {
return key, nil, nil, nil, fmt.Errorf("failed to parse expiration time: %w", err)
}
expireTime = &expireT
}
if isBinary {
@ -777,17 +801,17 @@ func readRow(row interface{ Scan(dest ...any) error }) (key string, value []byte
err = json.Unmarshal(value, &s)
if err != nil {
return key, nil, nil, fmt.Errorf("failed to unmarshal JSON binary data: %w", err)
return key, nil, nil, nil, fmt.Errorf("failed to unmarshal JSON binary data: %w", err)
}
data, err = base64.StdEncoding.DecodeString(s)
if err != nil {
return key, nil, nil, fmt.Errorf("failed to decode binary data: %w", err)
return key, nil, nil, nil, fmt.Errorf("failed to decode binary data: %w", err)
}
return key, data, &etag, nil
return key, data, &etag, expireTime, nil
}
return key, value, &etag, nil
return key, value, &etag, expireTime, nil
}
// Multi handles multiple transactions.

View File

@ -22,6 +22,7 @@ import (
"encoding/json"
"fmt"
"os"
"sort"
"strings"
"testing"
"time"
@ -29,6 +30,7 @@ import (
"github.com/go-sql-driver/mysql"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/state"
@ -253,6 +255,12 @@ func TestMySQLIntegration(t *testing.T) {
testBulkSetAndBulkDelete(t, mys)
})
t.Run("Get and BulkGet with ttl", func(t *testing.T) {
t.Parallel()
testGetExpireTime(t, mys)
testGetBulkExpireTime(t, mys)
})
t.Run("Update and delete with eTag succeeds", func(t *testing.T) {
t.Parallel()
@ -572,6 +580,65 @@ func testBulkSetAndBulkDelete(t *testing.T, mys *MySQL) {
assert.False(t, storeItemExists(t, setReq[1].Key))
}
func testGetExpireTime(t *testing.T, mys *MySQL) {
key1 := randomKey()
assert.NoError(t, mys.Set(context.Background(), &state.SetRequest{
Key: key1,
Value: "123",
Metadata: map[string]string{
"ttlInSeconds": "1000",
},
}))
resp, err := mys.Get(context.Background(), &state.GetRequest{Key: key1})
assert.NoError(t, err)
assert.Equal(t, `"123"`, string(resp.Data))
require.Len(t, resp.Metadata, 1)
expireTime, err := time.Parse(time.RFC3339, resp.Metadata["ttlExpireTime"])
require.NoError(t, err)
assert.InDelta(t, time.Now().Add(time.Second*1000).Unix(), expireTime.Unix(), 5)
}
func testGetBulkExpireTime(t *testing.T, mys *MySQL) {
key1 := randomKey()
key2 := randomKey()
assert.NoError(t, mys.Set(context.Background(), &state.SetRequest{
Key: key1,
Value: "123",
Metadata: map[string]string{
"ttlInSeconds": "1000",
},
}))
assert.NoError(t, mys.Set(context.Background(), &state.SetRequest{
Key: key2,
Value: "456",
Metadata: map[string]string{
"ttlInSeconds": "2001",
},
}))
resp, err := mys.BulkGet(context.Background(), []state.GetRequest{
{Key: key1}, {Key: key2},
}, state.BulkGetOpts{})
require.NoError(t, err)
assert.Len(t, resp, 2)
sort.Slice(resp, func(i, j int) bool {
return string(resp[i].Data) < string(resp[j].Data)
})
assert.Equal(t, `"123"`, string(resp[0].Data))
assert.Equal(t, `"456"`, string(resp[1].Data))
require.Len(t, resp[0].Metadata, 1)
require.Len(t, resp[1].Metadata, 1)
expireTime, err := time.Parse(time.RFC3339, resp[0].Metadata["ttlExpireTime"])
require.NoError(t, err)
assert.InDelta(t, time.Now().Add(time.Second*1000).Unix(), expireTime.Unix(), 5)
expireTime, err = time.Parse(time.RFC3339, resp[1].Metadata["ttlExpireTime"])
require.NoError(t, err)
assert.InDelta(t, time.Now().Add(time.Second*2001).Unix(), expireTime.Unix(), 5)
}
func dropTable(t *testing.T, db *sql.DB, tableName string) {
_, err := db.Exec(fmt.Sprintf(
`DROP TABLE %s;`,

View File

@ -452,8 +452,8 @@ func TestGetSucceeds(t *testing.T) {
defer m.mySQL.Close()
t.Run("has json type", func(t *testing.T) {
rows := sqlmock.NewRows([]string{"id", "value", "eTag", "isbinary"}).AddRow("UnitTest", "{}", "946af56e", false)
m.mock1.ExpectQuery("SELECT id, value, eTag, isbinary FROM state WHERE id = ?").WillReturnRows(rows)
rows := sqlmock.NewRows([]string{"id", "value", "eTag", "isbinary", "expiredate"}).AddRow("UnitTest", "{}", "946af56e", false, "")
m.mock1.ExpectQuery(`SELECT id, value, eTag, isbinary, IFNULL\(expiredate, ""\) FROM state WHERE id = ?`).WillReturnRows(rows)
request := &state.GetRequest{
Key: "UnitTest",
@ -466,12 +466,15 @@ func TestGetSucceeds(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, response)
assert.Equal(t, "{}", string(response.Data))
assert.NotContains(t, response.Metadata, state.GetRespMetaKeyTTLExpireTime)
})
t.Run("has binary type", func(t *testing.T) {
t.Run("has binary type and expiredate", func(t *testing.T) {
now := time.UnixMilli(20001).UTC()
value, _ := utils.Marshal(base64.StdEncoding.EncodeToString([]byte("abcdefg")), json.Marshal)
rows := sqlmock.NewRows([]string{"id", "value", "eTag", "isbinary"}).AddRow("UnitTest", value, "946af56e", true)
m.mock1.ExpectQuery("SELECT id, value, eTag, isbinary FROM state WHERE id = ?").WillReturnRows(rows)
rows := sqlmock.NewRows([]string{"id", "value", "eTag", "isbinary", "expiredate"}).AddRow("UnitTest", value, "946af56e", true, now.Format(time.DateTime))
m.mock1.ExpectQuery(`SELECT id, value, eTag, isbinary, IFNULL\(expiredate, ""\) FROM state WHERE id = ?`).WillReturnRows(rows)
request := &state.GetRequest{
Key: "UnitTest",
@ -484,6 +487,8 @@ func TestGetSucceeds(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, response)
assert.Equal(t, "abcdefg", string(response.Data))
assert.Contains(t, response.Metadata, state.GetRespMetaKeyTTLExpireTime)
assert.Equal(t, "1970-01-01T00:00:20Z", response.Metadata[state.GetRespMetaKeyTTLExpireTime])
})
}

View File

@ -207,6 +207,11 @@ func TestMySQL(t *testing.T) {
resp3, err := client.GetState(ctx, stateStoreName, "reqKey3", nil)
require.NoError(t, err)
assert.Equal(t, "reqVal103", string(resp3.Value))
require.Contains(t, resp3.Metadata, "ttlExpireTime")
expireTime, err := time.Parse(time.RFC3339, resp3.Metadata["ttlExpireTime"])
assert.InDelta(t, time.Now().Add(50*time.Second).Unix(), expireTime.Unix(), 5)
return nil
}
}