diff --git a/bindings/postgres/postgres.go b/bindings/postgres/postgres.go new file mode 100644 index 000000000..50594679b --- /dev/null +++ b/bindings/postgres/postgres.go @@ -0,0 +1,165 @@ +// ------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +// ------------------------------------------------------------ + +package postgres + +import ( + "context" + "encoding/json" + "strconv" + "time" + + "github.com/dapr/components-contrib/bindings" + "github.com/dapr/dapr/pkg/logger" + "github.com/pkg/errors" + + "github.com/jackc/pgx/v4/pgxpool" +) + +// List of operations. +const ( + execOperation bindings.OperationKind = "exec" + queryOperation bindings.OperationKind = "query" + closeOperation bindings.OperationKind = "close" + + connectionURLKey = "url" + commandSQLKey = "sql" +) + +// Postgres represents PostgreSQL output binding +type Postgres struct { + logger logger.Logger + db *pgxpool.Pool +} + +var _ = bindings.OutputBinding(&Postgres{}) + +// NewPostgres returns a new PostgreSQL output binding +func NewPostgres(logger logger.Logger) *Postgres { + return &Postgres{logger: logger} +} + +// Init initializes the PostgreSql binding +func (p *Postgres) Init(metadata bindings.Metadata) error { + url, ok := metadata.Properties[connectionURLKey] + if !ok || url == "" { + return errors.Errorf("required metadata not set: %s", connectionURLKey) + } + + poolConfig, err := pgxpool.ParseConfig(url) + if err != nil { + return errors.Wrap(err, "error opening DB connection") + } + + p.db, err = pgxpool.ConnectConfig(context.Background(), poolConfig) + if err != nil { + return errors.Wrap(err, "unable to ping the DB") + } + + return nil +} + +// Operations returns list of operations supported by PostgreSql binding +func (p *Postgres) Operations() []bindings.OperationKind { + return []bindings.OperationKind{ + execOperation, + queryOperation, + closeOperation, + } +} + +// Invoke handles all invoke operations +func (p *Postgres) Invoke(req *bindings.InvokeRequest) (resp *bindings.InvokeResponse, err error) { + if req == nil { + return nil, errors.Errorf("invoke request required") + } + + if req.Operation == closeOperation { + p.db.Close() + return nil, nil + } + + if req.Metadata == nil { + return nil, errors.Errorf("metadata required") + } + p.logger.Debugf("operation: %v", req.Operation) + + sql, ok := req.Metadata[commandSQLKey] + if !ok || sql == "" { + return nil, errors.Errorf("required metadata not set: %s", commandSQLKey) + } + + startTime := time.Now().UTC() + resp = &bindings.InvokeResponse{ + Metadata: map[string]string{ + "operation": string(req.Operation), + "sql": sql, + "start-time": startTime.Format(time.RFC3339Nano), + }, + } + + switch req.Operation { + case execOperation: + r, err := p.exec(sql) + if err != nil { + resp.Metadata["error"] = err.Error() + } + resp.Metadata["rows-affected"] = strconv.FormatInt(r, 10) // 0 if error + + case queryOperation: + d, err := p.query(sql) + if err != nil { + resp.Metadata["error"] = err.Error() + } + resp.Data = d + + default: + return nil, errors.Errorf( + "invalid operation type: %s. Expected %s, %s, or %s", + req.Operation, execOperation, queryOperation, closeOperation, + ) + } + + endTime := time.Now().UTC() + resp.Metadata["end-time"] = endTime.Format(time.RFC3339Nano) + resp.Metadata["duration"] = endTime.Sub(startTime).String() + + return resp, nil +} + +func (p *Postgres) query(sql string) (result []byte, err error) { + p.logger.Debugf("select: %s", sql) + + rows, err := p.db.Query(context.Background(), sql) + if err != nil { + return nil, errors.Wrapf(err, "error executing query: %s", sql) + } + + rs := make([]interface{}, 0) + for rows.Next() { + val, rowErr := rows.Values() + if rowErr != nil { + return nil, errors.Wrapf(rowErr, "error parsing result: %v", rows.Err()) + } + rs = append(rs, val) + } + + if result, err = json.Marshal(rs); err != nil { + err = errors.Wrap(err, "error serializing results") + } + return +} + +func (p *Postgres) exec(sql string) (result int64, err error) { + p.logger.Debugf("exec: %s", sql) + + res, err := p.db.Exec(context.Background(), sql) + if err != nil { + return 0, errors.Wrapf(err, "error executing query: %s", sql) + } + + result = res.RowsAffected() + return +} diff --git a/bindings/postgres/postgres_test.go b/bindings/postgres/postgres_test.go new file mode 100644 index 000000000..6be191cf9 --- /dev/null +++ b/bindings/postgres/postgres_test.go @@ -0,0 +1,107 @@ +// ------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +// ------------------------------------------------------------ + +package postgres + +import ( + "fmt" + "os" + "testing" + "time" + + "github.com/dapr/components-contrib/bindings" + "github.com/dapr/dapr/pkg/logger" + "github.com/stretchr/testify/assert" +) + +const ( + testTableDDL = `CREATE TABLE IF NOT EXISTS foo ( + id bigint NOT NULL, + v1 character varying(50) NOT NULL, + ts TIMESTAMP)` + testInsert = "INSERT INTO foo (id, v1, ts) VALUES (%d, 'test-%d', '%v')" + testDelete = "DELETE FROM foo" + testUpdate = "UPDATE foo SET ts = '%v' WHERE id = %d" + testSelect = "SELECT * FROM foo WHERE id < 3" +) + +func TestOperations(t *testing.T) { + t.Parallel() + b := NewPostgres(logger.NewLogger("test")) + assert.NotNil(t, b) + l := b.Operations() + assert.Equal(t, 3, len(l)) +} + +// SETUP TESTS +// 1. `createdb daprtest` +// 2. `createuser daprtest` +// 3. `psql=# grant all privileges on database daprtest to daprtest;`` +// 4. `export POSTGRES_TEST_CONN_URL="postgres://daprtest@localhost:5432/daprtest?application_name=test&connect_timeout=5"`` +// 5. `go test -v -count=1 ./bindings/postgres -run ^TestPostgresIntegration` + +func TestPostgresIntegration(t *testing.T) { + url := os.Getenv("POSTGRES_TEST_CONN_URL") + if url == "" { + t.SkipNow() + } + + // live DB test + b := NewPostgres(logger.NewLogger("test")) + err := b.Init(bindings.Metadata{Properties: map[string]string{connectionURLKey: url}}) + assert.NoError(t, err) + + // create table + req := &bindings.InvokeRequest{ + Operation: execOperation, + Metadata: map[string]string{commandSQLKey: testTableDDL}, + } + res, err := b.Invoke(req) + assertResponse(t, res, err) + + // delete all previous records if any + req.Metadata[commandSQLKey] = testDelete + res, err = b.Invoke(req) + assertResponse(t, res, err) + + // insert recrods + for i := 0; i < 10; i++ { + req.Metadata[commandSQLKey] = fmt.Sprintf(testInsert, i, i, time.Now().Format(time.RFC3339)) + res, err = b.Invoke(req) + assertResponse(t, res, err) + } + + // update recrods + for i := 0; i < 10; i++ { + req.Metadata[commandSQLKey] = fmt.Sprintf(testUpdate, time.Now().Format(time.RFC3339), i) + res, err = b.Invoke(req) + assertResponse(t, res, err) + } + + // select records + req.Operation = queryOperation + req.Metadata[commandSQLKey] = testSelect + res, err = b.Invoke(req) + assertResponse(t, res, err) + t.Logf("result data: %v", string(res.Data)) + + // delete records + req.Operation = execOperation + req.Metadata[commandSQLKey] = testDelete + res, err = b.Invoke(req) + assertResponse(t, res, err) + + // close connection + req.Operation = closeOperation + _, err = b.Invoke(req) + assert.NoError(t, err) +} + +func assertResponse(t *testing.T, res *bindings.InvokeResponse, err error) { + assert.NoError(t, err) + assert.NotNil(t, res) + assert.NotNil(t, res.Metadata) + t.Logf("result meta: %v", res.Metadata) +} diff --git a/go.mod b/go.mod index 976140220..86610719c 100644 --- a/go.mod +++ b/go.mod @@ -48,6 +48,7 @@ require ( github.com/influxdata/influxdb-client-go v1.4.0 github.com/jackc/pgx/v4 v4.6.0 github.com/json-iterator/go v1.1.8 + github.com/lib/pq v1.8.0 // indirect github.com/mitchellh/mapstructure v1.3.2 // indirect github.com/nats-io/go-nats v1.7.2 github.com/nats-io/nats-streaming-server v0.17.0 // indirect diff --git a/go.sum b/go.sum index f69aa05ad..77bbac5e6 100644 --- a/go.sum +++ b/go.sum @@ -513,6 +513,7 @@ github.com/jackc/pgx/v4 v4.6.0 h1:Fh0O9GdlG4gYpjpwOqjdEodJUQM9jzN3Hdv7PN0xmm0= github.com/jackc/pgx/v4 v4.6.0/go.mod h1:vPh43ZzxijXUVJ+t/EmXBtFmbFVO72cuneCT9oAlxAg= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.1.0 h1:musOWczZC/rSbqut475Vfcczg7jJsdUQf0D6oKPLgNU= github.com/jackc/puddle v1.1.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jcmturner/gofork v0.0.0-20190328161633-dc7c13fece03 h1:FUwcHNlEqkqLjLBdCp5PRlCFijNjvcYANOZXzCfXwCM= github.com/jcmturner/gofork v0.0.0-20190328161633-dc7c13fece03/go.mod h1:MK8+TM0La+2rjBD4jE12Kj1pCCxK7d2LK/UM3ncEo0o= @@ -588,6 +589,8 @@ github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.3.0 h1:/qkRGz8zljWiDcFvgpwUpwIAPu3r07TDvs3Rws+o/pU= github.com/lib/pq v1.3.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.8.0 h1:9xohqzkUwzR4Ga4ivdTcawVS89YSDVxXMa3xJX3cGzg= +github.com/lib/pq v1.8.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/liggitt/tabwriter v0.0.0-20181228230101-89fcab3d43de/go.mod h1:zAbeS9B/r2mtpb6U+EI2rYA5OAXxsYw6wTamcNW+zcE= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/mailru/easyjson v0.0.0-20160728113105-d5b7844b561a/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=