diff --git a/backend/api/v2beta1/filter.proto b/backend/api/v2beta1/filter.proto index ac5a39bd8b..d03c6ca865 100644 --- a/backend/api/v2beta1/filter.proto +++ b/backend/api/v2beta1/filter.proto @@ -34,7 +34,7 @@ service DummyFilterService { // filter { // predicate { // key: "status" -// op: EQUALS +// operation: EQUALS // string_value: "Running" // } // } @@ -43,12 +43,12 @@ service DummyFilterService { // filter { // predicate { // key: "status" -// op: EQUALS +// operation: EQUALS // string_value: "Succeeded" // } // predicate { // key: "created_at" -// op: GREATER_THAN +// operation: GREATER_THAN // timestamp_value { // seconds: 1543651200 // } @@ -60,7 +60,7 @@ service DummyFilterService { // filter { // predicate { // key: "label" -// op: IN +// operation: IN // string_values { // value: 'label_1' // value: 'label_2' diff --git a/backend/api/v2beta1/go_client/filter.pb.go b/backend/api/v2beta1/go_client/filter.pb.go index 2b62b51972..6cfd69ab5b 100644 --- a/backend/api/v2beta1/go_client/filter.pb.go +++ b/backend/api/v2beta1/go_client/filter.pb.go @@ -126,7 +126,7 @@ func (Predicate_Operation) EnumDescriptor() ([]byte, []int) { // filter { // predicate { // key: "status" -// op: EQUALS +// operation: EQUALS // string_value: "Running" // } // } @@ -135,12 +135,12 @@ func (Predicate_Operation) EnumDescriptor() ([]byte, []int) { // filter { // predicate { // key: "status" -// op: EQUALS +// operation: EQUALS // string_value: "Succeeded" // } // predicate { // key: "created_at" -// op: GREATER_THAN +// operation: GREATER_THAN // timestamp_value { // seconds: 1543651200 // } @@ -152,7 +152,7 @@ func (Predicate_Operation) EnumDescriptor() ([]byte, []int) { // filter { // predicate { // key: "label" -// op: IN +// operation: IN // string_values { // value: 'label_1' // value: 'label_2' diff --git a/backend/api/v2beta1/python_http_client/docs/V2beta1Filter.md b/backend/api/v2beta1/python_http_client/docs/V2beta1Filter.md index f8a8a48334..ee42b44ea1 100644 --- a/backend/api/v2beta1/python_http_client/docs/V2beta1Filter.md +++ b/backend/api/v2beta1/python_http_client/docs/V2beta1Filter.md @@ -1,6 +1,6 @@ # V2beta1Filter -Filter is used to filter resources returned from a ListXXX request. Example filters: 1) Filter runs with status = 'Running' filter { predicate { key: \"status\" op: EQUALS string_value: \"Running\" } } 2) Filter runs that succeeded since Dec 1, 2018 filter { predicate { key: \"status\" op: EQUALS string_value: \"Succeeded\" } predicate { key: \"created_at\" op: GREATER_THAN timestamp_value { seconds: 1543651200 } } } 3) Filter runs with one of labels 'label_1' or 'label_2' filter { predicate { key: \"label\" op: IN string_values { value: 'label_1' value: 'label_2' } } } +Filter is used to filter resources returned from a ListXXX request. Example filters: 1) Filter runs with status = 'Running' filter { predicate { key: \"status\" operation: EQUALS string_value: \"Running\" } } 2) Filter runs that succeeded since Dec 1, 2018 filter { predicate { key: \"status\" operation: EQUALS string_value: \"Succeeded\" } predicate { key: \"created_at\" operation: GREATER_THAN timestamp_value { seconds: 1543651200 } } } 3) Filter runs with one of labels 'label_1' or 'label_2' filter { predicate { key: \"label\" operation: IN string_values { value: 'label_1' value: 'label_2' } } } ## Properties Name | Type | Description | Notes ------------ | ------------- | ------------- | ------------- diff --git a/backend/api/v2beta1/swagger/filter.swagger.json b/backend/api/v2beta1/swagger/filter.swagger.json index 8ddc83038e..7e02c29163 100644 --- a/backend/api/v2beta1/swagger/filter.swagger.json +++ b/backend/api/v2beta1/swagger/filter.swagger.json @@ -65,7 +65,7 @@ "description": "All predicates are AND-ed when this filter is applied." } }, - "description": "Filter is used to filter resources returned from a ListXXX request.\n\nExample filters:\n1) Filter runs with status = 'Running'\nfilter {\n predicate {\n key: \"status\"\n op: EQUALS\n string_value: \"Running\"\n }\n}\n\n2) Filter runs that succeeded since Dec 1, 2018\nfilter {\n predicate {\n key: \"status\"\n op: EQUALS\n string_value: \"Succeeded\"\n }\n predicate {\n key: \"created_at\"\n op: GREATER_THAN\n timestamp_value {\n seconds: 1543651200\n }\n }\n}\n\n3) Filter runs with one of labels 'label_1' or 'label_2'\n\nfilter {\n predicate {\n key: \"label\"\n op: IN\n string_values {\n value: 'label_1'\n value: 'label_2'\n }\n }\n}" + "description": "Filter is used to filter resources returned from a ListXXX request.\n\nExample filters:\n1) Filter runs with status = 'Running'\nfilter {\n predicate {\n key: \"status\"\n operation: EQUALS\n string_value: \"Running\"\n }\n}\n\n2) Filter runs that succeeded since Dec 1, 2018\nfilter {\n predicate {\n key: \"status\"\n operation: EQUALS\n string_value: \"Succeeded\"\n }\n predicate {\n key: \"created_at\"\n operation: GREATER_THAN\n timestamp_value {\n seconds: 1543651200\n }\n }\n}\n\n3) Filter runs with one of labels 'label_1' or 'label_2'\n\nfilter {\n predicate {\n key: \"label\"\n operation: IN\n string_values {\n value: 'label_1'\n value: 'label_2'\n }\n }\n}" }, "v2beta1Predicate": { "type": "object", diff --git a/backend/api/v2beta1/swagger/kfp_api_single_file.swagger.json b/backend/api/v2beta1/swagger/kfp_api_single_file.swagger.json index a78b7763da..75ebcc92fe 100644 --- a/backend/api/v2beta1/swagger/kfp_api_single_file.swagger.json +++ b/backend/api/v2beta1/swagger/kfp_api_single_file.swagger.json @@ -1570,7 +1570,7 @@ "description": "All predicates are AND-ed when this filter is applied." } }, - "description": "Filter is used to filter resources returned from a ListXXX request.\n\nExample filters:\n1) Filter runs with status = 'Running'\nfilter {\n predicate {\n key: \"status\"\n op: EQUALS\n string_value: \"Running\"\n }\n}\n\n2) Filter runs that succeeded since Dec 1, 2018\nfilter {\n predicate {\n key: \"status\"\n op: EQUALS\n string_value: \"Succeeded\"\n }\n predicate {\n key: \"created_at\"\n op: GREATER_THAN\n timestamp_value {\n seconds: 1543651200\n }\n }\n}\n\n3) Filter runs with one of labels 'label_1' or 'label_2'\n\nfilter {\n predicate {\n key: \"label\"\n op: IN\n string_values {\n value: 'label_1'\n value: 'label_2'\n }\n }\n}" + "description": "Filter is used to filter resources returned from a ListXXX request.\n\nExample filters:\n1) Filter runs with status = 'Running'\nfilter {\n predicate {\n key: \"status\"\n operation: EQUALS\n string_value: \"Running\"\n }\n}\n\n2) Filter runs that succeeded since Dec 1, 2018\nfilter {\n predicate {\n key: \"status\"\n operation: EQUALS\n string_value: \"Succeeded\"\n }\n predicate {\n key: \"created_at\"\n operation: GREATER_THAN\n timestamp_value {\n seconds: 1543651200\n }\n }\n}\n\n3) Filter runs with one of labels 'label_1' or 'label_2'\n\nfilter {\n predicate {\n key: \"label\"\n operation: IN\n string_values {\n value: 'label_1'\n value: 'label_2'\n }\n }\n}" }, "v2beta1Predicate": { "type": "object", diff --git a/backend/src/apiserver/filter/filter.go b/backend/src/apiserver/filter/filter.go index 09a57a5dfc..423b19ba10 100644 --- a/backend/src/apiserver/filter/filter.go +++ b/backend/src/apiserver/filter/filter.go @@ -21,17 +21,22 @@ import ( "fmt" "github.com/Masterminds/squirrel" - "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/ptypes" - api "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client" + apiv1beta1 "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client" + apiv2beta1 "github.com/kubeflow/pipelines/backend/api/v2beta1/go_client" "github.com/kubeflow/pipelines/backend/src/common/util" ) +// Internal representation of a predicate. +type Predicate struct { + operation string + key string + value interface{} +} + // Filter represents a filter that can be applied when querying an arbitrary API // resource. type Filter struct { - filterProto *api.Filter - eq map[string][]interface{} neq map[string][]interface{} gt map[string][]interface{} @@ -47,8 +52,6 @@ type Filter struct { // filterForMarshaling is a helper struct for marshaling Filter into JSON. This // is needed as we don't want to export the fields in Filter. type filterForMarshaling struct { - FilterProto string - EQ map[string][]interface{} NEQ map[string][]interface{} GT map[string][]interface{} @@ -63,21 +66,15 @@ type filterForMarshaling struct { // MarshalJSON implements JSON Marshaler for Filter. func (f *Filter) MarshalJSON() ([]byte, error) { - m := &jsonpb.Marshaler{} - s, err := m.MarshalToString(f.filterProto) - if err != nil { - return nil, util.Wrap(err, "Failed to marshal filter proto into a string") - } return json.Marshal(&filterForMarshaling{ - FilterProto: s, - EQ: f.eq, - NEQ: f.neq, - GT: f.gt, - GTE: f.gte, - LT: f.lt, - LTE: f.lte, - IN: f.in, - SUBSTRING: f.substring, + EQ: f.eq, + NEQ: f.neq, + GT: f.gt, + GTE: f.gte, + LT: f.lt, + LTE: f.lte, + IN: f.in, + SUBSTRING: f.substring, }) } @@ -89,12 +86,6 @@ func (f *Filter) UnmarshalJSON(b []byte) error { return err } - f.filterProto = &api.Filter{} - err = jsonpb.UnmarshalString(ffm.FilterProto, f.filterProto) - if err != nil { - return util.Wrap(err, "Failed to unmarshal filter proto") - } - f.eq = ffm.EQ f.neq = ffm.NEQ f.gt = ffm.GT @@ -108,23 +99,12 @@ func (f *Filter) UnmarshalJSON(b []byte) error { } // New creates a new Filter from parsing the API filter protocol buffer. -func New(filterProto *api.Filter) (*Filter, error) { - f := &Filter{ - filterProto: filterProto, - eq: make(map[string][]interface{}, 0), - neq: make(map[string][]interface{}, 0), - gt: make(map[string][]interface{}, 0), - gte: make(map[string][]interface{}, 0), - lt: make(map[string][]interface{}, 0), - lte: make(map[string][]interface{}, 0), - in: make(map[string][]interface{}, 0), - substring: make(map[string][]interface{}, 0), - } - - if err := f.parseFilterProto(); err != nil { +func New(filterProto interface{}) (*Filter, error) { + predicates, err := toPredicates(filterProto) + if err != nil { return nil, err } - return f, nil + return NewFromPredicate(predicates) } // NewWithKeyMap is like New, but takes an additional map and model name for mapping key names @@ -132,21 +112,99 @@ func New(filterProto *api.Filter) (*Filter, error) { // model. For example, if the API name of a field is "name", the model name is "pipelines", and // the equivalent column name is "Name", then filterProto with predicates against key "name" // will be parsed as if the key value was "pipelines.Name". -func NewWithKeyMap(filterProto *api.Filter, keyMap map[string]string, modelName string) (*Filter, error) { +func NewWithKeyMap(filterProto interface{}, keyMap map[string]string, modelName string) (*Filter, error) { // Fully qualify column name to avoid "ambiguous column name" error. var modelNamePrefix string if modelName != "" { modelNamePrefix = modelName + "." } - for _, pred := range filterProto.Predicates { - k, ok := keyMap[pred.Key] - if !ok { - return nil, util.NewInvalidInputError("no support for filtering on unrecognized field %q", pred.Key) - } - pred.Key = modelNamePrefix + k + predicates, err := toPredicates(filterProto) + if err != nil { + return nil, err } - return New(filterProto) + + for _, pred := range predicates { + k, ok := keyMap[pred.key] + if !ok { + return nil, util.NewInvalidInputError("no support for filtering on unrecognized field %q", pred.key) + } + pred.key = modelNamePrefix + k + } + return NewFromPredicate(predicates) +} + +// New creates a new Filter from parsed predicates. +func NewFromPredicate(predicates []*Predicate) (*Filter, error) { + if len(predicates) == 0 { + return nil, nil + } + + f := &Filter{ + eq: make(map[string][]interface{}, 0), + neq: make(map[string][]interface{}, 0), + gt: make(map[string][]interface{}, 0), + gte: make(map[string][]interface{}, 0), + lt: make(map[string][]interface{}, 0), + lte: make(map[string][]interface{}, 0), + in: make(map[string][]interface{}, 0), + substring: make(map[string][]interface{}, 0), + } + + if err := f.parsePredicates(predicates); err != nil { + return nil, err + } + return f, nil +} + +// Replaces and adds a prefix to the keys for an existing filter. +// This is useful when someone wants to extend the filter with a table name. +func (f *Filter) ReplaceKeys(keyMap map[string]string, prefix string) error { + if prefix != "" { + prefix = prefix + "." + } + if err := replaceMapKeys(f.eq, keyMap, prefix); err != nil { + return err + } + if err := replaceMapKeys(f.neq, keyMap, prefix); err != nil { + return err + } + if err := replaceMapKeys(f.gt, keyMap, prefix); err != nil { + return err + } + if err := replaceMapKeys(f.gte, keyMap, prefix); err != nil { + return err + } + if err := replaceMapKeys(f.lt, keyMap, prefix); err != nil { + return err + } + if err := replaceMapKeys(f.lte, keyMap, prefix); err != nil { + return err + } + if err := replaceMapKeys(f.in, keyMap, prefix); err != nil { + return err + } + if err := replaceMapKeys(f.substring, keyMap, prefix); err != nil { + return err + } + return nil +} + +// Replaces string keys in a map and adds a prefix. +func replaceMapKeys(m map[string][]interface{}, keyMap map[string]string, prefix string) error { + keys := make([]string, 0) + for k := range m { + keys = append(keys, k) + } + for _, k := range keys { + newKey, ok := keyMap[k] + if !ok { + return util.NewInvalidInputError("no support for filtering on unrecognized field %q", k) + } + m[prefix+newKey] = m[k] + delete(m, k) + } + return nil } // AddToSelect builds a WHERE clause from the Filter f, adds it to the supplied @@ -215,104 +273,199 @@ func (f *Filter) AddToSelect(sb squirrel.SelectBuilder) squirrel.SelectBuilder { return sb } -func checkPredicate(p *api.Predicate) error { - switch p.Op { - case api.Predicate_IN: - switch t := p.Value.(type) { - case *api.Predicate_IntValue, *api.Predicate_LongValue, *api.Predicate_StringValue, *api.Predicate_TimestampValue: +func checkPredicate(p *Predicate) error { + switch p.operation { + case apiv1beta1.Predicate_IN.String(), apiv2beta1.Predicate_IN.String(): + switch t := p.value.(type) { + case int32, int64, string: return util.NewInvalidInputError("cannot use IN operator with scalar type %T", t) } - - case api.Predicate_EQUALS, api.Predicate_NOT_EQUALS, api.Predicate_GREATER_THAN, api.Predicate_GREATER_THAN_EQUALS, api.Predicate_LESS_THAN, api.Predicate_LESS_THAN_EQUALS: - switch t := p.Value.(type) { - case *api.Predicate_IntValues, *api.Predicate_LongValues, *api.Predicate_StringValues: - return util.NewInvalidInputError("cannot use scalar operator %v on array type %T", p.Op, t) + case apiv1beta1.Predicate_EQUALS.String(), apiv1beta1.Predicate_NOT_EQUALS.String(), apiv1beta1.Predicate_GREATER_THAN.String(), apiv1beta1.Predicate_GREATER_THAN_EQUALS.String(), apiv1beta1.Predicate_LESS_THAN.String(), apiv1beta1.Predicate_LESS_THAN_EQUALS.String(), apiv2beta1.Predicate_EQUALS.String(), apiv2beta1.Predicate_NOT_EQUALS.String(), apiv2beta1.Predicate_GREATER_THAN.String(), apiv2beta1.Predicate_GREATER_THAN_EQUALS.String(), apiv2beta1.Predicate_LESS_THAN.String(), apiv2beta1.Predicate_LESS_THAN_EQUALS.String(): + switch t := p.value.(type) { + case []int32, []int64, []string: + return util.NewInvalidInputError("cannot use scalar operator %v on array type %T", p.operation, t) } - - case api.Predicate_IS_SUBSTRING: - switch t := p.Value.(type) { - case *api.Predicate_StringValue: + case apiv1beta1.Predicate_IS_SUBSTRING.String(), apiv2beta1.Predicate_IS_SUBSTRING.String(): + switch t := p.value.(type) { + case string: return nil default: - return util.NewInvalidInputError("cannot use non string value type %T with operator %v", p.Op, t) + return util.NewInvalidInputError("cannot use non string value type %T with operator %v", p.operation, t) } - default: - return util.NewInvalidInputError("invalid predicate operation: %v", p.Op) + return util.NewInvalidInputError("invalid predicate operation: %v", p.operation) } return nil } -func (f *Filter) parseFilterProto() error { - for _, pred := range f.filterProto.Predicates { +func (f *Filter) parsePredicates(preds []*Predicate) error { + for _, pred := range preds { if err := checkPredicate(pred); err != nil { return err } - - var m map[string][]interface{} - switch pred.Op { - case api.Predicate_EQUALS: - m = f.eq - case api.Predicate_NOT_EQUALS: - m = f.neq - case api.Predicate_GREATER_THAN: - m = f.gt - case api.Predicate_GREATER_THAN_EQUALS: - m = f.gte - case api.Predicate_LESS_THAN: - m = f.lt - case api.Predicate_LESS_THAN_EQUALS: - m = f.lte - case api.Predicate_IN: - m = f.in - case api.Predicate_IS_SUBSTRING: - m = f.substring + switch pred.operation { + case "EQUALS": + f.eq[pred.key] = append(f.eq[pred.key], pred.value) + case "NOT_EQUALS": + f.neq[pred.key] = append(f.neq[pred.key], pred.value) + case "GREATER_THAN": + f.gt[pred.key] = append(f.gt[pred.key], pred.value) + case "GREATER_THAN_EQUALS": + f.gte[pred.key] = append(f.gte[pred.key], pred.value) + case "LESS_THAN": + f.lt[pred.key] = append(f.lt[pred.key], pred.value) + case "LESS_THAN_EQUALS": + f.lte[pred.key] = append(f.lte[pred.key], pred.value) + case "IN": + f.in[pred.key] = append(f.in[pred.key], pred.value) + case "IS_SUBSTRING": + f.substring[pred.key] = append(f.substring[pred.key], pred.value) default: - return util.NewInvalidInputError("invalid predicate operation: %v", pred.Op) - } - - if err := addPredicateValue(m, pred); err != nil { - return err + return util.NewInvalidInputError("invalid predicate operation: %v", pred.operation) } } return nil } -func addPredicateValue(m map[string][]interface{}, p *api.Predicate) error { - switch t := p.Value.(type) { - case *api.Predicate_IntValue: - m[p.Key] = append(m[p.Key], p.GetIntValue()) - case *api.Predicate_LongValue: - m[p.Key] = append(m[p.Key], p.GetLongValue()) - case *api.Predicate_StringValue: - m[p.Key] = append(m[p.Key], p.GetStringValue()) - case *api.Predicate_TimestampValue: - ts, err := ptypes.Timestamp(p.GetTimestampValue()) - if err != nil { - return util.NewInvalidInputError("invalid timestamp: %v", err) +func toPredicates(filterProto interface{}) ([]*Predicate, error) { + if filterProto == nil { + return nil, nil + } + predicates := make([]*Predicate, 0) + switch filterProto := filterProto.(type) { + case *apiv2beta1.Filter: + for _, p := range filterProto.GetPredicates() { + if pred, err := toPredicate(p); err != nil { + return nil, err + } else { + predicates = append(predicates, pred) + } } - m[p.Key] = append(m[p.Key], ts.Unix()) + case *apiv1beta1.Filter: + for _, p := range filterProto.GetPredicates() { + if pred, err := toPredicate(p); err != nil { + return nil, err + } else { + predicates = append(predicates, pred) + } + } + default: + return nil, util.NewUnknownApiVersionError("Filter", filterProto) + } + return predicates, nil +} - case *api.Predicate_IntValues: - v := p.GetIntValues().GetValues() - m[p.Key] = append(m[p.Key], v) +func toPredicate(p interface{}) (*Predicate, error) { + if p == nil { + return nil, nil + } + operation := "" + key := "" + var value interface{} + switch p := p.(type) { + case *apiv2beta1.Predicate: + key = p.GetKey() + if temp, err := toOperation(p.GetOperation()); err != nil { + return nil, err + } else { + operation = temp + } + if temp, err := toValue(p.GetValue()); err != nil { + return nil, err + } else { + value = temp + } + case *apiv1beta1.Predicate: + key = p.GetKey() + if temp, err := toOperation(p.GetOp()); err != nil { + return nil, err + } else { + operation = temp + } + if temp, err := toValue(p.GetValue()); err != nil { + return nil, err + } else { + value = temp + } + default: + return nil, util.NewUnknownApiVersionError("Filter.Predicate", p) + } + if key == "" { + return nil, util.NewInvalidInputError("Predicate key cannot be empty for operation %v and value %v", operation, value) + } + return &Predicate{ + operation: operation, + key: key, + value: value, + }, nil +} - case *api.Predicate_LongValues: - v := p.GetLongValues().GetValues() - m[p.Key] = append(m[p.Key], v) +func toOperation(o interface{}) (string, error) { + switch o { + case apiv2beta1.Predicate_EQUALS, apiv1beta1.Predicate_EQUALS: + return "EQUALS", nil + case apiv2beta1.Predicate_NOT_EQUALS, apiv1beta1.Predicate_NOT_EQUALS: + return "NOT_EQUALS", nil + case apiv2beta1.Predicate_GREATER_THAN, apiv1beta1.Predicate_GREATER_THAN: + return "GREATER_THAN", nil + case apiv2beta1.Predicate_GREATER_THAN_EQUALS, apiv1beta1.Predicate_GREATER_THAN_EQUALS: + return "GREATER_THAN_EQUALS", nil + case apiv2beta1.Predicate_LESS_THAN, apiv1beta1.Predicate_LESS_THAN: + return "LESS_THAN", nil + case apiv2beta1.Predicate_LESS_THAN_EQUALS, apiv1beta1.Predicate_LESS_THAN_EQUALS: + return "LESS_THAN_EQUALS", nil + case apiv2beta1.Predicate_IN, apiv1beta1.Predicate_IN: + return "IN", nil + case apiv2beta1.Predicate_IS_SUBSTRING, apiv1beta1.Predicate_IS_SUBSTRING: + return "IS_SUBSTRING", nil + default: + return "", util.NewUnknownApiVersionError("Filter.Predicate.Operation", o) + } +} - case *api.Predicate_StringValues: - v := p.GetStringValues().GetValues() - m[p.Key] = append(m[p.Key], v) +func toValue(v interface{}) (interface{}, error) { + switch v := v.(type) { + case *apiv2beta1.Predicate_IntValue: + return v.IntValue, nil + case *apiv2beta1.Predicate_LongValue: + return v.LongValue, nil + case *apiv2beta1.Predicate_StringValue: + return v.StringValue, nil + case *apiv2beta1.Predicate_TimestampValue: + ts, err := ptypes.Timestamp(v.TimestampValue) + if err != nil { + return nil, util.NewInvalidInputError("invalid timestamp: %v", err) + } + return ts.Unix(), nil + case *apiv2beta1.Predicate_IntValues_: + return v.IntValues.GetValues(), nil + case *apiv2beta1.Predicate_StringValues_: + return v.StringValues.GetValues(), nil + case *apiv2beta1.Predicate_LongValues_: + return v.LongValues.GetValues(), nil - case nil: - return util.NewInvalidInputError("no value set for predicate on key %q", p.Key) + case *apiv1beta1.Predicate_IntValue: + return v.IntValue, nil + case *apiv1beta1.Predicate_LongValue: + return v.LongValue, nil + case *apiv1beta1.Predicate_StringValue: + return v.StringValue, nil + case *apiv1beta1.Predicate_TimestampValue: + ts, err := ptypes.Timestamp(v.TimestampValue) + if err != nil { + return nil, util.NewInvalidInputError("invalid timestamp: %v", err) + } + return ts.Unix(), nil + case *apiv1beta1.Predicate_IntValues: + return v.IntValues.GetValues(), nil + case *apiv1beta1.Predicate_StringValues: + return v.StringValues.GetValues(), nil + case *apiv1beta1.Predicate_LongValues: + return v.LongValues.GetValues(), nil default: - return util.NewInvalidInputError("unknown value type in Filter for predicate key %q: %T", p.Key, t) + return nil, util.NewUnknownApiVersionError("Filter.Predicate.Value", v) } - - return nil } diff --git a/backend/src/apiserver/filter/filter_test.go b/backend/src/apiserver/filter/filter_test.go index b1a74214a3..105472cac9 100644 --- a/backend/src/apiserver/filter/filter_test.go +++ b/backend/src/apiserver/filter/filter_test.go @@ -22,17 +22,16 @@ import ( "github.com/golang/protobuf/proto" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - api "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client" + apiv1beta1 "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client" + apiv2beta1 "github.com/kubeflow/pipelines/backend/api/v2beta1/go_client" "github.com/kubeflow/pipelines/backend/src/apiserver/model" + "github.com/stretchr/testify/assert" "google.golang.org/protobuf/testing/protocmp" ) -func TestValidNewFilters(t *testing.T) { +func TestValidNewFiltersV1(t *testing.T) { opts := []cmp.Option{ cmp.AllowUnexported(Filter{}), - cmp.FilterPath(func(p cmp.Path) bool { - return p.String() == "filterProto" - }, cmp.Ignore()), cmpopts.EquateEmpty(), } @@ -90,7 +89,7 @@ func TestValidNewFilters(t *testing.T) { } for _, test := range tests { - filterProto := &api.Filter{} + filterProto := &apiv1beta1.Filter{} if err := proto.UnmarshalText(test.protoStr, filterProto); err != nil { t.Errorf("Failed to unmarshal Filter text proto\n%q\nError: %v", test.protoStr, err) continue @@ -103,12 +102,82 @@ func TestValidNewFilters(t *testing.T) { } } -func TestValidNewFiltersWithKeyMap(t *testing.T) { +func TestValidNewFilters(t *testing.T) { + opts := []cmp.Option{ + cmp.AllowUnexported(Filter{}), + cmpopts.EquateEmpty(), + } + + tests := []struct { + protoStr string + want *Filter + }{ + { + `predicates { key: "status" operation: EQUALS string_value: "Running" }`, + &Filter{eq: map[string][]interface{}{"status": {"Running"}}}, + }, + { + `predicates { key: "status" operation: NOT_EQUALS string_value: "Running" }`, + &Filter{neq: map[string][]interface{}{"status": {"Running"}}}, + }, + { + `predicates { key: "total" operation: GREATER_THAN int_value: 10 }`, + &Filter{gt: map[string][]interface{}{"total": {int32(10)}}}, + }, + { + `predicates { key: "total" operation: GREATER_THAN_EQUALS long_value: 10 }`, + &Filter{gte: map[string][]interface{}{"total": {int64(10)}}}, + }, + { + `predicates { key: "total" operation: LESS_THAN timestamp_value { seconds: 10 }}`, + &Filter{lt: map[string][]interface{}{"total": {int64(10)}}}, + }, + { + `predicates { key: "total" operation: LESS_THAN_EQUALS timestamp_value { seconds: 10 }}`, + &Filter{lte: map[string][]interface{}{"total": {int64(10)}}}, + }, + { + `predicates { + key: "label" operation: IN + string_values { values: 'label_1' values: 'label_2' } }`, + &Filter{in: map[string][]interface{}{"label": {[]string{"label_1", "label_2"}}}}, + }, + { + `predicates { + key: "intvalues" operation: IN + int_values { values: 10 values: 20 } }`, + &Filter{in: map[string][]interface{}{"intvalues": {[]int32{10, 20}}}}, + }, + { + `predicates { + key: "longvalues" operation: IN + long_values { values: 100 values: 200 } }`, + &Filter{in: map[string][]interface{}{"longvalues": {[]int64{100, 200}}}}, + }, + { + `predicates { + key: "label" operation: IS_SUBSTRING string_value: "label_substring" }`, + &Filter{substring: map[string][]interface{}{"label": {"label_substring"}}}, + }, + } + + for _, test := range tests { + filterProto := &apiv2beta1.Filter{} + if err := proto.UnmarshalText(test.protoStr, filterProto); err != nil { + t.Errorf("Failed to unmarshal Filter text proto\n%q\nError: %v", test.protoStr, err) + continue + } + + got, err := New(filterProto) + if !cmp.Equal(got, test.want, opts...) || err != nil { + t.Errorf("New(%+v) = %+v, %v\nWant %+v, nil", filterProto, got, err, test.want) + } + } +} + +func TestValidNewFiltersWithKeyMapV1(t *testing.T) { opts := []cmp.Option{ cmp.AllowUnexported(Filter{}), - cmp.FilterPath(func(p cmp.Path) bool { - return p.String() == "filterProto" - }, cmp.Ignore()), cmpopts.EquateEmpty(), } @@ -138,7 +207,7 @@ func TestValidNewFiltersWithKeyMap(t *testing.T) { } for _, test := range tests { - filterProto := &api.Filter{} + filterProto := &apiv1beta1.Filter{} if err := proto.UnmarshalText(test.protoStr, filterProto); err != nil { t.Errorf("Failed to unmarshal Filter text proto\n%q\nError: %v", test.protoStr, err) continue @@ -158,7 +227,59 @@ func TestValidNewFiltersWithKeyMap(t *testing.T) { } } -func TestInvalidFilters(t *testing.T) { +func TestValidNewFiltersWithKeyMap(t *testing.T) { + opts := []cmp.Option{ + cmp.AllowUnexported(Filter{}), + cmpopts.EquateEmpty(), + } + + tests := []struct { + protoStr string + want *Filter + }{ + { + `predicates { key: "name" operation: EQUALS string_value: "pipeline" }`, + &Filter{eq: map[string][]interface{}{"pipelines.Name": {"pipeline"}}}, + }, + { + `predicates { key: "name" operation: NOT_EQUALS string_value: "pipeline" }`, + &Filter{neq: map[string][]interface{}{"pipelines.Name": {"pipeline"}}}, + }, + { + `predicates { + key: "name" operation: IN + string_values { values: 'pipeline_1' values: 'pipeline_2' } }`, + &Filter{in: map[string][]interface{}{"pipelines.Name": {[]string{"pipeline_1", "pipeline_2"}}}}, + }, + { + `predicates { + key: "name" operation: IS_SUBSTRING string_value: "pipeline" }`, + &Filter{substring: map[string][]interface{}{"pipelines.Name": {"pipeline"}}}, + }, + } + + for _, test := range tests { + filterProto := &apiv2beta1.Filter{} + if err := proto.UnmarshalText(test.protoStr, filterProto); err != nil { + t.Errorf("Failed to unmarshal Filter text proto\n%q\nError: %v", test.protoStr, err) + continue + } + + keyMap := map[string]string{ + "id": "UUID", + "name": "Name", + "created_at": "CreatedAtInSec", + "description": "Description", + } + modelName := "pipelines" + got, err := NewWithKeyMap(filterProto, keyMap, modelName) + if !cmp.Equal(got, test.want, opts...) || err != nil { + t.Errorf("New(%+v) = %+v, %v\nWant %+v, nil", filterProto, got, err, test.want) + } + } +} + +func TestInvalidFiltersV1(t *testing.T) { tests := []struct { protoStr string }{ @@ -223,7 +344,7 @@ func TestInvalidFilters(t *testing.T) { } for _, test := range tests { - filterProto := &api.Filter{} + filterProto := &apiv1beta1.Filter{} if err := proto.UnmarshalText(test.protoStr, filterProto); err != nil { t.Errorf("Failed to unmarshal Filter text proto\n%q\nError: %v", test.protoStr, err) continue @@ -236,7 +357,85 @@ func TestInvalidFilters(t *testing.T) { } } -func TestAddToSelect(t *testing.T) { +func TestInvalidFilters(t *testing.T) { + tests := []struct { + protoStr string + }{ + { + `predicates { key: "status" operation: EQUALS + string_values { values: "v1" values: "v2" }} `, + }, + { + `predicates { key: "status" operation: NOT_EQUALS + string_values { values: "v1" values: "v2"} }`, + }, + { + `predicates { key: "total" operation: GREATER_THAN + int_values { values: 10 values: 20} }`, + }, + { + `predicates { key: "total" operation: GREATER_THAN_EQUALS + long_values { values: 10 values: 20} }`, + }, + { + `predicates { key: "total" operation: LESS_THAN + int_values { values: 10 values: 20} }`, + }, + { + `predicates { key: "total" operation: LESS_THAN_EQUALS + long_values { values: 10 values: 20} }`, + }, + { + `predicates { key: "total" operation: IS_SUBSTRING + long_values { values: 10 values: 20} }`, + }, + { + `predicates { key: "total" operation: IS_SUBSTRING + int_values { values: 10 values: 20} }`, + }, + + { + `predicates { key: "total" operation: IN int_value: 10 }`, + }, + { + `predicates { key: "total" operation: IN long_value: 200}`, + }, + { + `predicates { key: "total" operation: IN string_value: "value"}`, + }, + { + `predicates { key: "total" operation: IN timestamp_value { seconds: 10 }}`, + }, + // Invalid predicate + { + `predicates { key: "total" timestamp_value { seconds: 10 }}`, + }, + // No value + { + `predicates { key: "total" operation: IN }`, + }, + // Bad timestamp + { + `predicates { key: "total" operation: LESS_THAN + timestamp_value { seconds: -100000000000 }}`, + }, + } + + for _, test := range tests { + filterProto := &apiv2beta1.Filter{} + if err := proto.UnmarshalText(test.protoStr, filterProto); err != nil { + t.Errorf("Failed to unmarshal Filter text proto\n%q\nError: %v", test.protoStr, err) + continue + } + + got, err := New(filterProto) + if err == nil { + t.Errorf("New(%+v) = %+v, \nWant non-nil error ", filterProto, got) + } + } +} + +func TestAddToSelectV1(t *testing.T) { tests := []struct { protoStr string wantSQL string @@ -305,7 +504,96 @@ func TestAddToSelect(t *testing.T) { } for _, test := range tests { - filterProto := &api.Filter{} + filterProto := &apiv1beta1.Filter{} + if err := proto.UnmarshalText(test.protoStr, filterProto); err != nil { + t.Errorf("Failed to unmarshal Filter text proto\n%q\nError: %v", test.protoStr, err) + continue + } + + filter, err := New(filterProto) + if err != nil { + t.Errorf("New(%+v) = %+v, %v\nWant nil error", filterProto, filter, err) + continue + } + + sb := squirrel.Select("mycolumn") + gotSQL, gotArgs, err := filter.AddToSelect(sb).ToSql() + if !cmp.Equal(gotSQL, test.wantSQL) || !cmp.Equal(gotArgs, test.wantArgs) || err != nil { + t.Errorf("Filter.AddToSelect(%+v).ToSql() =\nGot: %+v, %v, %v\nWant: %+v, %+v, ", filter, gotSQL, gotArgs, err, test.wantSQL, test.wantArgs) + } + } +} + +func TestAddToSelect(t *testing.T) { + tests := []struct { + protoStr string + wantSQL string + wantArgs []interface{} + }{ + { + `predicates { key: "status" operation: EQUALS string_value: "Running" }`, + "SELECT mycolumn WHERE status = ?", + []interface{}{"Running"}, + }, + { + `predicates { key: "status" operation: EQUALS string_value: "Running" } + predicates { key: "status" operation: EQUALS string_value: "Stopped" }`, + "SELECT mycolumn WHERE status = ? AND status = ?", + []interface{}{"Running", "Stopped"}, + }, + { + `predicates { key: "status" operation: EQUALS string_value: "Running" }`, + "SELECT mycolumn WHERE status = ?", + []interface{}{"Running"}, + }, + { + `predicates { key: "status" operation: EQUALS string_value: "Running" } + predicates { key: "total" operation: GREATER_THAN_EQUALS long_value: 100 }`, + "SELECT mycolumn WHERE status = ? AND total >= ?", + []interface{}{"Running", int64(100)}, + }, + { + `predicates { key: "status" operation: NOT_EQUALS string_value: "Running" } + predicates { key: "total" operation: GREATER_THAN long_value: 100 }`, + "SELECT mycolumn WHERE status <> ? AND total > ?", + []interface{}{"Running", int64(100)}, + }, + { + `predicates { key: "date" operation: LESS_THAN timestamp_value { seconds: 10 } } + predicates { key: "total" operation: LESS_THAN_EQUALS int_value: 100 }`, + "SELECT mycolumn WHERE date < ? AND total <= ?", + []interface{}{int64(10), int32(100)}, + }, + { + `predicates { key: "total" operation: IN int_values {values: 1 values: 2 values: 3} }`, + "SELECT mycolumn WHERE total IN (?,?,?)", + []interface{}{int32(1), int32(2), int32(3)}, + }, + { + `predicates { key: "runs" operation: IN long_values {values: 100 values: 200}}`, + "SELECT mycolumn WHERE runs IN (?,?)", + []interface{}{int64(100), int64(200)}, + }, + { + `predicates { key: "label" operation: IN string_values {values: "l1" values: "l2"}}`, + "SELECT mycolumn WHERE label IN (?,?)", + []interface{}{"l1", "l2"}, + }, + { + `predicates { key: "label" operation: IS_SUBSTRING string_value: "label_substring" }`, + "SELECT mycolumn WHERE label LIKE ?", + []interface{}{"%label_substring%"}, + }, + { + `predicates { key: "label" operation: IS_SUBSTRING string_value: "label_substring1" } + predicates { key: "label" operation: IS_SUBSTRING string_value: "label_substring2" }`, + "SELECT mycolumn WHERE label LIKE ? AND label LIKE ?", + []interface{}{"%label_substring1%", "%label_substring2%"}, + }, + } + + for _, test := range tests { + filterProto := &apiv2beta1.Filter{} if err := proto.UnmarshalText(test.protoStr, filterProto); err != nil { t.Errorf("Failed to unmarshal Filter text proto\n%q\nError: %v", test.protoStr, err) continue @@ -327,18 +615,10 @@ func TestAddToSelect(t *testing.T) { func TestMarshalJSON(t *testing.T) { f := &Filter{ - filterProto: &api.Filter{ - Predicates: []*api.Predicate{ - { - Key: "Name", Op: api.Predicate_EQUALS, - Value: &api.Predicate_StringValue{StringValue: "SomeName"}, - }, - }, - }, eq: map[string][]interface{}{"name": {"SomeName"}}, } - want := `{"FilterProto":"{\"predicates\":[{\"op\":\"EQUALS\",\"key\":\"Name\",\"stringValue\":\"SomeName\"}]}","EQ":{"name":["SomeName"]},"NEQ":null,"GT":null,"GTE":null,"LT":null,"LTE":null,"IN":null,"SUBSTRING":null}` + want := `{"EQ":{"name":["SomeName"]},"NEQ":null,"GT":null,"GTE":null,"LT":null,"LTE":null,"IN":null,"SUBSTRING":null}` got, err := json.Marshal(f) if err != nil || string(got) != want { @@ -347,17 +627,9 @@ func TestMarshalJSON(t *testing.T) { } func TestUnmarshalJSON(t *testing.T) { - in := `{"FilterProto":"{\"predicates\":[{\"op\":\"EQUALS\",\"key\":\"Name\",\"stringValue\":\"SomeName\"}]}","EQ":{"name":["SomeName"]},"NEQ":null,"GT":null,"GTE":null,"LT":null,"LTE":null,"IN":null,"SUBSTRING":null}` + in := `{"EQ":{"name":["SomeName"]},"NEQ":null,"GT":null,"GTE":null,"LT":null,"LTE":null,"IN":null,"SUBSTRING":null}` want := &Filter{ - filterProto: &api.Filter{ - Predicates: []*api.Predicate{ - { - Key: "Name", Op: api.Predicate_EQUALS, - Value: &api.Predicate_StringValue{StringValue: "SomeName"}, - }, - }, - }, eq: map[string][]interface{}{"name": {"SomeName"}}, } @@ -369,25 +641,17 @@ func TestUnmarshalJSON(t *testing.T) { } func TestNewWithKeyMap(t *testing.T) { - filterProto := &api.Filter{ - Predicates: []*api.Predicate{ + filterProto := &apiv1beta1.Filter{ + Predicates: []*apiv1beta1.Predicate{ { Key: "finished_at", - Op: api.Predicate_GREATER_THAN, - Value: &api.Predicate_StringValue{StringValue: "SomeTime"}, + Op: apiv1beta1.Predicate_GREATER_THAN, + Value: &apiv1beta1.Predicate_StringValue{StringValue: "SomeTime"}, }, }, } want := &Filter{ - filterProto: &api.Filter{ - Predicates: []*api.Predicate{ - { - Key: "runs.FinishedAtInSec", Op: api.Predicate_GREATER_THAN, - Value: &api.Predicate_StringValue{StringValue: "SomeTime"}, - }, - }, - }, gt: map[string][]interface{}{"runs.FinishedAtInSec": {"SomeTime"}}, } @@ -397,3 +661,60 @@ func TestNewWithKeyMap(t *testing.T) { t.Errorf("NewWithKeyMap(%+v):\nGot: %+v, Error: %v\nWant:\n%+v, Error: nil\n", filterProto, got, err, want) } } + +func TestFilter_ReplaceKeys(t *testing.T) { + argEQ := make(map[string][]interface{}) + argEQ["namespace"] = append(argEQ["namespace"], "kubeflow") + argEQ["created_at"] = append(argEQ["created_at"], int64(100)) + + argIN := make(map[string][]interface{}) + argIN["name"] = append(argIN["name"], "MyPipeline") + argIN["name"] = append(argIN["name"], "Default") + + expectedEQ := make(map[string][]interface{}) + expectedEQ["pipelines.Namespace"] = append(expectedEQ["pipelines.Namespace"], "kubeflow") + expectedEQ["pipelines.CreatedAtInSec"] = append(expectedEQ["pipelines.CreatedAtInSec"], int64(100)) + + expectedIN := make(map[string][]interface{}) + expectedIN["pipelines.Name"] = append(expectedIN["pipelines.Name"], "MyPipeline") + expectedIN["pipelines.Name"] = append(expectedIN["pipelines.Name"], "Default") + + tests := []struct { + name string + filter *Filter + replaceMap map[string]string + prefix string + want *Filter + }{ + { + "valid - pipelines", + &Filter{ + eq: argEQ, + in: argIN, + }, + map[string]string{ + "id": "UUID", + "pipeline_id": "UUID", + "name": "Name", + "display_name": "Name", + "created_at": "CreatedAtInSec", + "description": "Description", + "namespace": "Namespace", + }, + "pipelines", + &Filter{ + eq: expectedEQ, + in: expectedIN, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.filter.ReplaceKeys(tt.replaceMap, tt.prefix) + assert.Nil(t, err) + if err != nil || !cmp.Equal(tt.filter, tt.want, cmpopts.EquateEmpty(), protocmp.Transform(), cmp.AllowUnexported(Filter{})) { + t.Errorf("ReplaceKeys: Got: %v, Error: %v Want: %v", tt.filter, err.Error(), tt.want) + } + }) + } +} diff --git a/backend/src/apiserver/list/list.go b/backend/src/apiserver/list/list.go index f92d90f971..0d8c2eb4ee 100644 --- a/backend/src/apiserver/list/list.go +++ b/backend/src/apiserver/list/list.go @@ -26,7 +26,6 @@ import ( "strings" sq "github.com/Masterminds/squirrel" - api "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client" "github.com/kubeflow/pipelines/backend/src/apiserver/filter" "github.com/kubeflow/pipelines/backend/src/apiserver/model" "github.com/kubeflow/pipelines/backend/src/common/util" @@ -127,7 +126,7 @@ func NewOptionsFromToken(nextPageToken string, pageSize int) (*Options, error) { // NewOptions creates a new Options struct for the given listable. It uses // sorting and filtering criteria parsed from sortBy and filterProto // respectively. -func NewOptions(listable Listable, pageSize int, sortBy string, filterProto *api.Filter) (*Options, error) { +func NewOptions(listable Listable, pageSize int, sortBy string, filter *filter.Filter) (*Options, error) { pageSize, err := validatePageSize(pageSize) if err != nil { return nil, err @@ -163,14 +162,12 @@ func NewOptions(listable Listable, pageSize int, sortBy string, filterProto *api } // Filtering. - if filterProto != nil { - f, err := filter.NewWithKeyMap(filterProto, listable.APIToModelFieldMap(), listable.GetModelName()) - if err != nil { + if filter != nil { + if err := filter.ReplaceKeys(listable.APIToModelFieldMap(), listable.GetModelName()); err != nil { return nil, err } - token.Filter = f + token.Filter = filter } - return &Options{PageSize: pageSize, token: token}, nil } diff --git a/backend/src/apiserver/list/list_test.go b/backend/src/apiserver/list/list_test.go index 1a630e539e..1806e158ee 100644 --- a/backend/src/apiserver/list/list_test.go +++ b/backend/src/apiserver/list/list_test.go @@ -451,6 +451,7 @@ func TestNewOptions_ValidFilter(t *testing.T) { }, }, } + newFilter, _ := filter.New(protoFilter) protoFilterWithRightKeyNames := &api.Filter{ Predicates: []*api.Predicate{ @@ -467,7 +468,7 @@ func TestNewOptions_ValidFilter(t *testing.T) { t.Fatalf("failed to parse filter proto %+v: %v", protoFilter, err) } - got, err := NewOptions(&fakeListable{}, 10, "timestamp", protoFilter) + got, err := NewOptions(&fakeListable{}, 10, "timestamp", newFilter) want := &Options{ PageSize: 10, token: &token{ @@ -502,8 +503,9 @@ func TestNewOptions_InvalidFilter(t *testing.T) { }, }, } + newFilter, _ := filter.New(protoFilter) - got, err := NewOptions(&fakeListable{}, 10, "timestamp", protoFilter) + got, err := NewOptions(&fakeListable{}, 10, "timestamp", newFilter) if err == nil { t.Errorf("NewOptions(protoFilter=%+v) =\nGot: %+v, \nWant error", protoFilter, got) } @@ -519,6 +521,7 @@ func TestNewOptions_ModelFilter(t *testing.T) { }, }, } + newFilter, _ := filter.New(protoFilter) protoFilterWithRightKeyNames := &api.Filter{ Predicates: []*api.Predicate{ @@ -535,7 +538,7 @@ func TestNewOptions_ModelFilter(t *testing.T) { t.Fatalf("failed to parse filter proto %+v: %v", protoFilter, err) } - got, err := NewOptions(&model.Run{}, 10, "name", protoFilter) + got, err := NewOptions(&model.Run{}, 10, "name", newFilter) want := &Options{ PageSize: 10, token: &token{ @@ -1020,7 +1023,8 @@ func TestAddSortingToSelectWithPipelineVersionModel(t *testing.T) { CodeSourceUrl: "", } protoFilter := &api.Filter{} - listableOptions, err := NewOptions(listable, 10, "name", protoFilter) + newFilter, _ := filter.New(protoFilter) + listableOptions, err := NewOptions(listable, 10, "name", newFilter) assert.Nil(t, err) sqlBuilder := sq.Select("*").From("pipeline_versions") sql, _, err := listableOptions.AddSortingToSelect(sqlBuilder).ToSql() @@ -1048,7 +1052,8 @@ func TestAddStatusFilterToSelectWithRunModel(t *testing.T) { Value: &api.Predicate_StringValue{StringValue: "Succeeded"}, }, } - listableOptions, err := NewOptions(listable, 10, "name", protoFilter) + newFilter, _ := filter.New(protoFilter) + listableOptions, err := NewOptions(listable, 10, "name", newFilter) assert.Nil(t, err) sqlBuilder := sq.Select("*").From("run_details") sql, args, err := listableOptions.AddFilterToSelect(sqlBuilder).ToSql() @@ -1064,7 +1069,8 @@ func TestAddStatusFilterToSelectWithRunModel(t *testing.T) { Value: &api.Predicate_StringValue{StringValue: "somevalue"}, }, } - listableOptions, err = NewOptions(listable, 10, "name", notEqualProtoFilter) + newNotEqualFilter, _ := filter.New(notEqualProtoFilter) + listableOptions, err = NewOptions(listable, 10, "name", newNotEqualFilter) assert.Nil(t, err) sqlBuilder = sq.Select("*").From("run_details") sql, args, err = listableOptions.AddFilterToSelect(sqlBuilder).ToSql() diff --git a/backend/src/apiserver/server/experiment_server.go b/backend/src/apiserver/server/experiment_server.go index b83b993bab..0b3746e8d7 100644 --- a/backend/src/apiserver/server/experiment_server.go +++ b/backend/src/apiserver/server/experiment_server.go @@ -21,6 +21,7 @@ import ( apiv1beta1 "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client" apiv2beta1 "github.com/kubeflow/pipelines/backend/api/v2beta1/go_client" "github.com/kubeflow/pipelines/backend/src/apiserver/common" + "github.com/kubeflow/pipelines/backend/src/apiserver/list" "github.com/kubeflow/pipelines/backend/src/apiserver/model" "github.com/kubeflow/pipelines/backend/src/apiserver/resource" "github.com/kubeflow/pipelines/backend/src/common/util" @@ -196,7 +197,7 @@ func (s *ExperimentServer) GetExperiment(ctx context.Context, request *apiv2beta return apiExperiment, nil } -func (s *ExperimentServer) listExperiments(ctx context.Context, pageToken string, pageSize int32, sortBy string, filter string, namespace string) ([]*model.Experiment, int32, string, error) { +func (s *ExperimentServer) listExperiments(ctx context.Context, pageToken string, pageSize int32, sortBy string, opts *list.Options, namespace string) ([]*model.Experiment, int32, string, error) { namespace = s.resourceManager.ReplaceNamespace(namespace) resourceAttributes := &authorizationv1.ResourceAttributes{ Namespace: namespace, @@ -210,10 +211,6 @@ func (s *ExperimentServer) listExperiments(ctx context.Context, pageToken string ReferenceKey: &model.ReferenceKey{Type: model.NamespaceResourceType, ID: namespace}, } - opts, err := validatedListOptions(&model.Experiment{}, pageToken, int(pageSize), sortBy, filter) - if err != nil { - return nil, 0, "", util.Wrap(err, "Failed to create list options") - } experiments, totalSize, nextPageToken, err := s.resourceManager.ListExperiments(filterContext, opts) if err != nil { return nil, 0, "", util.Wrap(err, "List experiments failed") @@ -241,12 +238,17 @@ func (s *ExperimentServer) ListExperimentsV1(ctx context.Context, request *apiv1 } } + opts, err := validatedListOptions(&model.Experiment{}, request.GetPageToken(), int(request.GetPageSize()), request.GetSortBy(), request.GetFilter(), "v1beta1") + if err != nil { + return nil, util.Wrap(err, "Failed to create list options") + } + experiments, totalSize, nextPageToken, err := s.listExperiments( ctx, request.GetPageToken(), request.GetPageSize(), request.GetSortBy(), - request.GetFilter(), + opts, namespace, ) if err != nil { @@ -266,7 +268,12 @@ func (s *ExperimentServer) ListExperiments(ctx context.Context, request *apiv2be listExperimentsV1Requests.Inc() } - experiments, totalSize, nextPageToken, err := s.listExperiments(ctx, request.GetPageToken(), request.GetPageSize(), request.GetSortBy(), request.GetFilter(), request.GetNamespace()) + opts, err := validatedListOptions(&model.Experiment{}, request.GetPageToken(), int(request.GetPageSize()), request.GetSortBy(), request.GetFilter(), "v2beta1") + if err != nil { + return nil, util.Wrap(err, "Failed to create list options") + } + + experiments, totalSize, nextPageToken, err := s.listExperiments(ctx, request.GetPageToken(), request.GetPageSize(), request.GetSortBy(), opts, request.GetNamespace()) if err != nil { return nil, util.Wrap(err, "List experiments failed") } diff --git a/backend/src/apiserver/server/job_server.go b/backend/src/apiserver/server/job_server.go index 403f5da28d..1f4b95acae 100644 --- a/backend/src/apiserver/server/job_server.go +++ b/backend/src/apiserver/server/job_server.go @@ -21,6 +21,7 @@ import ( apiv1beta1 "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client" apiv2beta1 "github.com/kubeflow/pipelines/backend/api/v2beta1/go_client" "github.com/kubeflow/pipelines/backend/src/apiserver/common" + "github.com/kubeflow/pipelines/backend/src/apiserver/list" "github.com/kubeflow/pipelines/backend/src/apiserver/model" "github.com/kubeflow/pipelines/backend/src/apiserver/resource" "github.com/kubeflow/pipelines/backend/src/common/util" @@ -173,7 +174,7 @@ func (s *JobServer) GetJob(ctx context.Context, request *apiv1beta1.GetJobReques return apiJob, nil } -func (s *JobServer) listJobs(ctx context.Context, pageToken string, pageSize int, sortBy string, filter string, namespace string, experimentId string) ([]*model.Job, int, string, error) { +func (s *JobServer) listJobs(ctx context.Context, pageToken string, pageSize int, sortBy string, opts *list.Options, namespace string, experimentId string) ([]*model.Job, int, string, error) { namespace = s.resourceManager.ReplaceNamespace(namespace) if experimentId != "" { ns, err := s.resourceManager.GetNamespaceFromExperimentId(experimentId) @@ -191,10 +192,6 @@ func (s *JobServer) listJobs(ctx context.Context, pageToken string, pageSize int return nil, 0, "", util.Wrapf(err, "Failed to list recurring runs due to authorization error. Check if you have permission to access namespace %s", namespace) } - opts, err := validatedListOptions(&model.Job{}, pageToken, pageSize, sortBy, filter) - if err != nil { - return nil, 0, "", util.Wrap(err, "Failed to create list options") - } filterContext := &model.FilterContext{ ReferenceKey: &model.ReferenceKey{Type: model.NamespaceResourceType, ID: namespace}, } @@ -233,7 +230,13 @@ func (s *JobServer) ListJobs(ctx context.Context, r *apiv1beta1.ListJobsRequest) experimentId = filterContext.ReferenceKey.ID } } - jobs, total_size, nextPageToken, err := s.listJobs(ctx, r.GetPageToken(), int(r.GetPageSize()), r.GetSortBy(), r.GetFilter(), namespace, experimentId) + + opts, err := validatedListOptions(&model.Job{}, r.GetPageToken(), int(r.GetPageSize()), r.GetSortBy(), r.GetFilter(), "v1beta1") + if err != nil { + return nil, util.Wrap(err, "Failed to list jobs due to error parsing the listing options") + } + + jobs, total_size, nextPageToken, err := s.listJobs(ctx, r.GetPageToken(), int(r.GetPageSize()), r.GetSortBy(), opts, namespace, experimentId) if err != nil { return nil, util.Wrap(err, "Failed to list jobs") } @@ -357,7 +360,12 @@ func (s *JobServer) ListRecurringRuns(ctx context.Context, r *apiv2beta1.ListRec listJobRequests.Inc() } - jobs, total_size, nextPageToken, err := s.listJobs(ctx, r.GetPageToken(), int(r.GetPageSize()), r.GetSortBy(), r.GetFilter(), r.GetNamespace(), r.GetExperimentId()) + opts, err := validatedListOptions(&model.Job{}, r.GetPageToken(), int(r.GetPageSize()), r.GetSortBy(), r.GetFilter(), "v2beta1") + if err != nil { + return nil, util.Wrap(err, "Failed to list recurring runs due to error parsing the listing options") + } + + jobs, total_size, nextPageToken, err := s.listJobs(ctx, r.GetPageToken(), int(r.GetPageSize()), r.GetSortBy(), opts, r.GetNamespace(), r.GetExperimentId()) if err != nil { return nil, util.Wrap(err, "Failed to list jobs") } diff --git a/backend/src/apiserver/server/list_request_util.go b/backend/src/apiserver/server/list_request_util.go index 196b8a9a8b..c5bc8c71cf 100644 --- a/backend/src/apiserver/server/list_request_util.go +++ b/backend/src/apiserver/server/list_request_util.go @@ -22,8 +22,10 @@ import ( "strings" "github.com/golang/protobuf/jsonpb" - api "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client" + apiv1beta1 "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client" + apiv2beta1 "github.com/kubeflow/pipelines/backend/api/v2beta1/go_client" "github.com/kubeflow/pipelines/backend/src/apiserver/common" + "github.com/kubeflow/pipelines/backend/src/apiserver/filter" "github.com/kubeflow/pipelines/backend/src/apiserver/list" "github.com/kubeflow/pipelines/backend/src/apiserver/model" "github.com/kubeflow/pipelines/backend/src/common/util" @@ -34,7 +36,7 @@ const ( maxPageSize = 200 ) -func validateFilterV1(referenceKey *api.ResourceKey) (*model.FilterContext, error) { +func validateFilterV1(referenceKey *apiv1beta1.ResourceKey) (*model.FilterContext, error) { filterContext := &model.FilterContext{} if referenceKey != nil { refType, err := toModelResourceTypeV1(referenceKey.Type) @@ -135,39 +137,50 @@ func deserializePageToken(pageToken string) (*common.Token, error) { // parseAPIFilter attempts to decode a url-encoded JSON-stringified api // filter object. An empty string is considered valid input, and equivalent to // the nil filter, which trivially does nothing. -func parseAPIFilter(encoded string) (*api.Filter, error) { +func parseAPIFilter(encoded string, apiVersion string) (interface{}, error) { if encoded == "" { return nil, nil } - - errorF := func(err error) (*api.Filter, error) { - return nil, util.NewInvalidInputError("failed to parse valid filter from %q: %v", encoded, err) - } - decoded, err := url.QueryUnescape(encoded) if err != nil { - return errorF(err) + return nil, util.NewInvalidInputError("failed to parse valid filter from %q: %v", encoded, err) } - - f := &api.Filter{} - if err := jsonpb.UnmarshalString(decoded, f); err != nil { - return errorF(err) + switch apiVersion { + case "v2beta1": + f := &apiv2beta1.Filter{} + if err := jsonpb.UnmarshalString(decoded, f); err != nil { + return nil, util.NewInvalidInputError("failed to parse valid filter from %q: %v", encoded, err) + } + return f, nil + case "v1beta1": + f := &apiv1beta1.Filter{} + if err := jsonpb.UnmarshalString(decoded, f); err != nil { + return nil, util.NewInvalidInputError("failed to parse valid filter from %q: %v", encoded, err) + } + return f, nil + default: + return nil, util.NewUnknownApiVersionError("filter "+apiVersion, encoded) } - return f, nil } -func validatedListOptions(listable list.Listable, pageToken string, pageSize int, sortBy string, filterSpec string) (*list.Options, error) { +// Validates list options for a given resource and listing parameters. +// apiVersion cat be set to "v1beta1" or "v2beta1". Depending on the value, +// the corresponding API filter message will be used when parsing filterSpec. +func validatedListOptions(listable list.Listable, pageToken string, pageSize int, sortBy string, filterSpec string, apiVersion string) (*list.Options, error) { defaultOpts := func() (*list.Options, error) { if listable == nil { return nil, util.NewInvalidInputError("Please specify a valid type to list. E.g., list runs or list jobs") } - - f, err := parseAPIFilter(filterSpec) + f, err := parseAPIFilter(filterSpec, apiVersion) + if err != nil { + return nil, err + } + newFilter, err := filter.New(f) if err != nil { return nil, err } - return list.NewOptions(listable, pageSize, sortBy, f) + return list.NewOptions(listable, pageSize, sortBy, newFilter) } if pageToken == "" { diff --git a/backend/src/apiserver/server/list_request_util_test.go b/backend/src/apiserver/server/list_request_util_test.go index 792b372c9d..98efaaf7a6 100644 --- a/backend/src/apiserver/server/list_request_util_test.go +++ b/backend/src/apiserver/server/list_request_util_test.go @@ -21,7 +21,8 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - api "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client" + apiv1beta1 "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client" + apiv2beta1 "github.com/kubeflow/pipelines/backend/api/v2beta1/go_client" "github.com/kubeflow/pipelines/backend/src/apiserver/common" "github.com/kubeflow/pipelines/backend/src/apiserver/list" "github.com/kubeflow/pipelines/backend/src/apiserver/model" @@ -48,14 +49,14 @@ func getFakeModelToken() string { } func TestValidateFilterV1(t *testing.T) { - referenceKey := &api.ResourceKey{Type: api.ResourceType_EXPERIMENT, Id: "123"} + referenceKey := &apiv1beta1.ResourceKey{Type: apiv1beta1.ResourceType_EXPERIMENT, Id: "123"} ctx, err := validateFilterV1(referenceKey) assert.Nil(t, err) assert.Equal(t, &model.FilterContext{ReferenceKey: &model.ReferenceKey{Type: model.ExperimentResourceType, ID: "123"}}, ctx) } func TestValidateFilterV1_ToModelResourceTypeFailed(t *testing.T) { - referenceKey := &api.ResourceKey{Type: api.ResourceType_UNKNOWN_RESOURCE_TYPE, Id: "123"} + referenceKey := &apiv1beta1.ResourceKey{Type: apiv1beta1.ResourceType_UNKNOWN_RESOURCE_TYPE, Id: "123"} _, err := validateFilterV1(referenceKey) assert.NotNil(t, err) assert.Contains(t, err.Error(), "Unrecognized resource reference type") @@ -185,20 +186,26 @@ func TestParseSortByQueryString_StringTooLong(t *testing.T) { } func TestParseAPIFilter_EmptyStringYieldsNilFilter(t *testing.T) { - f, err := parseAPIFilter("") + f, err := parseAPIFilter("", "v1beta1") + assert.Nil(t, err) + assert.Nil(t, f) + f, err = parseAPIFilter("", "v2beta1") assert.Nil(t, err) assert.Nil(t, f) } func TestParseAPIFilter_InvalidStringYieldsError(t *testing.T) { - f, err := parseAPIFilter("lkjlkjlkj") + f, err := parseAPIFilter("lkjlkjlkj", "v1beta1") + assert.NotNil(t, err) + assert.Nil(t, f) + f, err = parseAPIFilter("lkjlkjlkj", "v2beta1") assert.NotNil(t, err) assert.Nil(t, f) } -func TestParseAPIFilter_DecodesEncodedString(t *testing.T) { +func TestParseAPIFilter_DecodesEncodedStringV1(t *testing.T) { // in was generated by calling JSON.stringify followed by encodeURIComponent in // the browser on the following JSON string: // {"predicates":[{"op":"EQUALS","key":"testkey","stringValue":"testvalue"}]} @@ -206,16 +213,40 @@ func TestParseAPIFilter_DecodesEncodedString(t *testing.T) { in := "%7B%22predicates%22%3A%5B%7B%22op%22%3A%22EQUALS%22%2C%22key%22%3A%22testkey%22%2C%22stringValue%22%3A%22testvalue%22%7D%5D%7D" // The above should correspond the following filter: - want := &api.Filter{ - Predicates: []*api.Predicate{ + want := &apiv1beta1.Filter{ + Predicates: []*apiv1beta1.Predicate{ { - Key: "testkey", Op: api.Predicate_EQUALS, - Value: &api.Predicate_StringValue{StringValue: "testvalue"}, + Key: "testkey", Op: apiv1beta1.Predicate_EQUALS, + Value: &apiv1beta1.Predicate_StringValue{StringValue: "testvalue"}, }, }, } - got, err := parseAPIFilter(in) + got, err := parseAPIFilter(in, "v1beta1") + if !cmp.Equal(got, want, cmpopts.EquateEmpty(), protocmp.Transform()) || err != nil { + t.Errorf("parseAPIString(%q) =\nGot %+v, %v\n Want %+v, \nDiff: %s", + in, got, err, want, cmp.Diff(want, got)) + } +} + +func TestParseAPIFilter_DecodesEncodedString(t *testing.T) { + // in was generated by calling JSON.stringify followed by encodeURIComponent in + // the browser on the following JSON string: + // {"predicates":[{"operation":"EQUALS","key":"testkey","stringValue":"testvalue"}]} + + in := "%7B%22predicates%22%3A%5B%7B%22operation%22%3A%22EQUALS%22%2C%22key%22%3A%22testkey%22%2C%22stringValue%22%3A%22testvalue%22%7D%5D%7D" + + // The above should correspond the following filter: + want := &apiv2beta1.Filter{ + Predicates: []*apiv2beta1.Predicate{ + { + Key: "testkey", Operation: apiv2beta1.Predicate_EQUALS, + Value: &apiv2beta1.Predicate_StringValue{StringValue: "testvalue"}, + }, + }, + } + + got, err := parseAPIFilter(in, "v2beta1") if !cmp.Equal(got, want, cmpopts.EquateEmpty(), protocmp.Transform()) || err != nil { t.Errorf("parseAPIString(%q) =\nGot %+v, %v\n Want %+v, \nDiff: %s", in, got, err, want, cmp.Diff(want, got)) @@ -289,12 +320,22 @@ func TestValidatedListOptions_Errors(t *testing.T) { t.Fatalf("opt.NextPageToken() = _, %+v; Want nil error", err) } - _, err = validatedListOptions(&fakeListable{}, npt, 10, "name asc", "") + _, err = validatedListOptions(&fakeListable{}, npt, 10, "name asc", "", "v1beta1") if err != nil { t.Fatalf("validatedListOptions(fakeListable, 10, \"name asc\") = _, %+v; Want nil error", err) } - _, err = validatedListOptions(&fakeListable{}, npt, 10, "name desc", "") + _, err = validatedListOptions(&fakeListable{}, npt, 10, "name asc", "", "v2beta1") + if err != nil { + t.Fatalf("validatedListOptions(fakeListable, 10, \"name asc\") = _, %+v; Want nil error", err) + } + + _, err = validatedListOptions(&fakeListable{}, npt, 10, "name desc", "", "v1beta1") + if err == nil { + t.Fatalf("validatedListOptions(fakeListable, 10, \"name desc\") = _, %+v; Want error", err) + } + + _, err = validatedListOptions(&fakeListable{}, npt, 10, "name desc", "", "v2beta1") if err == nil { t.Fatalf("validatedListOptions(fakeListable, 10, \"name desc\") = _, %+v; Want error", err) } diff --git a/backend/src/apiserver/server/pipeline_server.go b/backend/src/apiserver/server/pipeline_server.go index fe2495725b..850b1c7bfa 100644 --- a/backend/src/apiserver/server/pipeline_server.go +++ b/backend/src/apiserver/server/pipeline_server.go @@ -25,6 +25,7 @@ import ( apiv1beta1 "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client" apiv2beta1 "github.com/kubeflow/pipelines/backend/api/v2beta1/go_client" "github.com/kubeflow/pipelines/backend/src/apiserver/common" + "github.com/kubeflow/pipelines/backend/src/apiserver/list" "github.com/kubeflow/pipelines/backend/src/apiserver/model" "github.com/kubeflow/pipelines/backend/src/apiserver/resource" "github.com/kubeflow/pipelines/backend/src/common/util" @@ -359,7 +360,7 @@ func (s *PipelineServer) GetPipelineByName(ctx context.Context, request *apiv2be // Fetches an array of pipelines and an array of pipeline versions for given search query parameters. // Applies common logic on v1beta1 and v2beta1 API. -func (s *PipelineServer) listPipelines(ctx context.Context, namespace string, pageToken string, pageSize int32, sortBy string, filter string, apiRequestVersion string) ([]*model.Pipeline, []*model.PipelineVersion, int, string, error) { +func (s *PipelineServer) listPipelines(ctx context.Context, namespace string, pageToken string, pageSize int32, sortBy string, opts *list.Options, apiRequestVersion string) ([]*model.Pipeline, []*model.PipelineVersion, int, string, error) { // Fill in the default namespace namespace = s.resourceManager.ReplaceNamespace(namespace) if common.IsMultiUserMode() { @@ -375,12 +376,6 @@ func (s *PipelineServer) listPipelines(ctx context.Context, namespace string, pa ReferenceKey: &model.ReferenceKey{Type: model.NamespaceResourceType, ID: namespace}, } - // Validate list options - opts, err := validatedListOptions(&model.Pipeline{}, pageToken, int(pageSize), sortBy, filter) - if err != nil { - return nil, nil, 0, "", util.Wrapf(err, "Failed to list pipelines due invalid list options: pageToken: %v, pageSize: %v, sortBy: %v, filter: %v", pageToken, int(pageSize), sortBy, filter) - } - // List pipelines switch apiRequestVersion { case "v1beta1": @@ -427,7 +422,13 @@ func (s *PipelineServer) ListPipelinesV1(ctx context.Context, request *apiv1beta sortBy := request.GetSortBy() filter := request.GetFilter() - pipelines, pipelineVersions, totalSize, nextPageToken, err := s.listPipelines(ctx, namespace, pageToken, pageSize, sortBy, filter, "v1beta1") + // Validate list options + opts, err := validatedListOptions(&model.Pipeline{}, pageToken, int(pageSize), sortBy, filter, "v1beta1") + if err != nil { + return nil, util.Wrapf(err, "Failed to list pipelines due invalid list options: pageToken: %v, pageSize: %v, sortBy: %v, filter: %v", pageToken, int(pageSize), sortBy, filter) + } + + pipelines, pipelineVersions, totalSize, nextPageToken, err := s.listPipelines(ctx, namespace, pageToken, pageSize, sortBy, opts, "v1beta1") if err != nil { return nil, util.Wrapf(err, "Failed to list pipelines (v1beta1) in namespace %s. Check error stack", namespace) } @@ -449,7 +450,13 @@ func (s *PipelineServer) ListPipelines(ctx context.Context, request *apiv2beta1. sortBy := request.GetSortBy() filter := request.GetFilter() - pipelines, _, totalSize, nextPageToken, err := s.listPipelines(ctx, namespace, pageToken, pageSize, sortBy, filter, "v2beta1") + // Validate list options + opts, err := validatedListOptions(&model.Pipeline{}, pageToken, int(pageSize), sortBy, filter, "v2beta1") + if err != nil { + return nil, util.Wrapf(err, "Failed to list pipelines due invalid list options: pageToken: %v, pageSize: %v, sortBy: %v, filter: %v", pageToken, int(pageSize), sortBy, filter) + } + + pipelines, _, totalSize, nextPageToken, err := s.listPipelines(ctx, namespace, pageToken, pageSize, sortBy, opts, "v2beta1") if err != nil { return nil, util.Wrapf(err, "Failed to list pipelines in namespace %s. Check error stack", namespace) } @@ -767,24 +774,12 @@ func (s *PipelineServer) GetPipelineVersion(ctx context.Context, request *apiv2b // Fetches an array of pipeline versions for given search query parameters. // Applies common logic on v1beta1 and v2beta1 API. -func (s *PipelineServer) listPipelineVersions(ctx context.Context, pipelineId string, pageToken string, pageSize int32, sortBy string, filter string) ([]*model.PipelineVersion, int, string, error) { +func (s *PipelineServer) listPipelineVersions(ctx context.Context, pipelineId string, pageToken string, pageSize int32, sortBy string, opts *list.Options) ([]*model.PipelineVersion, int, string, error) { // Fail fast of pipeline id or namespace are missing if pipelineId == "" { return nil, 0, "", util.NewInvalidInputError("Failed to list pipeline versions. Pipeline id cannot be empty") } - // Validate query parameters - opts, err := validatedListOptions( - &model.PipelineVersion{}, - pageToken, - int(pageSize), - sortBy, - filter, - ) - if err != nil { - return nil, 0, "", util.Wrapf(err, "Failed to list pipeline versions due invalid list options: pageToken: %v, pageSize: %v, sortBy: %v, filter: %v", pageToken, int(pageSize), sortBy, filter) - } - // Check authorization if common.IsMultiUserMode() { namespace, err := s.resourceManager.FetchNamespaceFromPipelineId(pipelineId) @@ -823,7 +818,13 @@ func (s *PipelineServer) ListPipelineVersionsV1(ctx context.Context, request *ap sortBy := request.GetSortBy() filter := request.GetFilter() - pipelineVersions, totalSize, nextPageToken, err := s.listPipelineVersions(ctx, pipelineId, pageToken, pageSize, sortBy, filter) + // Validate query parameters + opts, err := validatedListOptions(&model.PipelineVersion{}, pageToken, int(pageSize), sortBy, filter, "v1beta1") + if err != nil { + return nil, util.Wrapf(err, "Failed to list pipeline versions due invalid list options: pageToken: %v, pageSize: %v, sortBy: %v, filter: %v", pageToken, int(pageSize), sortBy, filter) + } + + pipelineVersions, totalSize, nextPageToken, err := s.listPipelineVersions(ctx, pipelineId, pageToken, pageSize, sortBy, opts) if err != nil { return nil, util.Wrapf(err, "Failed to list pipeline versions (v1beta1) with pipeline id %s. Check error stack", pipelineId) } @@ -852,7 +853,13 @@ func (s *PipelineServer) ListPipelineVersions(ctx context.Context, request *apiv sortBy := request.GetSortBy() filter := request.GetFilter() - pipelineVersions, totalSize, nextPageToken, err := s.listPipelineVersions(ctx, pipelineId, pageToken, pageSize, sortBy, filter) + // Validate query parameters + opts, err := validatedListOptions(&model.PipelineVersion{}, pageToken, int(pageSize), sortBy, filter, "v2beta1") + if err != nil { + return nil, util.Wrapf(err, "Failed to list pipeline versions due invalid list options: pageToken: %v, pageSize: %v, sortBy: %v, filter: %v", pageToken, int(pageSize), sortBy, filter) + } + + pipelineVersions, totalSize, nextPageToken, err := s.listPipelineVersions(ctx, pipelineId, pageToken, pageSize, sortBy, opts) if err != nil { return nil, util.Wrapf(err, "Failed to list pipeline versions for pipeline %s", pipelineId) } diff --git a/backend/src/apiserver/server/run_server.go b/backend/src/apiserver/server/run_server.go index 2328ddf5b6..86e48679f7 100644 --- a/backend/src/apiserver/server/run_server.go +++ b/backend/src/apiserver/server/run_server.go @@ -22,6 +22,7 @@ import ( apiv1beta1 "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client" apiv2beta1 "github.com/kubeflow/pipelines/backend/api/v2beta1/go_client" "github.com/kubeflow/pipelines/backend/src/apiserver/common" + "github.com/kubeflow/pipelines/backend/src/apiserver/list" "github.com/kubeflow/pipelines/backend/src/apiserver/model" "github.com/kubeflow/pipelines/backend/src/apiserver/resource" "github.com/kubeflow/pipelines/backend/src/common/util" @@ -194,7 +195,7 @@ func (s *RunServer) GetRunV1(ctx context.Context, request *apiv1beta1.GetRunRequ // Fetches all runs that conform to the specified filter and listing options. // Applies common logic on v1beta1 and v2beta1 API. -func (s *RunServer) listRuns(ctx context.Context, pageToken string, pageSize int, sortBy string, filter string, namespace string, experimentId string) ([]*model.Run, int, string, error) { +func (s *RunServer) listRuns(ctx context.Context, pageToken string, pageSize int, sortBy string, opts *list.Options, namespace string, experimentId string) ([]*model.Run, int, string, error) { namespace = s.resourceManager.ReplaceNamespace(namespace) if experimentId != "" { ns, err := s.resourceManager.GetNamespaceFromExperimentId(experimentId) @@ -212,10 +213,6 @@ func (s *RunServer) listRuns(ctx context.Context, pageToken string, pageSize int return nil, 0, "", util.Wrapf(err, "Failed to list runs due to authorization error. Check if you have permission to access namespace %s", namespace) } - opts, err := validatedListOptions(&model.Run{}, pageToken, pageSize, sortBy, filter) - if err != nil { - return nil, 0, "", util.Wrap(err, "Failed to create list options") - } filterContext := &model.FilterContext{ ReferenceKey: &model.ReferenceKey{Type: model.NamespaceResourceType, ID: namespace}, } @@ -256,7 +253,13 @@ func (s *RunServer) ListRunsV1(ctx context.Context, r *apiv1beta1.ListRunsReques experimentId = filterContext.ReferenceKey.ID } } - runs, runsCount, nextPageToken, err := s.listRuns(ctx, r.GetPageToken(), int(r.GetPageSize()), r.GetSortBy(), r.GetFilter(), namespace, experimentId) + + opts, err := validatedListOptions(&model.Run{}, r.GetPageToken(), int(r.GetPageSize()), r.GetSortBy(), r.GetFilter(), "v1beta1") + if err != nil { + return nil, util.Wrap(err, "Failed to create list options") + } + + runs, runsCount, nextPageToken, err := s.listRuns(ctx, r.GetPageToken(), int(r.GetPageSize()), r.GetSortBy(), opts, namespace, experimentId) if err != nil { return nil, util.Wrap(err, "Failed to list v1beta1 runs") } @@ -588,26 +591,17 @@ func (s *RunServer) ListRuns(ctx context.Context, r *apiv2beta1.ListRunsRequest) if s.options.CollectMetrics { listRunRequests.Inc() } - runs, runsCount, nextPageToken, err := s.listRuns(ctx, r.GetPageToken(), int(r.GetPageSize()), r.GetSortBy(), r.GetFilter(), r.GetNamespace(), r.GetExperimentId()) + opts, err := validatedListOptions(&model.Run{}, r.GetPageToken(), int(r.GetPageSize()), r.GetSortBy(), r.GetFilter(), "v2beta1") + if err != nil { + return nil, util.Wrap(err, "Failed to create list options") + } + runs, runsCount, nextPageToken, err := s.listRuns(ctx, r.GetPageToken(), int(r.GetPageSize()), r.GetSortBy(), opts, r.GetNamespace(), r.GetExperimentId()) if err != nil { return nil, util.Wrap(err, "Failed to list runs") } return &apiv2beta1.ListRunsResponse{Runs: toApiRuns(runs), TotalSize: int32(runsCount), NextPageToken: nextPageToken}, nil } -// Fetches runs across all experiments given query parameters. -// Supports v2beta1 behavior. -func (s *RunServer) ListAllRuns(ctx context.Context, r *apiv2beta1.ListRunsRequest) (*apiv2beta1.ListRunsResponse, error) { - if s.options.CollectMetrics { - listRunRequests.Inc() - } - runs, runsCount, nextPageToken, err := s.listRuns(ctx, r.GetPageToken(), int(r.GetPageSize()), r.GetSortBy(), r.GetFilter(), r.GetNamespace(), "") - if err != nil { - return nil, util.Wrap(err, "Failed to list all runs") - } - return &apiv2beta1.ListRunsResponse{Runs: toApiRuns(runs), TotalSize: int32(runsCount), NextPageToken: nextPageToken}, nil -} - // Archives a run. // Supports v2beta1 behavior. func (s *RunServer) ArchiveRun(ctx context.Context, request *apiv2beta1.ArchiveRunRequest) (*empty.Empty, error) { diff --git a/backend/src/apiserver/server/task_server.go b/backend/src/apiserver/server/task_server.go index a2e811f591..8de983929b 100644 --- a/backend/src/apiserver/server/task_server.go +++ b/backend/src/apiserver/server/task_server.go @@ -93,7 +93,7 @@ func (s *TaskServer) validateCreateTaskRequest(request *api.CreateTaskRequest) e func (s *TaskServer) ListTasksV1(ctx context.Context, request *api.ListTasksRequest) ( *api.ListTasksResponse, error, ) { - opts, err := validatedListOptions(&model.Task{}, request.PageToken, int(request.PageSize), request.SortBy, request.Filter) + opts, err := validatedListOptions(&model.Task{}, request.PageToken, int(request.PageSize), request.SortBy, request.Filter, "v1beta1") if err != nil { return nil, util.Wrap(err, "Failed to create list options") } diff --git a/backend/src/apiserver/storage/experiment_store_test.go b/backend/src/apiserver/storage/experiment_store_test.go index 5b967a3950..291a1816f5 100644 --- a/backend/src/apiserver/storage/experiment_store_test.go +++ b/backend/src/apiserver/storage/experiment_store_test.go @@ -20,6 +20,7 @@ import ( apiv1beta1 "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client" apiv2beta1 "github.com/kubeflow/pipelines/backend/api/v2beta1/go_client" + "github.com/kubeflow/pipelines/backend/src/apiserver/filter" "github.com/kubeflow/pipelines/backend/src/apiserver/list" "github.com/kubeflow/pipelines/backend/src/apiserver/model" "github.com/kubeflow/pipelines/backend/src/common/util" @@ -392,8 +393,9 @@ func TestListExperiments_Filtering(t *testing.T) { }, }, } + newFilter, _ := filter.New(filterProto) - opts, err := list.NewOptions(&model.Experiment{}, 2, "id", filterProto) + opts, err := list.NewOptions(&model.Experiment{}, 2, "id", newFilter) assert.Nil(t, err) experiments, total_size, nextPageToken, err := experimentStore.ListExperiments(&model.FilterContext{}, opts) diff --git a/backend/src/apiserver/storage/job_store_test.go b/backend/src/apiserver/storage/job_store_test.go index ab3609bfa8..7834f8c36d 100644 --- a/backend/src/apiserver/storage/job_store_test.go +++ b/backend/src/apiserver/storage/job_store_test.go @@ -19,6 +19,7 @@ import ( "time" api "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client" + "github.com/kubeflow/pipelines/backend/src/apiserver/filter" "github.com/kubeflow/pipelines/backend/src/apiserver/list" "github.com/kubeflow/pipelines/backend/src/apiserver/model" "github.com/kubeflow/pipelines/backend/src/common/util" @@ -188,7 +189,7 @@ func TestListJobs_TotalSizeWithFilter(t *testing.T) { defer db.Close() // Add a filter - opts, _ := list.NewOptions(&model.Job{}, 4, "name", &api.Filter{ + protoFilter := &api.Filter{ Predicates: []*api.Predicate{ { Key: "name", @@ -200,7 +201,9 @@ func TestListJobs_TotalSizeWithFilter(t *testing.T) { }, }, }, - }) + } + newFilter, _ := filter.New(protoFilter) + opts, _ := list.NewOptions(&model.Job{}, 4, "name", newFilter) jobs, total_size, _, err := jobStore.ListJobs(&model.FilterContext{}, opts) assert.Nil(t, err) assert.Equal(t, 1, len(jobs)) diff --git a/backend/src/apiserver/storage/pipeline_store_test.go b/backend/src/apiserver/storage/pipeline_store_test.go index 10070b138d..31a8da9693 100644 --- a/backend/src/apiserver/storage/pipeline_store_test.go +++ b/backend/src/apiserver/storage/pipeline_store_test.go @@ -18,6 +18,7 @@ import ( "testing" api "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client" + "github.com/kubeflow/pipelines/backend/src/apiserver/filter" "github.com/kubeflow/pipelines/backend/src/apiserver/list" "github.com/kubeflow/pipelines/backend/src/apiserver/model" "github.com/kubeflow/pipelines/backend/src/common/util" @@ -224,7 +225,8 @@ func TestListPipelines_WithFilter(t *testing.T) { }, }, } - opts, err := list.NewOptions(&model.Pipeline{}, 10, "id", filterProto) + newFilter, _ := filter.New(filterProto) + opts, err := list.NewOptions(&model.Pipeline{}, 10, "id", newFilter) assert.Nil(t, err) pipelines, _, totalSize, nextPageToken, err := pipelineStore.ListPipelinesV1(&model.FilterContext{}, opts) @@ -1524,6 +1526,7 @@ func TestListPipelineVersions_WithFilter(t *testing.T) { }, }, } + equalFilter, _ := filter.New(equalFilterProto) // Filter for name prefix being pipeline_version prefixFilterProto := &api.Filter{ @@ -1535,9 +1538,10 @@ func TestListPipelineVersions_WithFilter(t *testing.T) { }, }, } + prefixFilter, _ := filter.New(prefixFilterProto) // Only return 1 pipeline version with equal filter. - opts, err := list.NewOptions(&model.PipelineVersion{}, 10, "id", equalFilterProto) + opts, err := list.NewOptions(&model.PipelineVersion{}, 10, "id", equalFilter) assert.Nil(t, err) _, totalSize, nextPageToken, err := pipelineStore.ListPipelineVersions(DefaultFakePipelineId, opts) assert.Nil(t, err) @@ -1553,7 +1557,7 @@ func TestListPipelineVersions_WithFilter(t *testing.T) { assert.Equal(t, 2, totalSize) // Return 2 pipeline versions with prefix filter. - opts, err = list.NewOptions(&model.PipelineVersion{}, 10, "id", prefixFilterProto) + opts, err = list.NewOptions(&model.PipelineVersion{}, 10, "id", prefixFilter) assert.Nil(t, err) _, totalSize, nextPageToken, err = pipelineStore.ListPipelineVersions(DefaultFakePipelineId, opts) assert.Nil(t, err) diff --git a/backend/src/apiserver/storage/run_store_test.go b/backend/src/apiserver/storage/run_store_test.go index a4de27910c..debeb69dd8 100644 --- a/backend/src/apiserver/storage/run_store_test.go +++ b/backend/src/apiserver/storage/run_store_test.go @@ -22,6 +22,7 @@ import ( sq "github.com/Masterminds/squirrel" api "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client" + "github.com/kubeflow/pipelines/backend/src/apiserver/filter" "github.com/kubeflow/pipelines/backend/src/apiserver/list" "github.com/kubeflow/pipelines/backend/src/apiserver/model" "github.com/kubeflow/pipelines/backend/src/common/util" @@ -390,7 +391,7 @@ func TestListRuns_TotalSizeWithFilter(t *testing.T) { defer db.Close() // Add a filter - opts, _ := list.NewOptions(&model.Run{}, 4, "", &api.Filter{ + filterProto := &api.Filter{ Predicates: []*api.Predicate{ { Key: "name", @@ -402,7 +403,9 @@ func TestListRuns_TotalSizeWithFilter(t *testing.T) { }, }, }, - }) + } + newFilter, _ := filter.New(filterProto) + opts, _ := list.NewOptions(&model.Run{}, 4, "", newFilter) runs, total_size, _, err := runStore.ListRuns(&model.FilterContext{}, opts) assert.Nil(t, err) assert.Equal(t, 2, len(runs))