403 lines
9.8 KiB
Go
403 lines
9.8 KiB
Go
/*
|
|
Copyright 2021 The Dapr Authors
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package redis
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
|
|
rediscomponent "github.com/dapr/components-contrib/common/component/redis"
|
|
"github.com/dapr/components-contrib/state"
|
|
"github.com/dapr/components-contrib/state/query"
|
|
)
|
|
|
|
var ErrMultipleSortBy error = errors.New("multiple SORTBY steps are not allowed. Sort multiple fields in a single step")
|
|
|
|
type Query struct {
|
|
schemaName string
|
|
aliases map[string]string
|
|
query []interface{}
|
|
limit int
|
|
offset int64
|
|
}
|
|
|
|
func NewQuery(schemaName string, aliases map[string]string) *Query {
|
|
return &Query{
|
|
schemaName: schemaName,
|
|
aliases: aliases,
|
|
}
|
|
}
|
|
|
|
func (q *Query) getAlias(jsonPath string) (string, error) {
|
|
alias, ok := q.aliases[jsonPath]
|
|
if !ok {
|
|
return "", fmt.Errorf("JSON path %q is not indexed", jsonPath)
|
|
}
|
|
return alias, nil
|
|
}
|
|
|
|
func (q *Query) VisitEQ(f *query.EQ) (string, error) {
|
|
// string: @<key>:(<val>)
|
|
// numeric: @<key>:[<val> <val>]
|
|
alias, err := q.getAlias(f.Key)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
switch v := f.Val.(type) {
|
|
case string:
|
|
return fmt.Sprintf("@%s:(%s)", alias, v), nil
|
|
case float64, float32:
|
|
return fmt.Sprintf("@%s:[%f %f]", alias, v, v), nil
|
|
default:
|
|
return fmt.Sprintf("@%s:[%d %d]", alias, v, v), nil
|
|
}
|
|
}
|
|
|
|
func (q *Query) VisitNEQ(f *query.NEQ) (string, error) {
|
|
// string: @<key>:(<val>)
|
|
// numeric: @<key>:[<val> <val>]
|
|
alias, err := q.getAlias(f.Key)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
switch v := f.Val.(type) {
|
|
case string:
|
|
return fmt.Sprintf("@%s:(%s)", alias, v), nil
|
|
case float64, float32:
|
|
return fmt.Sprintf("@%s:[%f %f]", alias, v, v), nil
|
|
default:
|
|
return fmt.Sprintf("@%s:[%d %d]", alias, v, v), nil
|
|
}
|
|
}
|
|
|
|
func (q *Query) VisitGT(f *query.GT) (string, error) {
|
|
// numeric: @<key>:[(<val> +inf]
|
|
alias, err := q.getAlias(f.Key)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
switch v := f.Val.(type) {
|
|
case string:
|
|
return "", fmt.Errorf("unsupported type of value %s; string type not permitted", f.Val)
|
|
case float64, float32:
|
|
return fmt.Sprintf("@%s:[(%f +inf]", alias, v), nil
|
|
default:
|
|
return fmt.Sprintf("@%s:[(%d +inf]", alias, v), nil
|
|
}
|
|
}
|
|
|
|
func (q *Query) VisitGTE(f *query.GTE) (string, error) {
|
|
// numeric: @<key>:[<val> +inf]
|
|
alias, err := q.getAlias(f.Key)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
switch v := f.Val.(type) {
|
|
case string:
|
|
return "", fmt.Errorf("unsupported type of value %s; string type not permitted", f.Val)
|
|
case float64, float32:
|
|
return fmt.Sprintf("@%s:[%f +inf]", alias, v), nil
|
|
default:
|
|
return fmt.Sprintf("@%s:[%d +inf]", alias, v), nil
|
|
}
|
|
}
|
|
|
|
func (q *Query) VisitLT(f *query.LT) (string, error) {
|
|
// numeric: @<key>:[-inf <val>)]
|
|
alias, err := q.getAlias(f.Key)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
switch v := f.Val.(type) {
|
|
case string:
|
|
return "", fmt.Errorf("unsupported type of value %s; string type not permitted", f.Val)
|
|
case float64, float32:
|
|
return fmt.Sprintf("@%s:[-inf (%f]", alias, v), nil
|
|
default:
|
|
return fmt.Sprintf("@%s:[-inf (%d]", alias, v), nil
|
|
}
|
|
}
|
|
|
|
func (q *Query) VisitLTE(f *query.LTE) (string, error) {
|
|
// numeric: @<key>:[-inf <val>]
|
|
alias, err := q.getAlias(f.Key)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
switch v := f.Val.(type) {
|
|
case string:
|
|
return "", fmt.Errorf("unsupported type of value %s; string type not permitted", f.Val)
|
|
case float64, float32:
|
|
return fmt.Sprintf("@%s:[-inf %f]", alias, v), nil
|
|
default:
|
|
return fmt.Sprintf("@%s:[-inf %d]", alias, v), nil
|
|
}
|
|
}
|
|
|
|
func (q *Query) VisitIN(f *query.IN) (string, error) {
|
|
// string: @<key>:(<val1>|<val2>...)
|
|
// numeric: replace with OR
|
|
n := len(f.Vals)
|
|
if n < 2 {
|
|
return "", fmt.Errorf("too few values in IN operator for key %q", f.Key)
|
|
}
|
|
|
|
switch f.Vals[0].(type) {
|
|
case string:
|
|
alias, err := q.getAlias(f.Key)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
vals := make([]string, n)
|
|
for i := 0; i < n; i++ {
|
|
vals[i] = f.Vals[i].(string)
|
|
}
|
|
str := fmt.Sprintf("@%s:(%s)", alias, strings.Join(vals, "|"))
|
|
|
|
return str, nil
|
|
|
|
default:
|
|
or := &query.OR{
|
|
Filters: make([]query.Filter, n),
|
|
}
|
|
for i := 0; i < n; i++ {
|
|
or.Filters[i] = &query.EQ{
|
|
Key: f.Key,
|
|
Val: f.Vals[i],
|
|
}
|
|
}
|
|
|
|
return q.VisitOR(or)
|
|
}
|
|
}
|
|
|
|
func (q *Query) visitFilters(op string, filters []query.Filter) (string, error) {
|
|
var (
|
|
arr []string
|
|
str string
|
|
err error
|
|
)
|
|
for _, fil := range filters {
|
|
switch f := fil.(type) {
|
|
case *query.EQ:
|
|
if str, err = q.VisitEQ(f); err != nil {
|
|
return "", err
|
|
}
|
|
arr = append(arr, fmt.Sprintf("(%s)", str))
|
|
case *query.NEQ:
|
|
if str, err = q.VisitNEQ(f); err != nil {
|
|
return "", err
|
|
}
|
|
arr = append(arr, fmt.Sprintf("-(%s)", str))
|
|
case *query.GT:
|
|
if str, err = q.VisitGT(f); err != nil {
|
|
return "", err
|
|
}
|
|
arr = append(arr, fmt.Sprintf("(%s)", str))
|
|
case *query.GTE:
|
|
if str, err = q.VisitGTE(f); err != nil {
|
|
return "", err
|
|
}
|
|
arr = append(arr, fmt.Sprintf("(%s)", str))
|
|
case *query.LT:
|
|
if str, err = q.VisitLT(f); err != nil {
|
|
return "", err
|
|
}
|
|
arr = append(arr, fmt.Sprintf("(%s)", str))
|
|
case *query.LTE:
|
|
if str, err = q.VisitLTE(f); err != nil {
|
|
return "", err
|
|
}
|
|
arr = append(arr, fmt.Sprintf("(%s)", str))
|
|
case *query.IN:
|
|
if str, err = q.VisitIN(f); err != nil {
|
|
return "", err
|
|
}
|
|
arr = append(arr, fmt.Sprintf("(%s)", str))
|
|
case *query.OR:
|
|
if str, err = q.VisitOR(f); err != nil {
|
|
return "", err
|
|
}
|
|
arr = append(arr, str)
|
|
case *query.AND:
|
|
if str, err = q.VisitAND(f); err != nil {
|
|
return "", err
|
|
}
|
|
arr = append(arr, str)
|
|
default:
|
|
return "", fmt.Errorf("unsupported filter type %#v", f)
|
|
}
|
|
}
|
|
|
|
return fmt.Sprintf("(%s)", strings.Join(arr, op)), nil
|
|
}
|
|
|
|
func (q *Query) VisitAND(f *query.AND) (string, error) {
|
|
// ( <expression1> <expression2> ... <expressionN> )
|
|
return q.visitFilters(" ", f.Filters)
|
|
}
|
|
|
|
func (q *Query) VisitOR(f *query.OR) (string, error) {
|
|
// ( <expression1> | <expression2> | ... | <expressionN> )
|
|
return q.visitFilters("|", f.Filters)
|
|
}
|
|
|
|
func (q *Query) Finalize(filters string, qq *query.Query) error {
|
|
if len(filters) == 0 {
|
|
filters = "*"
|
|
}
|
|
q.query = []interface{}{filters}
|
|
|
|
// sorting
|
|
if len(qq.Sort) > 0 {
|
|
if len(qq.Sort) != 1 {
|
|
return ErrMultipleSortBy
|
|
}
|
|
alias, err := q.getAlias(qq.Sort[0].Key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
q.query = append(q.query, "SORTBY", alias)
|
|
if qq.Sort[0].Order == query.DESC {
|
|
q.query = append(q.query, "DESC")
|
|
}
|
|
}
|
|
// pagination
|
|
if qq.Page.Limit > 0 {
|
|
q.limit = qq.Page.Limit
|
|
q.offset = 0
|
|
if len(qq.Page.Token) != 0 {
|
|
var err error
|
|
q.offset, err = strconv.ParseInt(qq.Page.Token, 10, 64)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
q.query = append(q.query, "LIMIT", qq.Page.Token, strconv.Itoa(q.limit))
|
|
} else {
|
|
q.offset = 0
|
|
q.query = append(q.query, "LIMIT", "0", strconv.Itoa(q.limit))
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (q *Query) execute(ctx context.Context, client rediscomponent.RedisClient) ([]state.QueryItem, string, error) {
|
|
query := append(append([]interface{}{"FT.SEARCH", q.schemaName}, q.query...), "RETURN", "2", "$.data", "$.version")
|
|
ret, err := client.DoRead(ctx, query...)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
res, ok, err := parseQueryResponsePost28(ret)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
if !ok {
|
|
res, err = parseQueryResponsePre28(ret)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
}
|
|
|
|
// set next query token only if limit is specified
|
|
var token string
|
|
if q.limit > 0 && len(res) > 0 {
|
|
token = strconv.FormatInt(q.offset+int64(len(res)), 10)
|
|
}
|
|
|
|
return res, token, err
|
|
}
|
|
|
|
// parseQueryResponsePost28 parses the query Do response from redisearch 2.8+.
|
|
func parseQueryResponsePost28(ret any) ([]state.QueryItem, bool, error) {
|
|
aarr, ok := ret.(map[any]any)
|
|
if !ok {
|
|
return nil, false, nil
|
|
}
|
|
|
|
var res []state.QueryItem
|
|
arr := aarr["results"].([]any)
|
|
if len(arr) == 0 {
|
|
return nil, false, errors.New("invalid output")
|
|
}
|
|
for i := 0; i < len(arr); i++ {
|
|
inner, ok := arr[i].(map[any]any)
|
|
if !ok {
|
|
return nil, false, fmt.Errorf("invalid output")
|
|
}
|
|
exattr, ok := inner["extra_attributes"].(map[any]any)
|
|
if !ok {
|
|
return nil, false, fmt.Errorf("invalid output")
|
|
}
|
|
item := state.QueryItem{
|
|
Key: inner["id"].(string),
|
|
}
|
|
if data, ok := exattr["$.data"].(string); ok {
|
|
item.Data = []byte(data)
|
|
} else {
|
|
item.Error = fmt.Sprintf("%#v is not string", exattr["$.data"])
|
|
}
|
|
if etag, ok := exattr["$.version"].(string); ok {
|
|
item.ETag = &etag
|
|
}
|
|
res = append(res, item)
|
|
}
|
|
|
|
return res, true, nil
|
|
}
|
|
|
|
// parseQueryResponsePre28 parses the query Do response from redisearch 2.8-.
|
|
func parseQueryResponsePre28(ret any) ([]state.QueryItem, error) {
|
|
arr, ok := ret.([]any)
|
|
if !ok {
|
|
return nil, errors.New("invalid output")
|
|
}
|
|
|
|
// arr[0] = number of matching elements in DB (ignoring pagination
|
|
// arr[2n] = key
|
|
// arr[2n+1][0] = "$.data"
|
|
// arr[2n+1][1] = value
|
|
// arr[2n+1][2] = "$.version"
|
|
// arr[2n+1][3] = etag
|
|
if len(arr)%2 != 1 {
|
|
return nil, fmt.Errorf("invalid output")
|
|
}
|
|
|
|
var res []state.QueryItem
|
|
for i := 1; i < len(arr); i += 2 {
|
|
item := state.QueryItem{
|
|
Key: arr[i].(string),
|
|
}
|
|
if data, ok := arr[i+1].([]interface{}); ok && len(data) == 4 && data[0] == "$.data" && data[2] == "$.version" {
|
|
item.Data = []byte(data[1].(string))
|
|
etag := data[3].(string)
|
|
item.ETag = &etag
|
|
} else {
|
|
item.Error = fmt.Sprintf("%#v is not []interface{}", arr[i+1])
|
|
}
|
|
res = append(res, item)
|
|
}
|
|
|
|
return res, nil
|
|
}
|