501 lines
12 KiB
Go
501 lines
12 KiB
Go
/*
|
|
Copyright 2023 The Dapr Authors
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package sqlite
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
|
|
// Blank import for the underlying SQLite Driver.
|
|
_ "modernc.org/sqlite"
|
|
|
|
"github.com/dapr/components-contrib/state"
|
|
stateutils "github.com/dapr/components-contrib/state/utils"
|
|
"github.com/dapr/kit/logger"
|
|
)
|
|
|
|
// DBAccess is a private interface which enables unit testing of SQLite.
|
|
type DBAccess interface {
|
|
Init(metadata state.Metadata) error
|
|
Ping(ctx context.Context) error
|
|
Set(ctx context.Context, req *state.SetRequest) error
|
|
Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error)
|
|
Delete(ctx context.Context, req *state.DeleteRequest) error
|
|
ExecuteMulti(ctx context.Context, reqs []state.TransactionalStateOperation) error
|
|
Close() error
|
|
}
|
|
|
|
// Interface for both sql.DB and sql.Tx
|
|
type querier interface {
|
|
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
|
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
|
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
|
}
|
|
|
|
// sqliteDBAccess implements DBAccess.
|
|
type sqliteDBAccess struct {
|
|
logger logger.Logger
|
|
metadata sqliteMetadataStruct
|
|
db *sql.DB
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
|
|
// Lock only on public write API. Any public API's implementation should not call other public write APIs.
|
|
lock *sync.Mutex
|
|
}
|
|
|
|
// newSqliteDBAccess creates a new instance of sqliteDbAccess.
|
|
func newSqliteDBAccess(logger logger.Logger) *sqliteDBAccess {
|
|
return &sqliteDBAccess{
|
|
logger: logger,
|
|
lock: &sync.Mutex{},
|
|
}
|
|
}
|
|
|
|
// Init sets up SQLite Database connection and ensures that the state table
|
|
// exists.
|
|
func (a *sqliteDBAccess) Init(md state.Metadata) error {
|
|
err := a.metadata.InitWithMetadata(md)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
db, err := sql.Open("sqlite", a.metadata.ConnectionString)
|
|
if err != nil {
|
|
a.logger.Error(err)
|
|
return err
|
|
}
|
|
|
|
a.db = db
|
|
a.ctx, a.cancel = context.WithCancel(context.Background())
|
|
|
|
err = a.Ping(a.ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = a.ensureStateTable(a.ctx, a.metadata.TableName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
a.scheduleCleanupExpiredData()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *sqliteDBAccess) Ping(parentCtx context.Context) error {
|
|
a.lock.Lock()
|
|
defer a.lock.Unlock()
|
|
|
|
ctx, cancel := context.WithTimeout(parentCtx, a.metadata.timeout)
|
|
err := a.db.PingContext(ctx)
|
|
cancel()
|
|
return err
|
|
}
|
|
|
|
func (a *sqliteDBAccess) Get(parentCtx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
|
|
a.lock.Lock()
|
|
defer a.lock.Unlock()
|
|
|
|
if req.Key == "" {
|
|
return nil, errors.New("missing key in get operation")
|
|
}
|
|
var (
|
|
value []byte
|
|
isBinary bool
|
|
etag string
|
|
)
|
|
|
|
// Sprintf is required for table name because sql.DB does not substitute parameters for table names
|
|
//nolint:gosec
|
|
stmt := fmt.Sprintf(
|
|
`SELECT value, is_binary, etag FROM %s
|
|
WHERE
|
|
key = ?
|
|
AND (expiration_time IS NULL OR expiration_time > CURRENT_TIMESTAMP)`,
|
|
a.metadata.TableName)
|
|
ctx, cancel := context.WithTimeout(parentCtx, a.metadata.timeout)
|
|
err := a.db.QueryRowContext(ctx, stmt, req.Key).
|
|
Scan(&value, &isBinary, &etag)
|
|
cancel()
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return &state.GetResponse{}, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
if isBinary {
|
|
var n int
|
|
data := make([]byte, len(value))
|
|
n, err = base64.StdEncoding.Decode(data, value)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &state.GetResponse{
|
|
Data: data[:n],
|
|
ETag: &etag,
|
|
Metadata: req.Metadata,
|
|
}, nil
|
|
}
|
|
|
|
return &state.GetResponse{
|
|
Data: value,
|
|
ETag: &etag,
|
|
Metadata: req.Metadata,
|
|
}, nil
|
|
}
|
|
|
|
func (a *sqliteDBAccess) Set(ctx context.Context, req *state.SetRequest) error {
|
|
a.lock.Lock()
|
|
defer a.lock.Unlock()
|
|
|
|
return a.doSet(ctx, a.db, req)
|
|
}
|
|
|
|
func (a *sqliteDBAccess) doSet(parentCtx context.Context, db querier, req *state.SetRequest) error {
|
|
err := state.CheckRequestOptions(req.Options)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if req.Key == "" {
|
|
return errors.New("missing key in set option")
|
|
}
|
|
|
|
if v, ok := req.Value.(string); ok && v == "" {
|
|
return fmt.Errorf("empty string is not allowed in set operation")
|
|
}
|
|
|
|
// TTL
|
|
var ttlSeconds int
|
|
ttl, ttlerr := stateutils.ParseTTL(req.Metadata)
|
|
if ttlerr != nil {
|
|
return fmt.Errorf("error parsing TTL: %w", ttlerr)
|
|
}
|
|
if ttl != nil {
|
|
ttlSeconds = *ttl
|
|
}
|
|
|
|
// Encode the value
|
|
var requestValue string
|
|
byteArray, isBinary := req.Value.([]uint8)
|
|
if isBinary {
|
|
requestValue = base64.StdEncoding.EncodeToString(byteArray)
|
|
} else {
|
|
var bt []byte
|
|
bt, err = json.Marshal(req.Value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
requestValue = string(bt)
|
|
}
|
|
|
|
// New ETag
|
|
etagObj, err := uuid.NewRandom()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
newEtag := etagObj.String()
|
|
|
|
// Also resets expiration time in case of an update
|
|
expiration := "NULL"
|
|
if ttlSeconds > 0 {
|
|
expiration = fmt.Sprintf("DATETIME(CURRENT_TIMESTAMP, '+%d seconds')", ttlSeconds)
|
|
}
|
|
|
|
// Only check for etag if FirstWrite specified (ref oracledatabaseaccess)
|
|
var res sql.Result
|
|
// Sprintf is required for table name because sql.DB does not substitute parameters for table names.
|
|
// And the same is for DATETIME function's seconds parameter (which is from an integer anyways).
|
|
if req.ETag == nil || *req.ETag == "" {
|
|
var op string
|
|
if req.Options.Concurrency == state.FirstWrite {
|
|
op = "INSERT"
|
|
} else {
|
|
op = "INSERT OR REPLACE"
|
|
}
|
|
stmt := fmt.Sprintf(
|
|
`%s INTO %s
|
|
(key, value, is_binary, etag, update_time, expiration_time)
|
|
VALUES(?, ?, ?, ?, CURRENT_TIMESTAMP, %s)`,
|
|
op, a.metadata.TableName, expiration,
|
|
)
|
|
ctx, cancel := context.WithTimeout(context.Background(), a.metadata.timeout)
|
|
res, err = db.ExecContext(ctx, stmt, req.Key, requestValue, isBinary, newEtag, req.Key)
|
|
cancel()
|
|
} else {
|
|
stmt := fmt.Sprintf(
|
|
`UPDATE %s SET
|
|
value = ?,
|
|
etag = ?,
|
|
is_binary = ?,
|
|
update_time = CURRENT_TIMESTAMP,
|
|
expiration_time = %s
|
|
WHERE
|
|
key = ?
|
|
AND eTag = ?`,
|
|
a.metadata.TableName, expiration,
|
|
)
|
|
ctx, cancel := context.WithTimeout(context.Background(), a.metadata.timeout)
|
|
res, err = db.ExecContext(ctx, stmt, requestValue, newEtag, isBinary, req.Key, *req.ETag)
|
|
cancel()
|
|
}
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
rows, err := res.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if rows == 0 {
|
|
if req.ETag != nil && *req.ETag != "" {
|
|
return state.NewETagError(state.ETagMismatch, nil)
|
|
}
|
|
return errors.New("no item was updated")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *sqliteDBAccess) Delete(ctx context.Context, req *state.DeleteRequest) error {
|
|
a.lock.Lock()
|
|
defer a.lock.Unlock()
|
|
|
|
return a.doDelete(ctx, a.db, req)
|
|
}
|
|
|
|
func (a *sqliteDBAccess) ExecuteMulti(parentCtx context.Context, reqs []state.TransactionalStateOperation) error {
|
|
a.lock.Lock()
|
|
defer a.lock.Unlock()
|
|
|
|
tx, err := a.db.BeginTx(parentCtx, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
for _, req := range reqs {
|
|
switch req.Operation {
|
|
case state.Upsert:
|
|
if setReq, ok := req.Request.(state.SetRequest); ok {
|
|
err = a.doSet(parentCtx, tx, &setReq)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
return fmt.Errorf("expecting set request")
|
|
}
|
|
case state.Delete:
|
|
if delReq, ok := req.Request.(state.DeleteRequest); ok {
|
|
err = a.doDelete(parentCtx, tx, &delReq)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
return fmt.Errorf("expecting delete request")
|
|
}
|
|
default:
|
|
// Do nothing
|
|
}
|
|
}
|
|
return tx.Commit()
|
|
}
|
|
|
|
// Close implements io.Close.
|
|
func (a *sqliteDBAccess) Close() error {
|
|
if a.cancel != nil {
|
|
a.cancel()
|
|
}
|
|
if a.db != nil {
|
|
_ = a.db.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Create table if not exists.
|
|
func (a *sqliteDBAccess) ensureStateTable(parentCtx context.Context, stateTableName string) error {
|
|
exists, err := a.tableExists(parentCtx)
|
|
if err != nil || exists {
|
|
return err
|
|
}
|
|
|
|
a.logger.Infof("Creating SQLite state table '%s'", stateTableName)
|
|
|
|
ctx, cancel := context.WithTimeout(parentCtx, a.metadata.timeout)
|
|
defer cancel()
|
|
|
|
tx, err := a.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
stmt := fmt.Sprintf(
|
|
`CREATE TABLE %s (
|
|
key TEXT NOT NULL PRIMARY KEY,
|
|
value TEXT NOT NULL,
|
|
is_binary BOOLEAN NOT NULL,
|
|
etag TEXT NOT NULL,
|
|
expiration_time TIMESTAMP DEFAULT NULL,
|
|
update_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
|
)`,
|
|
stateTableName,
|
|
)
|
|
_, err = tx.Exec(stmt)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
stmt = fmt.Sprintf(`CREATE INDEX idx_%s_expiration_time ON %s (expiration_time)`, stateTableName, stateTableName)
|
|
_, err = tx.Exec(stmt)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
// Check if table exists.
|
|
func (a *sqliteDBAccess) tableExists(parentCtx context.Context) (bool, error) {
|
|
ctx, cancel := context.WithTimeout(parentCtx, a.metadata.timeout)
|
|
defer cancel()
|
|
|
|
var exists string
|
|
// Returns 1 or 0 as a string if the table exists or not.
|
|
const q = `SELECT EXISTS (
|
|
SELECT name FROM sqlite_master WHERE type='table' AND name = ?
|
|
) AS 'exists'`
|
|
err := a.db.QueryRowContext(ctx, q, a.metadata.TableName).Scan(&exists)
|
|
return exists == "1", err
|
|
}
|
|
|
|
func (a *sqliteDBAccess) doDelete(parentCtx context.Context, db querier, req *state.DeleteRequest) error {
|
|
err := state.CheckRequestOptions(req.Options)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if req.Key == "" {
|
|
return fmt.Errorf("missing key in delete operation")
|
|
}
|
|
|
|
var result sql.Result
|
|
if req.ETag == nil || *req.ETag == "" {
|
|
// Sprintf is required for table name because sql.DB does not substitute parameters for table names.
|
|
stmt := fmt.Sprintf("DELETE FROM %s WHERE key = ?", a.metadata.TableName)
|
|
ctx, cancel := context.WithTimeout(parentCtx, a.metadata.timeout)
|
|
result, err = db.ExecContext(ctx, stmt, req.Key)
|
|
cancel()
|
|
} else {
|
|
// Sprintf is required for table name because sql.DB does not substitute parameters for table names.
|
|
stmt := fmt.Sprintf("DELETE FROM %s WHERE key = ? AND etag = ?", a.metadata.TableName)
|
|
ctx, cancel := context.WithTimeout(parentCtx, a.metadata.timeout)
|
|
result, err = db.ExecContext(ctx, stmt, req.Key, *req.ETag)
|
|
cancel()
|
|
}
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
rows, err := result.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if rows == 0 && req.ETag != nil && *req.ETag != "" {
|
|
return state.NewETagError(state.ETagMismatch, nil)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *sqliteDBAccess) scheduleCleanupExpiredData() {
|
|
if a.metadata.cleanupInterval == nil {
|
|
return
|
|
}
|
|
|
|
d := *a.metadata.cleanupInterval
|
|
a.logger.Infof("Schedule expired data clean up every %v", d)
|
|
|
|
ticker := time.NewTicker(d)
|
|
go func() {
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
a.cleanupTimeout()
|
|
case <-a.ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (a *sqliteDBAccess) cleanupTimeout() {
|
|
a.lock.Lock()
|
|
defer a.lock.Unlock()
|
|
|
|
ctx, cancel := context.WithTimeout(a.ctx, a.metadata.timeout)
|
|
defer cancel()
|
|
|
|
tx, err := a.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
a.logger.Errorf("Error removing expired data: failed to begin transaction: %v", err)
|
|
return
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
// Sprintf is required for table name because sql.DB does not substitute parameters for table names
|
|
//nolint:gosec
|
|
stmt := fmt.Sprintf(
|
|
`DELETE FROM %s
|
|
WHERE
|
|
expiration_time IS NOT NULL
|
|
AND expiration_time < CURRENT_TIMESTAMP`,
|
|
a.metadata.TableName,
|
|
)
|
|
res, err := tx.Exec(stmt)
|
|
if err != nil {
|
|
a.logger.Errorf("Error removing expired data: failed to execute query: %v", err)
|
|
return
|
|
}
|
|
|
|
cleaned, err := res.RowsAffected()
|
|
if err != nil {
|
|
a.logger.Errorf("Error removing expired data: failed to count affected rows: %v", err)
|
|
return
|
|
}
|
|
|
|
err = tx.Commit()
|
|
if err != nil {
|
|
a.logger.Errorf("Error removing expired data: failed to commit transaction: %v", err)
|
|
return
|
|
}
|
|
|
|
a.logger.Debugf("Removed %d expired rows", cleaned)
|
|
}
|