From 870da5a58ee9749ff7e923bafe2217da191bfd56 Mon Sep 17 00:00:00 2001 From: Jiahui Feng Date: Fri, 7 Oct 2022 15:02:47 -0700 Subject: [PATCH 1/3] move CEL package to apiserver package. only anything that does not require Structural Kubernetes-commit: 0dd316a5c11261c0e5fc7928d8697754b16ad461 --- go.mod | 24 +- go.sum | 13 +- pkg/cel/errors.go | 47 ++ pkg/cel/escaping.go | 170 ++++ pkg/cel/escaping_test.go | 206 +++++ pkg/cel/library/cost.go | 268 ++++++ pkg/cel/library/cost_test.go | 363 +++++++++ pkg/cel/library/libraries.go | 34 + pkg/cel/library/library_compatibility_test.go | 58 ++ pkg/cel/library/lists.go | 312 +++++++ pkg/cel/library/regex.go | 187 +++++ pkg/cel/library/urls.go | 236 ++++++ pkg/cel/limits.go | 49 ++ pkg/cel/metrics/metrics.go | 72 ++ pkg/cel/metrics/metrics_test.go | 68 ++ pkg/cel/registry.go | 79 ++ pkg/cel/types.go | 552 +++++++++++++ pkg/cel/types_test.go | 79 ++ pkg/cel/url.go | 80 ++ pkg/cel/value.go | 769 ++++++++++++++++++ pkg/cel/value_test.go | 362 +++++++++ 21 files changed, 4005 insertions(+), 23 deletions(-) create mode 100644 pkg/cel/errors.go create mode 100644 pkg/cel/escaping.go create mode 100644 pkg/cel/escaping_test.go create mode 100644 pkg/cel/library/cost.go create mode 100644 pkg/cel/library/cost_test.go create mode 100644 pkg/cel/library/libraries.go create mode 100644 pkg/cel/library/library_compatibility_test.go create mode 100644 pkg/cel/library/lists.go create mode 100644 pkg/cel/library/regex.go create mode 100644 pkg/cel/library/urls.go create mode 100644 pkg/cel/limits.go create mode 100644 pkg/cel/metrics/metrics.go create mode 100644 pkg/cel/metrics/metrics_test.go create mode 100644 pkg/cel/registry.go create mode 100644 pkg/cel/types.go create mode 100644 pkg/cel/types_test.go create mode 100644 pkg/cel/url.go create mode 100644 pkg/cel/value.go create mode 100644 pkg/cel/value_test.go diff --git a/go.mod b/go.mod index 90c1dfb88..0fea974b1 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/google/gnostic v0.5.7-v3refs github.com/google/go-cmp v0.5.9 github.com/google/gofuzz v1.1.0 + github.com/google/cel-go v0.12.5 github.com/google/uuid v1.1.2 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 @@ -39,12 +40,12 @@ require ( google.golang.org/grpc v1.49.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0 gopkg.in/square/go-jose.v2 v2.2.2 - k8s.io/api v0.0.0-20221012035047-0f8110492ea0 - k8s.io/apimachinery v0.0.0-20221012034848-78d003cc9419 - k8s.io/client-go v0.0.0-20221012035333-e6d958c7a853 - k8s.io/component-base v0.0.0-20221012040034-5d2a88c65282 + k8s.io/api v0.0.0 + k8s.io/apimachinery v0.0.0 + k8s.io/client-go v0.0.0 + k8s.io/component-base v0.0.0 k8s.io/klog/v2 v2.80.1 - k8s.io/kms v0.0.0-20221012040222-bf322548c086 + k8s.io/kms v0.0.0 k8s.io/kube-openapi v0.0.0-20220803162953-67bda5d908f1 k8s.io/utils v0.0.0-20220922133306-665eaaec4324 sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.0.33 @@ -109,7 +110,7 @@ require ( go.uber.org/multierr v1.6.0 // indirect golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b // indirect golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect - golang.org/x/text v0.3.8 // indirect + golang.org/x/text v0.3.7 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/genproto v0.0.0-20220502173005-c8bf987b8c21 // indirect google.golang.org/protobuf v1.28.1 // indirect @@ -119,9 +120,10 @@ require ( ) replace ( - k8s.io/api => k8s.io/api v0.0.0-20221012035047-0f8110492ea0 - k8s.io/apimachinery => k8s.io/apimachinery v0.0.0-20221012034848-78d003cc9419 - k8s.io/client-go => k8s.io/client-go v0.0.0-20221012035333-e6d958c7a853 - k8s.io/component-base => k8s.io/component-base v0.0.0-20221012040034-5d2a88c65282 - k8s.io/kms => k8s.io/kms v0.0.0-20221012040222-bf322548c086 + k8s.io/api => ../api + k8s.io/apimachinery => ../apimachinery + k8s.io/apiserver => ../apiserver + k8s.io/client-go => ../client-go + k8s.io/component-base => ../component-base + k8s.io/kms => ../kms ) diff --git a/go.sum b/go.sum index 6b2a8dc98..0b87ddd36 100644 --- a/go.sum +++ b/go.sum @@ -726,9 +726,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY= -golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -977,18 +976,8 @@ honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= -k8s.io/api v0.0.0-20221012035047-0f8110492ea0 h1:zsKtQXVrOo62IwClRKcF/RbgeFI2FQFZ99YhdfypWn8= -k8s.io/api v0.0.0-20221012035047-0f8110492ea0/go.mod h1:Q2jmki3lhHeeAyYtbIiEe/2JoGg2Ge1cJoMgpFkjtzg= -k8s.io/apimachinery v0.0.0-20221012034848-78d003cc9419 h1:6islJrEgy0CM8YsBcMFdlxm1lXtVC9X5A2AVRK7JEpc= -k8s.io/apimachinery v0.0.0-20221012034848-78d003cc9419/go.mod h1:1b4APDhID8eky0PLpgeoWtdLUptwwJD8Jk5uFOd22CE= -k8s.io/client-go v0.0.0-20221012035333-e6d958c7a853 h1:MA0fOvlQt7fzsEIujzUH4ssakSCKB6+iWRclNvRYAnk= -k8s.io/client-go v0.0.0-20221012035333-e6d958c7a853/go.mod h1:UVW05h29JMuginXCTvlNPTFVS/AyKlU3Q4XoA6UidtU= -k8s.io/component-base v0.0.0-20221012040034-5d2a88c65282 h1:JqSl7JGZ8FwUFz9Isq0YIJNLx5oK5x/g0Vl2VtVntI4= -k8s.io/component-base v0.0.0-20221012040034-5d2a88c65282/go.mod h1:t3mMmVABnDZIpNgn7xDM2AdNUr2jvWA8pleZlSDjAug= k8s.io/klog/v2 v2.80.1 h1:atnLQ121W371wYYFawwYx1aEY2eUfs4l3J72wtgAwV4= k8s.io/klog/v2 v2.80.1/go.mod h1:y1WjHnz7Dj687irZUWR/WLkLc5N1YHtjLdmgWjndZn0= -k8s.io/kms v0.0.0-20221012040222-bf322548c086 h1:BCrccAO6WPPwlzfgds9ERms+pPGvNJ4QpakEuqEQ5rk= -k8s.io/kms v0.0.0-20221012040222-bf322548c086/go.mod h1:Ef2bN4e3pWtPy0jFds29OObACUDyWrPpK4wuLxW/pQM= k8s.io/kube-openapi v0.0.0-20220803162953-67bda5d908f1 h1:MQ8BAZPZlWk3S9K4a9NCkIFQtZShWqoha7snGixVgEA= k8s.io/kube-openapi v0.0.0-20220803162953-67bda5d908f1/go.mod h1:C/N6wCaBHeBHkHUesQOQy2/MZqGgMAFPqGsGQLdbZBU= k8s.io/utils v0.0.0-20220922133306-665eaaec4324 h1:i+xdFemcSNuJvIfBlaYuXgRondKxK4z4prVPKzEaelI= diff --git a/pkg/cel/errors.go b/pkg/cel/errors.go new file mode 100644 index 000000000..907ca6ec8 --- /dev/null +++ b/pkg/cel/errors.go @@ -0,0 +1,47 @@ +/* +Copyright 2021 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cel + +// Error is an implementation of the 'error' interface, which represents a +// XValidation error. +type Error struct { + Type ErrorType + Detail string +} + +var _ error = &Error{} + +// Error implements the error interface. +func (v *Error) Error() string { + return v.Detail +} + +// ErrorType is a machine readable value providing more detail about why +// a XValidation is invalid. +type ErrorType string + +const ( + // ErrorTypeRequired is used to report withNullable values that are not + // provided (e.g. empty strings, null values, or empty arrays). See + // Required(). + ErrorTypeRequired ErrorType = "RuleRequired" + // ErrorTypeInvalid is used to report malformed values + ErrorTypeInvalid ErrorType = "RuleInvalid" + // ErrorTypeInternal is used to report other errors that are not related + // to user input. See InternalError(). + ErrorTypeInternal ErrorType = "InternalError" +) diff --git a/pkg/cel/escaping.go b/pkg/cel/escaping.go new file mode 100644 index 000000000..705c353a2 --- /dev/null +++ b/pkg/cel/escaping.go @@ -0,0 +1,170 @@ +/* +Copyright 2021 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cel + +import ( + "regexp" + + "k8s.io/apimachinery/pkg/util/sets" +) + +// celReservedSymbols is a list of RESERVED symbols defined in the CEL lexer. +// No identifiers are allowed to collide with these symbols. +// https://github.com/google/cel-spec/blob/master/doc/langdef.md#syntax +var celReservedSymbols = sets.NewString( + "true", "false", "null", "in", + "as", "break", "const", "continue", "else", + "for", "function", "if", "import", "let", + "loop", "package", "namespace", "return", // !! 'namespace' is used heavily in Kubernetes + "var", "void", "while", +) + +// expandMatcher matches the escape sequence, characters that are escaped, and characters that are unsupported +var expandMatcher = regexp.MustCompile(`(__|[-./]|[^a-zA-Z0-9-./_])`) + +// newCharacterFilter returns a boolean array to indicate the allowed characters +func newCharacterFilter(characters string) []bool { + maxChar := 0 + for _, c := range characters { + if maxChar < int(c) { + maxChar = int(c) + } + } + filter := make([]bool, maxChar+1) + + for _, c := range characters { + filter[int(c)] = true + } + + return filter +} + +type escapeCheck struct { + canSkipRegex bool + invalidCharFound bool +} + +// skipRegexCheck checks if escape would be skipped. +// if invalidCharFound is true, it must have invalid character; if invalidCharFound is false, not sure if it has invalid character or not +func skipRegexCheck(ident string) escapeCheck { + escapeCheck := escapeCheck{canSkipRegex: true, invalidCharFound: false} + // skip escape if possible + previous_underscore := false + for _, c := range ident { + if c == '/' || c == '-' || c == '.' { + escapeCheck.canSkipRegex = false + return escapeCheck + } + intc := int(c) + if intc < 0 || intc >= len(validCharacterFilter) || !validCharacterFilter[intc] { + escapeCheck.invalidCharFound = true + return escapeCheck + } + if c == '_' && previous_underscore { + escapeCheck.canSkipRegex = false + return escapeCheck + } + + previous_underscore = c == '_' + } + return escapeCheck +} + +// validCharacterFilter indicates the allowed characters. +var validCharacterFilter = newCharacterFilter("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_") + +// Escape escapes ident and returns a CEL identifier (of the form '[a-zA-Z_][a-zA-Z0-9_]*'), or returns +// false if the ident does not match the supported input format of `[a-zA-Z_.-/][a-zA-Z0-9_.-/]*`. +// Escaping Rules: +// - '__' escapes to '__underscores__' +// - '.' escapes to '__dot__' +// - '-' escapes to '__dash__' +// - '/' escapes to '__slash__' +// - Identifiers that exactly match a CEL RESERVED keyword escape to '__{keyword}__'. The keywords are: "true", "false", +// "null", "in", "as", "break", "const", "continue", "else", "for", "function", "if", "import", "let", loop", "package", +// "namespace", "return". +func Escape(ident string) (string, bool) { + if len(ident) == 0 || ('0' <= ident[0] && ident[0] <= '9') { + return "", false + } + if celReservedSymbols.Has(ident) { + return "__" + ident + "__", true + } + + escapeCheck := skipRegexCheck(ident) + if escapeCheck.invalidCharFound { + return "", false + } + if escapeCheck.canSkipRegex { + return ident, true + } + + ok := true + ident = expandMatcher.ReplaceAllStringFunc(ident, func(s string) string { + switch s { + case "__": + return "__underscores__" + case ".": + return "__dot__" + case "-": + return "__dash__" + case "/": + return "__slash__" + default: // matched a unsupported supported + ok = false + return "" + } + }) + if !ok { + return "", false + } + return ident, true +} + +var unexpandMatcher = regexp.MustCompile(`(_{2}[^_]+_{2})`) + +// Unescape unescapes an CEL identifier containing the escape sequences described in Escape, or return false if the +// string contains invalid escape sequences. The escaped input is expected to be a valid CEL identifier, but is +// not checked. +func Unescape(escaped string) (string, bool) { + ok := true + escaped = unexpandMatcher.ReplaceAllStringFunc(escaped, func(s string) string { + contents := s[2 : len(s)-2] + switch contents { + case "underscores": + return "__" + case "dot": + return "." + case "dash": + return "-" + case "slash": + return "/" + } + if celReservedSymbols.Has(contents) { + if len(s) != len(escaped) { + ok = false + } + return contents + } + ok = false + return "" + }) + if !ok { + return "", false + } + return escaped, true +} diff --git a/pkg/cel/escaping_test.go b/pkg/cel/escaping_test.go new file mode 100644 index 000000000..e4b2aa906 --- /dev/null +++ b/pkg/cel/escaping_test.go @@ -0,0 +1,206 @@ +/* +Copyright 2021 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cel + +import ( + "fmt" + "regexp" + "testing" + + fuzz "github.com/google/gofuzz" +) + +// TestEscaping tests that property names are escaped as expected. +func TestEscaping(t *testing.T) { + cases := []struct { + unescaped string + escaped string + unescapable bool + }{ + // '.', '-', '/' and '__' are escaped since + // CEL only allows identifiers of the form: [a-zA-Z_][a-zA-Z0-9_]* + {unescaped: "a.a", escaped: "a__dot__a"}, + {unescaped: "a-a", escaped: "a__dash__a"}, + {unescaped: "a__a", escaped: "a__underscores__a"}, + {unescaped: "a.-/__a", escaped: "a__dot____dash____slash____underscores__a"}, + {unescaped: "a._a", escaped: "a__dot___a"}, + {unescaped: "a__.__a", escaped: "a__underscores____dot____underscores__a"}, + {unescaped: "a___a", escaped: "a__underscores___a"}, + {unescaped: "a____a", escaped: "a__underscores____underscores__a"}, + {unescaped: "a__dot__a", escaped: "a__underscores__dot__underscores__a"}, + {unescaped: "a__underscores__a", escaped: "a__underscores__underscores__underscores__a"}, + // CEL lexer RESERVED keywords must be escaped + {unescaped: "true", escaped: "__true__"}, + {unescaped: "false", escaped: "__false__"}, + {unescaped: "null", escaped: "__null__"}, + {unescaped: "in", escaped: "__in__"}, + {unescaped: "as", escaped: "__as__"}, + {unescaped: "break", escaped: "__break__"}, + {unescaped: "const", escaped: "__const__"}, + {unescaped: "continue", escaped: "__continue__"}, + {unescaped: "else", escaped: "__else__"}, + {unescaped: "for", escaped: "__for__"}, + {unescaped: "function", escaped: "__function__"}, + {unescaped: "if", escaped: "__if__"}, + {unescaped: "import", escaped: "__import__"}, + {unescaped: "let", escaped: "__let__"}, + {unescaped: "loop", escaped: "__loop__"}, + {unescaped: "package", escaped: "__package__"}, + {unescaped: "namespace", escaped: "__namespace__"}, + {unescaped: "return", escaped: "__return__"}, + {unescaped: "var", escaped: "__var__"}, + {unescaped: "void", escaped: "__void__"}, + {unescaped: "while", escaped: "__while__"}, + // Not all property names are escapable + {unescaped: "@", unescapable: true}, + {unescaped: "1up", unescapable: true}, + {unescaped: "👑", unescapable: true}, + // CEL macro and function names do not need to be escaped because the parser keeps identifiers in a + // different namespace than function and macro names. + {unescaped: "has", escaped: "has"}, + {unescaped: "all", escaped: "all"}, + {unescaped: "exists", escaped: "exists"}, + {unescaped: "exists_one", escaped: "exists_one"}, + {unescaped: "filter", escaped: "filter"}, + {unescaped: "size", escaped: "size"}, + {unescaped: "contains", escaped: "contains"}, + {unescaped: "startsWith", escaped: "startsWith"}, + {unescaped: "endsWith", escaped: "endsWith"}, + {unescaped: "matches", escaped: "matches"}, + {unescaped: "duration", escaped: "duration"}, + {unescaped: "timestamp", escaped: "timestamp"}, + {unescaped: "getDate", escaped: "getDate"}, + {unescaped: "getDayOfMonth", escaped: "getDayOfMonth"}, + {unescaped: "getDayOfWeek", escaped: "getDayOfWeek"}, + {unescaped: "getFullYear", escaped: "getFullYear"}, + {unescaped: "getHours", escaped: "getHours"}, + {unescaped: "getMilliseconds", escaped: "getMilliseconds"}, + {unescaped: "getMinutes", escaped: "getMinutes"}, + {unescaped: "getMonth", escaped: "getMonth"}, + {unescaped: "getSeconds", escaped: "getSeconds"}, + // we don't escape a single _ + {unescaped: "_if", escaped: "_if"}, + {unescaped: "_has", escaped: "_has"}, + {unescaped: "_int", escaped: "_int"}, + {unescaped: "_anything", escaped: "_anything"}, + } + + for _, tc := range cases { + t.Run(tc.unescaped, func(t *testing.T) { + e, escapable := Escape(tc.unescaped) + if tc.unescapable { + if escapable { + t.Errorf("Expected escapable=false, but got %t", escapable) + } + return + } + if !escapable { + t.Fatalf("Expected escapable=true, but got %t", escapable) + } + if tc.escaped != e { + t.Errorf("Expected %s to escape to %s, but got %s", tc.unescaped, tc.escaped, e) + } + + if !validCelIdent.MatchString(e) { + t.Errorf("Expected %s to escape to a valid CEL identifier, but got %s", tc.unescaped, e) + } + + u, ok := Unescape(tc.escaped) + if !ok { + t.Fatalf("Expected %s to be escapable, but it was not", tc.escaped) + } + if tc.unescaped != u { + t.Errorf("Expected %s to unescape to %s, but got %s", tc.escaped, tc.unescaped, u) + } + }) + } +} + +func TestUnescapeMalformed(t *testing.T) { + for _, s := range []string{"__int__extra", "__illegal__"} { + t.Run(s, func(t *testing.T) { + e, ok := Unescape(s) + if ok { + t.Fatalf("Expected %s to be unescapable, but it escaped to: %s", s, e) + } + }) + } +} + +func TestEscapingFuzz(t *testing.T) { + fuzzer := fuzz.New() + for i := 0; i < 1000; i++ { + var unescaped string + fuzzer.Fuzz(&unescaped) + t.Run(fmt.Sprintf("%d - '%s'", i, unescaped), func(t *testing.T) { + if len(unescaped) == 0 { + return + } + escaped, ok := Escape(unescaped) + if !ok { + return + } + + if !validCelIdent.MatchString(escaped) { + t.Errorf("Expected %s to escape to a valid CEL identifier, but got %s", unescaped, escaped) + } + u, ok := Unescape(escaped) + if !ok { + t.Fatalf("Expected %s to be unescapable, but it was not", escaped) + } + if unescaped != u { + t.Errorf("Expected %s to unescape to %s, but got %s", escaped, unescaped, u) + } + }) + } +} + +var validCelIdent = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + +func TestCanSkipRegex(t *testing.T) { + cases := []struct { + unescaped string + canSkip bool + invalidCharFound bool + }{ + {unescaped: "a.a", canSkip: false}, + {unescaped: "a-a", canSkip: false}, + {unescaped: "a__a", canSkip: false}, + {unescaped: "a.-/__a", canSkip: false}, + {unescaped: "a_a", canSkip: true}, + {unescaped: "a_a_a", canSkip: true}, + {unescaped: "@", invalidCharFound: true}, + {unescaped: "👑", invalidCharFound: true}, + // if escape keyword is detected before invalid character, invalidCharFound + {unescaped: "/👑", canSkip: false}, + } + + for _, tc := range cases { + t.Run(tc.unescaped, func(t *testing.T) { + escapeCheck := skipRegexCheck(tc.unescaped) + if escapeCheck.invalidCharFound { + if escapeCheck.invalidCharFound != tc.invalidCharFound { + t.Errorf("Expected input validation to be %v, but got %t", tc.invalidCharFound, escapeCheck.invalidCharFound) + } + } else { + if escapeCheck.canSkipRegex != tc.canSkip { + t.Errorf("Expected %v, but got %t", tc.canSkip, escapeCheck.canSkipRegex) + } + } + }) + } +} diff --git a/pkg/cel/library/cost.go b/pkg/cel/library/cost.go new file mode 100644 index 000000000..39098e3f6 --- /dev/null +++ b/pkg/cel/library/cost.go @@ -0,0 +1,268 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package library + +import ( + "math" + + "github.com/google/cel-go/checker" + "github.com/google/cel-go/common" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" +) + +// CostEstimator implements CEL's interpretable.ActualCostEstimator and checker.CostEstimator. +type CostEstimator struct { + // SizeEstimator provides a CostEstimator.EstimateSize that this CostEstimator will delegate size estimation + // calculations to if the size is not well known (i.e. a constant). + SizeEstimator checker.CostEstimator +} + +func (l *CostEstimator) CallCost(function, overloadId string, args []ref.Val, result ref.Val) *uint64 { + switch function { + case "isSorted", "sum", "max", "min", "indexOf", "lastIndexOf": + var cost uint64 + if len(args) > 0 { + cost += traversalCost(args[0]) // these O(n) operations all cost roughly the cost of a single traversal + } + return &cost + case "url", "lowerAscii", "upperAscii", "substring", "trim": + if len(args) >= 1 { + cost := uint64(math.Ceil(float64(actualSize(args[0])) * common.StringTraversalCostFactor)) + return &cost + } + case "replace", "split": + if len(args) >= 1 { + // cost is the traversal plus the construction of the result + cost := uint64(math.Ceil(float64(actualSize(args[0])) * 2 * common.StringTraversalCostFactor)) + return &cost + } + case "join": + if len(args) >= 1 { + cost := uint64(math.Ceil(float64(actualSize(result)) * 2 * common.StringTraversalCostFactor)) + return &cost + } + case "find", "findAll": + if len(args) >= 2 { + strCost := uint64(math.Ceil((1.0 + float64(actualSize(args[0]))) * common.StringTraversalCostFactor)) + // We don't know how many expressions are in the regex, just the string length (a huge + // improvement here would be to somehow get a count the number of expressions in the regex or + // how many states are in the regex state machine and use that to measure regex cost). + // For now, we're making a guess that each expression in a regex is typically at least 4 chars + // in length. + regexCost := uint64(math.Ceil(float64(actualSize(args[1])) * common.RegexStringLengthCostFactor)) + cost := strCost * regexCost + return &cost + } + } + return nil +} + +func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + // WARNING: Any changes to this code impact API compatibility! The estimated cost is used to determine which CEL rules may be written to a + // CRD and any change (cost increases and cost decreases) are breaking. + switch function { + case "isSorted", "sum", "max", "min", "indexOf", "lastIndexOf": + if target != nil { + // Charge 1 cost for comparing each element in the list + elCost := checker.CostEstimate{Min: 1, Max: 1} + // If the list contains strings or bytes, add the cost of traversing all the strings/bytes as a way + // of estimating the additional comparison cost. + if elNode := l.listElementNode(*target); elNode != nil { + t := elNode.Type().GetPrimitive() + if t == exprpb.Type_STRING || t == exprpb.Type_BYTES { + sz := l.sizeEstimate(elNode) + elCost = elCost.Add(sz.MultiplyByCostFactor(common.StringTraversalCostFactor)) + } + return &checker.CallEstimate{CostEstimate: l.sizeEstimate(*target).MultiplyByCost(elCost)} + } else { // the target is a string, which is supported by indexOf and lastIndexOf + return &checker.CallEstimate{CostEstimate: l.sizeEstimate(*target).MultiplyByCostFactor(common.StringTraversalCostFactor)} + } + + } + case "url": + if len(args) == 1 { + sz := l.sizeEstimate(args[0]) + return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor)} + } + case "lowerAscii", "upperAscii", "substring", "trim": + if target != nil { + sz := l.sizeEstimate(*target) + return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor), ResultSize: &sz} + } + case "replace": + if target != nil && len(args) >= 2 { + sz := l.sizeEstimate(*target) + toReplaceSz := l.sizeEstimate(args[0]) + replaceWithSz := l.sizeEstimate(args[1]) + // smallest possible result: smallest input size composed of the largest possible substrings being replaced by smallest possible replacement + minSz := uint64(math.Ceil(float64(sz.Min)/float64(toReplaceSz.Max))) * replaceWithSz.Min + // largest possible result: largest input size composed of the smallest possible substrings being replaced by largest possible replacement + maxSz := uint64(math.Ceil(float64(sz.Max)/float64(toReplaceSz.Min))) * replaceWithSz.Max + + // cost is the traversal plus the construction of the result + return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(2 * common.StringTraversalCostFactor), ResultSize: &checker.SizeEstimate{Min: minSz, Max: maxSz}} + } + case "split": + if target != nil { + sz := l.sizeEstimate(*target) + + // Worst case size is where is that a separator of "" is used, and each char is returned as a list element. + max := sz.Max + if len(args) > 1 { + if c := args[1].Expr().GetConstExpr(); c != nil { + max = uint64(c.GetInt64Value()) + } + } + // Cost is the traversal plus the construction of the result. + return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(2 * common.StringTraversalCostFactor), ResultSize: &checker.SizeEstimate{Min: 0, Max: max}} + } + case "join": + if target != nil { + var sz checker.SizeEstimate + listSize := l.sizeEstimate(*target) + if elNode := l.listElementNode(*target); elNode != nil { + elemSize := l.sizeEstimate(elNode) + sz = listSize.Multiply(elemSize) + } + + if len(args) > 0 { + sepSize := l.sizeEstimate(args[0]) + minSeparators := uint64(0) + maxSeparators := uint64(0) + if listSize.Min > 0 { + minSeparators = listSize.Min - 1 + } + if listSize.Max > 0 { + maxSeparators = listSize.Max - 1 + } + sz = sz.Add(sepSize.Multiply(checker.SizeEstimate{Min: minSeparators, Max: maxSeparators})) + } + + return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor), ResultSize: &sz} + } + case "find", "findAll": + if target != nil && len(args) >= 1 { + sz := l.sizeEstimate(*target) + // Add one to string length for purposes of cost calculation to prevent product of string and regex to be 0 + // in case where string is empty but regex is still expensive. + strCost := sz.Add(checker.SizeEstimate{Min: 1, Max: 1}).MultiplyByCostFactor(common.StringTraversalCostFactor) + // We don't know how many expressions are in the regex, just the string length (a huge + // improvement here would be to somehow get a count the number of expressions in the regex or + // how many states are in the regex state machine and use that to measure regex cost). + // For now, we're making a guess that each expression in a regex is typically at least 4 chars + // in length. + regexCost := l.sizeEstimate(args[0]).MultiplyByCostFactor(common.RegexStringLengthCostFactor) + // worst case size of result is that every char is returned as separate find result. + return &checker.CallEstimate{CostEstimate: strCost.Multiply(regexCost), ResultSize: &checker.SizeEstimate{Min: 0, Max: sz.Max}} + } + } + return nil +} + +func actualSize(value ref.Val) uint64 { + if sz, ok := value.(traits.Sizer); ok { + return uint64(sz.Size().(types.Int)) + } + return 1 +} + +func (l *CostEstimator) sizeEstimate(t checker.AstNode) checker.SizeEstimate { + if sz := t.ComputedSize(); sz != nil { + return *sz + } + if sz := l.EstimateSize(t); sz != nil { + return *sz + } + return checker.SizeEstimate{Min: 0, Max: math.MaxUint64} +} + +func (l *CostEstimator) listElementNode(list checker.AstNode) checker.AstNode { + if lt := list.Type().GetListType(); lt != nil { + nodePath := list.Path() + if nodePath != nil { + // Provide path if we have it so that a OpenAPIv3 maxLength validation can be looked up, if it exists + // for this node. + path := make([]string, len(nodePath)+1) + copy(path, nodePath) + path[len(nodePath)] = "@items" + return &itemsNode{path: path, t: lt.GetElemType(), expr: nil} + } else { + // Provide just the type if no path is available so that worst case size can be looked up based on type. + return &itemsNode{t: lt.GetElemType(), expr: nil} + } + } + return nil +} + +func (l *CostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate { + if l.SizeEstimator != nil { + return l.SizeEstimator.EstimateSize(element) + } + return nil +} + +type itemsNode struct { + path []string + t *exprpb.Type + expr *exprpb.Expr +} + +func (i *itemsNode) Path() []string { + return i.path +} + +func (i *itemsNode) Type() *exprpb.Type { + return i.t +} + +func (i *itemsNode) Expr() *exprpb.Expr { + return i.expr +} + +func (i *itemsNode) ComputedSize() *checker.SizeEstimate { + return nil +} + +// traversalCost computes the cost of traversing a ref.Val as a data tree. +func traversalCost(v ref.Val) uint64 { + // TODO: This could potentially be optimized by sampling maps and lists instead of traversing. + switch vt := v.(type) { + case types.String: + return uint64(float64(len(string(vt))) * common.StringTraversalCostFactor) + case types.Bytes: + return uint64(float64(len([]byte(vt))) * common.StringTraversalCostFactor) + case traits.Lister: + cost := uint64(0) + for it := vt.Iterator(); it.HasNext() == types.True; { + i := it.Next() + cost += traversalCost(i) + } + return cost + case traits.Mapper: // maps and objects + cost := uint64(0) + for it := vt.Iterator(); it.HasNext() == types.True; { + k := it.Next() + cost += traversalCost(k) + traversalCost(vt.Get(k)) + } + return cost + default: + return 1 + } +} diff --git a/pkg/cel/library/cost_test.go b/pkg/cel/library/cost_test.go new file mode 100644 index 000000000..0b1e0020c --- /dev/null +++ b/pkg/cel/library/cost_test.go @@ -0,0 +1,363 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package library + +import ( + "fmt" + "testing" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/checker" + "github.com/google/cel-go/ext" + expr "google.golang.org/genproto/googleapis/api/expr/v1alpha1" +) + +const ( + intListLiteral = "[1, 2, 3, 4, 5]" + uintListLiteral = "[uint(1), uint(2), uint(3), uint(4), uint(5)]" + doubleListLiteral = "[1.0, 2.0, 3.0, 4.0, 5.0]" + boolListLiteral = "[false, true, false, true, false]" + stringListLiteral = "['012345678901', '012345678901', '012345678901', '012345678901', '012345678901']" + bytesListLiteral = "[bytes('012345678901'), bytes('012345678901'), bytes('012345678901'), bytes('012345678901'), bytes('012345678901')]" + durationListLiteral = "[duration('1s'), duration('2s'), duration('3s'), duration('4s'), duration('5s')]" + timestampListLiteral = "[timestamp('2011-01-01T00:00:00.000+01:00'), timestamp('2011-01-02T00:00:00.000+01:00'), " + + "timestamp('2011-01-03T00:00:00.000+01:00'), timestamp('2011-01-04T00:00:00.000+01:00'), " + + "timestamp('2011-01-05T00:00:00.000+01:00')]" + stringLiteral = "'01234567890123456789012345678901234567890123456789'" +) + +type comparableCost struct { + comparableLiteral string + expectedEstimatedCost checker.CostEstimate + expectedRuntimeCost uint64 + + param string +} + +func TestListsCost(t *testing.T) { + cases := []struct { + opts []string + costs []comparableCost + }{ + { + opts: []string{".sum()"}, + // 10 cost for the list declaration, the rest is the due to the function call + costs: []comparableCost{ + { + comparableLiteral: intListLiteral, + expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15, + }, + { + comparableLiteral: uintListLiteral, + expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for casts + }, + { + comparableLiteral: doubleListLiteral, + expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15, + }, + { + comparableLiteral: durationListLiteral, + expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for casts + }, + }, + }, + { + opts: []string{".isSorted()", ".max()", ".min()"}, + // 10 cost for the list declaration, the rest is the due to the function call + costs: []comparableCost{ + { + comparableLiteral: intListLiteral, + expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15, + }, + { + comparableLiteral: uintListLiteral, + expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for numeric casts + }, + { + comparableLiteral: doubleListLiteral, + expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15, + }, + { + comparableLiteral: boolListLiteral, + expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15, + }, + { + comparableLiteral: stringListLiteral, + expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 25}, expectedRuntimeCost: 15, // +5 for string comparisons + }, + { + comparableLiteral: bytesListLiteral, + expectedEstimatedCost: checker.CostEstimate{Min: 25, Max: 35}, expectedRuntimeCost: 25, // +10 for casts from string to byte, +5 for byte comparisons + }, + { + comparableLiteral: durationListLiteral, + expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for numeric casts + }, + { + comparableLiteral: timestampListLiteral, + expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for casts + }, + }, + }, + } + for _, tc := range cases { + for _, op := range tc.opts { + for _, typ := range tc.costs { + t.Run(typ.comparableLiteral+op, func(t *testing.T) { + e := typ.comparableLiteral + op + testCost(t, e, typ.expectedEstimatedCost, typ.expectedRuntimeCost) + }) + } + } + } +} + +func TestIndexOfCost(t *testing.T) { + cases := []struct { + opts []string + costs []comparableCost + }{ + { + opts: []string{".indexOf(%s)", ".lastIndexOf(%s)"}, + // 10 cost for the list declaration, the rest is the due to the function call + costs: []comparableCost{ + { + comparableLiteral: intListLiteral, param: "3", + expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15, + }, + { + comparableLiteral: uintListLiteral, param: "uint(3)", + expectedEstimatedCost: checker.CostEstimate{Min: 21, Max: 21}, expectedRuntimeCost: 21, // +5 for numeric casts + }, + { + comparableLiteral: doubleListLiteral, param: "3.0", + expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15, + }, + { + comparableLiteral: boolListLiteral, param: "true", + expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15, + }, + { + comparableLiteral: stringListLiteral, param: "'x'", + expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 25}, expectedRuntimeCost: 15, // +5 for string comparisons + }, + { + comparableLiteral: bytesListLiteral, param: "bytes('x')", + expectedEstimatedCost: checker.CostEstimate{Min: 26, Max: 36}, expectedRuntimeCost: 26, // +11 for casts from string to byte, +5 for byte comparisons + }, + { + comparableLiteral: durationListLiteral, param: "duration('3s')", + expectedEstimatedCost: checker.CostEstimate{Min: 21, Max: 21}, expectedRuntimeCost: 21, // +6 for casts from duration to byte + }, + { + comparableLiteral: timestampListLiteral, param: "timestamp('2011-01-03T00:00:00.000+01:00')", + expectedEstimatedCost: checker.CostEstimate{Min: 21, Max: 21}, expectedRuntimeCost: 21, // +6 for casts from timestamp to byte + }, + + // index of operations are also defined for strings + { + comparableLiteral: stringLiteral, param: "'123'", + expectedEstimatedCost: checker.CostEstimate{Min: 5, Max: 5}, expectedRuntimeCost: 5, + }, + }, + }, + } + for _, tc := range cases { + for _, op := range tc.opts { + for _, typ := range tc.costs { + opWithParam := fmt.Sprintf(op, typ.param) + t.Run(typ.comparableLiteral+opWithParam, func(t *testing.T) { + e := typ.comparableLiteral + opWithParam + testCost(t, e, typ.expectedEstimatedCost, typ.expectedRuntimeCost) + }) + } + } + } +} + +func TestURLsCost(t *testing.T) { + cases := []struct { + ops []string + expectEsimatedCost checker.CostEstimate + expectRuntimeCost uint64 + }{ + { + ops: []string{".getScheme()", ".getHostname()", ".getHost()", ".getPort()", ".getEscapedPath()", ".getQuery()"}, + expectEsimatedCost: checker.CostEstimate{Min: 4, Max: 4}, + expectRuntimeCost: 4, + }, + } + + for _, tc := range cases { + for _, op := range tc.ops { + t.Run("url."+op, func(t *testing.T) { + testCost(t, "url('https:://kubernetes.io/')"+op, tc.expectEsimatedCost, tc.expectRuntimeCost) + }) + } + } +} + +func TestStringLibrary(t *testing.T) { + cases := []struct { + name string + expr string + expectEsimatedCost checker.CostEstimate + expectRuntimeCost uint64 + }{ + { + name: "lowerAscii", + expr: "'ABCDEFGHIJ abcdefghij'.lowerAscii()", + expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3}, + expectRuntimeCost: 3, + }, + { + name: "upperAscii", + expr: "'ABCDEFGHIJ abcdefghij'.upperAscii()", + expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3}, + expectRuntimeCost: 3, + }, + { + name: "replace", + expr: "'abc 123 def 123'.replace('123', '456')", + expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3}, + expectRuntimeCost: 3, + }, + { + name: "replace with limit", + expr: "'abc 123 def 123'.replace('123', '456', 1)", + expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3}, + expectRuntimeCost: 3, + }, + { + name: "split", + expr: "'abc 123 def 123'.split(' ')", + expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3}, + expectRuntimeCost: 3, + }, + { + name: "split with limit", + expr: "'abc 123 def 123'.split(' ', 1)", + expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3}, + expectRuntimeCost: 3, + }, + { + name: "substring", + expr: "'abc 123 def 123'.substring(5)", + expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2}, + expectRuntimeCost: 2, + }, + { + name: "substring with end", + expr: "'abc 123 def 123'.substring(5, 8)", + expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2}, + expectRuntimeCost: 2, + }, + { + name: "trim", + expr: "' abc 123 def 123 '.trim()", + expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2}, + expectRuntimeCost: 2, + }, + { + name: "join with separator", + expr: "['aa', 'bb', 'cc', 'd', 'e', 'f', 'g', 'h', 'i', 'j'].join(' ')", + expectEsimatedCost: checker.CostEstimate{Min: 11, Max: 23}, + expectRuntimeCost: 15, + }, + { + name: "join", + expr: "['aa', 'bb', 'cc', 'd', 'e', 'f', 'g', 'h', 'i', 'j'].join()", + expectEsimatedCost: checker.CostEstimate{Min: 10, Max: 22}, + expectRuntimeCost: 13, + }, + { + name: "find", + expr: "'abc 123 def 123'.find('123')", + expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2}, + expectRuntimeCost: 2, + }, + { + name: "findAll", + expr: "'abc 123 def 123'.findAll('123')", + expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2}, + expectRuntimeCost: 2, + }, + { + name: "findAll with limit", + expr: "'abc 123 def 123'.findAll('123', 1)", + expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2}, + expectRuntimeCost: 2, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + testCost(t, tc.expr, tc.expectEsimatedCost, tc.expectRuntimeCost) + }) + } +} + +func testCost(t *testing.T, expr string, expectEsimatedCost checker.CostEstimate, expectRuntimeCost uint64) { + est := &CostEstimator{SizeEstimator: &testCostEstimator{}} + env, err := cel.NewEnv(append(k8sExtensionLibs, ext.Strings())...) + if err != nil { + t.Fatalf("%v", err) + } + compiled, issues := env.Compile(expr) + if len(issues.Errors()) > 0 { + t.Fatalf("%v", issues.Errors()) + } + estCost, err := env.EstimateCost(compiled, est) + if err != nil { + t.Fatalf("%v", err) + } + if estCost.Min != expectEsimatedCost.Min || estCost.Max != expectEsimatedCost.Max { + t.Errorf("Expected estimated cost of %d..%d but got %d..%d", expectEsimatedCost.Min, expectEsimatedCost.Max, estCost.Min, estCost.Max) + } + prog, err := env.Program(compiled, cel.CostTracking(est)) + if err != nil { + t.Fatalf("%v", err) + } + _, details, err := prog.Eval(map[string]interface{}{}) + if err != nil { + t.Fatalf("%v", err) + } + cost := details.ActualCost() + if *cost != expectRuntimeCost { + t.Errorf("Expected cost of %d but got %d", expectRuntimeCost, *cost) + } +} + +type testCostEstimator struct { +} + +func (t *testCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate { + switch t := element.Type().TypeKind.(type) { + case *expr.Type_Primitive: + switch t.Primitive { + case expr.Type_STRING: + return &checker.SizeEstimate{Min: 0, Max: 12} + case expr.Type_BYTES: + return &checker.SizeEstimate{Min: 0, Max: 12} + } + } + return nil +} + +func (t *testCostEstimator) EstimateCallCost(function, overloadId string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + return nil +} diff --git a/pkg/cel/library/libraries.go b/pkg/cel/library/libraries.go new file mode 100644 index 000000000..18f6d7a7c --- /dev/null +++ b/pkg/cel/library/libraries.go @@ -0,0 +1,34 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package library + +import ( + "github.com/google/cel-go/cel" + "github.com/google/cel-go/ext" + "github.com/google/cel-go/interpreter" +) + +// ExtensionLibs declares the set of CEL extension libraries available everywhere CEL is used in Kubernetes. +var ExtensionLibs = append(k8sExtensionLibs, ext.Strings()) + +var k8sExtensionLibs = []cel.EnvOption{ + URLs(), + Regex(), + Lists(), +} + +var ExtensionLibRegexOptimizations = []*interpreter.RegexOptimization{FindRegexOptimization, FindAllRegexOptimization} diff --git a/pkg/cel/library/library_compatibility_test.go b/pkg/cel/library/library_compatibility_test.go new file mode 100644 index 000000000..65473ff09 --- /dev/null +++ b/pkg/cel/library/library_compatibility_test.go @@ -0,0 +1,58 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package library + +import ( + "testing" + + "github.com/google/cel-go/cel" +) + +func TestLibraryCompatibility(t *testing.T) { + functionNames := map[string]struct{}{} + + decls := map[cel.Library]map[string][]cel.FunctionOpt{ + urlsLib: urlLibraryDecls, + listsLib: listsLibraryDecls, + regexLib: regexLibraryDecls, + } + if len(k8sExtensionLibs) != len(decls) { + t.Errorf("Expected the same number of libraries in the ExtensionLibs as are tested for compatibility") + } + for _, decl := range decls { + for name := range decl { + functionNames[name] = struct{}{} + } + } + + // WARN: All library changes must follow + // https://github.com/kubernetes/enhancements/tree/master/keps/sig-api-machinery/2876-crd-validation-expression-language#function-library-updates + // and must track the functions here along with which Kubernetes version introduced them. + knownFunctions := []string{ + // Kubernetes 1.24: + "isSorted", "sum", "max", "min", "indexOf", "lastIndexOf", "find", "findAll", "url", "getScheme", "getHost", "getHostname", + "getPort", "getEscapedPath", "getQuery", "isURL", + // Kubernetes <1.??>: + } + for _, fn := range knownFunctions { + delete(functionNames, fn) + } + + if len(functionNames) != 0 { + t.Errorf("Expected all functions in the libraries to be assigned to a kubernetes release, but found the unassigned function names: %v", functionNames) + } +} diff --git a/pkg/cel/library/lists.go b/pkg/cel/library/lists.go new file mode 100644 index 000000000..fe51dc87f --- /dev/null +++ b/pkg/cel/library/lists.go @@ -0,0 +1,312 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package library + +import ( + "fmt" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" + "github.com/google/cel-go/interpreter/functions" +) + +// Lists provides a CEL function library extension of list utility functions. +// +// isSorted +// +// Returns true if the provided list of comparable elements is sorted, else returns false. +// +// >.isSorted() , T must be a comparable type +// +// Examples: +// +// [1, 2, 3].isSorted() // return true +// ['a', 'b', 'b', 'c'].isSorted() // return true +// [2.0, 1.0].isSorted() // return false +// [1].isSorted() // return true +// [].isSorted() // return true +// +// sum +// +// Returns the sum of the elements of the provided list. Supports CEL number (int, uint, double) and duration types. +// +// >.sum() , T must be a numeric type or a duration +// +// Examples: +// +// [1, 3].sum() // returns 4 +// [1.0, 3.0].sum() // returns 4.0 +// ['1m', '1s'].sum() // returns '1m1s' +// emptyIntList.sum() // returns 0 +// emptyDoubleList.sum() // returns 0.0 +// [].sum() // returns 0 +// +// min / max +// +// Returns the minimum/maximum valued element of the provided list. Supports all comparable types. +// If the list is empty, an error is returned. +// +// >.min() , T must be a comparable type +// >.max() , T must be a comparable type +// +// Examples: +// +// [1, 3].min() // returns 1 +// [1, 3].max() // returns 3 +// [].min() // error +// [1].min() // returns 1 +// ([0] + emptyList).min() // returns 0 +// +// indexOf / lastIndexOf +// +// Returns either the first or last positional index of the provided element in the list. +// If the element is not found, -1 is returned. Supports all equatable types. +// +// >.indexOf() , T must be an equatable type +// >.lastIndexOf() , T must be an equatable type +// +// Examples: +// +// [1, 2, 2, 3].indexOf(2) // returns 1 +// ['a', 'b', 'b', 'c'].lastIndexOf('b') // returns 2 +// [1.0].indexOf(1.1) // returns -1 +// [].indexOf('string') // returns -1 +func Lists() cel.EnvOption { + return cel.Lib(listsLib) +} + +var listsLib = &lists{} + +type lists struct{} + +var paramA = cel.TypeParamType("A") + +// CEL typeParams can be used to constraint to a specific trait (e.g. traits.ComparableType) if the 1st operand is the type to constrain. +// But the functions we need to constrain are >, not just . +// Make sure the order of overload set is deterministic +type namedCELType struct { + typeName string + celType *cel.Type +} + +var summableTypes = []namedCELType{ + {typeName: "int", celType: cel.IntType}, + {typeName: "uint", celType: cel.UintType}, + {typeName: "double", celType: cel.DoubleType}, + {typeName: "duration", celType: cel.DurationType}, +} + +var zeroValuesOfSummableTypes = map[string]ref.Val{ + "int": types.Int(0), + "uint": types.Uint(0), + "double": types.Double(0.0), + "duration": types.Duration{Duration: 0}, +} +var comparableTypes = []namedCELType{ + {typeName: "int", celType: cel.IntType}, + {typeName: "uint", celType: cel.UintType}, + {typeName: "double", celType: cel.DoubleType}, + {typeName: "bool", celType: cel.BoolType}, + {typeName: "duration", celType: cel.DurationType}, + {typeName: "timestamp", celType: cel.TimestampType}, + {typeName: "string", celType: cel.StringType}, + {typeName: "bytes", celType: cel.BytesType}, +} + +// WARNING: All library additions or modifications must follow +// https://github.com/kubernetes/enhancements/tree/master/keps/sig-api-machinery/2876-crd-validation-expression-language#function-library-updates +var listsLibraryDecls = map[string][]cel.FunctionOpt{ + "isSorted": templatedOverloads(comparableTypes, func(name string, paramType *cel.Type) cel.FunctionOpt { + return cel.MemberOverload(fmt.Sprintf("list_%s_is_sorted_bool", name), + []*cel.Type{cel.ListType(paramType)}, cel.BoolType, cel.UnaryBinding(isSorted)) + }), + "sum": templatedOverloads(summableTypes, func(name string, paramType *cel.Type) cel.FunctionOpt { + return cel.MemberOverload(fmt.Sprintf("list_%s_sum_%s", name, name), + []*cel.Type{cel.ListType(paramType)}, paramType, cel.UnaryBinding(func(list ref.Val) ref.Val { + return sum( + func() ref.Val { + return zeroValuesOfSummableTypes[name] + })(list) + })) + }), + "max": templatedOverloads(comparableTypes, func(name string, paramType *cel.Type) cel.FunctionOpt { + return cel.MemberOverload(fmt.Sprintf("list_%s_max_%s", name, name), + []*cel.Type{cel.ListType(paramType)}, paramType, cel.UnaryBinding(max())) + }), + "min": templatedOverloads(comparableTypes, func(name string, paramType *cel.Type) cel.FunctionOpt { + return cel.MemberOverload(fmt.Sprintf("list_%s_min_%s", name, name), + []*cel.Type{cel.ListType(paramType)}, paramType, cel.UnaryBinding(min())) + }), + "indexOf": { + cel.MemberOverload("list_a_index_of_int", []*cel.Type{cel.ListType(paramA), paramA}, cel.IntType, + cel.BinaryBinding(indexOf)), + }, + "lastIndexOf": { + cel.MemberOverload("list_a_last_index_of_int", []*cel.Type{cel.ListType(paramA), paramA}, cel.IntType, + cel.BinaryBinding(lastIndexOf)), + }, +} + +func (*lists) CompileOptions() []cel.EnvOption { + options := []cel.EnvOption{} + for name, overloads := range listsLibraryDecls { + options = append(options, cel.Function(name, overloads...)) + } + return options +} + +func (*lists) ProgramOptions() []cel.ProgramOption { + return []cel.ProgramOption{} +} + +func isSorted(val ref.Val) ref.Val { + var prev traits.Comparer + iterable, ok := val.(traits.Iterable) + if !ok { + return types.MaybeNoSuchOverloadErr(val) + } + for it := iterable.Iterator(); it.HasNext() == types.True; { + next := it.Next() + nextCmp, ok := next.(traits.Comparer) + if !ok { + return types.MaybeNoSuchOverloadErr(next) + } + if prev != nil { + cmp := prev.Compare(next) + if cmp == types.IntOne { + return types.False + } + } + prev = nextCmp + } + return types.True +} + +func sum(init func() ref.Val) functions.UnaryOp { + return func(val ref.Val) ref.Val { + i := init() + acc, ok := i.(traits.Adder) + if !ok { + // Should never happen since all passed in init values are valid + return types.MaybeNoSuchOverloadErr(i) + } + iterable, ok := val.(traits.Iterable) + if !ok { + return types.MaybeNoSuchOverloadErr(val) + } + for it := iterable.Iterator(); it.HasNext() == types.True; { + next := it.Next() + nextAdder, ok := next.(traits.Adder) + if !ok { + // Should never happen for type checked CEL programs + return types.MaybeNoSuchOverloadErr(next) + } + if acc != nil { + s := acc.Add(next) + sum, ok := s.(traits.Adder) + if !ok { + // Should never happen for type checked CEL programs + return types.MaybeNoSuchOverloadErr(s) + } + acc = sum + } else { + acc = nextAdder + } + } + return acc.(ref.Val) + } +} + +func min() functions.UnaryOp { + return cmp("min", types.IntOne) +} + +func max() functions.UnaryOp { + return cmp("max", types.IntNegOne) +} + +func cmp(opName string, opPreferCmpResult ref.Val) functions.UnaryOp { + return func(val ref.Val) ref.Val { + var result traits.Comparer + iterable, ok := val.(traits.Iterable) + if !ok { + return types.MaybeNoSuchOverloadErr(val) + } + for it := iterable.Iterator(); it.HasNext() == types.True; { + next := it.Next() + nextCmp, ok := next.(traits.Comparer) + if !ok { + // Should never happen for type checked CEL programs + return types.MaybeNoSuchOverloadErr(next) + } + if result == nil { + result = nextCmp + } else { + cmp := result.Compare(next) + if cmp == opPreferCmpResult { + result = nextCmp + } + } + } + if result == nil { + return types.NewErr("%s called on empty list", opName) + } + return result.(ref.Val) + } +} + +func indexOf(list ref.Val, item ref.Val) ref.Val { + lister, ok := list.(traits.Lister) + if !ok { + return types.MaybeNoSuchOverloadErr(list) + } + sz := lister.Size().(types.Int) + for i := types.Int(0); i < sz; i++ { + if lister.Get(types.Int(i)).Equal(item) == types.True { + return types.Int(i) + } + } + return types.Int(-1) +} + +func lastIndexOf(list ref.Val, item ref.Val) ref.Val { + lister, ok := list.(traits.Lister) + if !ok { + return types.MaybeNoSuchOverloadErr(list) + } + sz := lister.Size().(types.Int) + for i := sz - 1; i >= 0; i-- { + if lister.Get(types.Int(i)).Equal(item) == types.True { + return types.Int(i) + } + } + return types.Int(-1) +} + +// templatedOverloads returns overloads for each of the provided types. The template function is called with each type +// name (map key) and type to construct the overloads. +func templatedOverloads(types []namedCELType, template func(name string, t *cel.Type) cel.FunctionOpt) []cel.FunctionOpt { + overloads := make([]cel.FunctionOpt, len(types)) + i := 0 + for _, t := range types { + overloads[i] = template(t.typeName, t.celType) + i++ + } + return overloads +} diff --git a/pkg/cel/library/regex.go b/pkg/cel/library/regex.go new file mode 100644 index 000000000..6db5ef195 --- /dev/null +++ b/pkg/cel/library/regex.go @@ -0,0 +1,187 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package library + +import ( + "regexp" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/interpreter" +) + +// Regex provides a CEL function library extension of regex utility functions. +// +// find / findAll +// +// Returns substrings that match the provided regular expression. find returns the first match. findAll may optionally +// be provided a limit. If the limit is set and >= 0, no more than the limit number of matches are returned. +// +// .find() +// .findAll() > +// .findAll(, ) > +// +// Examples: +// +// "abc 123".find('[0-9]*') // returns '123' +// "abc 123".find('xyz') // returns '' +// "123 abc 456".findAll('[0-9]*') // returns ['123', '456'] +// "123 abc 456".findAll('[0-9]*', 1) // returns ['123'] +// "123 abc 456".findAll('xyz') // returns [] +func Regex() cel.EnvOption { + return cel.Lib(regexLib) +} + +var regexLib = ®ex{} + +type regex struct{} + +var regexLibraryDecls = map[string][]cel.FunctionOpt{ + "find": { + cel.MemberOverload("string_find_string", []*cel.Type{cel.StringType, cel.StringType}, cel.StringType, + cel.BinaryBinding(find))}, + "findAll": { + cel.MemberOverload("string_find_all_string", []*cel.Type{cel.StringType, cel.StringType}, + cel.ListType(cel.StringType), + cel.BinaryBinding(func(str, regex ref.Val) ref.Val { + return findAll(str, regex, types.Int(-1)) + })), + cel.MemberOverload("string_find_all_string_int", + []*cel.Type{cel.StringType, cel.StringType, cel.IntType}, + cel.ListType(cel.StringType), + cel.FunctionBinding(findAll)), + }, +} + +func (*regex) CompileOptions() []cel.EnvOption { + options := []cel.EnvOption{} + for name, overloads := range regexLibraryDecls { + options = append(options, cel.Function(name, overloads...)) + } + return options +} + +func (*regex) ProgramOptions() []cel.ProgramOption { + return []cel.ProgramOption{} +} + +func find(strVal ref.Val, regexVal ref.Val) ref.Val { + str, ok := strVal.Value().(string) + if !ok { + return types.MaybeNoSuchOverloadErr(strVal) + } + regex, ok := regexVal.Value().(string) + if !ok { + return types.MaybeNoSuchOverloadErr(regexVal) + } + re, err := regexp.Compile(regex) + if err != nil { + return types.NewErr("Illegal regex: %v", err.Error()) + } + result := re.FindString(str) + return types.String(result) +} + +func findAll(args ...ref.Val) ref.Val { + argn := len(args) + if argn < 2 || argn > 3 { + return types.NoSuchOverloadErr() + } + str, ok := args[0].Value().(string) + if !ok { + return types.MaybeNoSuchOverloadErr(args[0]) + } + regex, ok := args[1].Value().(string) + if !ok { + return types.MaybeNoSuchOverloadErr(args[1]) + } + n := int64(-1) + if argn == 3 { + n, ok = args[2].Value().(int64) + if !ok { + return types.MaybeNoSuchOverloadErr(args[2]) + } + } + + re, err := regexp.Compile(regex) + if err != nil { + return types.NewErr("Illegal regex: %v", err.Error()) + } + + result := re.FindAllString(str, int(n)) + + return types.NewStringList(types.DefaultTypeAdapter, result) +} + +// FindRegexOptimization optimizes the 'find' function by compiling the regex pattern and +// reporting any compilation errors at program creation time, and using the compiled regex pattern for all function +// call invocations. +var FindRegexOptimization = &interpreter.RegexOptimization{ + Function: "find", + RegexIndex: 1, + Factory: func(call interpreter.InterpretableCall, regexPattern string) (interpreter.InterpretableCall, error) { + compiledRegex, err := regexp.Compile(regexPattern) + if err != nil { + return nil, err + } + return interpreter.NewCall(call.ID(), call.Function(), call.OverloadID(), call.Args(), func(args ...ref.Val) ref.Val { + if len(args) != 2 { + return types.NoSuchOverloadErr() + } + in, ok := args[0].Value().(string) + if !ok { + return types.MaybeNoSuchOverloadErr(args[0]) + } + return types.String(compiledRegex.FindString(in)) + }), nil + }, +} + +// FindAllRegexOptimization optimizes the 'findAll' function by compiling the regex pattern and +// reporting any compilation errors at program creation time, and using the compiled regex pattern for all function +// call invocations. +var FindAllRegexOptimization = &interpreter.RegexOptimization{ + Function: "findAll", + RegexIndex: 1, + Factory: func(call interpreter.InterpretableCall, regexPattern string) (interpreter.InterpretableCall, error) { + compiledRegex, err := regexp.Compile(regexPattern) + if err != nil { + return nil, err + } + return interpreter.NewCall(call.ID(), call.Function(), call.OverloadID(), call.Args(), func(args ...ref.Val) ref.Val { + argn := len(args) + if argn < 2 || argn > 3 { + return types.NoSuchOverloadErr() + } + str, ok := args[0].Value().(string) + if !ok { + return types.MaybeNoSuchOverloadErr(args[0]) + } + n := int64(-1) + if argn == 3 { + n, ok = args[2].Value().(int64) + if !ok { + return types.MaybeNoSuchOverloadErr(args[2]) + } + } + + result := compiledRegex.FindAllString(str, int(n)) + return types.NewStringList(types.DefaultTypeAdapter, result) + }), nil + }, +} diff --git a/pkg/cel/library/urls.go b/pkg/cel/library/urls.go new file mode 100644 index 000000000..afe80f493 --- /dev/null +++ b/pkg/cel/library/urls.go @@ -0,0 +1,236 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package library + +import ( + "net/url" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + + apiservercel "k8s.io/apiserver/pkg/cel" +) + +// URLs provides a CEL function library extension of URL parsing functions. +// +// url +// +// Converts a string to a URL or results in an error if the string is not a valid URL. The URL must be an absolute URI +// or an absolute path. +// +// url() +// +// Examples: +// +// url('https://user:pass@example.com:80/path?query=val#fragment') // returns a URL +// url('/absolute-path') // returns a URL +// url('https://a:b:c/') // error +// url('../relative-path') // error +// +// isURL +// +// Returns true if a string is a valid URL. The URL must be an absolute URI or an absolute path. +// +// isURL( ) +// +// Examples: +// +// isURL('https://user:pass@example.com:80/path?query=val#fragment') // returns true +// isURL('/absolute-path') // returns true +// isURL('https://a:b:c/') // returns false +// isURL('../relative-path') // returns false +// +// getScheme / getHost / getHostname / getPort / getEscapedPath / getQuery +// +// Return the parsed components of a URL. +// +// - getScheme: If absent in the URL, returns an empty string. +// +// - getHostname: IPv6 addresses are returned with braces, e.g. "[::1]". If absent in the URL, returns an empty string. +// +// - getHost: IPv6 addresses are returned without braces, e.g. "::1". If absent in the URL, returns an empty string. +// +// - getEscapedPath: The string returned by getEscapedPath is URL escaped, e.g. "with space" becomes "with%20space". +// If absent in the URL, returns an empty string. +// +// - getPort: If absent in the URL, returns an empty string. +// +// - getQuery: Returns the query parameters in "matrix" form where a repeated query key is interpreted to +// mean that there are multiple values for that key. The keys and values are returned unescaped. +// If absent in the URL, returns an empty map. +// +// .getScheme() +// .getHost() +// .getHostname() +// .getPort() +// .getEscapedPath() +// .getQuery() , > +// +// Examples: +// +// url('/path').getScheme() // returns '' +// url('https://example.com/').getScheme() // returns 'https' +// url('https://example.com:80/').getHost() // returns 'example.com:80' +// url('https://example.com/').getHost() // returns 'example.com' +// url('https://[::1]:80/').getHost() // returns '[::1]:80' +// url('https://[::1]/').getHost() // returns '[::1]' +// url('/path').getHost() // returns '' +// url('https://example.com:80/').getHostname() // returns 'example.com' +// url('https://127.0.0.1:80/').getHostname() // returns '127.0.0.1' +// url('https://[::1]:80/').getHostname() // returns '::1' +// url('/path').getHostname() // returns '' +// url('https://example.com:80/').getPort() // returns '80' +// url('https://example.com/').getPort() // returns '' +// url('/path').getPort() // returns '' +// url('https://example.com/path').getEscapedPath() // returns '/path' +// url('https://example.com/path with spaces/').getEscapedPath() // returns '/path%20with%20spaces/' +// url('https://example.com').getEscapedPath() // returns '' +// url('https://example.com/path?k1=a&k2=b&k2=c').getQuery() // returns { 'k1': ['a'], 'k2': ['b', 'c']} +// url('https://example.com/path?key with spaces=value with spaces').getQuery() // returns { 'key with spaces': ['value with spaces']} +// url('https://example.com/path?').getQuery() // returns {} +// url('https://example.com/path').getQuery() // returns {} +func URLs() cel.EnvOption { + return cel.Lib(urlsLib) +} + +var urlsLib = &urls{} + +type urls struct{} + +var urlLibraryDecls = map[string][]cel.FunctionOpt{ + "url": { + cel.Overload("string_to_url", []*cel.Type{cel.StringType}, apiservercel.URLType, + cel.UnaryBinding(stringToUrl))}, + "getScheme": { + cel.MemberOverload("url_get_scheme", []*cel.Type{apiservercel.URLType}, cel.StringType, + cel.UnaryBinding(getScheme))}, + "getHost": { + cel.MemberOverload("url_get_host", []*cel.Type{apiservercel.URLType}, cel.StringType, + cel.UnaryBinding(getHost))}, + "getHostname": { + cel.MemberOverload("url_get_hostname", []*cel.Type{apiservercel.URLType}, cel.StringType, + cel.UnaryBinding(getHostname))}, + "getPort": { + cel.MemberOverload("url_get_port", []*cel.Type{apiservercel.URLType}, cel.StringType, + cel.UnaryBinding(getPort))}, + "getEscapedPath": { + cel.MemberOverload("url_get_escaped_path", []*cel.Type{apiservercel.URLType}, cel.StringType, + cel.UnaryBinding(getEscapedPath))}, + "getQuery": { + cel.MemberOverload("url_get_query", []*cel.Type{apiservercel.URLType}, + cel.MapType(cel.StringType, cel.ListType(cel.StringType)), + cel.UnaryBinding(getQuery))}, + "isURL": { + cel.Overload("is_url_string", []*cel.Type{cel.StringType}, cel.BoolType, + cel.UnaryBinding(isURL))}, +} + +func (*urls) CompileOptions() []cel.EnvOption { + options := []cel.EnvOption{} + for name, overloads := range urlLibraryDecls { + options = append(options, cel.Function(name, overloads...)) + } + return options +} + +func (*urls) ProgramOptions() []cel.ProgramOption { + return []cel.ProgramOption{} +} + +func stringToUrl(arg ref.Val) ref.Val { + s, ok := arg.Value().(string) + if !ok { + return types.MaybeNoSuchOverloadErr(arg) + } + // Use ParseRequestURI to check the URL before conversion. + // ParseRequestURI requires absolute URLs and is used by the OpenAPIv3 'uri' format. + _, err := url.ParseRequestURI(s) + if err != nil { + return types.NewErr("URL parse error during conversion from string: %v", err) + } + // We must parse again with Parse since ParseRequestURI incorrectly parses URLs that contain a fragment + // part and will incorrectly append the fragment to either the path or the query, depending on which it was adjacent to. + u, err := url.Parse(s) + if err != nil { + // Errors are not expected here since Parse is a more lenient parser than ParseRequestURI. + return types.NewErr("URL parse error during conversion from string: %v", err) + } + return apiservercel.URL{URL: u} +} + +func getScheme(arg ref.Val) ref.Val { + u, ok := arg.Value().(*url.URL) + if !ok { + return types.MaybeNoSuchOverloadErr(arg) + } + return types.String(u.Scheme) +} + +func getHost(arg ref.Val) ref.Val { + u, ok := arg.Value().(*url.URL) + if !ok { + return types.MaybeNoSuchOverloadErr(arg) + } + return types.String(u.Host) +} + +func getHostname(arg ref.Val) ref.Val { + u, ok := arg.Value().(*url.URL) + if !ok { + return types.MaybeNoSuchOverloadErr(arg) + } + return types.String(u.Hostname()) +} + +func getPort(arg ref.Val) ref.Val { + u, ok := arg.Value().(*url.URL) + if !ok { + return types.MaybeNoSuchOverloadErr(arg) + } + return types.String(u.Port()) +} + +func getEscapedPath(arg ref.Val) ref.Val { + u, ok := arg.Value().(*url.URL) + if !ok { + return types.MaybeNoSuchOverloadErr(arg) + } + return types.String(u.EscapedPath()) +} + +func getQuery(arg ref.Val) ref.Val { + u, ok := arg.Value().(*url.URL) + if !ok { + return types.MaybeNoSuchOverloadErr(arg) + } + + result := map[ref.Val]ref.Val{} + for k, v := range u.Query() { + result[types.String(k)] = types.NewStringList(types.DefaultTypeAdapter, v) + } + return types.NewRefValMap(types.DefaultTypeAdapter, result) +} + +func isURL(arg ref.Val) ref.Val { + s, ok := arg.Value().(string) + if !ok { + return types.MaybeNoSuchOverloadErr(arg) + } + _, err := url.ParseRequestURI(s) + return types.Bool(err == nil) +} diff --git a/pkg/cel/limits.go b/pkg/cel/limits.go new file mode 100644 index 000000000..c38a47cea --- /dev/null +++ b/pkg/cel/limits.go @@ -0,0 +1,49 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cel + +const ( + // MaxRequestSizeBytes is the largest request that will be accepted is 3MB + // TODO(DangerOnTheRanger): wire in MaxRequestBodyBytes from apiserver/pkg/server/options/server_run_options.go to make this configurable + MaxRequestSizeBytes = int64(3 * 1024 * 1024) + + // MaxDurationSizeJSON + // OpenAPI duration strings follow RFC 3339, section 5.6 - see the comment on maxDatetimeSizeJSON + MaxDurationSizeJSON = 32 + // MaxDatetimeSizeJSON + // OpenAPI datetime strings follow RFC 3339, section 5.6, and the longest possible + // such string is 9999-12-31T23:59:59.999999999Z, which has length 30 - we add 2 + // to allow for quotation marks + MaxDatetimeSizeJSON = 32 + // MinDurationSizeJSON + // Golang allows a string of 0 to be parsed as a duration, so that plus 2 to account for + // quotation marks makes 3 + MinDurationSizeJSON = 3 + // JSONDateSize is the size of a date serialized as part of a JSON object + // RFC 3339 dates require YYYY-MM-DD, and then we add 2 to allow for quotation marks + JSONDateSize = 12 + // MinDatetimeSizeJSON is the minimal length of a datetime formatted as RFC 3339 + // RFC 3339 datetimes require a full date (YYYY-MM-DD) and full time (HH:MM:SS), and we add 3 for + // quotation marks like always in addition to the capital T that separates the date and time + MinDatetimeSizeJSON = 21 + // MinStringSize is the size of literal "" + MinStringSize = 2 + // MinBoolSize is the length of literal true + MinBoolSize = 4 + // MinNumberSize is the length of literal 0 + MinNumberSize = 1 +) diff --git a/pkg/cel/metrics/metrics.go b/pkg/cel/metrics/metrics.go new file mode 100644 index 000000000..3ddb76cdf --- /dev/null +++ b/pkg/cel/metrics/metrics.go @@ -0,0 +1,72 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package metrics + +import ( + "time" + + "k8s.io/component-base/metrics" + "k8s.io/component-base/metrics/legacyregistry" +) + +// TODO(jiahuif) CEL is to be used in multiple components, revise naming when that happens. +const ( + namespace = "apiserver" + subsystem = "cel" +) + +// Metrics provides access to CEL metrics. +var Metrics = newCelMetrics() + +type CelMetrics struct { + compilationTime *metrics.Histogram + evaluationTime *metrics.Histogram +} + +func newCelMetrics() *CelMetrics { + m := &CelMetrics{ + compilationTime: metrics.NewHistogram(&metrics.HistogramOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "compilation_duration_seconds", + StabilityLevel: metrics.ALPHA, + }), + evaluationTime: metrics.NewHistogram(&metrics.HistogramOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "evaluation_duration_seconds", + StabilityLevel: metrics.ALPHA, + }), + } + + legacyregistry.MustRegister(m.compilationTime) + legacyregistry.MustRegister(m.evaluationTime) + + return m +} + +// ObserveCompilation records a CEL compilation with the time the compilation took. +func (m *CelMetrics) ObserveCompilation(elapsed time.Duration) { + seconds := elapsed.Seconds() + m.compilationTime.Observe(seconds) +} + +// ObserveEvaluation records a CEL evaluation with the time the evaluation took. +func (m *CelMetrics) ObserveEvaluation(elapsed time.Duration) { + seconds := elapsed.Seconds() + m.evaluationTime.Observe(seconds) +} diff --git a/pkg/cel/metrics/metrics_test.go b/pkg/cel/metrics/metrics_test.go new file mode 100644 index 000000000..11fb23563 --- /dev/null +++ b/pkg/cel/metrics/metrics_test.go @@ -0,0 +1,68 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package metrics + +import ( + "math" + "testing" + "time" + + "k8s.io/component-base/metrics/legacyregistry" +) + +func TestObserveCompilation(t *testing.T) { + defer legacyregistry.Reset() + Metrics.ObserveCompilation(2 * time.Second) + c, s := gatherHistogram(t, "apiserver_cel_compilation_duration_seconds") + if c != 1 { + t.Errorf("unexpected count: %v", c) + } + if math.Abs(s-2.0) > 1e-7 { + t.Fatalf("incorrect sum: %v", s) + } +} + +func TestObserveEvaluation(t *testing.T) { + defer legacyregistry.Reset() + Metrics.ObserveEvaluation(2 * time.Second) + c, s := gatherHistogram(t, "apiserver_cel_evaluation_duration_seconds") + if c != 1 { + t.Errorf("unexpected count: %v", c) + } + if math.Abs(s-2.0) > 1e-7 { + t.Fatalf("incorrect sum: %v", s) + } +} + +func gatherHistogram(t *testing.T, name string) (count uint64, sum float64) { + metrics, err := legacyregistry.DefaultGatherer.Gather() + if err != nil { + t.Fatalf("Failed to gather metrics: %s", err) + } + for _, mf := range metrics { + if mf.GetName() == name { + for _, m := range mf.GetMetric() { + h := m.GetHistogram() + count += h.GetSampleCount() + sum += h.GetSampleSum() + } + return + } + } + t.Fatalf("metric not found: %v", name) + return 0, 0 +} diff --git a/pkg/cel/registry.go b/pkg/cel/registry.go new file mode 100644 index 000000000..1aee3a127 --- /dev/null +++ b/pkg/cel/registry.go @@ -0,0 +1,79 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cel + +import ( + "sync" + + "github.com/google/cel-go/cel" +) + +// Resolver declares methods to find policy templates and related configuration objects. +type Resolver interface { + // FindType returns a DeclType instance corresponding to the given fully-qualified name, if + // present. + FindType(name string) (*DeclType, bool) +} + +// NewRegistry create a registry for keeping track of environments and types +// from a base cel.Env expression environment. +func NewRegistry(stdExprEnv *cel.Env) *Registry { + return &Registry{ + exprEnvs: map[string]*cel.Env{"": stdExprEnv}, + types: map[string]*DeclType{ + BoolType.TypeName(): BoolType, + BytesType.TypeName(): BytesType, + DoubleType.TypeName(): DoubleType, + DurationType.TypeName(): DurationType, + IntType.TypeName(): IntType, + NullType.TypeName(): NullType, + StringType.TypeName(): StringType, + TimestampType.TypeName(): TimestampType, + UintType.TypeName(): UintType, + ListType.TypeName(): ListType, + MapType.TypeName(): MapType, + }, + } +} + +// Registry defines a repository of environment, schema, template, and type definitions. +// +// Registry instances are concurrency-safe. +type Registry struct { + rwMux sync.RWMutex + exprEnvs map[string]*cel.Env + types map[string]*DeclType +} + +// FindType implements the Resolver interface method. +func (r *Registry) FindType(name string) (*DeclType, bool) { + r.rwMux.RLock() + defer r.rwMux.RUnlock() + typ, found := r.types[name] + if found { + return typ, true + } + return typ, found +} + +// SetType registers a DeclType descriptor by its fully qualified name. +func (r *Registry) SetType(name string, declType *DeclType) error { + r.rwMux.Lock() + defer r.rwMux.Unlock() + r.types[name] = declType + return nil +} diff --git a/pkg/cel/types.go b/pkg/cel/types.go new file mode 100644 index 000000000..13171ad21 --- /dev/null +++ b/pkg/cel/types.go @@ -0,0 +1,552 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cel + +import ( + "fmt" + "math" + "time" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" + + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + "google.golang.org/protobuf/proto" +) + +const ( + noMaxLength = math.MaxInt +) + +// NewListType returns a parameterized list type with a specified element type. +func NewListType(elem *DeclType, maxItems int64) *DeclType { + return &DeclType{ + name: "list", + ElemType: elem, + MaxElements: maxItems, + celType: cel.ListType(elem.CelType()), + defaultValue: NewListValue(), + // a list can always be represented as [] in JSON, so hardcode the min size + // to 2 + MinSerializedSize: 2, + } +} + +// NewMapType returns a parameterized map type with the given key and element types. +func NewMapType(key, elem *DeclType, maxProperties int64) *DeclType { + return &DeclType{ + name: "map", + KeyType: key, + ElemType: elem, + MaxElements: maxProperties, + celType: cel.MapType(key.CelType(), elem.CelType()), + defaultValue: NewMapValue(), + // a map can always be represented as {} in JSON, so hardcode the min size + // to 2 + MinSerializedSize: 2, + } +} + +// NewObjectType creates an object type with a qualified name and a set of field declarations. +func NewObjectType(name string, fields map[string]*DeclField) *DeclType { + t := &DeclType{ + name: name, + Fields: fields, + celType: cel.ObjectType(name), + traitMask: traits.FieldTesterType | traits.IndexerType, + // an object could potentially be larger than the min size we default to here ({}), + // but we rely upon the caller to change MinSerializedSize accordingly if they add + // properties to the object + MinSerializedSize: 2, + } + t.defaultValue = NewObjectValue(t) + return t +} + +func NewSimpleTypeWithMinSize(name string, celType *cel.Type, zeroVal ref.Val, minSize int64) *DeclType { + return &DeclType{ + name: name, + celType: celType, + defaultValue: zeroVal, + MinSerializedSize: minSize, + } +} + +// DeclType represents the universal type descriptor for OpenAPIv3 types. +type DeclType struct { + fmt.Stringer + + name string + // Fields contains a map of escaped CEL identifier field names to field declarations. + Fields map[string]*DeclField + KeyType *DeclType + ElemType *DeclType + TypeParam bool + Metadata map[string]string + MaxElements int64 + // MinSerializedSize represents the smallest possible size in bytes that + // the DeclType could be serialized to in JSON. + MinSerializedSize int64 + + celType *cel.Type + traitMask int + defaultValue ref.Val +} + +// MaybeAssignTypeName attempts to set the DeclType name to a fully qualified name, if the type +// is of `object` type. +// +// The DeclType must return true for `IsObject` or this assignment will error. +func (t *DeclType) MaybeAssignTypeName(name string) *DeclType { + if t.IsObject() { + objUpdated := false + if t.name != "object" { + name = t.name + } else { + objUpdated = true + } + fieldMap := make(map[string]*DeclField, len(t.Fields)) + for fieldName, field := range t.Fields { + fieldType := field.Type + fieldTypeName := fmt.Sprintf("%s.%s", name, fieldName) + updated := fieldType.MaybeAssignTypeName(fieldTypeName) + if updated == fieldType { + fieldMap[fieldName] = field + continue + } + objUpdated = true + fieldMap[fieldName] = &DeclField{ + Name: fieldName, + Type: updated, + Required: field.Required, + enumValues: field.enumValues, + defaultValue: field.defaultValue, + } + } + if !objUpdated { + return t + } + return &DeclType{ + name: name, + Fields: fieldMap, + KeyType: t.KeyType, + ElemType: t.ElemType, + TypeParam: t.TypeParam, + Metadata: t.Metadata, + celType: cel.ObjectType(name), + traitMask: t.traitMask, + defaultValue: t.defaultValue, + MinSerializedSize: t.MinSerializedSize, + } + } + if t.IsMap() { + elemTypeName := fmt.Sprintf("%s.@elem", name) + updated := t.ElemType.MaybeAssignTypeName(elemTypeName) + if updated == t.ElemType { + return t + } + return NewMapType(t.KeyType, updated, t.MaxElements) + } + if t.IsList() { + elemTypeName := fmt.Sprintf("%s.@idx", name) + updated := t.ElemType.MaybeAssignTypeName(elemTypeName) + if updated == t.ElemType { + return t + } + return NewListType(updated, t.MaxElements) + } + return t +} + +// ExprType returns the CEL expression type of this declaration. +func (t *DeclType) ExprType() (*exprpb.Type, error) { + return cel.TypeToExprType(t.celType) +} + +// CelType returns the CEL type of this declaration. +func (t *DeclType) CelType() *cel.Type { + return t.celType +} + +// FindField returns the DeclField with the given name if present. +func (t *DeclType) FindField(name string) (*DeclField, bool) { + f, found := t.Fields[name] + return f, found +} + +// HasTrait implements the CEL ref.Type interface making this type declaration suitable for use +// within the CEL evaluator. +func (t *DeclType) HasTrait(trait int) bool { + if t.traitMask&trait == trait { + return true + } + if t.defaultValue == nil { + return false + } + _, isDecl := t.defaultValue.Type().(*DeclType) + if isDecl { + return false + } + return t.defaultValue.Type().HasTrait(trait) +} + +// IsList returns whether the declaration is a `list` type which defines a parameterized element +// type, but not a parameterized key type or fields. +func (t *DeclType) IsList() bool { + return t.KeyType == nil && t.ElemType != nil && t.Fields == nil +} + +// IsMap returns whether the declaration is a 'map' type which defines parameterized key and +// element types, but not fields. +func (t *DeclType) IsMap() bool { + return t.KeyType != nil && t.ElemType != nil && t.Fields == nil +} + +// IsObject returns whether the declartion is an 'object' type which defined a set of typed fields. +func (t *DeclType) IsObject() bool { + return t.KeyType == nil && t.ElemType == nil && t.Fields != nil +} + +// String implements the fmt.Stringer interface method. +func (t *DeclType) String() string { + return t.name +} + +// TypeName returns the fully qualified type name for the DeclType. +func (t *DeclType) TypeName() string { + return t.name +} + +// DefaultValue returns the CEL ref.Val representing the default value for this object type, +// if one exists. +func (t *DeclType) DefaultValue() ref.Val { + return t.defaultValue +} + +// FieldTypeMap constructs a map of the field and object types nested within a given type. +func FieldTypeMap(path string, t *DeclType) map[string]*DeclType { + if t.IsObject() && t.TypeName() != "object" { + path = t.TypeName() + } + types := make(map[string]*DeclType) + buildDeclTypes(path, t, types) + return types +} + +func buildDeclTypes(path string, t *DeclType, types map[string]*DeclType) { + // Ensure object types are properly named according to where they appear in the schema. + if t.IsObject() { + // Hack to ensure that names are uniquely qualified and work well with the type + // resolution steps which require fully qualified type names for field resolution + // to function properly. + types[t.TypeName()] = t + for name, field := range t.Fields { + fieldPath := fmt.Sprintf("%s.%s", path, name) + buildDeclTypes(fieldPath, field.Type, types) + } + } + // Map element properties to type names if needed. + if t.IsMap() { + mapElemPath := fmt.Sprintf("%s.@elem", path) + buildDeclTypes(mapElemPath, t.ElemType, types) + types[path] = t + } + // List element properties. + if t.IsList() { + listIdxPath := fmt.Sprintf("%s.@idx", path) + buildDeclTypes(listIdxPath, t.ElemType, types) + types[path] = t + } +} + +// DeclField describes the name, ordinal, and optionality of a field declaration within a type. +type DeclField struct { + Name string + Type *DeclType + Required bool + enumValues []interface{} + defaultValue interface{} +} + +func NewDeclField(name string, declType *DeclType, required bool, enumValues []interface{}, defaultValue interface{}) *DeclField { + return &DeclField{ + Name: name, + Type: declType, + Required: required, + enumValues: enumValues, + defaultValue: defaultValue, + } +} + +// TypeName returns the string type name of the field. +func (f *DeclField) TypeName() string { + return f.Type.TypeName() +} + +// DefaultValue returns the zero value associated with the field. +func (f *DeclField) DefaultValue() ref.Val { + if f.defaultValue != nil { + return types.DefaultTypeAdapter.NativeToValue(f.defaultValue) + } + return f.Type.DefaultValue() +} + +// EnumValues returns the set of values that this field may take. +func (f *DeclField) EnumValues() []ref.Val { + if f.enumValues == nil || len(f.enumValues) == 0 { + return []ref.Val{} + } + ev := make([]ref.Val, len(f.enumValues)) + for i, e := range f.enumValues { + ev[i] = types.DefaultTypeAdapter.NativeToValue(e) + } + return ev +} + +// NewRuleTypes returns an Open API Schema-based type-system which is CEL compatible. +func NewRuleTypes(kind string, + declType *DeclType, + res Resolver) (*RuleTypes, error) { + // Note, if the schema indicates that it's actually based on another proto + // then prefer the proto definition. For expressions in the proto, a new field + // annotation will be needed to indicate the expected environment and type of + // the expression. + schemaTypes, err := newSchemaTypeProvider(kind, declType) + if err != nil { + return nil, err + } + if schemaTypes == nil { + return nil, nil + } + return &RuleTypes{ + ruleSchemaDeclTypes: schemaTypes, + resolver: res, + }, nil +} + +// RuleTypes extends the CEL ref.TypeProvider interface and provides an Open API Schema-based +// type-system. +type RuleTypes struct { + ref.TypeProvider + ruleSchemaDeclTypes *schemaTypeProvider + typeAdapter ref.TypeAdapter + resolver Resolver +} + +// EnvOptions returns a set of cel.EnvOption values which includes the declaration set +// as well as a custom ref.TypeProvider. +// +// Note, the standard declaration set includes 'rule' which is defined as the top-level rule-schema +// type if one is configured. +// +// If the RuleTypes value is nil, an empty []cel.EnvOption set is returned. +func (rt *RuleTypes) EnvOptions(tp ref.TypeProvider) ([]cel.EnvOption, error) { + if rt == nil { + return []cel.EnvOption{}, nil + } + var ta ref.TypeAdapter = types.DefaultTypeAdapter + tpa, ok := tp.(ref.TypeAdapter) + if ok { + ta = tpa + } + rtWithTypes := &RuleTypes{ + TypeProvider: tp, + typeAdapter: ta, + ruleSchemaDeclTypes: rt.ruleSchemaDeclTypes, + resolver: rt.resolver, + } + for name, declType := range rt.ruleSchemaDeclTypes.types { + tpType, found := tp.FindType(name) + expT, err := declType.ExprType() + if err != nil { + return nil, fmt.Errorf("fail to get cel type: %s", err) + } + if found && !proto.Equal(tpType, expT) { + return nil, fmt.Errorf( + "type %s definition differs between CEL environment and rule", name) + } + } + return []cel.EnvOption{ + cel.CustomTypeProvider(rtWithTypes), + cel.CustomTypeAdapter(rtWithTypes), + cel.Variable("rule", rt.ruleSchemaDeclTypes.root.CelType()), + }, nil +} + +// FindType attempts to resolve the typeName provided from the rule's rule-schema, or if not +// from the embedded ref.TypeProvider. +// +// FindType overrides the default type-finding behavior of the embedded TypeProvider. +// +// Note, when the type name is based on the Open API Schema, the name will reflect the object path +// where the type definition appears. +func (rt *RuleTypes) FindType(typeName string) (*exprpb.Type, bool) { + if rt == nil { + return nil, false + } + declType, found := rt.findDeclType(typeName) + if found { + expT, err := declType.ExprType() + if err != nil { + return expT, false + } + return expT, found + } + return rt.TypeProvider.FindType(typeName) +} + +// FindDeclType returns the CPT type description which can be mapped to a CEL type. +func (rt *RuleTypes) FindDeclType(typeName string) (*DeclType, bool) { + if rt == nil { + return nil, false + } + return rt.findDeclType(typeName) +} + +// FindFieldType returns a field type given a type name and field name, if found. +// +// Note, the type name for an Open API Schema type is likely to be its qualified object path. +// If, in the future an object instance rather than a type name were provided, the field +// resolution might more accurately reflect the expected type model. However, in this case +// concessions were made to align with the existing CEL interfaces. +func (rt *RuleTypes) FindFieldType(typeName, fieldName string) (*ref.FieldType, bool) { + st, found := rt.findDeclType(typeName) + if !found { + return rt.TypeProvider.FindFieldType(typeName, fieldName) + } + + f, found := st.Fields[fieldName] + if found { + ft := f.Type + expT, err := ft.ExprType() + if err != nil { + return nil, false + } + return &ref.FieldType{ + Type: expT, + }, true + } + // This could be a dynamic map. + if st.IsMap() { + et := st.ElemType + expT, err := et.ExprType() + if err != nil { + return nil, false + } + return &ref.FieldType{ + Type: expT, + }, true + } + return nil, false +} + +// NativeToValue is an implementation of the ref.TypeAdapater interface which supports conversion +// of rule values to CEL ref.Val instances. +func (rt *RuleTypes) NativeToValue(val interface{}) ref.Val { + return rt.typeAdapter.NativeToValue(val) +} + +// TypeNames returns the list of type names declared within the RuleTypes object. +func (rt *RuleTypes) TypeNames() []string { + typeNames := make([]string, len(rt.ruleSchemaDeclTypes.types)) + i := 0 + for name := range rt.ruleSchemaDeclTypes.types { + typeNames[i] = name + i++ + } + return typeNames +} + +func (rt *RuleTypes) findDeclType(typeName string) (*DeclType, bool) { + declType, found := rt.ruleSchemaDeclTypes.types[typeName] + if found { + return declType, true + } + declType, found = rt.resolver.FindType(typeName) + if found { + return declType, true + } + return nil, false +} + +func newSchemaTypeProvider(kind string, declType *DeclType) (*schemaTypeProvider, error) { + if declType == nil { + return nil, nil + } + root := declType.MaybeAssignTypeName(kind) + types := FieldTypeMap(kind, root) + return &schemaTypeProvider{ + root: root, + types: types, + }, nil +} + +type schemaTypeProvider struct { + root *DeclType + types map[string]*DeclType +} + +var ( + // AnyType is equivalent to the CEL 'protobuf.Any' type in that the value may have any of the + // types supported. + AnyType = NewSimpleTypeWithMinSize("any", cel.AnyType, nil, 1) + + // BoolType is equivalent to the CEL 'bool' type. + BoolType = NewSimpleTypeWithMinSize("bool", cel.BoolType, types.False, MinBoolSize) + + // BytesType is equivalent to the CEL 'bytes' type. + BytesType = NewSimpleTypeWithMinSize("bytes", cel.BytesType, types.Bytes([]byte{}), MinStringSize) + + // DoubleType is equivalent to the CEL 'double' type which is a 64-bit floating point value. + DoubleType = NewSimpleTypeWithMinSize("double", cel.DoubleType, types.Double(0), MinNumberSize) + + // DurationType is equivalent to the CEL 'duration' type. + DurationType = NewSimpleTypeWithMinSize("duration", cel.DurationType, types.Duration{Duration: time.Duration(0)}, MinDurationSizeJSON) + + // DateType is equivalent to the CEL 'date' type. + DateType = NewSimpleTypeWithMinSize("date", cel.TimestampType, types.Timestamp{Time: time.Time{}}, JSONDateSize) + + // DynType is the equivalent of the CEL 'dyn' concept which indicates that the type will be + // determined at runtime rather than compile time. + DynType = NewSimpleTypeWithMinSize("dyn", cel.DynType, nil, 1) + + // IntType is equivalent to the CEL 'int' type which is a 64-bit signed int. + IntType = NewSimpleTypeWithMinSize("int", cel.IntType, types.IntZero, MinNumberSize) + + // NullType is equivalent to the CEL 'null_type'. + NullType = NewSimpleTypeWithMinSize("null_type", cel.NullType, types.NullValue, 4) + + // StringType is equivalent to the CEL 'string' type which is expected to be a UTF-8 string. + // StringType values may either be string literals or expression strings. + StringType = NewSimpleTypeWithMinSize("string", cel.StringType, types.String(""), MinStringSize) + + // TimestampType corresponds to the well-known protobuf.Timestamp type supported within CEL. + // Note that both the OpenAPI date and date-time types map onto TimestampType, so not all types + // labeled as Timestamp will necessarily have the same MinSerializedSize. + TimestampType = NewSimpleTypeWithMinSize("timestamp", cel.TimestampType, types.Timestamp{Time: time.Time{}}, JSONDateSize) + + // UintType is equivalent to the CEL 'uint' type. + UintType = NewSimpleTypeWithMinSize("uint", cel.UintType, types.Uint(0), 1) + + // ListType is equivalent to the CEL 'list' type. + ListType = NewListType(AnyType, noMaxLength) + + // MapType is equivalent to the CEL 'map' type. + MapType = NewMapType(AnyType, AnyType, noMaxLength) +) diff --git a/pkg/cel/types_test.go b/pkg/cel/types_test.go new file mode 100644 index 000000000..fef500b53 --- /dev/null +++ b/pkg/cel/types_test.go @@ -0,0 +1,79 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cel + +import ( + "testing" +) + +func TestTypes_ListType(t *testing.T) { + list := NewListType(StringType, -1) + if !list.IsList() { + t.Error("list type not identifiable as list") + } + if list.TypeName() != "list" { + t.Errorf("got %s, wanted list", list.TypeName()) + } + if list.DefaultValue() == nil { + t.Error("got nil zero value for list type") + } + if list.ElemType.TypeName() != "string" { + t.Errorf("got %s, wanted elem type of string", list.ElemType.TypeName()) + } + expT, err := list.ExprType() + if err != nil { + t.Errorf("fail to get cel type: %s", err) + } + if expT.GetListType() == nil { + t.Errorf("got %v, wanted CEL list type", expT) + } +} + +func TestTypes_MapType(t *testing.T) { + mp := NewMapType(StringType, IntType, -1) + if !mp.IsMap() { + t.Error("map type not identifiable as map") + } + if mp.TypeName() != "map" { + t.Errorf("got %s, wanted map", mp.TypeName()) + } + if mp.DefaultValue() == nil { + t.Error("got nil zero value for map type") + } + if mp.KeyType.TypeName() != "string" { + t.Errorf("got %s, wanted key type of string", mp.KeyType.TypeName()) + } + if mp.ElemType.TypeName() != "int" { + t.Errorf("got %s, wanted elem type of int", mp.ElemType.TypeName()) + } + expT, err := mp.ExprType() + if err != nil { + t.Errorf("fail to get cel type: %s", err) + } + if expT.GetMapType() == nil { + t.Errorf("got %v, wanted CEL map type", expT) + } +} + +func testValue(t *testing.T, id int64, val interface{}) *DynValue { + t.Helper() + dv, err := NewDynValue(id, val) + if err != nil { + t.Fatalf("NewDynValue(%d, %v) failed: %v", id, val, err) + } + return dv +} diff --git a/pkg/cel/url.go b/pkg/cel/url.go new file mode 100644 index 000000000..6800205c9 --- /dev/null +++ b/pkg/cel/url.go @@ -0,0 +1,80 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cel + +import ( + "fmt" + "net/url" + "reflect" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/checker/decls" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" +) + +// URL provides a CEL representation of a URL. +type URL struct { + *url.URL +} + +var ( + URLObject = decls.NewObjectType("kubernetes.URL") + typeValue = types.NewTypeValue("kubernetes.URL") + URLType = cel.ObjectType("kubernetes.URL") +) + +// ConvertToNative implements ref.Val.ConvertToNative. +func (d URL) ConvertToNative(typeDesc reflect.Type) (interface{}, error) { + if reflect.TypeOf(d.URL).AssignableTo(typeDesc) { + return d.URL, nil + } + if reflect.TypeOf("").AssignableTo(typeDesc) { + return d.URL.String(), nil + } + return nil, fmt.Errorf("type conversion error from 'URL' to '%v'", typeDesc) +} + +// ConvertToType implements ref.Val.ConvertToType. +func (d URL) ConvertToType(typeVal ref.Type) ref.Val { + switch typeVal { + case typeValue: + return d + case types.TypeType: + return typeValue + } + return types.NewErr("type conversion error from '%s' to '%s'", typeValue, typeVal) +} + +// Equal implements ref.Val.Equal. +func (d URL) Equal(other ref.Val) ref.Val { + otherDur, ok := other.(URL) + if !ok { + return types.MaybeNoSuchOverloadErr(other) + } + return types.Bool(d.URL.String() == otherDur.URL.String()) +} + +// Type implements ref.Val.Type. +func (d URL) Type() ref.Type { + return typeValue +} + +// Value implements ref.Val.Value. +func (d URL) Value() interface{} { + return d.URL +} diff --git a/pkg/cel/value.go b/pkg/cel/value.go new file mode 100644 index 000000000..01c7f20ac --- /dev/null +++ b/pkg/cel/value.go @@ -0,0 +1,769 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cel + +import ( + "fmt" + "reflect" + "sync" + "time" + + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" +) + +// EncodeStyle is a hint for string encoding of parsed values. +type EncodeStyle int + +const ( + // BlockValueStyle is the default string encoding which preserves whitespace and newlines. + BlockValueStyle EncodeStyle = iota + + // FlowValueStyle indicates that the string is an inline representation of complex types. + FlowValueStyle + + // FoldedValueStyle is a multiline string with whitespace and newlines trimmed to a single + // a whitespace. Repeated newlines are replaced with a single newline rather than a single + // whitespace. + FoldedValueStyle + + // LiteralStyle is a multiline string that preserves newlines, but trims all other whitespace + // to a single character. + LiteralStyle +) + +// NewEmptyDynValue returns the zero-valued DynValue. +func NewEmptyDynValue() *DynValue { + // note: 0 is not a valid parse node identifier. + dv, _ := NewDynValue(0, nil) + return dv +} + +// NewDynValue returns a DynValue that corresponds to a parse node id and value. +func NewDynValue(id int64, val interface{}) (*DynValue, error) { + dv := &DynValue{ID: id} + err := dv.SetValue(val) + return dv, err +} + +// DynValue is a dynamically typed value used to describe unstructured content. +// Whether the value has the desired type is determined by where it is used within the Instance or +// Template, and whether there are schemas which might enforce a more rigid type definition. +type DynValue struct { + ID int64 + EncodeStyle EncodeStyle + value interface{} + exprValue ref.Val + declType *DeclType +} + +// DeclType returns the policy model type of the dyn value. +func (dv *DynValue) DeclType() *DeclType { + return dv.declType +} + +// ConvertToNative is an implementation of the CEL ref.Val method used to adapt between CEL types +// and Go-native types. +// +// The default behavior of this method is to first convert to a CEL type which has a well-defined +// set of conversion behaviors and proxy to the CEL ConvertToNative method for the type. +func (dv *DynValue) ConvertToNative(typeDesc reflect.Type) (interface{}, error) { + ev := dv.ExprValue() + if types.IsError(ev) { + return nil, ev.(*types.Err) + } + return ev.ConvertToNative(typeDesc) +} + +// Equal returns whether the dyn value is equal to a given CEL value. +func (dv *DynValue) Equal(other ref.Val) ref.Val { + dvType := dv.Type() + otherType := other.Type() + // Preserve CEL's homogeneous equality constraint. + if dvType.TypeName() != otherType.TypeName() { + return types.MaybeNoSuchOverloadErr(other) + } + switch v := dv.value.(type) { + case ref.Val: + return v.Equal(other) + case PlainTextValue: + return celBool(string(v) == other.Value().(string)) + case *MultilineStringValue: + return celBool(v.Value == other.Value().(string)) + case time.Duration: + otherDuration := other.Value().(time.Duration) + return celBool(v == otherDuration) + case time.Time: + otherTimestamp := other.Value().(time.Time) + return celBool(v.Equal(otherTimestamp)) + default: + return celBool(reflect.DeepEqual(v, other.Value())) + } +} + +// ExprValue converts the DynValue into a CEL value. +func (dv *DynValue) ExprValue() ref.Val { + return dv.exprValue +} + +// Value returns the underlying value held by this reference. +func (dv *DynValue) Value() interface{} { + return dv.value +} + +// SetValue updates the underlying value held by this reference. +func (dv *DynValue) SetValue(value interface{}) error { + dv.value = value + var err error + dv.exprValue, dv.declType, err = exprValue(value) + return err +} + +// Type returns the CEL type for the given value. +func (dv *DynValue) Type() ref.Type { + return dv.ExprValue().Type() +} + +func exprValue(value interface{}) (ref.Val, *DeclType, error) { + switch v := value.(type) { + case bool: + return types.Bool(v), BoolType, nil + case []byte: + return types.Bytes(v), BytesType, nil + case float64: + return types.Double(v), DoubleType, nil + case int64: + return types.Int(v), IntType, nil + case string: + return types.String(v), StringType, nil + case uint64: + return types.Uint(v), UintType, nil + case time.Duration: + return types.Duration{Duration: v}, DurationType, nil + case time.Time: + return types.Timestamp{Time: v}, TimestampType, nil + case types.Null: + return v, NullType, nil + case *ListValue: + return v, ListType, nil + case *MapValue: + return v, MapType, nil + case *ObjectValue: + return v, v.objectType, nil + default: + return nil, unknownType, fmt.Errorf("unsupported type: (%T)%v", v, v) + } +} + +// PlainTextValue is a text string literal which must not be treated as an expression. +type PlainTextValue string + +// MultilineStringValue is a multiline string value which has been parsed in a way which omits +// whitespace as well as a raw form which preserves whitespace. +type MultilineStringValue struct { + Value string + Raw string +} + +func newStructValue() *structValue { + return &structValue{ + Fields: []*Field{}, + fieldMap: map[string]*Field{}, + } +} + +type structValue struct { + Fields []*Field + fieldMap map[string]*Field +} + +// AddField appends a MapField to the MapValue and indexes the field by name. +func (sv *structValue) AddField(field *Field) { + sv.Fields = append(sv.Fields, field) + sv.fieldMap[field.Name] = field +} + +// ConvertToNative converts the MapValue type to a native go types. +func (sv *structValue) ConvertToNative(typeDesc reflect.Type) (interface{}, error) { + if typeDesc.Kind() != reflect.Map && + typeDesc.Kind() != reflect.Struct && + typeDesc.Kind() != reflect.Pointer && + typeDesc.Kind() != reflect.Interface { + return nil, fmt.Errorf("type conversion error from object to '%v'", typeDesc) + } + + // Unwrap pointers, but track their use. + isPtr := false + if typeDesc.Kind() == reflect.Pointer { + tk := typeDesc + typeDesc = typeDesc.Elem() + if typeDesc.Kind() == reflect.Pointer { + return nil, fmt.Errorf("unsupported type conversion to '%v'", tk) + } + isPtr = true + } + + if typeDesc.Kind() == reflect.Map { + keyType := typeDesc.Key() + if keyType.Kind() != reflect.String && keyType.Kind() != reflect.Interface { + return nil, fmt.Errorf("object fields cannot be converted to type '%v'", keyType) + } + elemType := typeDesc.Elem() + sz := len(sv.fieldMap) + ntvMap := reflect.MakeMapWithSize(typeDesc, sz) + for name, val := range sv.fieldMap { + refVal, err := val.Ref.ConvertToNative(elemType) + if err != nil { + return nil, err + } + ntvMap.SetMapIndex(reflect.ValueOf(name), reflect.ValueOf(refVal)) + } + return ntvMap.Interface(), nil + } + + if typeDesc.Kind() == reflect.Struct { + ntvObjPtr := reflect.New(typeDesc) + ntvObj := ntvObjPtr.Elem() + for name, val := range sv.fieldMap { + f := ntvObj.FieldByName(name) + if !f.IsValid() { + return nil, fmt.Errorf("type conversion error, no such field %s in type %v", + name, typeDesc) + } + fv, err := val.Ref.ConvertToNative(f.Type()) + if err != nil { + return nil, err + } + f.Set(reflect.ValueOf(fv)) + } + if isPtr { + return ntvObjPtr.Interface(), nil + } + return ntvObj.Interface(), nil + } + return nil, fmt.Errorf("type conversion error from object to '%v'", typeDesc) +} + +// GetField returns a MapField by name if one exists. +func (sv *structValue) GetField(name string) (*Field, bool) { + field, found := sv.fieldMap[name] + return field, found +} + +// IsSet returns whether the given field, which is defined, has also been set. +func (sv *structValue) IsSet(key ref.Val) ref.Val { + k, ok := key.(types.String) + if !ok { + return types.MaybeNoSuchOverloadErr(key) + } + name := string(k) + _, found := sv.fieldMap[name] + return celBool(found) +} + +// NewObjectValue creates a struct value with a schema type and returns the empty ObjectValue. +func NewObjectValue(sType *DeclType) *ObjectValue { + return &ObjectValue{ + structValue: newStructValue(), + objectType: sType, + } +} + +// ObjectValue is a struct with a custom schema type which indicates the fields and types +// associated with the structure. +type ObjectValue struct { + *structValue + objectType *DeclType +} + +// ConvertToType is an implementation of the CEL ref.Val interface method. +func (o *ObjectValue) ConvertToType(t ref.Type) ref.Val { + if t == types.TypeType { + return types.NewObjectTypeValue(o.objectType.TypeName()) + } + if t.TypeName() == o.objectType.TypeName() { + return o + } + return types.NewErr("type conversion error from '%s' to '%s'", o.Type(), t) +} + +// Equal returns true if the two object types are equal and their field values are equal. +func (o *ObjectValue) Equal(other ref.Val) ref.Val { + // Preserve CEL's homogeneous equality semantics. + if o.objectType.TypeName() != other.Type().TypeName() { + return types.MaybeNoSuchOverloadErr(other) + } + o2 := other.(traits.Indexer) + for name := range o.objectType.Fields { + k := types.String(name) + v := o.Get(k) + ov := o2.Get(k) + vEq := v.Equal(ov) + if vEq != types.True { + return vEq + } + } + return types.True +} + +// Get returns the value of the specified field. +// +// If the field is set, its value is returned. If the field is not set, the default value for the +// field is returned thus allowing for safe-traversal and preserving proto-like field traversal +// semantics for Open API Schema backed types. +func (o *ObjectValue) Get(name ref.Val) ref.Val { + n, ok := name.(types.String) + if !ok { + return types.MaybeNoSuchOverloadErr(n) + } + nameStr := string(n) + field, found := o.fieldMap[nameStr] + if found { + return field.Ref.ExprValue() + } + fieldDef, found := o.objectType.Fields[nameStr] + if !found { + return types.NewErr("no such field: %s", nameStr) + } + defValue := fieldDef.DefaultValue() + if defValue != nil { + return defValue + } + return types.NewErr("no default for type: %s", fieldDef.TypeName()) +} + +// Type returns the CEL type value of the object. +func (o *ObjectValue) Type() ref.Type { + return o.objectType +} + +// Value returns the Go-native representation of the object. +func (o *ObjectValue) Value() interface{} { + return o +} + +// NewMapValue returns an empty MapValue. +func NewMapValue() *MapValue { + return &MapValue{ + structValue: newStructValue(), + } +} + +// MapValue declares an object with a set of named fields whose values are dynamically typed. +type MapValue struct { + *structValue +} + +// ConvertToObject produces an ObjectValue from the MapValue with the associated schema type. +// +// The conversion is shallow and the memory shared between the Object and Map as all references +// to the map are expected to be replaced with the Object reference. +func (m *MapValue) ConvertToObject(declType *DeclType) *ObjectValue { + return &ObjectValue{ + structValue: m.structValue, + objectType: declType, + } +} + +// Contains returns whether the given key is contained in the MapValue. +func (m *MapValue) Contains(key ref.Val) ref.Val { + v, found := m.Find(key) + if v != nil && types.IsUnknownOrError(v) { + return v + } + return celBool(found) +} + +// ConvertToType converts the MapValue to another CEL type, if possible. +func (m *MapValue) ConvertToType(t ref.Type) ref.Val { + switch t { + case types.MapType: + return m + case types.TypeType: + return types.MapType + } + return types.NewErr("type conversion error from '%s' to '%s'", m.Type(), t) +} + +// Equal returns true if the maps are of the same size, have the same keys, and the key-values +// from each map are equal. +func (m *MapValue) Equal(other ref.Val) ref.Val { + oMap, isMap := other.(traits.Mapper) + if !isMap { + return types.MaybeNoSuchOverloadErr(other) + } + if m.Size() != oMap.Size() { + return types.False + } + for name, field := range m.fieldMap { + k := types.String(name) + ov, found := oMap.Find(k) + if !found { + return types.False + } + v := field.Ref.ExprValue() + vEq := v.Equal(ov) + if vEq != types.True { + return vEq + } + } + return types.True +} + +// Find returns the value for the key in the map, if found. +func (m *MapValue) Find(name ref.Val) (ref.Val, bool) { + // Currently only maps with string keys are supported as this is best aligned with JSON, + // and also much simpler to support. + n, ok := name.(types.String) + if !ok { + return types.MaybeNoSuchOverloadErr(n), true + } + nameStr := string(n) + field, found := m.fieldMap[nameStr] + if found { + return field.Ref.ExprValue(), true + } + return nil, false +} + +// Get returns the value for the key in the map, or error if not found. +func (m *MapValue) Get(key ref.Val) ref.Val { + v, found := m.Find(key) + if found { + return v + } + return types.ValOrErr(key, "no such key: %v", key) +} + +// Iterator produces a traits.Iterator which walks over the map keys. +// +// The Iterator is frequently used within comprehensions. +func (m *MapValue) Iterator() traits.Iterator { + keys := make([]ref.Val, len(m.fieldMap)) + i := 0 + for k := range m.fieldMap { + keys[i] = types.String(k) + i++ + } + return &baseMapIterator{ + baseVal: &baseVal{}, + keys: keys, + } +} + +// Size returns the number of keys in the map. +func (m *MapValue) Size() ref.Val { + return types.Int(len(m.Fields)) +} + +// Type returns the CEL ref.Type for the map. +func (m *MapValue) Type() ref.Type { + return types.MapType +} + +// Value returns the Go-native representation of the MapValue. +func (m *MapValue) Value() interface{} { + return m +} + +type baseMapIterator struct { + *baseVal + keys []ref.Val + idx int +} + +// HasNext implements the traits.Iterator interface method. +func (it *baseMapIterator) HasNext() ref.Val { + if it.idx < len(it.keys) { + return types.True + } + return types.False +} + +// Next implements the traits.Iterator interface method. +func (it *baseMapIterator) Next() ref.Val { + key := it.keys[it.idx] + it.idx++ + return key +} + +// Type implements the CEL ref.Val interface metohd. +func (it *baseMapIterator) Type() ref.Type { + return types.IteratorType +} + +// NewField returns a MapField instance with an empty DynValue that refers to the +// specified parse node id and field name. +func NewField(id int64, name string) *Field { + return &Field{ + ID: id, + Name: name, + Ref: NewEmptyDynValue(), + } +} + +// Field specifies a field name and a reference to a dynamic value. +type Field struct { + ID int64 + Name string + Ref *DynValue +} + +// NewListValue returns an empty ListValue instance. +func NewListValue() *ListValue { + return &ListValue{ + Entries: []*DynValue{}, + } +} + +// ListValue contains a list of dynamically typed entries. +type ListValue struct { + Entries []*DynValue + initValueSet sync.Once + valueSet map[ref.Val]struct{} +} + +// Add concatenates two lists together to produce a new CEL list value. +func (lv *ListValue) Add(other ref.Val) ref.Val { + oArr, isArr := other.(traits.Lister) + if !isArr { + return types.MaybeNoSuchOverloadErr(other) + } + szRight := len(lv.Entries) + szLeft := int(oArr.Size().(types.Int)) + sz := szRight + szLeft + combo := make([]ref.Val, sz) + for i := 0; i < szRight; i++ { + combo[i] = lv.Entries[i].ExprValue() + } + for i := 0; i < szLeft; i++ { + combo[i+szRight] = oArr.Get(types.Int(i)) + } + return types.DefaultTypeAdapter.NativeToValue(combo) +} + +// Append adds another entry into the ListValue. +func (lv *ListValue) Append(entry *DynValue) { + lv.Entries = append(lv.Entries, entry) + // The append resets all previously built indices. + lv.initValueSet = sync.Once{} +} + +// Contains returns whether the input `val` is equal to an element in the list. +// +// If any pair-wise comparison between the input value and the list element is an error, the +// operation will return an error. +func (lv *ListValue) Contains(val ref.Val) ref.Val { + if types.IsUnknownOrError(val) { + return val + } + lv.initValueSet.Do(lv.finalizeValueSet) + if lv.valueSet != nil { + _, found := lv.valueSet[val] + if found { + return types.True + } + // Instead of returning false, ensure that CEL's heterogeneous equality constraint + // is satisfied by allowing pair-wise equality behavior to determine the outcome. + } + var err ref.Val + sz := len(lv.Entries) + for i := 0; i < sz; i++ { + elem := lv.Entries[i] + cmp := elem.Equal(val) + b, ok := cmp.(types.Bool) + if !ok && err == nil { + err = types.MaybeNoSuchOverloadErr(cmp) + } + if b == types.True { + return types.True + } + } + if err != nil { + return err + } + return types.False +} + +// ConvertToNative is an implementation of the CEL ref.Val method used to adapt between CEL types +// and Go-native array-like types. +func (lv *ListValue) ConvertToNative(typeDesc reflect.Type) (interface{}, error) { + // Non-list conversion. + if typeDesc.Kind() != reflect.Slice && + typeDesc.Kind() != reflect.Array && + typeDesc.Kind() != reflect.Interface { + return nil, fmt.Errorf("type conversion error from list to '%v'", typeDesc) + } + + // If the list is already assignable to the desired type return it. + if reflect.TypeOf(lv).AssignableTo(typeDesc) { + return lv, nil + } + + // List conversion. + otherElem := typeDesc.Elem() + + // Allow the element ConvertToNative() function to determine whether conversion is possible. + sz := len(lv.Entries) + nativeList := reflect.MakeSlice(typeDesc, int(sz), int(sz)) + for i := 0; i < sz; i++ { + elem := lv.Entries[i] + nativeElemVal, err := elem.ConvertToNative(otherElem) + if err != nil { + return nil, err + } + nativeList.Index(int(i)).Set(reflect.ValueOf(nativeElemVal)) + } + return nativeList.Interface(), nil +} + +// ConvertToType converts the ListValue to another CEL type. +func (lv *ListValue) ConvertToType(t ref.Type) ref.Val { + switch t { + case types.ListType: + return lv + case types.TypeType: + return types.ListType + } + return types.NewErr("type conversion error from '%s' to '%s'", ListType, t) +} + +// Equal returns true if two lists are of the same size, and the values at each index are also +// equal. +func (lv *ListValue) Equal(other ref.Val) ref.Val { + oArr, isArr := other.(traits.Lister) + if !isArr { + return types.MaybeNoSuchOverloadErr(other) + } + sz := types.Int(len(lv.Entries)) + if sz != oArr.Size() { + return types.False + } + for i := types.Int(0); i < sz; i++ { + cmp := lv.Get(i).Equal(oArr.Get(i)) + if cmp != types.True { + return cmp + } + } + return types.True +} + +// Get returns the value at the given index. +// +// If the index is negative or greater than the size of the list, an error is returned. +func (lv *ListValue) Get(idx ref.Val) ref.Val { + iv, isInt := idx.(types.Int) + if !isInt { + return types.ValOrErr(idx, "unsupported index: %v", idx) + } + i := int(iv) + if i < 0 || i >= len(lv.Entries) { + return types.NewErr("index out of bounds: %v", idx) + } + return lv.Entries[i].ExprValue() +} + +// Iterator produces a traits.Iterator suitable for use in CEL comprehension macros. +func (lv *ListValue) Iterator() traits.Iterator { + return &baseListIterator{ + getter: lv.Get, + sz: len(lv.Entries), + } +} + +// Size returns the number of elements in the list. +func (lv *ListValue) Size() ref.Val { + return types.Int(len(lv.Entries)) +} + +// Type returns the CEL ref.Type for the list. +func (lv *ListValue) Type() ref.Type { + return types.ListType +} + +// Value returns the Go-native value. +func (lv *ListValue) Value() interface{} { + return lv +} + +// finalizeValueSet inspects the ListValue entries in order to make internal optimizations once all list +// entries are known. +func (lv *ListValue) finalizeValueSet() { + valueSet := make(map[ref.Val]struct{}) + for _, e := range lv.Entries { + switch e.value.(type) { + case bool, float64, int64, string, uint64, types.Null, PlainTextValue: + valueSet[e.ExprValue()] = struct{}{} + default: + lv.valueSet = nil + return + } + } + lv.valueSet = valueSet +} + +type baseVal struct{} + +func (*baseVal) ConvertToNative(typeDesc reflect.Type) (interface{}, error) { + return nil, fmt.Errorf("unsupported native conversion to: %v", typeDesc) +} + +func (*baseVal) ConvertToType(t ref.Type) ref.Val { + return types.NewErr("unsupported type conversion to: %v", t) +} + +func (*baseVal) Equal(other ref.Val) ref.Val { + return types.NewErr("unsupported equality test between instances") +} + +func (v *baseVal) Value() interface{} { + return nil +} + +type baseListIterator struct { + *baseVal + getter func(idx ref.Val) ref.Val + sz int + idx int +} + +func (it *baseListIterator) HasNext() ref.Val { + if it.idx < it.sz { + return types.True + } + return types.False +} + +func (it *baseListIterator) Next() ref.Val { + v := it.getter(types.Int(it.idx)) + it.idx++ + return v +} + +func (it *baseListIterator) Type() ref.Type { + return types.IteratorType +} + +func celBool(pred bool) ref.Val { + if pred { + return types.True + } + return types.False +} + +var unknownType = &DeclType{name: "unknown", MinSerializedSize: 1} diff --git a/pkg/cel/value_test.go b/pkg/cel/value_test.go new file mode 100644 index 000000000..84d83bcee --- /dev/null +++ b/pkg/cel/value_test.go @@ -0,0 +1,362 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cel + +import ( + "fmt" + "reflect" + "testing" + "time" + + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" +) + +func TestConvertToType(t *testing.T) { + objType := NewObjectType("TestObject", map[string]*DeclField{}) + tests := []struct { + val interface{} + typ ref.Type + }{ + {true, types.BoolType}, + {float64(1.2), types.DoubleType}, + {int64(-42), types.IntType}, + {uint64(63), types.UintType}, + {time.Duration(300), types.DurationType}, + {time.Now().UTC(), types.TimestampType}, + {types.NullValue, types.NullType}, + {NewListValue(), types.ListType}, + {NewMapValue(), types.MapType}, + {[]byte("bytes"), types.BytesType}, + {NewObjectValue(objType), objType}, + } + for i, tc := range tests { + idx := i + tst := tc + t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) { + dv := testValue(t, int64(idx), tst.val) + ev := dv.ExprValue() + if ev.ConvertToType(types.TypeType).(ref.Type).TypeName() != tst.typ.TypeName() { + t.Errorf("got %v, wanted %v type", ev.ConvertToType(types.TypeType), tst.typ) + } + if ev.ConvertToType(tst.typ).Equal(ev) != types.True { + t.Errorf("got %v, wanted input value %v", ev.ConvertToType(tst.typ), ev) + } + }) + } +} + +func TestEqual(t *testing.T) { + vals := []interface{}{ + true, []byte("bytes"), float64(1.2), int64(-42), uint64(63), time.Duration(300), + time.Now().UTC(), types.NullValue, NewListValue(), NewMapValue(), + NewObjectValue(NewObjectType("TestObject", map[string]*DeclField{})), + } + for i, v := range vals { + dv := testValue(t, int64(i), v) + if dv.Equal(dv.ExprValue()) != types.True { + t.Errorf("got %v, wanted dyn value %v equal to itself", dv.Equal(dv.ExprValue()), dv.ExprValue()) + } + } +} + +func TestListValueAdd(t *testing.T) { + lv := NewListValue() + lv.Append(testValue(t, 1, "first")) + ov := NewListValue() + ov.Append(testValue(t, 2, "second")) + ov.Append(testValue(t, 3, "third")) + llv := NewListValue() + llv.Append(testValue(t, 4, lv)) + lov := NewListValue() + lov.Append(testValue(t, 5, ov)) + var v traits.Lister = llv.Add(lov).(traits.Lister) + if v.Size() != types.Int(2) { + t.Errorf("got list size %d, wanted 2", v.Size()) + } + complex, err := v.ConvertToNative(reflect.TypeOf([][]string{})) + complexList := complex.([][]string) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(complexList, [][]string{{"first"}, {"second", "third"}}) { + t.Errorf("got %v, wanted [['first'], ['second', 'third']]", complexList) + } +} + +func TestListValueContains(t *testing.T) { + lv := NewListValue() + lv.Append(testValue(t, 1, "first")) + lv.Append(testValue(t, 2, "second")) + lv.Append(testValue(t, 3, "third")) + for i := types.Int(0); i < lv.Size().(types.Int); i++ { + e := lv.Get(i) + contained := lv.Contains(e) + if contained != types.True { + t.Errorf("got %v, wanted list contains elem[%v] %v == true", contained, i, e) + } + } + if lv.Contains(types.String("fourth")) != types.False { + t.Errorf("got %v, wanted false 'fourth'", lv.Contains(types.String("fourth"))) + } + if !types.IsError(lv.Contains(types.Int(-1))) { + t.Errorf("got %v, wanted error for invalid type", lv.Contains(types.Int(-1))) + } +} + +func TestListValueContainsNestedList(t *testing.T) { + lvA := NewListValue() + lvA.Append(testValue(t, 1, int64(1))) + lvA.Append(testValue(t, 2, int64(2))) + + lvB := NewListValue() + lvB.Append(testValue(t, 3, int64(3))) + + elemA, elemB := testValue(t, 4, lvA), testValue(t, 5, lvB) + lv := NewListValue() + lv.Append(elemA) + lv.Append(elemB) + + contained := lv.Contains(elemA.ExprValue()) + if contained != types.True { + t.Errorf("got %v, wanted elemA contained in list value", contained) + } + contained = lv.Contains(elemB.ExprValue()) + if contained != types.True { + t.Errorf("got %v, wanted elemB contained in list value", contained) + } + contained = lv.Contains(types.DefaultTypeAdapter.NativeToValue([]int32{4})) + if contained != types.False { + t.Errorf("got %v, wanted empty list not contained", contained) + } +} + +func TestListValueConvertToNative(t *testing.T) { + lv := NewListValue() + none, err := lv.ConvertToNative(reflect.TypeOf([]interface{}{})) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(none, []interface{}{}) { + t.Errorf("got %v, wanted empty list", none) + } + lv.Append(testValue(t, 1, "first")) + one, err := lv.ConvertToNative(reflect.TypeOf([]string{})) + oneList := one.([]string) + if err != nil { + t.Fatal(err) + } + if len(oneList) != 1 { + t.Errorf("got len(one) == %d, wanted 1", len(oneList)) + } + if !reflect.DeepEqual(oneList, []string{"first"}) { + t.Errorf("got %v, wanted string list", oneList) + } + ov := NewListValue() + ov.Append(testValue(t, 2, "second")) + ov.Append(testValue(t, 3, "third")) + if ov.Size() != types.Int(2) { + t.Errorf("got list size %d, wanted 2", ov.Size()) + } + llv := NewListValue() + llv.Append(testValue(t, 4, lv)) + llv.Append(testValue(t, 5, ov)) + if llv.Size() != types.Int(2) { + t.Errorf("got list size %d, wanted 2", llv.Size()) + } + complex, err := llv.ConvertToNative(reflect.TypeOf([][]string{})) + complexList := complex.([][]string) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(complexList, [][]string{{"first"}, {"second", "third"}}) { + t.Errorf("got %v, wanted [['first'], ['second', 'third']]", complexList) + } +} + +func TestListValueIterator(t *testing.T) { + lv := NewListValue() + lv.Append(testValue(t, 1, "first")) + lv.Append(testValue(t, 2, "second")) + lv.Append(testValue(t, 3, "third")) + it := lv.Iterator() + if it.Type() != types.IteratorType { + t.Errorf("got type %v for iterator, wanted IteratorType", it.Type()) + } + i := types.Int(0) + for it.HasNext() == types.True { + v := it.Next() + if v.Equal(lv.Get(i)) != types.True { + t.Errorf("iterator value %v and value %v at index %d not equal", v, lv.Get(i), i) + } + i++ + } +} + +func TestMapValueConvertToNative(t *testing.T) { + mv := NewMapValue() + none, err := mv.ConvertToNative(reflect.TypeOf(map[string]interface{}{})) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(none, map[string]interface{}{}) { + t.Errorf("got %v, wanted empty map", none) + } + none, err = mv.ConvertToNative(reflect.TypeOf(map[interface{}]interface{}{})) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(none, map[interface{}]interface{}{}) { + t.Errorf("got %v, wanted empty map", none) + } + mv.AddField(NewField(1, "Test")) + tst, _ := mv.GetField("Test") + tst.Ref = testValue(t, 2, uint64(12)) + mv.AddField(NewField(3, "Check")) + chk, _ := mv.GetField("Check") + chk.Ref = testValue(t, 4, uint64(34)) + if mv.Size() != types.Int(2) { + t.Errorf("got size %d, wanted 2", mv.Size()) + } + if mv.Contains(types.String("Test")) != types.True { + t.Error("key 'Test' not found") + } + if mv.Contains(types.String("Check")) != types.True { + t.Error("key 'Check' not found") + } + if mv.Contains(types.String("Checked")) != types.False { + t.Error("key 'Checked' found, wanted not found") + } + it := mv.Iterator() + for it.HasNext() == types.True { + k := it.Next() + v := mv.Get(k) + if k == types.String("Test") && v != types.Uint(12) { + t.Errorf("key 'Test' not equal to 12u") + } + if k == types.String("Check") && v != types.Uint(34) { + t.Errorf("key 'Check' not equal to 34u") + } + } + mpStrUint, err := mv.ConvertToNative(reflect.TypeOf(map[string]uint64{})) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(mpStrUint, map[string]uint64{ + "Test": uint64(12), + "Check": uint64(34), + }) { + t.Errorf("got %v, wanted {'Test': 12u, 'Check': 34u}", mpStrUint) + } + tstStr, err := mv.ConvertToNative(reflect.TypeOf(&tstStruct{})) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(tstStr, &tstStruct{ + Test: uint64(12), + Check: uint64(34), + }) { + t.Errorf("got %v, wanted tstStruct{Test: 12u, Check: 34u}", tstStr) + } +} + +func TestMapValueEqual(t *testing.T) { + mv := NewMapValue() + name := NewField(1, "name") + name.Ref = testValue(t, 2, "alert") + priority := NewField(3, "priority") + priority.Ref = testValue(t, 4, int64(4)) + mv.AddField(name) + mv.AddField(priority) + if mv.Equal(mv) != types.True { + t.Fatalf("map.Equal(map) failed: %v", mv.Equal(mv)) + } +} + +func TestMapValueNotEqual(t *testing.T) { + mv := NewMapValue() + name := NewField(1, "name") + name.Ref = testValue(t, 2, "alert") + priority := NewField(3, "priority") + priority.Ref = testValue(t, 4, int64(4)) + mv.AddField(name) + mv.AddField(priority) + + mv2 := NewMapValue() + mv2.AddField(name) + if mv.Equal(mv2) != types.False { + t.Fatalf("mv.Equal(mv2) failed: %v", mv.Equal(mv2)) + } + + priority2 := NewField(5, "priority") + priority2.Ref = testValue(t, 6, int64(3)) + mv2.AddField(priority2) + if mv.Equal(mv2) != types.False { + t.Fatalf("mv.Equal(mv2) failed: %v", mv.Equal(mv2)) + } +} + +func TestMapValueIsSet(t *testing.T) { + mv := NewMapValue() + if mv.IsSet(types.String("name")) != types.False { + t.Error("map.IsSet('name') returned true for unset key") + } + mv.AddField(NewField(1, "name")) + if mv.IsSet(types.String("name")) != types.True { + t.Error("map.IsSet('name') returned false for a set key") + } +} + +func TestObjectValueEqual(t *testing.T) { + objType := NewObjectType("Notice", map[string]*DeclField{ + "name": {Name: "name", Type: StringType}, + "priority": {Name: "priority", Type: IntType}, + "message": {Name: "message", Type: StringType, defaultValue: ""}, + }) + name := NewField(1, "name") + name.Ref = testValue(t, 2, "alert") + priority := NewField(3, "priority") + priority.Ref = testValue(t, 4, int64(4)) + message := NewField(5, "message") + message.Ref = testValue(t, 6, "call immediately") + + mv1 := NewMapValue() + mv1.AddField(name) + mv1.AddField(priority) + obj1 := mv1.ConvertToObject(objType) + if obj1.Equal(obj1) != types.True { + t.Errorf("obj1.Equal(obj1) failed, got: %v", obj1.Equal(obj1)) + } + + mv2 := NewMapValue() + mv2.AddField(name) + mv2.AddField(priority) + mv2.AddField(message) + obj2 := mv2.ConvertToObject(objType) + if obj1.Equal(obj2) == types.True { + t.Error("obj1.Equal(obj2) returned true, wanted false") + } + if obj2.Equal(obj1) == types.True { + t.Error("obj2.Equal(obj1) returned true, wanted false") + } +} + +type tstStruct struct { + Test uint64 + Check uint64 +} From 56d541647490f5ce38d5af3df78c76788c314bdd Mon Sep 17 00:00:00 2001 From: Jiahui Feng Date: Fri, 7 Oct 2022 15:36:19 -0700 Subject: [PATCH 2/3] generated: ./hack/update-vendor.sh Kubernetes-commit: 5b8a5b37d5b2031f5733396613d781f8967c25ed --- go.mod | 8 +++++--- go.sum | 5 +++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 0fea974b1..1e21118ba 100644 --- a/go.mod +++ b/go.mod @@ -12,10 +12,10 @@ require ( github.com/evanphx/json-patch v4.12.0+incompatible github.com/fsnotify/fsnotify v1.5.4 github.com/gogo/protobuf v1.3.2 + github.com/google/cel-go v0.12.5 github.com/google/gnostic v0.5.7-v3refs github.com/google/go-cmp v0.5.9 github.com/google/gofuzz v1.1.0 - github.com/google/cel-go v0.12.5 github.com/google/uuid v1.1.2 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 @@ -37,7 +37,9 @@ require ( golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 + google.golang.org/genproto v0.0.0-20220502173005-c8bf987b8c21 google.golang.org/grpc v1.49.0 + google.golang.org/protobuf v1.28.1 gopkg.in/natefinch/lumberjack.v2 v2.0.0 gopkg.in/square/go-jose.v2 v2.2.2 k8s.io/api v0.0.0 @@ -57,6 +59,7 @@ require ( require ( cloud.google.com/go v0.97.0 // indirect github.com/NYTimes/gziphandler v1.1.1 // indirect + github.com/antlr/antlr4/runtime/Go/antlr v1.4.10 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/blang/semver/v4 v4.0.0 // indirect github.com/cenkalti/backoff/v4 v4.1.3 // indirect @@ -96,6 +99,7 @@ require ( github.com/sirupsen/logrus v1.8.1 // indirect github.com/soheilhy/cmux v0.1.5 // indirect github.com/spf13/cobra v1.5.0 // indirect + github.com/stoewer/go-strcase v1.2.0 // indirect github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802 // indirect github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 // indirect go.etcd.io/bbolt v1.3.6 // indirect @@ -112,8 +116,6 @@ require ( golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect golang.org/x/text v0.3.7 // indirect google.golang.org/appengine v1.6.7 // indirect - google.golang.org/genproto v0.0.0-20220502173005-c8bf987b8c21 // indirect - google.golang.org/protobuf v1.28.1 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 0b87ddd36..d0e91dd29 100644 --- a/go.sum +++ b/go.sum @@ -57,6 +57,8 @@ github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRF github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= +github.com/antlr/antlr4/runtime/Go/antlr v1.4.10 h1:yL7+Jz0jTC6yykIK/Wh74gnTJnrGr5AyrNMXuA0gves= +github.com/antlr/antlr4/runtime/Go/antlr v1.4.10/go.mod h1:F7bn7fEU90QkQ3tnmaTx3LTKLEDqnwWODIYppRQ5hnY= github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= @@ -220,6 +222,8 @@ github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Z github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= +github.com/google/cel-go v0.12.5 h1:DmzaiSgoaqGCjtpPQWl26/gND+yRpim56H1jCVev6d8= +github.com/google/cel-go v0.12.5/go.mod h1:Jk7ljRzLBhkmiAwBoUxB1sZSCVBAzkqPF25olK/iRDw= github.com/google/gnostic v0.5.7-v3refs h1:FhTMOKj2VhjpouxvWJAV1TL304uMlb9zcDqkl6cEI54= github.com/google/gnostic v0.5.7-v3refs/go.mod h1:73MKFl6jIHelAJNaBGFzt3SPtZULs9dYrGFt8OiIsHQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= @@ -439,6 +443,7 @@ github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnIn github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.7.0/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg= +github.com/stoewer/go-strcase v1.2.0 h1:Z2iHWqGXH00XYgqDmNgQbIBxf3wrNq0F3feEy0ainaU= github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= From ac0ce38abe87104bc16a6b4a3b519a8b3277491f Mon Sep 17 00:00:00 2001 From: Jiahui Feng Date: Mon, 10 Oct 2022 14:42:24 -0700 Subject: [PATCH 3/3] use DefaultMaxRequestSizeBytes for maxRequestSizeBytes. Kubernetes-commit: 755f41a185e828d9c64ae3ac37ce829e60592ad1 --- pkg/cel/limits.go | 5 ++--- pkg/server/config.go | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pkg/cel/limits.go b/pkg/cel/limits.go index c38a47cea..7bdb958d0 100644 --- a/pkg/cel/limits.go +++ b/pkg/cel/limits.go @@ -17,9 +17,8 @@ limitations under the License. package cel const ( - // MaxRequestSizeBytes is the largest request that will be accepted is 3MB - // TODO(DangerOnTheRanger): wire in MaxRequestBodyBytes from apiserver/pkg/server/options/server_run_options.go to make this configurable - MaxRequestSizeBytes = int64(3 * 1024 * 1024) + // DefaultMaxRequestSizeBytes is the size of the largest request that will be accepted + DefaultMaxRequestSizeBytes = int64(3 * 1024 * 1024) // MaxDurationSizeJSON // OpenAPI duration strings follow RFC 3339, section 5.6 - see the comment on maxDatetimeSizeJSON diff --git a/pkg/server/config.go b/pkg/server/config.go index 0248fdd0a..1d0753ea5 100644 --- a/pkg/server/config.go +++ b/pkg/server/config.go @@ -366,7 +366,7 @@ func NewConfig(codecs serializer.CodecFactory) *Config { // A request body might be encoded in json, and is converted to // proto when persisted in etcd, so we allow 2x as the largest request // body size to be accepted and decoded in a write request. - // If this constant is changed, maxRequestSizeBytes in apiextensions-apiserver/pkg/apiserver/schema/cel/model/schemas.go + // If this constant is changed, DefaultMaxRequestSizeBytes in k8s.io/apiserver/pkg/cel/limits.go // should be changed to reflect the new value, if the two haven't // been wired together already somehow. MaxRequestBodyBytes: int64(3 * 1024 * 1024),