Move SQLite "auth" metadata to a separate package `internal/authentication/sqlite` (#3135)
Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> Co-authored-by: Bernd Verst <github@bernd.dev>
This commit is contained in:
parent
a874485a32
commit
e7db4cf3ad
|
|
@ -0,0 +1,180 @@
|
|||
/*
|
||||
Copyright 2023 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dapr/kit/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultTimeout = 20 * time.Second // Default timeout for database requests, in seconds
|
||||
DefaultBusyTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
// SqliteAuthMetadata contains the auth metadata for a SQLite component.
|
||||
type SqliteAuthMetadata struct {
|
||||
ConnectionString string `mapstructure:"connectionString" mapstructurealiases:"url"`
|
||||
Timeout time.Duration `mapstructure:"timeout" mapstructurealiases:"timeoutInSeconds"`
|
||||
BusyTimeout time.Duration `mapstructure:"busyTimeout"`
|
||||
DisableWAL bool `mapstructure:"disableWAL"` // Disable WAL journaling. You should not use WAL if the database is stored on a network filesystem (or data corruption may happen). This is ignored if the database is in-memory.
|
||||
}
|
||||
|
||||
// Reset the object
|
||||
func (m *SqliteAuthMetadata) Reset() {
|
||||
m.ConnectionString = ""
|
||||
m.Timeout = DefaultTimeout
|
||||
m.BusyTimeout = DefaultBusyTimeout
|
||||
m.DisableWAL = false
|
||||
}
|
||||
|
||||
func (m *SqliteAuthMetadata) Validate() error {
|
||||
// Validate and sanitize input
|
||||
if m.ConnectionString == "" {
|
||||
return errors.New("missing connection string")
|
||||
}
|
||||
if m.Timeout < time.Second {
|
||||
return errors.New("invalid value for 'timeout': must be greater than 1s")
|
||||
}
|
||||
|
||||
// Busy timeout
|
||||
// Truncate values to milliseconds. Values <= 0 do not set any timeout
|
||||
m.BusyTimeout = m.BusyTimeout.Truncate(time.Millisecond)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *SqliteAuthMetadata) GetConnectionString(log logger.Logger) (string, error) {
|
||||
// Check if we're using the in-memory database
|
||||
lc := strings.ToLower(m.ConnectionString)
|
||||
isMemoryDB := strings.HasPrefix(lc, ":memory:") || strings.HasPrefix(lc, "file::memory:")
|
||||
|
||||
// Get the "query string" from the connection string if present
|
||||
idx := strings.IndexRune(m.ConnectionString, '?')
|
||||
var qs url.Values
|
||||
if idx > 0 {
|
||||
qs, _ = url.ParseQuery(m.ConnectionString[(idx + 1):])
|
||||
}
|
||||
if len(qs) == 0 {
|
||||
qs = make(url.Values, 2)
|
||||
}
|
||||
|
||||
// If the database is in-memory, we must ensure that cache=shared is set
|
||||
if isMemoryDB {
|
||||
qs["cache"] = []string{"shared"}
|
||||
}
|
||||
|
||||
// Check if the database is read-only or immutable
|
||||
isReadOnly := false
|
||||
if len(qs["mode"]) > 0 {
|
||||
// Keep the first value only
|
||||
qs["mode"] = []string{
|
||||
qs["mode"][0],
|
||||
}
|
||||
if qs["mode"][0] == "ro" {
|
||||
isReadOnly = true
|
||||
}
|
||||
}
|
||||
if len(qs["immutable"]) > 0 {
|
||||
// Keep the first value only
|
||||
qs["immutable"] = []string{
|
||||
qs["immutable"][0],
|
||||
}
|
||||
if qs["immutable"][0] == "1" {
|
||||
isReadOnly = true
|
||||
}
|
||||
}
|
||||
|
||||
// We do not want to override a _txlock if set, but we'll show a warning if it's not "immediate"
|
||||
if len(qs["_txlock"]) > 0 {
|
||||
// Keep the first value only
|
||||
qs["_txlock"] = []string{
|
||||
strings.ToLower(qs["_txlock"][0]),
|
||||
}
|
||||
if qs["_txlock"][0] != "immediate" {
|
||||
log.Warn("Database connection is being created with a _txlock different from the recommended value 'immediate'")
|
||||
}
|
||||
} else {
|
||||
qs["_txlock"] = []string{"immediate"}
|
||||
}
|
||||
|
||||
// Add pragma values
|
||||
if len(qs["_pragma"]) == 0 {
|
||||
qs["_pragma"] = make([]string, 0, 2)
|
||||
} else {
|
||||
for _, p := range qs["_pragma"] {
|
||||
p = strings.ToLower(p)
|
||||
if strings.HasPrefix(p, "busy_timeout") {
|
||||
log.Error("Cannot set `_pragma=busy_timeout` option in the connection string; please use the `busyTimeout` metadata property instead")
|
||||
return "", errors.New("found forbidden option '_pragma=busy_timeout' in the connection string")
|
||||
} else if strings.HasPrefix(p, "journal_mode") {
|
||||
log.Error("Cannot set `_pragma=journal_mode` option in the connection string; please use the `disableWAL` metadata property instead")
|
||||
return "", errors.New("found forbidden option '_pragma=journal_mode' in the connection string")
|
||||
}
|
||||
}
|
||||
}
|
||||
if m.BusyTimeout > 0 {
|
||||
qs["_pragma"] = append(qs["_pragma"], fmt.Sprintf("busy_timeout(%d)", m.BusyTimeout.Milliseconds()))
|
||||
}
|
||||
if isMemoryDB {
|
||||
// For in-memory databases, set the journal to MEMORY, the only allowed option besides OFF (which would make transactions ineffective)
|
||||
qs["_pragma"] = append(qs["_pragma"], "journal_mode(MEMORY)")
|
||||
} else if m.DisableWAL || isReadOnly {
|
||||
// Set the journaling mode to "DELETE" (the default) if WAL is disabled or if the database is read-only
|
||||
qs["_pragma"] = append(qs["_pragma"], "journal_mode(DELETE)")
|
||||
} else {
|
||||
// Enable WAL
|
||||
qs["_pragma"] = append(qs["_pragma"], "journal_mode(WAL)")
|
||||
}
|
||||
|
||||
// Build the final connection string
|
||||
connString := m.ConnectionString
|
||||
if idx > 0 {
|
||||
connString = connString[:idx]
|
||||
}
|
||||
connString += "?" + qs.Encode()
|
||||
|
||||
// If the connection string doesn't begin with "file:", add the prefix
|
||||
if !strings.HasPrefix(lc, "file:") {
|
||||
log.Debug("prefix 'file:' added to the connection string")
|
||||
connString = "file:" + connString
|
||||
}
|
||||
|
||||
return connString, nil
|
||||
}
|
||||
|
||||
// Validates an identifier, such as table or DB name.
|
||||
func ValidIdentifier(v string) bool {
|
||||
if v == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Loop through the string as byte slice as we only care about ASCII characters
|
||||
b := []byte(v)
|
||||
for i := 0; i < len(b); i++ {
|
||||
if (b[i] >= '0' && b[i] <= '9') ||
|
||||
(b[i] >= 'a' && b[i] <= 'z') ||
|
||||
(b[i] >= 'A' && b[i] <= 'Z') ||
|
||||
b[i] == '_' {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
/*
|
||||
Copyright 2023 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/components-contrib/state"
|
||||
)
|
||||
|
||||
func TestSqliteMetadata(t *testing.T) {
|
||||
stateMetadata := func(props map[string]string) state.Metadata {
|
||||
return state.Metadata{Base: metadata.Base{Properties: props}}
|
||||
}
|
||||
|
||||
t.Run("default options", func(t *testing.T) {
|
||||
md := &SqliteAuthMetadata{}
|
||||
md.Reset()
|
||||
|
||||
err := metadata.DecodeMetadata(stateMetadata(map[string]string{
|
||||
"connectionString": "file:data.db",
|
||||
}), &md)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = md.Validate()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "file:data.db", md.ConnectionString)
|
||||
assert.Equal(t, DefaultTimeout, md.Timeout)
|
||||
assert.Equal(t, DefaultBusyTimeout, md.BusyTimeout)
|
||||
assert.False(t, md.DisableWAL)
|
||||
})
|
||||
|
||||
t.Run("empty connection string", func(t *testing.T) {
|
||||
md := &SqliteAuthMetadata{}
|
||||
md.Reset()
|
||||
|
||||
err := metadata.DecodeMetadata(stateMetadata(map[string]string{}), &md)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = md.Validate()
|
||||
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "missing connection string")
|
||||
})
|
||||
|
||||
t.Run("invalid timeout", func(t *testing.T) {
|
||||
md := &SqliteAuthMetadata{}
|
||||
md.Reset()
|
||||
|
||||
err := metadata.DecodeMetadata(stateMetadata(map[string]string{
|
||||
"connectionString": "file:data.db",
|
||||
"timeout": "500ms",
|
||||
}), &md)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = md.Validate()
|
||||
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "timeout")
|
||||
})
|
||||
|
||||
t.Run("aliases", func(t *testing.T) {
|
||||
md := &SqliteAuthMetadata{}
|
||||
md.Reset()
|
||||
|
||||
err := metadata.DecodeMetadata(stateMetadata(map[string]string{
|
||||
"url": "file:data.db",
|
||||
"timeoutinseconds": "1200",
|
||||
}), &md)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = md.Validate()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "file:data.db", md.ConnectionString)
|
||||
assert.Equal(t, 20*time.Minute, md.Timeout)
|
||||
})
|
||||
}
|
||||
|
|
@ -158,7 +158,7 @@ func DecodeMetadata(input any, result any) error {
|
|||
}
|
||||
|
||||
// Handle aliases
|
||||
err = resolveAliases(inputMap, result)
|
||||
err = resolveAliases(inputMap, reflect.TypeOf(result))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve aliases: %w", err)
|
||||
}
|
||||
|
|
@ -183,7 +183,7 @@ func DecodeMetadata(input any, result any) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func resolveAliases(md map[string]string, result any) error {
|
||||
func resolveAliases(md map[string]string, t reflect.Type) error {
|
||||
// Get the list of all keys in the map
|
||||
keys := make(map[string]string, len(md))
|
||||
for k := range md {
|
||||
|
|
@ -199,7 +199,6 @@ func resolveAliases(md map[string]string, result any) error {
|
|||
}
|
||||
|
||||
// Error if result is not pointer to struct, or pointer to pointer to struct
|
||||
t := reflect.TypeOf(result)
|
||||
if t.Kind() != reflect.Pointer {
|
||||
return fmt.Errorf("not a pointer: %s", t.Kind().String())
|
||||
}
|
||||
|
|
@ -211,7 +210,14 @@ func resolveAliases(md map[string]string, result any) error {
|
|||
return fmt.Errorf("not a struct: %s", t.Kind().String())
|
||||
}
|
||||
|
||||
// Iterate through all the properties of result to see if anyone has the "mapstructurealiases" property
|
||||
// Iterate through all the properties, possibly recursively
|
||||
resolveAliasesInType(md, keys, t)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func resolveAliasesInType(md map[string]string, keys map[string]string, t reflect.Type) {
|
||||
// Iterate through all the properties of the type to see if anyone has the "mapstructurealiases" property
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
currentField := t.Field(i)
|
||||
|
||||
|
|
@ -221,6 +227,12 @@ func resolveAliases(md map[string]string, result any) error {
|
|||
continue
|
||||
}
|
||||
|
||||
// Check if this is an embedded struct
|
||||
if mapstructureTag == ",squash" {
|
||||
resolveAliasesInType(md, keys, currentField.Type)
|
||||
continue
|
||||
}
|
||||
|
||||
// If the current property has a value in the metadata, then we don't need to handle aliases
|
||||
_, ok := keys[strings.ToLower(mapstructureTag)]
|
||||
if ok {
|
||||
|
|
@ -246,8 +258,6 @@ func resolveAliases(md map[string]string, result any) error {
|
|||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func toTruthyBoolHookFunc() mapstructure.DecodeHookFunc {
|
||||
|
|
|
|||
|
|
@ -98,7 +98,13 @@ func TestTryGetContentType(t *testing.T) {
|
|||
|
||||
func TestMetadataDecode(t *testing.T) {
|
||||
t.Run("Test metadata decoding", func(t *testing.T) {
|
||||
type TestEmbedded struct {
|
||||
MyEmbedded string `mapstructure:"embedded"`
|
||||
MyEmbeddedAliased string `mapstructure:"embalias" mapstructurealiases:"embalias2"`
|
||||
}
|
||||
type testMetadata struct {
|
||||
TestEmbedded `mapstructure:",squash"`
|
||||
|
||||
Mystring string `mapstructure:"mystring"`
|
||||
Myduration Duration `mapstructure:"myduration"`
|
||||
Myinteger int `mapstructure:"myinteger"`
|
||||
|
|
@ -139,6 +145,8 @@ func TestMetadataDecode(t *testing.T) {
|
|||
"aliasA2": "hello",
|
||||
"aliasB1": "ciao",
|
||||
"aliasB2": "bonjour",
|
||||
"embedded": "hi",
|
||||
"embalias2": "ciao",
|
||||
}
|
||||
|
||||
err := DecodeMetadata(testData, &m)
|
||||
|
|
@ -159,6 +167,8 @@ func TestMetadataDecode(t *testing.T) {
|
|||
assert.Equal(t, []time.Duration{}, *m.MyDurationArrayPointerEmpty)
|
||||
assert.Equal(t, "hello", m.AliasedFieldA)
|
||||
assert.Equal(t, "ciao", m.AliasedFieldB)
|
||||
assert.Equal(t, "hi", m.TestEmbedded.MyEmbedded)
|
||||
assert.Equal(t, "ciao", m.TestEmbedded.MyEmbeddedAliased)
|
||||
})
|
||||
|
||||
t.Run("Test metadata decode hook for truthy values", func(t *testing.T) {
|
||||
|
|
@ -346,6 +356,10 @@ func TestMetadataStructToStringMap(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestResolveAliases(t *testing.T) {
|
||||
type Embedded struct {
|
||||
Hello string `mapstructure:"hello" mapstructurealiases:"ciao"`
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
md map[string]string
|
||||
|
|
@ -497,11 +511,27 @@ func TestResolveAliases(t *testing.T) {
|
|||
"bonjour": "monde",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "aliases in embedded struct",
|
||||
md: map[string]string{
|
||||
"ciao": "mondo",
|
||||
"bonjour": "monde",
|
||||
},
|
||||
result: &struct {
|
||||
Embedded `mapstructure:",squash"`
|
||||
Bonjour string `mapstructure:"bonjour"`
|
||||
}{},
|
||||
wantMd: map[string]string{
|
||||
"bonjour": "monde",
|
||||
"ciao": "mondo",
|
||||
"hello": "mondo",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
md := maps.Clone(tt.md)
|
||||
err := resolveAliases(md, tt.result)
|
||||
err := resolveAliases(md, reflect.TypeOf(tt.result))
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
|
@ -78,7 +77,7 @@ func (a *sqliteDBAccess) Init(ctx context.Context, md state.Metadata) error {
|
|||
return err
|
||||
}
|
||||
|
||||
connString, err := a.getConnectionString()
|
||||
connString, err := a.metadata.GetConnectionString(a.logger)
|
||||
if err != nil {
|
||||
// Already logged
|
||||
return err
|
||||
|
|
@ -137,107 +136,8 @@ func (a *sqliteDBAccess) CleanupExpired() error {
|
|||
return a.gc.CleanupExpired()
|
||||
}
|
||||
|
||||
func (a *sqliteDBAccess) getConnectionString() (string, error) {
|
||||
// Check if we're using the in-memory database
|
||||
lc := strings.ToLower(a.metadata.ConnectionString)
|
||||
isMemoryDB := strings.HasPrefix(lc, ":memory:") || strings.HasPrefix(lc, "file::memory:")
|
||||
|
||||
// Get the "query string" from the connection string if present
|
||||
idx := strings.IndexRune(a.metadata.ConnectionString, '?')
|
||||
var qs url.Values
|
||||
if idx > 0 {
|
||||
qs, _ = url.ParseQuery(a.metadata.ConnectionString[(idx + 1):])
|
||||
}
|
||||
if len(qs) == 0 {
|
||||
qs = make(url.Values, 2)
|
||||
}
|
||||
|
||||
// If the database is in-memory, we must ensure that cache=shared is set
|
||||
if isMemoryDB {
|
||||
qs["cache"] = []string{"shared"}
|
||||
}
|
||||
|
||||
// Check if the database is read-only or immutable
|
||||
isReadOnly := false
|
||||
if len(qs["mode"]) > 0 {
|
||||
// Keep the first value only
|
||||
qs["mode"] = []string{
|
||||
qs["mode"][0],
|
||||
}
|
||||
if qs["mode"][0] == "ro" {
|
||||
isReadOnly = true
|
||||
}
|
||||
}
|
||||
if len(qs["immutable"]) > 0 {
|
||||
// Keep the first value only
|
||||
qs["immutable"] = []string{
|
||||
qs["immutable"][0],
|
||||
}
|
||||
if qs["immutable"][0] == "1" {
|
||||
isReadOnly = true
|
||||
}
|
||||
}
|
||||
|
||||
// We do not want to override a _txlock if set, but we'll show a warning if it's not "immediate"
|
||||
if len(qs["_txlock"]) > 0 {
|
||||
// Keep the first value only
|
||||
qs["_txlock"] = []string{
|
||||
strings.ToLower(qs["_txlock"][0]),
|
||||
}
|
||||
if qs["_txlock"][0] != "immediate" {
|
||||
a.logger.Warn("Database connection is being created with a _txlock different from the recommended value 'immediate'")
|
||||
}
|
||||
} else {
|
||||
qs["_txlock"] = []string{"immediate"}
|
||||
}
|
||||
|
||||
// Add pragma values
|
||||
if len(qs["_pragma"]) == 0 {
|
||||
qs["_pragma"] = make([]string, 0, 2)
|
||||
} else {
|
||||
for _, p := range qs["_pragma"] {
|
||||
p = strings.ToLower(p)
|
||||
if strings.HasPrefix(p, "busy_timeout") {
|
||||
a.logger.Error("Cannot set `_pragma=busy_timeout` option in the connection string; please use the `busyTimeout` metadata property instead")
|
||||
return "", errors.New("found forbidden option '_pragma=busy_timeout' in the connection string")
|
||||
} else if strings.HasPrefix(p, "journal_mode") {
|
||||
a.logger.Error("Cannot set `_pragma=journal_mode` option in the connection string; please use the `disableWAL` metadata property instead")
|
||||
return "", errors.New("found forbidden option '_pragma=journal_mode' in the connection string")
|
||||
}
|
||||
}
|
||||
}
|
||||
if a.metadata.BusyTimeout > 0 {
|
||||
qs["_pragma"] = append(qs["_pragma"], fmt.Sprintf("busy_timeout(%d)", a.metadata.BusyTimeout.Milliseconds()))
|
||||
}
|
||||
if isMemoryDB {
|
||||
// For in-memory databases, set the journal to MEMORY, the only allowed option besides OFF (which would make transactions ineffective)
|
||||
qs["_pragma"] = append(qs["_pragma"], "journal_mode(MEMORY)")
|
||||
} else if a.metadata.DisableWAL || isReadOnly {
|
||||
// Set the journaling mode to "DELETE" (the default) if WAL is disabled or if the database is read-only
|
||||
qs["_pragma"] = append(qs["_pragma"], "journal_mode(DELETE)")
|
||||
} else {
|
||||
// Enable WAL
|
||||
qs["_pragma"] = append(qs["_pragma"], "journal_mode(WAL)")
|
||||
}
|
||||
|
||||
// Build the final connection string
|
||||
connString := a.metadata.ConnectionString
|
||||
if idx > 0 {
|
||||
connString = connString[:idx]
|
||||
}
|
||||
connString += "?" + qs.Encode()
|
||||
|
||||
// If the connection string doesn't begin with "file:", add the prefix
|
||||
if !strings.HasPrefix(lc, "file:") {
|
||||
a.logger.Debug("prefix 'file:' added to the connection string")
|
||||
connString = "file:" + connString
|
||||
}
|
||||
|
||||
return connString, nil
|
||||
}
|
||||
|
||||
func (a *sqliteDBAccess) Ping(parentCtx context.Context) error {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, a.metadata.timeout)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, a.metadata.Timeout)
|
||||
err := a.db.PingContext(ctx)
|
||||
cancel()
|
||||
return err
|
||||
|
|
@ -253,7 +153,7 @@ func (a *sqliteDBAccess) Get(parentCtx context.Context, req *state.GetRequest) (
|
|||
WHERE
|
||||
key = ?
|
||||
AND (expiration_time IS NULL OR expiration_time > CURRENT_TIMESTAMP)`
|
||||
ctx, cancel := context.WithTimeout(parentCtx, a.metadata.timeout)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, a.metadata.Timeout)
|
||||
defer cancel()
|
||||
row := a.db.QueryRowContext(ctx, stmt, req.Key)
|
||||
_, value, etag, expireTime, err := readRow(row)
|
||||
|
|
@ -296,7 +196,7 @@ func (a *sqliteDBAccess) BulkGet(parentCtx context.Context, req []state.GetReque
|
|||
WHERE
|
||||
key IN (` + inClause + `)
|
||||
AND (expiration_time IS NULL OR expiration_time > CURRENT_TIMESTAMP)`
|
||||
ctx, cancel := context.WithTimeout(parentCtx, a.metadata.timeout)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, a.metadata.Timeout)
|
||||
defer cancel()
|
||||
rows, err := a.db.QueryContext(ctx, stmt, params...)
|
||||
if err != nil {
|
||||
|
|
@ -475,7 +375,7 @@ func (a *sqliteDBAccess) doSet(parentCtx context.Context, db querier, req *state
|
|||
stmt = "INSERT OR REPLACE INTO " + a.metadata.TableName + `
|
||||
(key, value, is_binary, etag, update_time, expiration_time)
|
||||
VALUES(?, ?, ?, ?, CURRENT_TIMESTAMP, ` + expiration + `)`
|
||||
ctx, cancel := context.WithTimeout(context.Background(), a.metadata.timeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), a.metadata.Timeout)
|
||||
defer cancel()
|
||||
res, err = db.ExecContext(ctx, stmt, req.Key, requestValue, isBinary, newEtag, req.Key)
|
||||
} else {
|
||||
|
|
@ -489,7 +389,7 @@ func (a *sqliteDBAccess) doSet(parentCtx context.Context, db querier, req *state
|
|||
key = ?
|
||||
AND etag = ?
|
||||
AND (expiration_time IS NULL OR expiration_time > CURRENT_TIMESTAMP)`
|
||||
ctx, cancel := context.WithTimeout(context.Background(), a.metadata.timeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), a.metadata.Timeout)
|
||||
defer cancel()
|
||||
res, err = db.ExecContext(ctx, stmt, requestValue, newEtag, isBinary, req.Key, *req.ETag)
|
||||
}
|
||||
|
|
@ -573,7 +473,7 @@ func (a *sqliteDBAccess) doDelete(parentCtx context.Context, db querier, req *st
|
|||
return fmt.Errorf("missing key in delete operation")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(parentCtx, a.metadata.timeout)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, a.metadata.Timeout)
|
||||
defer cancel()
|
||||
var result sql.Result
|
||||
if !req.HasETag() {
|
||||
|
|
|
|||
|
|
@ -678,7 +678,7 @@ func testInitConfiguration(t *testing.T) {
|
|||
{
|
||||
name: "Empty",
|
||||
props: map[string]string{},
|
||||
expectedErr: errMissingConnectionString,
|
||||
expectedErr: "missing connection string",
|
||||
},
|
||||
{
|
||||
name: "Valid connection string",
|
||||
|
|
@ -703,10 +703,10 @@ func testInitConfiguration(t *testing.T) {
|
|||
|
||||
err := p.Init(context.Background(), metadata)
|
||||
if tt.expectedErr == "" {
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, err.Error(), tt.expectedErr)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.expectedErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,11 +14,10 @@ limitations under the License.
|
|||
package sqlite
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
authSqlite "github.com/dapr/components-contrib/internal/authentication/sqlite"
|
||||
"github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/components-contrib/state"
|
||||
)
|
||||
|
|
@ -27,30 +26,18 @@ const (
|
|||
defaultTableName = "state"
|
||||
defaultMetadataTableName = "metadata"
|
||||
defaultCleanupInternal = time.Duration(0) // Disabled by default
|
||||
defaultTimeout = 20 * time.Second // Default timeout for database requests, in seconds
|
||||
defaultBusyTimeout = 2 * time.Second
|
||||
|
||||
errMissingConnectionString = "missing connection string"
|
||||
errInvalidIdentifier = "invalid identifier: %s" // specify identifier type, e.g. "table name"
|
||||
)
|
||||
|
||||
type sqliteMetadataStruct struct {
|
||||
ConnectionString string `json:"connectionString" mapstructure:"connectionString"`
|
||||
TableName string `json:"tableName" mapstructure:"tableName"`
|
||||
MetadataTableName string `json:"metadataTableName" mapstructure:"metadataTableName"`
|
||||
TimeoutInSeconds string `json:"timeoutInSeconds" mapstructure:"timeoutInSeconds"`
|
||||
CleanupInterval time.Duration `json:"cleanupInterval" mapstructure:"cleanupInterval"`
|
||||
BusyTimeout time.Duration `json:"busyTimeout" mapstructure:"busyTimeout"`
|
||||
DisableWAL bool `json:"disableWAL" mapstructure:"disableWAL"` // Disable WAL journaling. You should not use WAL if the database is stored on a network filesystem (or data corruption may happen). This is ignored if the database is in-memory.
|
||||
authSqlite.SqliteAuthMetadata `mapstructure:",squash"`
|
||||
|
||||
// Deprecated properties, maintained for backwards-compatibility
|
||||
CleanupIntervalInSeconds string `json:"cleanupIntervalInSeconds" mapstructure:"cleanupIntervalInSeconds"`
|
||||
|
||||
// Internal properties
|
||||
timeout time.Duration
|
||||
TableName string `mapstructure:"tableName"`
|
||||
MetadataTableName string `mapstructure:"metadataTableName"`
|
||||
CleanupInterval time.Duration `mapstructure:"cleanupInterval" mapstructurealiases:"cleanupIntervalInSeconds"`
|
||||
}
|
||||
|
||||
func (m *sqliteMetadataStruct) InitWithMetadata(meta state.Metadata) error {
|
||||
// Reset the object
|
||||
m.reset()
|
||||
|
||||
// Decode the metadata
|
||||
|
|
@ -60,78 +47,25 @@ func (m *sqliteMetadataStruct) InitWithMetadata(meta state.Metadata) error {
|
|||
}
|
||||
|
||||
// Validate and sanitize input
|
||||
if m.ConnectionString == "" {
|
||||
return errors.New(errMissingConnectionString)
|
||||
err = m.SqliteAuthMetadata.Validate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !validIdentifier(m.TableName) {
|
||||
return fmt.Errorf(errInvalidIdentifier, m.TableName)
|
||||
if !authSqlite.ValidIdentifier(m.TableName) {
|
||||
return fmt.Errorf("invalid identifier: %s", m.TableName)
|
||||
}
|
||||
|
||||
// Timeout
|
||||
if m.TimeoutInSeconds != "" {
|
||||
timeoutInSec, err := strconv.ParseInt(m.TimeoutInSeconds, 10, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid value for 'timeoutInSeconds': %s", m.TimeoutInSeconds)
|
||||
}
|
||||
if timeoutInSec < 1 {
|
||||
return errors.New("invalid value for 'timeoutInSeconds': must be greater than 0")
|
||||
}
|
||||
|
||||
m.timeout = time.Duration(timeoutInSec) * time.Second
|
||||
if !authSqlite.ValidIdentifier(m.MetadataTableName) {
|
||||
return fmt.Errorf("invalid identifier: %s", m.MetadataTableName)
|
||||
}
|
||||
|
||||
// Legacy "CleanupIntervalInSeconds" property
|
||||
// Non-positive duration means never clean up expired data
|
||||
if v := meta.Properties["cleanupInterval"]; v == "" && m.CleanupIntervalInSeconds != "" {
|
||||
cleanupIntervalInSec, err := strconv.ParseInt(m.CleanupIntervalInSeconds, 10, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid value for 'cleanupIntervalInSeconds': %s", m.CleanupIntervalInSeconds)
|
||||
}
|
||||
|
||||
// Non-positive value from meta means disable auto cleanup.
|
||||
if cleanupIntervalInSec > 0 {
|
||||
m.CleanupInterval = time.Duration(cleanupIntervalInSec) * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
// Busy timeout
|
||||
// Truncate values to milliseconds. Values <= 0 do not set any timeout
|
||||
m.BusyTimeout = m.BusyTimeout.Truncate(time.Millisecond)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset the object
|
||||
func (m *sqliteMetadataStruct) reset() {
|
||||
m.ConnectionString = ""
|
||||
m.SqliteAuthMetadata.Reset()
|
||||
|
||||
m.TableName = defaultTableName
|
||||
m.MetadataTableName = defaultMetadataTableName
|
||||
m.TimeoutInSeconds = ""
|
||||
m.CleanupInterval = defaultCleanupInternal
|
||||
m.BusyTimeout = defaultBusyTimeout
|
||||
m.DisableWAL = false
|
||||
|
||||
m.CleanupIntervalInSeconds = ""
|
||||
|
||||
m.timeout = defaultTimeout
|
||||
}
|
||||
|
||||
// Validates an identifier, such as table or DB name.
|
||||
func validIdentifier(v string) bool {
|
||||
if v == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Loop through the string as byte slice as we only care about ASCII characters
|
||||
b := []byte(v)
|
||||
for i := 0; i < len(b); i++ {
|
||||
if (b[i] >= '0' && b[i] <= '9') ||
|
||||
(b[i] >= 'a' && b[i] <= 'z') ||
|
||||
(b[i] >= 'A' && b[i] <= 'Z') ||
|
||||
b[i] == '_' {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,103 @@
|
|||
/*
|
||||
Copyright 2023 The Dapr Authors
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
authSqlite "github.com/dapr/components-contrib/internal/authentication/sqlite"
|
||||
"github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/components-contrib/state"
|
||||
)
|
||||
|
||||
func TestSqliteMetadata(t *testing.T) {
|
||||
stateMetadata := func(props map[string]string) state.Metadata {
|
||||
return state.Metadata{Base: metadata.Base{Properties: props}}
|
||||
}
|
||||
|
||||
t.Run("default options", func(t *testing.T) {
|
||||
md := &sqliteMetadataStruct{}
|
||||
err := md.InitWithMetadata(stateMetadata(map[string]string{
|
||||
"connectionString": "file:data.db",
|
||||
}))
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "file:data.db", md.ConnectionString)
|
||||
assert.Equal(t, defaultTableName, md.TableName)
|
||||
assert.Equal(t, defaultMetadataTableName, md.MetadataTableName)
|
||||
assert.Equal(t, authSqlite.DefaultTimeout, md.Timeout)
|
||||
assert.Equal(t, defaultCleanupInternal, md.CleanupInterval)
|
||||
assert.Equal(t, authSqlite.DefaultBusyTimeout, md.BusyTimeout)
|
||||
assert.False(t, md.DisableWAL)
|
||||
})
|
||||
|
||||
t.Run("empty connection string", func(t *testing.T) {
|
||||
md := &sqliteMetadataStruct{}
|
||||
err := md.InitWithMetadata(stateMetadata(map[string]string{}))
|
||||
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "missing connection string")
|
||||
})
|
||||
|
||||
t.Run("invalid state table name", func(t *testing.T) {
|
||||
md := &sqliteMetadataStruct{}
|
||||
err := md.InitWithMetadata(stateMetadata(map[string]string{
|
||||
"connectionstring": "file:data.db",
|
||||
"tablename": "not.valid",
|
||||
}))
|
||||
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "invalid identifier")
|
||||
})
|
||||
|
||||
t.Run("invalid metadata table name", func(t *testing.T) {
|
||||
md := &sqliteMetadataStruct{}
|
||||
err := md.InitWithMetadata(stateMetadata(map[string]string{
|
||||
"connectionstring": "file:data.db",
|
||||
"metadatatablename": "not.valid",
|
||||
}))
|
||||
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "invalid identifier")
|
||||
})
|
||||
|
||||
t.Run("invalid timeout", func(t *testing.T) {
|
||||
md := &sqliteMetadataStruct{}
|
||||
err := md.InitWithMetadata(stateMetadata(map[string]string{
|
||||
"connectionString": "file:data.db",
|
||||
"timeout": "500ms",
|
||||
}))
|
||||
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "timeout")
|
||||
})
|
||||
|
||||
t.Run("aliases", func(t *testing.T) {
|
||||
md := &sqliteMetadataStruct{}
|
||||
err := md.InitWithMetadata(stateMetadata(map[string]string{
|
||||
"url": "file:data.db",
|
||||
"timeoutinseconds": "1200",
|
||||
"cleanupintervalinseconds": "22",
|
||||
}))
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "file:data.db", md.ConnectionString)
|
||||
assert.Equal(t, 20*time.Minute, md.Timeout)
|
||||
assert.Equal(t, 22*time.Second, md.CleanupInterval)
|
||||
})
|
||||
}
|
||||
|
|
@ -48,7 +48,7 @@ func TestGetConnectionString(t *testing.T) {
|
|||
db.metadata.reset()
|
||||
db.metadata.ConnectionString = "file:test.db"
|
||||
|
||||
connString, err := db.getConnectionString()
|
||||
connString, err := db.metadata.GetConnectionString(log)
|
||||
require.NoError(t, err)
|
||||
|
||||
values := url.Values{
|
||||
|
|
@ -64,7 +64,7 @@ func TestGetConnectionString(t *testing.T) {
|
|||
db.metadata.reset()
|
||||
db.metadata.ConnectionString = "test.db"
|
||||
|
||||
connString, err := db.getConnectionString()
|
||||
connString, err := db.metadata.GetConnectionString(log)
|
||||
require.NoError(t, err)
|
||||
|
||||
values := url.Values{
|
||||
|
|
@ -82,7 +82,7 @@ func TestGetConnectionString(t *testing.T) {
|
|||
db.metadata.reset()
|
||||
db.metadata.ConnectionString = ":memory:"
|
||||
|
||||
connString, err := db.getConnectionString()
|
||||
connString, err := db.metadata.GetConnectionString(log)
|
||||
require.NoError(t, err)
|
||||
|
||||
values := url.Values{
|
||||
|
|
@ -103,7 +103,7 @@ func TestGetConnectionString(t *testing.T) {
|
|||
db.metadata.reset()
|
||||
db.metadata.ConnectionString = "file:test.db?_txlock=immediate"
|
||||
|
||||
connString, err := db.getConnectionString()
|
||||
connString, err := db.metadata.GetConnectionString(log)
|
||||
require.NoError(t, err)
|
||||
|
||||
values := url.Values{
|
||||
|
|
@ -121,7 +121,7 @@ func TestGetConnectionString(t *testing.T) {
|
|||
db.metadata.reset()
|
||||
db.metadata.ConnectionString = "file:test.db?_txlock=deferred"
|
||||
|
||||
connString, err := db.getConnectionString()
|
||||
connString, err := db.metadata.GetConnectionString(log)
|
||||
require.NoError(t, err)
|
||||
|
||||
values := url.Values{
|
||||
|
|
@ -141,7 +141,7 @@ func TestGetConnectionString(t *testing.T) {
|
|||
db.metadata.reset()
|
||||
db.metadata.ConnectionString = "file:test.db?_pragma=busy_timeout(50)"
|
||||
|
||||
_, err := db.getConnectionString()
|
||||
_, err := db.metadata.GetConnectionString(log)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "found forbidden option '_pragma=busy_timeout' in the connection string")
|
||||
})
|
||||
|
|
@ -150,7 +150,7 @@ func TestGetConnectionString(t *testing.T) {
|
|||
db.metadata.reset()
|
||||
db.metadata.ConnectionString = "file:test.db?_pragma=journal_mode(WAL)"
|
||||
|
||||
_, err := db.getConnectionString()
|
||||
_, err := db.metadata.GetConnectionString(log)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "found forbidden option '_pragma=journal_mode' in the connection string")
|
||||
})
|
||||
|
|
@ -162,7 +162,7 @@ func TestGetConnectionString(t *testing.T) {
|
|||
db.metadata.ConnectionString = "file:test.db"
|
||||
db.metadata.BusyTimeout = time.Second
|
||||
|
||||
connString, err := db.getConnectionString()
|
||||
connString, err := db.metadata.GetConnectionString(log)
|
||||
require.NoError(t, err)
|
||||
|
||||
values := url.Values{
|
||||
|
|
@ -179,7 +179,7 @@ func TestGetConnectionString(t *testing.T) {
|
|||
db.metadata.ConnectionString = "file:test.db"
|
||||
db.metadata.DisableWAL = false
|
||||
|
||||
connString, err := db.getConnectionString()
|
||||
connString, err := db.metadata.GetConnectionString(log)
|
||||
require.NoError(t, err)
|
||||
|
||||
values := url.Values{
|
||||
|
|
@ -195,7 +195,7 @@ func TestGetConnectionString(t *testing.T) {
|
|||
db.metadata.ConnectionString = "file:test.db"
|
||||
db.metadata.DisableWAL = true
|
||||
|
||||
connString, err := db.getConnectionString()
|
||||
connString, err := db.metadata.GetConnectionString(log)
|
||||
require.NoError(t, err)
|
||||
|
||||
values := url.Values{
|
||||
|
|
@ -210,7 +210,7 @@ func TestGetConnectionString(t *testing.T) {
|
|||
db.metadata.reset()
|
||||
db.metadata.ConnectionString = "file::memory:"
|
||||
|
||||
connString, err := db.getConnectionString()
|
||||
connString, err := db.metadata.GetConnectionString(log)
|
||||
require.NoError(t, err)
|
||||
|
||||
values := url.Values{
|
||||
|
|
@ -226,7 +226,7 @@ func TestGetConnectionString(t *testing.T) {
|
|||
db.metadata.reset()
|
||||
db.metadata.ConnectionString = "file:test.db?mode=ro"
|
||||
|
||||
connString, err := db.getConnectionString()
|
||||
connString, err := db.metadata.GetConnectionString(log)
|
||||
require.NoError(t, err)
|
||||
|
||||
values := url.Values{
|
||||
|
|
@ -242,7 +242,7 @@ func TestGetConnectionString(t *testing.T) {
|
|||
db.metadata.reset()
|
||||
db.metadata.ConnectionString = "file:test.db?immutable=1"
|
||||
|
||||
connString, err := db.getConnectionString()
|
||||
connString, err := db.metadata.GetConnectionString(log)
|
||||
require.NoError(t, err)
|
||||
|
||||
values := url.Values{
|
||||
|
|
|
|||
Loading…
Reference in New Issue