From 2e7d5e7df678030e556d42d93d12f81712f93e87 Mon Sep 17 00:00:00 2001 From: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> Date: Tue, 22 Nov 2022 02:05:07 +0000 Subject: [PATCH] Added garbage collection Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com> --- state/postgresql/postgresdbaccess.go | 144 ++++++++++++++---- state/postgresql/postgresdbaccess_test.go | 98 +++++++++++- .../postgresql/postgresql_integration_test.go | 16 +- state/postgresql/postgresql_test.go | 2 +- 4 files changed, 221 insertions(+), 39 deletions(-) diff --git a/state/postgresql/postgresdbaccess.go b/state/postgresql/postgresdbaccess.go index 8e1652488..59f4f8af0 100644 --- a/state/postgresql/postgresdbaccess.go +++ b/state/postgresql/postgresdbaccess.go @@ -14,6 +14,7 @@ limitations under the License. package postgresql import ( + "context" "database/sql" "encoding/base64" "encoding/json" @@ -35,18 +36,21 @@ import ( ) const ( - connectionStringKey = "connectionString" - errMissingConnectionString = "missing connection string" - defaultTableName = "state" + defaultTableName = "state" + cleanupIntervalKey = "cleanupIntervalInSeconds" + defaultCleanupInternal = 3600 // In seconds = 1 hour ) +var errMissingConnectionString = errors.New("missing connection string") + // postgresDBAccess implements dbaccess. type postgresDBAccess struct { - logger logger.Logger - metadata postgresMetadataStruct - db *sql.DB - connectionString string - tableName string + logger logger.Logger + metadata postgresMetadataStruct + cleanupInterval *time.Duration + db *sql.DB + ctx context.Context + cancel context.CancelFunc } // newPostgresDBAccess creates a new instance of postgresAccess. @@ -61,31 +65,24 @@ func newPostgresDBAccess(logger logger.Logger) *postgresDBAccess { type postgresMetadataStruct struct { ConnectionString string ConnectionMaxIdleTime time.Duration - TableName string + TableName string // Could be in the format "schema.table" or just "table" } // Init sets up Postgres connection and ensures that the state table exists. func (p *postgresDBAccess) Init(meta state.Metadata) error { p.logger.Debug("Initializing Postgres state store") - m := postgresMetadataStruct{ - TableName: defaultTableName, - } - err := metadata.DecodeMetadata(meta.Properties, &m) + + p.ctx, p.cancel = context.WithCancel(context.Background()) + + err := p.parseMetadata(meta) if err != nil { + p.logger.Errorf("Failed to parse metadata: %v", err) return err } - p.metadata = m - if m.ConnectionString == "" { - p.logger.Error("Missing Postgres connection string") - return errors.New(errMissingConnectionString) - } - p.connectionString = m.ConnectionString - - db, err := sql.Open("pgx", p.connectionString) + db, err := sql.Open("pgx", p.metadata.ConnectionString) if err != nil { p.logger.Error(err) - return err } @@ -96,16 +93,49 @@ func (p *postgresDBAccess) Init(meta state.Metadata) error { return pingErr } - p.db.SetConnMaxIdleTime(m.ConnectionMaxIdleTime) + p.db.SetConnMaxIdleTime(p.metadata.ConnectionMaxIdleTime) if err != nil { return err } - err = p.ensureStateTable(m.TableName) + err = p.ensureStateTable(p.metadata.TableName) if err != nil { return err } - p.tableName = m.TableName + + p.scheduleCleanupExpiredData() + + return nil +} + +func (p *postgresDBAccess) parseMetadata(meta state.Metadata) error { + m := postgresMetadataStruct{ + TableName: defaultTableName, + } + err := metadata.DecodeMetadata(meta.Properties, &m) + if err != nil { + return err + } + p.metadata = m + + if m.ConnectionString == "" { + return errMissingConnectionString + } + + s, ok := meta.Properties[cleanupIntervalKey] + if ok && s != "" { + cleanupIntervalInSec, err := strconv.ParseInt(s, 10, 0) + if err != nil { + return fmt.Errorf("invalid value for '%s': %s", cleanupIntervalKey, s) + } + + // Non-positive value from meta means disable auto cleanup. + if cleanupIntervalInSec > 0 { + p.cleanupInterval = ptr.Of(time.Duration(cleanupIntervalInSec) * time.Second) + } + } else { + p.cleanupInterval = ptr.Of(defaultCleanupInternal * time.Second) + } return nil } @@ -198,7 +228,7 @@ func (p *postgresDBAccess) Set(req *state.SetRequest) error { } else { queryExpiredate = "NULL" } - result, err = p.db.Exec(fmt.Sprintf(query, p.tableName, queryExpiredate), params...) + result, err = p.db.Exec(fmt.Sprintf(query, p.metadata.TableName, queryExpiredate), params...) if err != nil { if req.ETag != nil && *req.ETag != "" { @@ -259,7 +289,7 @@ func (p *postgresDBAccess) Get(req *state.GetRequest) (*state.GetResponse, error key = $1 AND expiredate IS NULL OR expiredate >= CURRENT_TIMESTAMP` err := p.db. - QueryRow(fmt.Sprintf(query, p.tableName), req.Key). + QueryRow(fmt.Sprintf(query, p.metadata.TableName), req.Key). Scan(&value, &isBinary, &etag) if err != nil { // If no rows exist, return an empty response, otherwise return the error. @@ -410,7 +440,7 @@ func (p *postgresDBAccess) Query(req *state.QueryRequest) (*state.QueryResponse, q := &Query{ query: "", params: []any{}, - tableName: p.tableName, + tableName: p.metadata.TableName, } qbuilder := query.NewQueryBuilder(q) if err := qbuilder.BuildQuery(&req.Query); err != nil { @@ -427,8 +457,66 @@ func (p *postgresDBAccess) Query(req *state.QueryRequest) (*state.QueryResponse, }, nil } +func (p *postgresDBAccess) scheduleCleanupExpiredData() { + if p.cleanupInterval == nil { + return + } + + p.logger.Infof("Schedule expired data clean up every %d seconds", int(p.cleanupInterval.Seconds())) + + ticker := time.NewTicker(*p.cleanupInterval) + go func() { + for { + select { + case <-ticker.C: + p.cleanupTimeout() + case <-p.ctx.Done(): + p.logger.Debug("Stopped background cleanup of expired data") + return + } + } + }() +} + +func (p *postgresDBAccess) cleanupTimeout() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tx, err := p.db.BeginTx(ctx, nil) + if err != nil { + p.logger.Errorf("Error removing expired data: failed to begin transaction: %v", err) + return + } + defer tx.Rollback() + + stmt := fmt.Sprintf(`DELETE FROM %s WHERE expiredate IS NOT NULL AND expiredate < CURRENT_TIMESTAMP`, p.metadata.TableName) + res, err := tx.Exec(stmt) + if err != nil { + p.logger.Errorf("Error removing expired data: failed to execute query: %v", err) + return + } + + cleaned, err := res.RowsAffected() + if err != nil { + p.logger.Errorf("Error removing expired data: failed to count affected rows: %v", err) + return + } + + err = tx.Commit() + if err != nil { + p.logger.Errorf("Error removing expired data: failed to commit transaction: %v", err) + return + } + + p.logger.Debugf("Removed %d expired rows", cleaned) +} + // Close implements io.Close. func (p *postgresDBAccess) Close() error { + if p.cancel != nil { + p.cancel() + p.cancel = nil + } if p.db != nil { return p.db.Close() } diff --git a/state/postgresql/postgresdbaccess_test.go b/state/postgresql/postgresdbaccess_test.go index 82057f501..5ea420bad 100644 --- a/state/postgresql/postgresdbaccess_test.go +++ b/state/postgresql/postgresdbaccess_test.go @@ -17,12 +17,14 @@ package postgresql import ( "database/sql" "testing" + "time" "github.com/DATA-DOG/go-sqlmock" - "github.com/stretchr/testify/assert" - + "github.com/dapr/components-contrib/metadata" "github.com/dapr/components-contrib/state" "github.com/dapr/kit/logger" + _ "github.com/jackc/pgx/v5/stdlib" + "github.com/stretchr/testify/assert" ) type mocks struct { @@ -461,3 +463,95 @@ func mockDatabase(t *testing.T) (*mocks, error) { pgDba: dba, }, err } + +func TestParseMetadata(t *testing.T) { + t.Run("missing connection string", func(t *testing.T) { + p := &postgresDBAccess{} + props := map[string]string{} + + err := p.parseMetadata(state.Metadata{Base: metadata.Base{Properties: props}}) + assert.Error(t, err) + assert.ErrorIs(t, err, errMissingConnectionString) + }) + + t.Run("has connection string", func(t *testing.T) { + p := &postgresDBAccess{} + props := map[string]string{ + "connectionString": "foo", + } + + err := p.parseMetadata(state.Metadata{Base: metadata.Base{Properties: props}}) + assert.NoError(t, err) + }) + + t.Run("default table name", func(t *testing.T) { + p := &postgresDBAccess{} + props := map[string]string{ + "connectionString": "foo", + } + + err := p.parseMetadata(state.Metadata{Base: metadata.Base{Properties: props}}) + assert.NoError(t, err) + assert.Equal(t, p.metadata.TableName, defaultTableName) + }) + + t.Run("custom table name", func(t *testing.T) { + p := &postgresDBAccess{} + props := map[string]string{ + "connectionString": "foo", + "tableName": "mytable", + } + + err := p.parseMetadata(state.Metadata{Base: metadata.Base{Properties: props}}) + assert.NoError(t, err) + assert.Equal(t, p.metadata.TableName, "mytable") + }) + + t.Run("default cleanupIntervalInSeconds", func(t *testing.T) { + p := &postgresDBAccess{} + props := map[string]string{ + "connectionString": "foo", + } + + err := p.parseMetadata(state.Metadata{Base: metadata.Base{Properties: props}}) + assert.NoError(t, err) + _ = assert.NotNil(t, p.cleanupInterval) && + assert.Equal(t, *p.cleanupInterval, defaultCleanupInternal*time.Second) + }) + + t.Run("invalid cleanupIntervalInSeconds", func(t *testing.T) { + p := &postgresDBAccess{} + props := map[string]string{ + "connectionString": "foo", + "cleanupIntervalInSeconds": "NaN", + } + + err := p.parseMetadata(state.Metadata{Base: metadata.Base{Properties: props}}) + assert.Error(t, err) + }) + + t.Run("positive cleanupIntervalInSeconds", func(t *testing.T) { + p := &postgresDBAccess{} + props := map[string]string{ + "connectionString": "foo", + "cleanupIntervalInSeconds": "42", + } + + err := p.parseMetadata(state.Metadata{Base: metadata.Base{Properties: props}}) + assert.NoError(t, err) + _ = assert.NotNil(t, p.cleanupInterval) && + assert.Equal(t, *p.cleanupInterval, 42*time.Second) + }) + + t.Run("zero cleanupIntervalInSeconds", func(t *testing.T) { + p := &postgresDBAccess{} + props := map[string]string{ + "connectionString": "foo", + "cleanupIntervalInSeconds": "0", + } + + err := p.parseMetadata(state.Metadata{Base: metadata.Base{Properties: props}}) + assert.NoError(t, err) + assert.Nil(t, p.cleanupInterval) + }) +} diff --git a/state/postgresql/postgresql_integration_test.go b/state/postgresql/postgresql_integration_test.go index 6cd7f1f67..f504132f4 100644 --- a/state/postgresql/postgresql_integration_test.go +++ b/state/postgresql/postgresql_integration_test.go @@ -48,7 +48,7 @@ func TestPostgreSQLIntegration(t *testing.T) { }) metadata := state.Metadata{ - Base: metadata.Base{Properties: map[string]string{connectionStringKey: connectionString}}, + Base: metadata.Base{Properties: map[string]string{"connectionString": connectionString}}, } pgs := NewPostgreSQLStateStore(logger.NewLogger("test")).(*PostgreSQL) @@ -476,7 +476,7 @@ func testInitConfiguration(t *testing.T) { tests := []struct { name string props map[string]string - expectedErr string + expectedErr error }{ { name: "Empty", @@ -485,8 +485,8 @@ func testInitConfiguration(t *testing.T) { }, { name: "Valid connection string", - props: map[string]string{connectionStringKey: getConnectionString()}, - expectedErr: "", + props: map[string]string{"connectionString": getConnectionString()}, + expectedErr: nil, }, } @@ -500,11 +500,11 @@ func testInitConfiguration(t *testing.T) { } err := p.Init(metadata) - if tt.expectedErr == "" { - assert.Nil(t, err) + if tt.expectedErr == nil { + assert.NoError(t, err) } else { - assert.NotNil(t, err) - assert.Equal(t, err.Error(), tt.expectedErr) + assert.Error(t, err) + assert.ErrorIs(t, err, tt.expectedErr) } }) } diff --git a/state/postgresql/postgresql_test.go b/state/postgresql/postgresql_test.go index 99ab088ef..950632030 100644 --- a/state/postgresql/postgresql_test.go +++ b/state/postgresql/postgresql_test.go @@ -106,7 +106,7 @@ func createPostgreSQL(t *testing.T) *PostgreSQL { assert.NotNil(t, pgs) metadata := &state.Metadata{ - Base: metadata.Base{Properties: map[string]string{connectionStringKey: fakeConnectionString}}, + Base: metadata.Base{Properties: map[string]string{"connectionString": fakeConnectionString}}, } err := pgs.Init(*metadata)