components-contrib/state/sqlite/sqlite_dbaccess.go

501 lines
12 KiB
Go

/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sqlite
import (
"context"
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"sync"
"time"
"github.com/google/uuid"
// Blank import for the underlying SQLite Driver.
_ "modernc.org/sqlite"
"github.com/dapr/components-contrib/state"
stateutils "github.com/dapr/components-contrib/state/utils"
"github.com/dapr/kit/logger"
)
// DBAccess is a private interface which enables unit testing of SQLite.
type DBAccess interface {
Init(metadata state.Metadata) error
Ping(ctx context.Context) error
Set(ctx context.Context, req *state.SetRequest) error
Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error)
Delete(ctx context.Context, req *state.DeleteRequest) error
ExecuteMulti(ctx context.Context, reqs []state.TransactionalStateOperation) error
Close() error
}
// Interface for both sql.DB and sql.Tx
type querier interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}
// sqliteDBAccess implements DBAccess.
type sqliteDBAccess struct {
logger logger.Logger
metadata sqliteMetadataStruct
db *sql.DB
ctx context.Context
cancel context.CancelFunc
// Lock only on public write API. Any public API's implementation should not call other public write APIs.
lock *sync.Mutex
}
// newSqliteDBAccess creates a new instance of sqliteDbAccess.
func newSqliteDBAccess(logger logger.Logger) *sqliteDBAccess {
return &sqliteDBAccess{
logger: logger,
lock: &sync.Mutex{},
}
}
// Init sets up SQLite Database connection and ensures that the state table
// exists.
func (a *sqliteDBAccess) Init(md state.Metadata) error {
err := a.metadata.InitWithMetadata(md)
if err != nil {
return err
}
db, err := sql.Open("sqlite", a.metadata.ConnectionString)
if err != nil {
a.logger.Error(err)
return err
}
a.db = db
a.ctx, a.cancel = context.WithCancel(context.Background())
err = a.Ping(a.ctx)
if err != nil {
return err
}
err = a.ensureStateTable(a.ctx, a.metadata.TableName)
if err != nil {
return err
}
a.scheduleCleanupExpiredData()
return nil
}
func (a *sqliteDBAccess) Ping(parentCtx context.Context) error {
a.lock.Lock()
defer a.lock.Unlock()
ctx, cancel := context.WithTimeout(parentCtx, a.metadata.timeout)
err := a.db.PingContext(ctx)
cancel()
return err
}
func (a *sqliteDBAccess) Get(parentCtx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
a.lock.Lock()
defer a.lock.Unlock()
if req.Key == "" {
return nil, errors.New("missing key in get operation")
}
var (
value []byte
isBinary bool
etag string
)
// Sprintf is required for table name because sql.DB does not substitute parameters for table names
//nolint:gosec
stmt := fmt.Sprintf(
`SELECT value, is_binary, etag FROM %s
WHERE
key = ?
AND (expiration_time IS NULL OR expiration_time > CURRENT_TIMESTAMP)`,
a.metadata.TableName)
ctx, cancel := context.WithTimeout(parentCtx, a.metadata.timeout)
err := a.db.QueryRowContext(ctx, stmt, req.Key).
Scan(&value, &isBinary, &etag)
cancel()
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return &state.GetResponse{}, nil
}
return nil, err
}
if isBinary {
var n int
data := make([]byte, len(value))
n, err = base64.StdEncoding.Decode(data, value)
if err != nil {
return nil, err
}
return &state.GetResponse{
Data: data[:n],
ETag: &etag,
Metadata: req.Metadata,
}, nil
}
return &state.GetResponse{
Data: value,
ETag: &etag,
Metadata: req.Metadata,
}, nil
}
func (a *sqliteDBAccess) Set(ctx context.Context, req *state.SetRequest) error {
a.lock.Lock()
defer a.lock.Unlock()
return a.doSet(ctx, a.db, req)
}
func (a *sqliteDBAccess) doSet(parentCtx context.Context, db querier, req *state.SetRequest) error {
err := state.CheckRequestOptions(req.Options)
if err != nil {
return err
}
if req.Key == "" {
return errors.New("missing key in set option")
}
if v, ok := req.Value.(string); ok && v == "" {
return fmt.Errorf("empty string is not allowed in set operation")
}
// TTL
var ttlSeconds int
ttl, ttlerr := stateutils.ParseTTL(req.Metadata)
if ttlerr != nil {
return fmt.Errorf("error parsing TTL: %w", ttlerr)
}
if ttl != nil {
ttlSeconds = *ttl
}
// Encode the value
var requestValue string
byteArray, isBinary := req.Value.([]uint8)
if isBinary {
requestValue = base64.StdEncoding.EncodeToString(byteArray)
} else {
var bt []byte
bt, err = json.Marshal(req.Value)
if err != nil {
return err
}
requestValue = string(bt)
}
// New ETag
etagObj, err := uuid.NewRandom()
if err != nil {
return err
}
newEtag := etagObj.String()
// Also resets expiration time in case of an update
expiration := "NULL"
if ttlSeconds > 0 {
expiration = fmt.Sprintf("DATETIME(CURRENT_TIMESTAMP, '+%d seconds')", ttlSeconds)
}
// Only check for etag if FirstWrite specified (ref oracledatabaseaccess)
var res sql.Result
// Sprintf is required for table name because sql.DB does not substitute parameters for table names.
// And the same is for DATETIME function's seconds parameter (which is from an integer anyways).
if req.ETag == nil || *req.ETag == "" {
var op string
if req.Options.Concurrency == state.FirstWrite {
op = "INSERT"
} else {
op = "INSERT OR REPLACE"
}
stmt := fmt.Sprintf(
`%s INTO %s
(key, value, is_binary, etag, update_time, expiration_time)
VALUES(?, ?, ?, ?, CURRENT_TIMESTAMP, %s)`,
op, a.metadata.TableName, expiration,
)
ctx, cancel := context.WithTimeout(context.Background(), a.metadata.timeout)
res, err = db.ExecContext(ctx, stmt, req.Key, requestValue, isBinary, newEtag, req.Key)
cancel()
} else {
stmt := fmt.Sprintf(
`UPDATE %s SET
value = ?,
etag = ?,
is_binary = ?,
update_time = CURRENT_TIMESTAMP,
expiration_time = %s
WHERE
key = ?
AND eTag = ?`,
a.metadata.TableName, expiration,
)
ctx, cancel := context.WithTimeout(context.Background(), a.metadata.timeout)
res, err = db.ExecContext(ctx, stmt, requestValue, newEtag, isBinary, req.Key, *req.ETag)
cancel()
}
if err != nil {
return err
}
rows, err := res.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
if req.ETag != nil && *req.ETag != "" {
return state.NewETagError(state.ETagMismatch, nil)
}
return errors.New("no item was updated")
}
return nil
}
func (a *sqliteDBAccess) Delete(ctx context.Context, req *state.DeleteRequest) error {
a.lock.Lock()
defer a.lock.Unlock()
return a.doDelete(ctx, a.db, req)
}
func (a *sqliteDBAccess) ExecuteMulti(parentCtx context.Context, reqs []state.TransactionalStateOperation) error {
a.lock.Lock()
defer a.lock.Unlock()
tx, err := a.db.BeginTx(parentCtx, nil)
if err != nil {
return err
}
defer tx.Rollback()
for _, req := range reqs {
switch req.Operation {
case state.Upsert:
if setReq, ok := req.Request.(state.SetRequest); ok {
err = a.doSet(parentCtx, tx, &setReq)
if err != nil {
return err
}
} else {
return fmt.Errorf("expecting set request")
}
case state.Delete:
if delReq, ok := req.Request.(state.DeleteRequest); ok {
err = a.doDelete(parentCtx, tx, &delReq)
if err != nil {
return err
}
} else {
return fmt.Errorf("expecting delete request")
}
default:
// Do nothing
}
}
return tx.Commit()
}
// Close implements io.Close.
func (a *sqliteDBAccess) Close() error {
if a.cancel != nil {
a.cancel()
}
if a.db != nil {
_ = a.db.Close()
}
return nil
}
// Create table if not exists.
func (a *sqliteDBAccess) ensureStateTable(parentCtx context.Context, stateTableName string) error {
exists, err := a.tableExists(parentCtx)
if err != nil || exists {
return err
}
a.logger.Infof("Creating SQLite state table '%s'", stateTableName)
ctx, cancel := context.WithTimeout(parentCtx, a.metadata.timeout)
defer cancel()
tx, err := a.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
stmt := fmt.Sprintf(
`CREATE TABLE %s (
key TEXT NOT NULL PRIMARY KEY,
value TEXT NOT NULL,
is_binary BOOLEAN NOT NULL,
etag TEXT NOT NULL,
expiration_time TIMESTAMP DEFAULT NULL,
update_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
)`,
stateTableName,
)
_, err = tx.Exec(stmt)
if err != nil {
return err
}
stmt = fmt.Sprintf(`CREATE INDEX idx_%s_expiration_time ON %s (expiration_time)`, stateTableName, stateTableName)
_, err = tx.Exec(stmt)
if err != nil {
return err
}
return tx.Commit()
}
// Check if table exists.
func (a *sqliteDBAccess) tableExists(parentCtx context.Context) (bool, error) {
ctx, cancel := context.WithTimeout(parentCtx, a.metadata.timeout)
defer cancel()
var exists string
// Returns 1 or 0 as a string if the table exists or not.
const q = `SELECT EXISTS (
SELECT name FROM sqlite_master WHERE type='table' AND name = ?
) AS 'exists'`
err := a.db.QueryRowContext(ctx, q, a.metadata.TableName).Scan(&exists)
return exists == "1", err
}
func (a *sqliteDBAccess) doDelete(parentCtx context.Context, db querier, req *state.DeleteRequest) error {
err := state.CheckRequestOptions(req.Options)
if err != nil {
return err
}
if req.Key == "" {
return fmt.Errorf("missing key in delete operation")
}
var result sql.Result
if req.ETag == nil || *req.ETag == "" {
// Sprintf is required for table name because sql.DB does not substitute parameters for table names.
stmt := fmt.Sprintf("DELETE FROM %s WHERE key = ?", a.metadata.TableName)
ctx, cancel := context.WithTimeout(parentCtx, a.metadata.timeout)
result, err = db.ExecContext(ctx, stmt, req.Key)
cancel()
} else {
// Sprintf is required for table name because sql.DB does not substitute parameters for table names.
stmt := fmt.Sprintf("DELETE FROM %s WHERE key = ? AND etag = ?", a.metadata.TableName)
ctx, cancel := context.WithTimeout(parentCtx, a.metadata.timeout)
result, err = db.ExecContext(ctx, stmt, req.Key, *req.ETag)
cancel()
}
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 && req.ETag != nil && *req.ETag != "" {
return state.NewETagError(state.ETagMismatch, nil)
}
return nil
}
func (a *sqliteDBAccess) scheduleCleanupExpiredData() {
if a.metadata.cleanupInterval == nil {
return
}
d := *a.metadata.cleanupInterval
a.logger.Infof("Schedule expired data clean up every %v", d)
ticker := time.NewTicker(d)
go func() {
for {
select {
case <-ticker.C:
a.cleanupTimeout()
case <-a.ctx.Done():
return
}
}
}()
}
func (a *sqliteDBAccess) cleanupTimeout() {
a.lock.Lock()
defer a.lock.Unlock()
ctx, cancel := context.WithTimeout(a.ctx, a.metadata.timeout)
defer cancel()
tx, err := a.db.BeginTx(ctx, nil)
if err != nil {
a.logger.Errorf("Error removing expired data: failed to begin transaction: %v", err)
return
}
defer tx.Rollback()
// Sprintf is required for table name because sql.DB does not substitute parameters for table names
//nolint:gosec
stmt := fmt.Sprintf(
`DELETE FROM %s
WHERE
expiration_time IS NOT NULL
AND expiration_time < CURRENT_TIMESTAMP`,
a.metadata.TableName,
)
res, err := tx.Exec(stmt)
if err != nil {
a.logger.Errorf("Error removing expired data: failed to execute query: %v", err)
return
}
cleaned, err := res.RowsAffected()
if err != nil {
a.logger.Errorf("Error removing expired data: failed to count affected rows: %v", err)
return
}
err = tx.Commit()
if err != nil {
a.logger.Errorf("Error removing expired data: failed to commit transaction: %v", err)
return
}
a.logger.Debugf("Removed %d expired rows", cleaned)
}