diff --git a/state/postgresql/dbaccess.go b/state/postgresql/dbaccess.go index f26393516..c9bbbd323 100644 --- a/state/postgresql/dbaccess.go +++ b/state/postgresql/dbaccess.go @@ -24,5 +24,6 @@ type dbAccess interface { Get(req *state.GetRequest) (*state.GetResponse, error) Delete(req *state.DeleteRequest) error ExecuteMulti(sets []state.SetRequest, deletes []state.DeleteRequest) error + Query(req *state.QueryRequest) (*state.QueryResponse, error) Close() error // io.Closer } diff --git a/state/postgresql/postgresdbaccess.go b/state/postgresql/postgresdbaccess.go index 930021dff..0ad6b3089 100644 --- a/state/postgresql/postgresdbaccess.go +++ b/state/postgresql/postgresdbaccess.go @@ -23,6 +23,7 @@ import ( "github.com/agrea/ptr" "github.com/dapr/components-contrib/state" + "github.com/dapr/components-contrib/state/query" "github.com/dapr/components-contrib/state/utils" "github.com/dapr/kit/logger" @@ -289,6 +290,28 @@ func (p *postgresDBAccess) ExecuteMulti(sets []state.SetRequest, deletes []state return err } +// Query executes a query against store. +func (p *postgresDBAccess) Query(req *state.QueryRequest) (*state.QueryResponse, error) { + p.logger.Debug("Getting query value from PostgreSQL") + q := &Query{ + query: "", + params: []interface{}{}, + } + qbuilder := query.NewQueryBuilder(q) + if err := qbuilder.BuildQuery(&req.Query); err != nil { + return &state.QueryResponse{}, err + } + data, token, err := q.execute(p.logger, p.db) + if err != nil { + return &state.QueryResponse{}, err + } + + return &state.QueryResponse{ + Results: data, + Token: token, + }, nil +} + // Close implements io.Close. func (p *postgresDBAccess) Close() error { if p.db != nil { diff --git a/state/postgresql/postgresql.go b/state/postgresql/postgresql.go index e879f28b9..4e4af33cb 100644 --- a/state/postgresql/postgresql.go +++ b/state/postgresql/postgresql.go @@ -121,6 +121,11 @@ func (p *PostgreSQL) Multi(request *state.TransactionalStateRequest) error { return nil } +// Query executes a query against store. +func (p *PostgreSQL) Query(req *state.QueryRequest) (*state.QueryResponse, error) { + return p.dbaccess.Query(req) +} + // Close implements io.Closer. func (p *PostgreSQL) Close() error { if p.dbaccess != nil { diff --git a/state/postgresql/postgresql_query.go b/state/postgresql/postgresql_query.go new file mode 100644 index 000000000..eac2fda7f --- /dev/null +++ b/state/postgresql/postgresql_query.go @@ -0,0 +1,209 @@ +/* +Copyright 2022 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package postgresql + +import ( + "database/sql" + "fmt" + "strconv" + "strings" + + "github.com/agrea/ptr" + + "github.com/dapr/components-contrib/state" + "github.com/dapr/components-contrib/state/query" + "github.com/dapr/kit/logger" +) + +type Query struct { + query string + params []interface{} + limit int + skip *int64 +} + +func (q *Query) VisitEQ(f *query.EQ) (string, error) { + return q.whereFieldEqual(f.Key, f.Val), nil +} + +func (q *Query) VisitIN(f *query.IN) (string, error) { + if len(f.Vals) == 0 { + return "", fmt.Errorf("empty IN operator for key %q", f.Key) + } + + str := "(" + str += q.whereFieldEqual(f.Key, f.Vals[0]) + + for _, v := range f.Vals[1:] { + str += " OR " + str += q.whereFieldEqual(f.Key, v) + } + str += ")" + return str, nil +} + +func (q *Query) visitFilters(op string, filters []query.Filter) (string, error) { + var ( + arr []string + str string + err error + ) + + for _, fil := range filters { + switch f := fil.(type) { + case *query.EQ: + if str, err = q.VisitEQ(f); err != nil { + return "", err + } + arr = append(arr, str) + case *query.IN: + if str, err = q.VisitIN(f); err != nil { + return "", err + } + arr = append(arr, str) + case *query.OR: + if str, err = q.VisitOR(f); err != nil { + return "", err + } + arr = append(arr, str) + case *query.AND: + if str, err = q.VisitAND(f); err != nil { + return "", err + } + arr = append(arr, str) + default: + return "", fmt.Errorf("unsupported filter type %#v", f) + } + } + + sep := fmt.Sprintf(" %s ", op) + + return fmt.Sprintf("(%s)", strings.Join(arr, sep)), nil +} + +func (q *Query) VisitAND(f *query.AND) (string, error) { + return q.visitFilters("AND", f.Filters) +} + +func (q *Query) VisitOR(f *query.OR) (string, error) { + return q.visitFilters("OR", f.Filters) +} + +func (q *Query) Finalize(filters string, qq *query.Query) error { + q.query = fmt.Sprintf("SELECT key, value, xmin as etag FROM %s", tableName) + + if filters != "" { + q.query += fmt.Sprintf(" WHERE %s", filters) + } + + if len(qq.Sort) > 0 { + q.query += " ORDER BY " + + for sortIndex, sortItem := range qq.Sort { + if sortIndex > 0 { + q.query += ", " + } + q.query += translateFieldToFilter(sortItem.Key) + if sortItem.Order != "" { + q.query += fmt.Sprintf(" %s", sortItem.Order) + } + } + } + + if qq.Page.Limit > 0 { + q.query += fmt.Sprintf(" LIMIT %d", qq.Page.Limit) + q.limit = qq.Page.Limit + } + + if len(qq.Page.Token) != 0 { + skip, err := strconv.ParseInt(qq.Page.Token, 10, 64) + if err != nil { + return err + } + q.query += fmt.Sprintf(" OFFSET %d", skip) + q.skip = &skip + } + + return nil +} + +func (q *Query) execute(logger logger.Logger, db *sql.DB) ([]state.QueryItem, string, error) { + rows, err := db.Query(q.query, q.params...) + if err != nil { + return nil, "", err + } + defer rows.Close() + + ret := []state.QueryItem{} + for rows.Next() { + var ( + key string + data []byte + etag int + ) + if err = rows.Scan(&key, &data, &etag); err != nil { + return nil, "", err + } + result := state.QueryItem{ + Key: key, + Data: data, + ETag: ptr.String(strconv.Itoa(etag)), + } + ret = append(ret, result) + } + + if err = rows.Err(); err != nil { + return nil, "", err + } + + var token string + if q.limit != 0 { + var skip int64 + if q.skip != nil { + skip = *q.skip + } + token = strconv.FormatInt(skip+int64(len(ret)), 10) + } + + return ret, token, nil +} + +func (q *Query) addParamValueAndReturnPosition(value interface{}) int { + q.params = append(q.params, fmt.Sprintf("%v", value)) + return len(q.params) +} + +func translateFieldToFilter(key string) string { + fieldParts := strings.Split(key, ".") + filterField := fieldParts[0] + fieldParts = fieldParts[1:] + + for fieldIndex, fieldPart := range fieldParts { + filterField += "->" + + if fieldIndex+1 == len(fieldParts) { + filterField += ">" + } + + filterField += fmt.Sprintf("'%s'", fieldPart) + } + + return filterField +} + +func (q *Query) whereFieldEqual(key string, value interface{}) string { + position := q.addParamValueAndReturnPosition(value) + filterField := translateFieldToFilter(key) + query := fmt.Sprintf("%s=$%v", filterField, position) + return query +} diff --git a/state/postgresql/postgresql_query_test.go b/state/postgresql/postgresql_query_test.go new file mode 100644 index 000000000..a1727f90a --- /dev/null +++ b/state/postgresql/postgresql_query_test.go @@ -0,0 +1,69 @@ +/* +Copyright 2022 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package postgresql + +import ( + "encoding/json" + "io/ioutil" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/dapr/components-contrib/state/query" +) + +func TestPostgresqlQueryBuildQuery(t *testing.T) { + tests := []struct { + input string + query string + }{ + { + input: "../../tests/state/query/q1.json", + query: "SELECT key, value, xmin as etag FROM state LIMIT 2", + }, + { + input: "../../tests/state/query/q2.json", + query: "SELECT key, value, xmin as etag FROM state WHERE state=$1 LIMIT 2", + }, + { + input: "../../tests/state/query/q3.json", + query: "SELECT key, value, xmin as etag FROM state WHERE (person->>'org'=$1 AND (state=$2 OR state=$3)) ORDER BY state DESC, person->>'name'", + }, + { + input: "../../tests/state/query/q4.json", + query: "SELECT key, value, xmin as etag FROM state WHERE (person->>'org'=$1 OR (person->>'org'=$2 AND (state=$3 OR state=$4))) ORDER BY state DESC, person->>'name' LIMIT 2", + }, + { + input: "../../tests/state/query/q5.json", + query: "SELECT key, value, xmin as etag FROM state WHERE (value->'person'->>'org'=$1 AND (value->'person'->>'name'=$2 OR (value->>'state'=$3 OR value->>'state'=$4))) ORDER BY value->>'state' DESC, value->'person'->>'name' LIMIT 2", + }, + { + input: "../../tests/state/query/q6.json", + query: "SELECT key, value, xmin as etag FROM state WHERE value->>'state'=$1 LIMIT 2 OFFSET 2", + }, + } + for _, test := range tests { + data, err := ioutil.ReadFile(test.input) + assert.NoError(t, err) + var qq query.Query + err = json.Unmarshal(data, &qq) + assert.NoError(t, err) + + q := &Query{} + qbuilder := query.NewQueryBuilder(q) + err = qbuilder.BuildQuery(&qq) + assert.NoError(t, err) + assert.Equal(t, test.query, q.query) + } +} diff --git a/state/postgresql/postgresql_test.go b/state/postgresql/postgresql_test.go index 36cec53d7..461511dda 100644 --- a/state/postgresql/postgresql_test.go +++ b/state/postgresql/postgresql_test.go @@ -59,6 +59,10 @@ func (m *fakeDBaccess) ExecuteMulti(sets []state.SetRequest, deletes []state.Del return nil } +func (m *fakeDBaccess) Query(req *state.QueryRequest) (*state.QueryResponse, error) { + return nil, nil +} + func (m *fakeDBaccess) Close() error { return nil } diff --git a/tests/config/state/tests.yml b/tests/config/state/tests.yml index 73e9115a0..68c9a3e21 100644 --- a/tests/config/state/tests.yml +++ b/tests/config/state/tests.yml @@ -16,7 +16,7 @@ components: operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "first-write" ] - component: postgresql allOperations: false - operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag" ] + operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag", "query" ] - component: mysql allOperations: false operations: [ "set", "get", "delete", "bulkset", "bulkdelete", "transaction", "etag" ] diff --git a/tests/conformance/state/state.go b/tests/conformance/state/state.go index 2e7622f62..3c88bf757 100644 --- a/tests/conformance/state/state.go +++ b/tests/conformance/state/state.go @@ -261,8 +261,13 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St assert.NoError(t, err) assert.Equal(t, len(scenario.results), len(resp.Results)) for i := range scenario.results { + var expected, actual interface{} + err = json.Unmarshal(scenario.results[i].Data, &expected) + assert.NoError(t, err) + err = json.Unmarshal(resp.Results[i].Data, &actual) + assert.NoError(t, err) assert.Equal(t, scenario.results[i].Key, resp.Results[i].Key) - assert.Equal(t, string(scenario.results[i].Data), string(resp.Results[i].Data)) + assert.Equal(t, expected, actual) } } }) diff --git a/tests/state/query/q5.json b/tests/state/query/q5.json new file mode 100644 index 000000000..96471d42d --- /dev/null +++ b/tests/state/query/q5.json @@ -0,0 +1,37 @@ +{ + "filter": { + "AND": [ + { + "EQ": { + "value.person.org": "A" + } + }, + { + "OR": [ + { + "EQ": { + "value.person.name": "B" + } + }, + { + "IN": { + "value.state": ["CA", "WA"] + } + } + ] + } + ] + }, + "sort": [ + { + "key": "value.state", + "order": "DESC" + }, + { + "key": "value.person.name" + } + ], + "page": { + "limit": 2 + } +} diff --git a/tests/state/query/q6.json b/tests/state/query/q6.json new file mode 100644 index 000000000..1adbf4712 --- /dev/null +++ b/tests/state/query/q6.json @@ -0,0 +1,12 @@ + +{ + "filter": { + "EQ": { + "value.state": "CA" + } + }, + "page": { + "limit": 2, + "token": "2" + } +}