Added garbage collection
Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
parent
c54286a60d
commit
2e7d5e7df6
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||
package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
|
@ -35,18 +36,21 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
connectionStringKey = "connectionString"
|
||||
errMissingConnectionString = "missing connection string"
|
||||
defaultTableName = "state"
|
||||
defaultTableName = "state"
|
||||
cleanupIntervalKey = "cleanupIntervalInSeconds"
|
||||
defaultCleanupInternal = 3600 // In seconds = 1 hour
|
||||
)
|
||||
|
||||
var errMissingConnectionString = errors.New("missing connection string")
|
||||
|
||||
// postgresDBAccess implements dbaccess.
|
||||
type postgresDBAccess struct {
|
||||
logger logger.Logger
|
||||
metadata postgresMetadataStruct
|
||||
db *sql.DB
|
||||
connectionString string
|
||||
tableName string
|
||||
logger logger.Logger
|
||||
metadata postgresMetadataStruct
|
||||
cleanupInterval *time.Duration
|
||||
db *sql.DB
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// newPostgresDBAccess creates a new instance of postgresAccess.
|
||||
|
@ -61,31 +65,24 @@ func newPostgresDBAccess(logger logger.Logger) *postgresDBAccess {
|
|||
type postgresMetadataStruct struct {
|
||||
ConnectionString string
|
||||
ConnectionMaxIdleTime time.Duration
|
||||
TableName string
|
||||
TableName string // Could be in the format "schema.table" or just "table"
|
||||
}
|
||||
|
||||
// Init sets up Postgres connection and ensures that the state table exists.
|
||||
func (p *postgresDBAccess) Init(meta state.Metadata) error {
|
||||
p.logger.Debug("Initializing Postgres state store")
|
||||
m := postgresMetadataStruct{
|
||||
TableName: defaultTableName,
|
||||
}
|
||||
err := metadata.DecodeMetadata(meta.Properties, &m)
|
||||
|
||||
p.ctx, p.cancel = context.WithCancel(context.Background())
|
||||
|
||||
err := p.parseMetadata(meta)
|
||||
if err != nil {
|
||||
p.logger.Errorf("Failed to parse metadata: %v", err)
|
||||
return err
|
||||
}
|
||||
p.metadata = m
|
||||
|
||||
if m.ConnectionString == "" {
|
||||
p.logger.Error("Missing Postgres connection string")
|
||||
return errors.New(errMissingConnectionString)
|
||||
}
|
||||
p.connectionString = m.ConnectionString
|
||||
|
||||
db, err := sql.Open("pgx", p.connectionString)
|
||||
db, err := sql.Open("pgx", p.metadata.ConnectionString)
|
||||
if err != nil {
|
||||
p.logger.Error(err)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -96,16 +93,49 @@ func (p *postgresDBAccess) Init(meta state.Metadata) error {
|
|||
return pingErr
|
||||
}
|
||||
|
||||
p.db.SetConnMaxIdleTime(m.ConnectionMaxIdleTime)
|
||||
p.db.SetConnMaxIdleTime(p.metadata.ConnectionMaxIdleTime)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = p.ensureStateTable(m.TableName)
|
||||
err = p.ensureStateTable(p.metadata.TableName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.tableName = m.TableName
|
||||
|
||||
p.scheduleCleanupExpiredData()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *postgresDBAccess) parseMetadata(meta state.Metadata) error {
|
||||
m := postgresMetadataStruct{
|
||||
TableName: defaultTableName,
|
||||
}
|
||||
err := metadata.DecodeMetadata(meta.Properties, &m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.metadata = m
|
||||
|
||||
if m.ConnectionString == "" {
|
||||
return errMissingConnectionString
|
||||
}
|
||||
|
||||
s, ok := meta.Properties[cleanupIntervalKey]
|
||||
if ok && s != "" {
|
||||
cleanupIntervalInSec, err := strconv.ParseInt(s, 10, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid value for '%s': %s", cleanupIntervalKey, s)
|
||||
}
|
||||
|
||||
// Non-positive value from meta means disable auto cleanup.
|
||||
if cleanupIntervalInSec > 0 {
|
||||
p.cleanupInterval = ptr.Of(time.Duration(cleanupIntervalInSec) * time.Second)
|
||||
}
|
||||
} else {
|
||||
p.cleanupInterval = ptr.Of(defaultCleanupInternal * time.Second)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -198,7 +228,7 @@ func (p *postgresDBAccess) Set(req *state.SetRequest) error {
|
|||
} else {
|
||||
queryExpiredate = "NULL"
|
||||
}
|
||||
result, err = p.db.Exec(fmt.Sprintf(query, p.tableName, queryExpiredate), params...)
|
||||
result, err = p.db.Exec(fmt.Sprintf(query, p.metadata.TableName, queryExpiredate), params...)
|
||||
|
||||
if err != nil {
|
||||
if req.ETag != nil && *req.ETag != "" {
|
||||
|
@ -259,7 +289,7 @@ func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error
|
|||
key = $1
|
||||
AND expiredate IS NULL OR expiredate >= CURRENT_TIMESTAMP`
|
||||
err := p.db.
|
||||
QueryRow(fmt.Sprintf(query, p.tableName), req.Key).
|
||||
QueryRow(fmt.Sprintf(query, p.metadata.TableName), req.Key).
|
||||
Scan(&value, &isBinary, &etag)
|
||||
if err != nil {
|
||||
// If no rows exist, return an empty response, otherwise return the error.
|
||||
|
@ -410,7 +440,7 @@ func (p *postgresDBAccess) Query(req *state.QueryRequest) (*state.QueryResponse,
|
|||
q := &Query{
|
||||
query: "",
|
||||
params: []any{},
|
||||
tableName: p.tableName,
|
||||
tableName: p.metadata.TableName,
|
||||
}
|
||||
qbuilder := query.NewQueryBuilder(q)
|
||||
if err := qbuilder.BuildQuery(&req.Query); err != nil {
|
||||
|
@ -427,8 +457,66 @@ func (p *postgresDBAccess) Query(req *state.QueryRequest) (*state.QueryResponse,
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (p *postgresDBAccess) scheduleCleanupExpiredData() {
|
||||
if p.cleanupInterval == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.logger.Infof("Schedule expired data clean up every %d seconds", int(p.cleanupInterval.Seconds()))
|
||||
|
||||
ticker := time.NewTicker(*p.cleanupInterval)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
p.cleanupTimeout()
|
||||
case <-p.ctx.Done():
|
||||
p.logger.Debug("Stopped background cleanup of expired data")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (p *postgresDBAccess) cleanupTimeout() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
tx, err := p.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
p.logger.Errorf("Error removing expired data: failed to begin transaction: %v", err)
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
stmt := fmt.Sprintf(`DELETE FROM %s WHERE expiredate IS NOT NULL AND expiredate < CURRENT_TIMESTAMP`, p.metadata.TableName)
|
||||
res, err := tx.Exec(stmt)
|
||||
if err != nil {
|
||||
p.logger.Errorf("Error removing expired data: failed to execute query: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
cleaned, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
p.logger.Errorf("Error removing expired data: failed to count affected rows: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
p.logger.Errorf("Error removing expired data: failed to commit transaction: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.logger.Debugf("Removed %d expired rows", cleaned)
|
||||
}
|
||||
|
||||
// Close implements io.Close.
|
||||
func (p *postgresDBAccess) Close() error {
|
||||
if p.cancel != nil {
|
||||
p.cancel()
|
||||
p.cancel = nil
|
||||
}
|
||||
if p.db != nil {
|
||||
return p.db.Close()
|
||||
}
|
||||
|
|
|
@ -17,12 +17,14 @@ package postgresql
|
|||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/dapr/components-contrib/metadata"
|
||||
"github.com/dapr/components-contrib/state"
|
||||
"github.com/dapr/kit/logger"
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mocks struct {
|
||||
|
@ -461,3 +463,95 @@ func mockDatabase(t *testing.T) (*mocks, error) {
|
|||
pgDba: dba,
|
||||
}, err
|
||||
}
|
||||
|
||||
func TestParseMetadata(t *testing.T) {
|
||||
t.Run("missing connection string", func(t *testing.T) {
|
||||
p := &postgresDBAccess{}
|
||||
props := map[string]string{}
|
||||
|
||||
err := p.parseMetadata(state.Metadata{Base: metadata.Base{Properties: props}})
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, errMissingConnectionString)
|
||||
})
|
||||
|
||||
t.Run("has connection string", func(t *testing.T) {
|
||||
p := &postgresDBAccess{}
|
||||
props := map[string]string{
|
||||
"connectionString": "foo",
|
||||
}
|
||||
|
||||
err := p.parseMetadata(state.Metadata{Base: metadata.Base{Properties: props}})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("default table name", func(t *testing.T) {
|
||||
p := &postgresDBAccess{}
|
||||
props := map[string]string{
|
||||
"connectionString": "foo",
|
||||
}
|
||||
|
||||
err := p.parseMetadata(state.Metadata{Base: metadata.Base{Properties: props}})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, p.metadata.TableName, defaultTableName)
|
||||
})
|
||||
|
||||
t.Run("custom table name", func(t *testing.T) {
|
||||
p := &postgresDBAccess{}
|
||||
props := map[string]string{
|
||||
"connectionString": "foo",
|
||||
"tableName": "mytable",
|
||||
}
|
||||
|
||||
err := p.parseMetadata(state.Metadata{Base: metadata.Base{Properties: props}})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, p.metadata.TableName, "mytable")
|
||||
})
|
||||
|
||||
t.Run("default cleanupIntervalInSeconds", func(t *testing.T) {
|
||||
p := &postgresDBAccess{}
|
||||
props := map[string]string{
|
||||
"connectionString": "foo",
|
||||
}
|
||||
|
||||
err := p.parseMetadata(state.Metadata{Base: metadata.Base{Properties: props}})
|
||||
assert.NoError(t, err)
|
||||
_ = assert.NotNil(t, p.cleanupInterval) &&
|
||||
assert.Equal(t, *p.cleanupInterval, defaultCleanupInternal*time.Second)
|
||||
})
|
||||
|
||||
t.Run("invalid cleanupIntervalInSeconds", func(t *testing.T) {
|
||||
p := &postgresDBAccess{}
|
||||
props := map[string]string{
|
||||
"connectionString": "foo",
|
||||
"cleanupIntervalInSeconds": "NaN",
|
||||
}
|
||||
|
||||
err := p.parseMetadata(state.Metadata{Base: metadata.Base{Properties: props}})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("positive cleanupIntervalInSeconds", func(t *testing.T) {
|
||||
p := &postgresDBAccess{}
|
||||
props := map[string]string{
|
||||
"connectionString": "foo",
|
||||
"cleanupIntervalInSeconds": "42",
|
||||
}
|
||||
|
||||
err := p.parseMetadata(state.Metadata{Base: metadata.Base{Properties: props}})
|
||||
assert.NoError(t, err)
|
||||
_ = assert.NotNil(t, p.cleanupInterval) &&
|
||||
assert.Equal(t, *p.cleanupInterval, 42*time.Second)
|
||||
})
|
||||
|
||||
t.Run("zero cleanupIntervalInSeconds", func(t *testing.T) {
|
||||
p := &postgresDBAccess{}
|
||||
props := map[string]string{
|
||||
"connectionString": "foo",
|
||||
"cleanupIntervalInSeconds": "0",
|
||||
}
|
||||
|
||||
err := p.parseMetadata(state.Metadata{Base: metadata.Base{Properties: props}})
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, p.cleanupInterval)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -48,7 +48,7 @@ func TestPostgreSQLIntegration(t *testing.T) {
|
|||
})
|
||||
|
||||
metadata := state.Metadata{
|
||||
Base: metadata.Base{Properties: map[string]string{connectionStringKey: connectionString}},
|
||||
Base: metadata.Base{Properties: map[string]string{"connectionString": connectionString}},
|
||||
}
|
||||
|
||||
pgs := NewPostgreSQLStateStore(logger.NewLogger("test")).(*PostgreSQL)
|
||||
|
@ -476,7 +476,7 @@ func testInitConfiguration(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
props map[string]string
|
||||
expectedErr string
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "Empty",
|
||||
|
@ -485,8 +485,8 @@ func testInitConfiguration(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "Valid connection string",
|
||||
props: map[string]string{connectionStringKey: getConnectionString()},
|
||||
expectedErr: "",
|
||||
props: map[string]string{"connectionString": getConnectionString()},
|
||||
expectedErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -500,11 +500,11 @@ func testInitConfiguration(t *testing.T) {
|
|||
}
|
||||
|
||||
err := p.Init(metadata)
|
||||
if tt.expectedErr == "" {
|
||||
assert.Nil(t, err)
|
||||
if tt.expectedErr == nil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, err.Error(), tt.expectedErr)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, tt.expectedErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -106,7 +106,7 @@ func createPostgreSQL(t *testing.T) *PostgreSQL {
|
|||
assert.NotNil(t, pgs)
|
||||
|
||||
metadata := &state.Metadata{
|
||||
Base: metadata.Base{Properties: map[string]string{connectionStringKey: fakeConnectionString}},
|
||||
Base: metadata.Base{Properties: map[string]string{"connectionString": fakeConnectionString}},
|
||||
}
|
||||
|
||||
err := pgs.Init(*metadata)
|
||||
|
|
Loading…
Reference in New Issue