Merge pull request #112926 from jiahuif-forks/refactor/cel-out-of-apiextensions
split and move CEL package Kubernetes-commit: 61ca612cbb85efa13444a6d8ae517cd5e9c151a4
This commit is contained in:
commit
db8c02bd35
15
go.mod
15
go.mod
|
|
@ -12,6 +12,7 @@ require (
|
|||
github.com/evanphx/json-patch v4.12.0+incompatible
|
||||
github.com/fsnotify/fsnotify v1.5.4
|
||||
github.com/gogo/protobuf v1.3.2
|
||||
github.com/google/cel-go v0.12.5
|
||||
github.com/google/gnostic v0.5.7-v3refs
|
||||
github.com/google/go-cmp v0.5.9
|
||||
github.com/google/gofuzz v1.1.0
|
||||
|
|
@ -36,13 +37,15 @@ require (
|
|||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f
|
||||
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8
|
||||
google.golang.org/genproto v0.0.0-20220502173005-c8bf987b8c21
|
||||
google.golang.org/grpc v1.49.0
|
||||
google.golang.org/protobuf v1.28.1
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.0.0
|
||||
gopkg.in/square/go-jose.v2 v2.2.2
|
||||
k8s.io/api v0.0.0-20221012035047-0f8110492ea0
|
||||
k8s.io/api v0.0.0-20221012115127-0184bd884c5e
|
||||
k8s.io/apimachinery v0.0.0-20221012034848-78d003cc9419
|
||||
k8s.io/client-go v0.0.0-20221012035333-e6d958c7a853
|
||||
k8s.io/component-base v0.0.0-20221012040034-5d2a88c65282
|
||||
k8s.io/component-base v0.0.0-20221012235520-6ecca3322b4e
|
||||
k8s.io/klog/v2 v2.80.1
|
||||
k8s.io/kms v0.0.0-20221012040222-bf322548c086
|
||||
k8s.io/kube-openapi v0.0.0-20220803162953-67bda5d908f1
|
||||
|
|
@ -56,6 +59,7 @@ require (
|
|||
require (
|
||||
cloud.google.com/go v0.97.0 // indirect
|
||||
github.com/NYTimes/gziphandler v1.1.1 // indirect
|
||||
github.com/antlr/antlr4/runtime/Go/antlr v1.4.10 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/blang/semver/v4 v4.0.0 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.1.3 // indirect
|
||||
|
|
@ -95,6 +99,7 @@ require (
|
|||
github.com/sirupsen/logrus v1.8.1 // indirect
|
||||
github.com/soheilhy/cmux v0.1.5 // indirect
|
||||
github.com/spf13/cobra v1.5.0 // indirect
|
||||
github.com/stoewer/go-strcase v1.2.0 // indirect
|
||||
github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802 // indirect
|
||||
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 // indirect
|
||||
go.etcd.io/bbolt v1.3.6 // indirect
|
||||
|
|
@ -111,17 +116,15 @@ require (
|
|||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect
|
||||
golang.org/x/text v0.3.8 // indirect
|
||||
google.golang.org/appengine v1.6.7 // indirect
|
||||
google.golang.org/genproto v0.0.0-20220502173005-c8bf987b8c21 // indirect
|
||||
google.golang.org/protobuf v1.28.1 // indirect
|
||||
gopkg.in/inf.v0 v0.9.1 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
replace (
|
||||
k8s.io/api => k8s.io/api v0.0.0-20221012035047-0f8110492ea0
|
||||
k8s.io/api => k8s.io/api v0.0.0-20221012115127-0184bd884c5e
|
||||
k8s.io/apimachinery => k8s.io/apimachinery v0.0.0-20221012034848-78d003cc9419
|
||||
k8s.io/client-go => k8s.io/client-go v0.0.0-20221012035333-e6d958c7a853
|
||||
k8s.io/component-base => k8s.io/component-base v0.0.0-20221012040034-5d2a88c65282
|
||||
k8s.io/component-base => k8s.io/component-base v0.0.0-20221012235520-6ecca3322b4e
|
||||
k8s.io/kms => k8s.io/kms v0.0.0-20221012040222-bf322548c086
|
||||
)
|
||||
|
|
|
|||
13
go.sum
13
go.sum
|
|
@ -57,6 +57,8 @@ github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRF
|
|||
github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
|
||||
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho=
|
||||
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
|
||||
github.com/antlr/antlr4/runtime/Go/antlr v1.4.10 h1:yL7+Jz0jTC6yykIK/Wh74gnTJnrGr5AyrNMXuA0gves=
|
||||
github.com/antlr/antlr4/runtime/Go/antlr v1.4.10/go.mod h1:F7bn7fEU90QkQ3tnmaTx3LTKLEDqnwWODIYppRQ5hnY=
|
||||
github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o=
|
||||
github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY=
|
||||
github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
|
||||
|
|
@ -220,6 +222,8 @@ github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Z
|
|||
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
|
||||
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
|
||||
github.com/google/cel-go v0.12.5 h1:DmzaiSgoaqGCjtpPQWl26/gND+yRpim56H1jCVev6d8=
|
||||
github.com/google/cel-go v0.12.5/go.mod h1:Jk7ljRzLBhkmiAwBoUxB1sZSCVBAzkqPF25olK/iRDw=
|
||||
github.com/google/gnostic v0.5.7-v3refs h1:FhTMOKj2VhjpouxvWJAV1TL304uMlb9zcDqkl6cEI54=
|
||||
github.com/google/gnostic v0.5.7-v3refs/go.mod h1:73MKFl6jIHelAJNaBGFzt3SPtZULs9dYrGFt8OiIsHQ=
|
||||
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
|
||||
|
|
@ -439,6 +443,7 @@ github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnIn
|
|||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.7.0/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg=
|
||||
github.com/stoewer/go-strcase v1.2.0 h1:Z2iHWqGXH00XYgqDmNgQbIBxf3wrNq0F3feEy0ainaU=
|
||||
github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
|
|
@ -977,14 +982,14 @@ honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWh
|
|||
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
|
||||
honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
|
||||
honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
|
||||
k8s.io/api v0.0.0-20221012035047-0f8110492ea0 h1:zsKtQXVrOo62IwClRKcF/RbgeFI2FQFZ99YhdfypWn8=
|
||||
k8s.io/api v0.0.0-20221012035047-0f8110492ea0/go.mod h1:Q2jmki3lhHeeAyYtbIiEe/2JoGg2Ge1cJoMgpFkjtzg=
|
||||
k8s.io/api v0.0.0-20221012115127-0184bd884c5e h1:N2+121lkGMNfLeos4/AlwuujPJd5xJfFkyi15BUzob8=
|
||||
k8s.io/api v0.0.0-20221012115127-0184bd884c5e/go.mod h1:Q2jmki3lhHeeAyYtbIiEe/2JoGg2Ge1cJoMgpFkjtzg=
|
||||
k8s.io/apimachinery v0.0.0-20221012034848-78d003cc9419 h1:6islJrEgy0CM8YsBcMFdlxm1lXtVC9X5A2AVRK7JEpc=
|
||||
k8s.io/apimachinery v0.0.0-20221012034848-78d003cc9419/go.mod h1:1b4APDhID8eky0PLpgeoWtdLUptwwJD8Jk5uFOd22CE=
|
||||
k8s.io/client-go v0.0.0-20221012035333-e6d958c7a853 h1:MA0fOvlQt7fzsEIujzUH4ssakSCKB6+iWRclNvRYAnk=
|
||||
k8s.io/client-go v0.0.0-20221012035333-e6d958c7a853/go.mod h1:UVW05h29JMuginXCTvlNPTFVS/AyKlU3Q4XoA6UidtU=
|
||||
k8s.io/component-base v0.0.0-20221012040034-5d2a88c65282 h1:JqSl7JGZ8FwUFz9Isq0YIJNLx5oK5x/g0Vl2VtVntI4=
|
||||
k8s.io/component-base v0.0.0-20221012040034-5d2a88c65282/go.mod h1:t3mMmVABnDZIpNgn7xDM2AdNUr2jvWA8pleZlSDjAug=
|
||||
k8s.io/component-base v0.0.0-20221012235520-6ecca3322b4e h1:Gg6uxVzmcw4M5KKG4jy6mE+CahWF6VvA39YUsXTaJcI=
|
||||
k8s.io/component-base v0.0.0-20221012235520-6ecca3322b4e/go.mod h1:nOngF+6Y18fCD8T0L9mxX2FDBo3vzC1t6hsFv6DE3lI=
|
||||
k8s.io/klog/v2 v2.80.1 h1:atnLQ121W371wYYFawwYx1aEY2eUfs4l3J72wtgAwV4=
|
||||
k8s.io/klog/v2 v2.80.1/go.mod h1:y1WjHnz7Dj687irZUWR/WLkLc5N1YHtjLdmgWjndZn0=
|
||||
k8s.io/kms v0.0.0-20221012040222-bf322548c086 h1:BCrccAO6WPPwlzfgds9ERms+pPGvNJ4QpakEuqEQ5rk=
|
||||
|
|
|
|||
|
|
@ -0,0 +1,47 @@
|
|||
/*
|
||||
Copyright 2021 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package cel
|
||||
|
||||
// Error is an implementation of the 'error' interface, which represents a
|
||||
// XValidation error.
|
||||
type Error struct {
|
||||
Type ErrorType
|
||||
Detail string
|
||||
}
|
||||
|
||||
var _ error = &Error{}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (v *Error) Error() string {
|
||||
return v.Detail
|
||||
}
|
||||
|
||||
// ErrorType is a machine readable value providing more detail about why
|
||||
// a XValidation is invalid.
|
||||
type ErrorType string
|
||||
|
||||
const (
|
||||
// ErrorTypeRequired is used to report withNullable values that are not
|
||||
// provided (e.g. empty strings, null values, or empty arrays). See
|
||||
// Required().
|
||||
ErrorTypeRequired ErrorType = "RuleRequired"
|
||||
// ErrorTypeInvalid is used to report malformed values
|
||||
ErrorTypeInvalid ErrorType = "RuleInvalid"
|
||||
// ErrorTypeInternal is used to report other errors that are not related
|
||||
// to user input. See InternalError().
|
||||
ErrorTypeInternal ErrorType = "InternalError"
|
||||
)
|
||||
|
|
@ -0,0 +1,170 @@
|
|||
/*
|
||||
Copyright 2021 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package cel
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
|
||||
"k8s.io/apimachinery/pkg/util/sets"
|
||||
)
|
||||
|
||||
// celReservedSymbols is a list of RESERVED symbols defined in the CEL lexer.
|
||||
// No identifiers are allowed to collide with these symbols.
|
||||
// https://github.com/google/cel-spec/blob/master/doc/langdef.md#syntax
|
||||
var celReservedSymbols = sets.NewString(
|
||||
"true", "false", "null", "in",
|
||||
"as", "break", "const", "continue", "else",
|
||||
"for", "function", "if", "import", "let",
|
||||
"loop", "package", "namespace", "return", // !! 'namespace' is used heavily in Kubernetes
|
||||
"var", "void", "while",
|
||||
)
|
||||
|
||||
// expandMatcher matches the escape sequence, characters that are escaped, and characters that are unsupported
|
||||
var expandMatcher = regexp.MustCompile(`(__|[-./]|[^a-zA-Z0-9-./_])`)
|
||||
|
||||
// newCharacterFilter returns a boolean array to indicate the allowed characters
|
||||
func newCharacterFilter(characters string) []bool {
|
||||
maxChar := 0
|
||||
for _, c := range characters {
|
||||
if maxChar < int(c) {
|
||||
maxChar = int(c)
|
||||
}
|
||||
}
|
||||
filter := make([]bool, maxChar+1)
|
||||
|
||||
for _, c := range characters {
|
||||
filter[int(c)] = true
|
||||
}
|
||||
|
||||
return filter
|
||||
}
|
||||
|
||||
type escapeCheck struct {
|
||||
canSkipRegex bool
|
||||
invalidCharFound bool
|
||||
}
|
||||
|
||||
// skipRegexCheck checks if escape would be skipped.
|
||||
// if invalidCharFound is true, it must have invalid character; if invalidCharFound is false, not sure if it has invalid character or not
|
||||
func skipRegexCheck(ident string) escapeCheck {
|
||||
escapeCheck := escapeCheck{canSkipRegex: true, invalidCharFound: false}
|
||||
// skip escape if possible
|
||||
previous_underscore := false
|
||||
for _, c := range ident {
|
||||
if c == '/' || c == '-' || c == '.' {
|
||||
escapeCheck.canSkipRegex = false
|
||||
return escapeCheck
|
||||
}
|
||||
intc := int(c)
|
||||
if intc < 0 || intc >= len(validCharacterFilter) || !validCharacterFilter[intc] {
|
||||
escapeCheck.invalidCharFound = true
|
||||
return escapeCheck
|
||||
}
|
||||
if c == '_' && previous_underscore {
|
||||
escapeCheck.canSkipRegex = false
|
||||
return escapeCheck
|
||||
}
|
||||
|
||||
previous_underscore = c == '_'
|
||||
}
|
||||
return escapeCheck
|
||||
}
|
||||
|
||||
// validCharacterFilter indicates the allowed characters.
|
||||
var validCharacterFilter = newCharacterFilter("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_")
|
||||
|
||||
// Escape escapes ident and returns a CEL identifier (of the form '[a-zA-Z_][a-zA-Z0-9_]*'), or returns
|
||||
// false if the ident does not match the supported input format of `[a-zA-Z_.-/][a-zA-Z0-9_.-/]*`.
|
||||
// Escaping Rules:
|
||||
// - '__' escapes to '__underscores__'
|
||||
// - '.' escapes to '__dot__'
|
||||
// - '-' escapes to '__dash__'
|
||||
// - '/' escapes to '__slash__'
|
||||
// - Identifiers that exactly match a CEL RESERVED keyword escape to '__{keyword}__'. The keywords are: "true", "false",
|
||||
// "null", "in", "as", "break", "const", "continue", "else", "for", "function", "if", "import", "let", loop", "package",
|
||||
// "namespace", "return".
|
||||
func Escape(ident string) (string, bool) {
|
||||
if len(ident) == 0 || ('0' <= ident[0] && ident[0] <= '9') {
|
||||
return "", false
|
||||
}
|
||||
if celReservedSymbols.Has(ident) {
|
||||
return "__" + ident + "__", true
|
||||
}
|
||||
|
||||
escapeCheck := skipRegexCheck(ident)
|
||||
if escapeCheck.invalidCharFound {
|
||||
return "", false
|
||||
}
|
||||
if escapeCheck.canSkipRegex {
|
||||
return ident, true
|
||||
}
|
||||
|
||||
ok := true
|
||||
ident = expandMatcher.ReplaceAllStringFunc(ident, func(s string) string {
|
||||
switch s {
|
||||
case "__":
|
||||
return "__underscores__"
|
||||
case ".":
|
||||
return "__dot__"
|
||||
case "-":
|
||||
return "__dash__"
|
||||
case "/":
|
||||
return "__slash__"
|
||||
default: // matched a unsupported supported
|
||||
ok = false
|
||||
return ""
|
||||
}
|
||||
})
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return ident, true
|
||||
}
|
||||
|
||||
var unexpandMatcher = regexp.MustCompile(`(_{2}[^_]+_{2})`)
|
||||
|
||||
// Unescape unescapes an CEL identifier containing the escape sequences described in Escape, or return false if the
|
||||
// string contains invalid escape sequences. The escaped input is expected to be a valid CEL identifier, but is
|
||||
// not checked.
|
||||
func Unescape(escaped string) (string, bool) {
|
||||
ok := true
|
||||
escaped = unexpandMatcher.ReplaceAllStringFunc(escaped, func(s string) string {
|
||||
contents := s[2 : len(s)-2]
|
||||
switch contents {
|
||||
case "underscores":
|
||||
return "__"
|
||||
case "dot":
|
||||
return "."
|
||||
case "dash":
|
||||
return "-"
|
||||
case "slash":
|
||||
return "/"
|
||||
}
|
||||
if celReservedSymbols.Has(contents) {
|
||||
if len(s) != len(escaped) {
|
||||
ok = false
|
||||
}
|
||||
return contents
|
||||
}
|
||||
ok = false
|
||||
return ""
|
||||
})
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return escaped, true
|
||||
}
|
||||
|
|
@ -0,0 +1,206 @@
|
|||
/*
|
||||
Copyright 2021 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package cel
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
fuzz "github.com/google/gofuzz"
|
||||
)
|
||||
|
||||
// TestEscaping tests that property names are escaped as expected.
|
||||
func TestEscaping(t *testing.T) {
|
||||
cases := []struct {
|
||||
unescaped string
|
||||
escaped string
|
||||
unescapable bool
|
||||
}{
|
||||
// '.', '-', '/' and '__' are escaped since
|
||||
// CEL only allows identifiers of the form: [a-zA-Z_][a-zA-Z0-9_]*
|
||||
{unescaped: "a.a", escaped: "a__dot__a"},
|
||||
{unescaped: "a-a", escaped: "a__dash__a"},
|
||||
{unescaped: "a__a", escaped: "a__underscores__a"},
|
||||
{unescaped: "a.-/__a", escaped: "a__dot____dash____slash____underscores__a"},
|
||||
{unescaped: "a._a", escaped: "a__dot___a"},
|
||||
{unescaped: "a__.__a", escaped: "a__underscores____dot____underscores__a"},
|
||||
{unescaped: "a___a", escaped: "a__underscores___a"},
|
||||
{unescaped: "a____a", escaped: "a__underscores____underscores__a"},
|
||||
{unescaped: "a__dot__a", escaped: "a__underscores__dot__underscores__a"},
|
||||
{unescaped: "a__underscores__a", escaped: "a__underscores__underscores__underscores__a"},
|
||||
// CEL lexer RESERVED keywords must be escaped
|
||||
{unescaped: "true", escaped: "__true__"},
|
||||
{unescaped: "false", escaped: "__false__"},
|
||||
{unescaped: "null", escaped: "__null__"},
|
||||
{unescaped: "in", escaped: "__in__"},
|
||||
{unescaped: "as", escaped: "__as__"},
|
||||
{unescaped: "break", escaped: "__break__"},
|
||||
{unescaped: "const", escaped: "__const__"},
|
||||
{unescaped: "continue", escaped: "__continue__"},
|
||||
{unescaped: "else", escaped: "__else__"},
|
||||
{unescaped: "for", escaped: "__for__"},
|
||||
{unescaped: "function", escaped: "__function__"},
|
||||
{unescaped: "if", escaped: "__if__"},
|
||||
{unescaped: "import", escaped: "__import__"},
|
||||
{unescaped: "let", escaped: "__let__"},
|
||||
{unescaped: "loop", escaped: "__loop__"},
|
||||
{unescaped: "package", escaped: "__package__"},
|
||||
{unescaped: "namespace", escaped: "__namespace__"},
|
||||
{unescaped: "return", escaped: "__return__"},
|
||||
{unescaped: "var", escaped: "__var__"},
|
||||
{unescaped: "void", escaped: "__void__"},
|
||||
{unescaped: "while", escaped: "__while__"},
|
||||
// Not all property names are escapable
|
||||
{unescaped: "@", unescapable: true},
|
||||
{unescaped: "1up", unescapable: true},
|
||||
{unescaped: "👑", unescapable: true},
|
||||
// CEL macro and function names do not need to be escaped because the parser keeps identifiers in a
|
||||
// different namespace than function and macro names.
|
||||
{unescaped: "has", escaped: "has"},
|
||||
{unescaped: "all", escaped: "all"},
|
||||
{unescaped: "exists", escaped: "exists"},
|
||||
{unescaped: "exists_one", escaped: "exists_one"},
|
||||
{unescaped: "filter", escaped: "filter"},
|
||||
{unescaped: "size", escaped: "size"},
|
||||
{unescaped: "contains", escaped: "contains"},
|
||||
{unescaped: "startsWith", escaped: "startsWith"},
|
||||
{unescaped: "endsWith", escaped: "endsWith"},
|
||||
{unescaped: "matches", escaped: "matches"},
|
||||
{unescaped: "duration", escaped: "duration"},
|
||||
{unescaped: "timestamp", escaped: "timestamp"},
|
||||
{unescaped: "getDate", escaped: "getDate"},
|
||||
{unescaped: "getDayOfMonth", escaped: "getDayOfMonth"},
|
||||
{unescaped: "getDayOfWeek", escaped: "getDayOfWeek"},
|
||||
{unescaped: "getFullYear", escaped: "getFullYear"},
|
||||
{unescaped: "getHours", escaped: "getHours"},
|
||||
{unescaped: "getMilliseconds", escaped: "getMilliseconds"},
|
||||
{unescaped: "getMinutes", escaped: "getMinutes"},
|
||||
{unescaped: "getMonth", escaped: "getMonth"},
|
||||
{unescaped: "getSeconds", escaped: "getSeconds"},
|
||||
// we don't escape a single _
|
||||
{unescaped: "_if", escaped: "_if"},
|
||||
{unescaped: "_has", escaped: "_has"},
|
||||
{unescaped: "_int", escaped: "_int"},
|
||||
{unescaped: "_anything", escaped: "_anything"},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.unescaped, func(t *testing.T) {
|
||||
e, escapable := Escape(tc.unescaped)
|
||||
if tc.unescapable {
|
||||
if escapable {
|
||||
t.Errorf("Expected escapable=false, but got %t", escapable)
|
||||
}
|
||||
return
|
||||
}
|
||||
if !escapable {
|
||||
t.Fatalf("Expected escapable=true, but got %t", escapable)
|
||||
}
|
||||
if tc.escaped != e {
|
||||
t.Errorf("Expected %s to escape to %s, but got %s", tc.unescaped, tc.escaped, e)
|
||||
}
|
||||
|
||||
if !validCelIdent.MatchString(e) {
|
||||
t.Errorf("Expected %s to escape to a valid CEL identifier, but got %s", tc.unescaped, e)
|
||||
}
|
||||
|
||||
u, ok := Unescape(tc.escaped)
|
||||
if !ok {
|
||||
t.Fatalf("Expected %s to be escapable, but it was not", tc.escaped)
|
||||
}
|
||||
if tc.unescaped != u {
|
||||
t.Errorf("Expected %s to unescape to %s, but got %s", tc.escaped, tc.unescaped, u)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnescapeMalformed(t *testing.T) {
|
||||
for _, s := range []string{"__int__extra", "__illegal__"} {
|
||||
t.Run(s, func(t *testing.T) {
|
||||
e, ok := Unescape(s)
|
||||
if ok {
|
||||
t.Fatalf("Expected %s to be unescapable, but it escaped to: %s", s, e)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEscapingFuzz(t *testing.T) {
|
||||
fuzzer := fuzz.New()
|
||||
for i := 0; i < 1000; i++ {
|
||||
var unescaped string
|
||||
fuzzer.Fuzz(&unescaped)
|
||||
t.Run(fmt.Sprintf("%d - '%s'", i, unescaped), func(t *testing.T) {
|
||||
if len(unescaped) == 0 {
|
||||
return
|
||||
}
|
||||
escaped, ok := Escape(unescaped)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if !validCelIdent.MatchString(escaped) {
|
||||
t.Errorf("Expected %s to escape to a valid CEL identifier, but got %s", unescaped, escaped)
|
||||
}
|
||||
u, ok := Unescape(escaped)
|
||||
if !ok {
|
||||
t.Fatalf("Expected %s to be unescapable, but it was not", escaped)
|
||||
}
|
||||
if unescaped != u {
|
||||
t.Errorf("Expected %s to unescape to %s, but got %s", escaped, unescaped, u)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var validCelIdent = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
|
||||
|
||||
func TestCanSkipRegex(t *testing.T) {
|
||||
cases := []struct {
|
||||
unescaped string
|
||||
canSkip bool
|
||||
invalidCharFound bool
|
||||
}{
|
||||
{unescaped: "a.a", canSkip: false},
|
||||
{unescaped: "a-a", canSkip: false},
|
||||
{unescaped: "a__a", canSkip: false},
|
||||
{unescaped: "a.-/__a", canSkip: false},
|
||||
{unescaped: "a_a", canSkip: true},
|
||||
{unescaped: "a_a_a", canSkip: true},
|
||||
{unescaped: "@", invalidCharFound: true},
|
||||
{unescaped: "👑", invalidCharFound: true},
|
||||
// if escape keyword is detected before invalid character, invalidCharFound
|
||||
{unescaped: "/👑", canSkip: false},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.unescaped, func(t *testing.T) {
|
||||
escapeCheck := skipRegexCheck(tc.unescaped)
|
||||
if escapeCheck.invalidCharFound {
|
||||
if escapeCheck.invalidCharFound != tc.invalidCharFound {
|
||||
t.Errorf("Expected input validation to be %v, but got %t", tc.invalidCharFound, escapeCheck.invalidCharFound)
|
||||
}
|
||||
} else {
|
||||
if escapeCheck.canSkipRegex != tc.canSkip {
|
||||
t.Errorf("Expected %v, but got %t", tc.canSkip, escapeCheck.canSkipRegex)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,268 @@
|
|||
/*
|
||||
Copyright 2022 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package library
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/google/cel-go/checker"
|
||||
"github.com/google/cel-go/common"
|
||||
"github.com/google/cel-go/common/types"
|
||||
"github.com/google/cel-go/common/types/ref"
|
||||
"github.com/google/cel-go/common/types/traits"
|
||||
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
)
|
||||
|
||||
// CostEstimator implements CEL's interpretable.ActualCostEstimator and checker.CostEstimator.
|
||||
type CostEstimator struct {
|
||||
// SizeEstimator provides a CostEstimator.EstimateSize that this CostEstimator will delegate size estimation
|
||||
// calculations to if the size is not well known (i.e. a constant).
|
||||
SizeEstimator checker.CostEstimator
|
||||
}
|
||||
|
||||
func (l *CostEstimator) CallCost(function, overloadId string, args []ref.Val, result ref.Val) *uint64 {
|
||||
switch function {
|
||||
case "isSorted", "sum", "max", "min", "indexOf", "lastIndexOf":
|
||||
var cost uint64
|
||||
if len(args) > 0 {
|
||||
cost += traversalCost(args[0]) // these O(n) operations all cost roughly the cost of a single traversal
|
||||
}
|
||||
return &cost
|
||||
case "url", "lowerAscii", "upperAscii", "substring", "trim":
|
||||
if len(args) >= 1 {
|
||||
cost := uint64(math.Ceil(float64(actualSize(args[0])) * common.StringTraversalCostFactor))
|
||||
return &cost
|
||||
}
|
||||
case "replace", "split":
|
||||
if len(args) >= 1 {
|
||||
// cost is the traversal plus the construction of the result
|
||||
cost := uint64(math.Ceil(float64(actualSize(args[0])) * 2 * common.StringTraversalCostFactor))
|
||||
return &cost
|
||||
}
|
||||
case "join":
|
||||
if len(args) >= 1 {
|
||||
cost := uint64(math.Ceil(float64(actualSize(result)) * 2 * common.StringTraversalCostFactor))
|
||||
return &cost
|
||||
}
|
||||
case "find", "findAll":
|
||||
if len(args) >= 2 {
|
||||
strCost := uint64(math.Ceil((1.0 + float64(actualSize(args[0]))) * common.StringTraversalCostFactor))
|
||||
// We don't know how many expressions are in the regex, just the string length (a huge
|
||||
// improvement here would be to somehow get a count the number of expressions in the regex or
|
||||
// how many states are in the regex state machine and use that to measure regex cost).
|
||||
// For now, we're making a guess that each expression in a regex is typically at least 4 chars
|
||||
// in length.
|
||||
regexCost := uint64(math.Ceil(float64(actualSize(args[1])) * common.RegexStringLengthCostFactor))
|
||||
cost := strCost * regexCost
|
||||
return &cost
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
|
||||
// WARNING: Any changes to this code impact API compatibility! The estimated cost is used to determine which CEL rules may be written to a
|
||||
// CRD and any change (cost increases and cost decreases) are breaking.
|
||||
switch function {
|
||||
case "isSorted", "sum", "max", "min", "indexOf", "lastIndexOf":
|
||||
if target != nil {
|
||||
// Charge 1 cost for comparing each element in the list
|
||||
elCost := checker.CostEstimate{Min: 1, Max: 1}
|
||||
// If the list contains strings or bytes, add the cost of traversing all the strings/bytes as a way
|
||||
// of estimating the additional comparison cost.
|
||||
if elNode := l.listElementNode(*target); elNode != nil {
|
||||
t := elNode.Type().GetPrimitive()
|
||||
if t == exprpb.Type_STRING || t == exprpb.Type_BYTES {
|
||||
sz := l.sizeEstimate(elNode)
|
||||
elCost = elCost.Add(sz.MultiplyByCostFactor(common.StringTraversalCostFactor))
|
||||
}
|
||||
return &checker.CallEstimate{CostEstimate: l.sizeEstimate(*target).MultiplyByCost(elCost)}
|
||||
} else { // the target is a string, which is supported by indexOf and lastIndexOf
|
||||
return &checker.CallEstimate{CostEstimate: l.sizeEstimate(*target).MultiplyByCostFactor(common.StringTraversalCostFactor)}
|
||||
}
|
||||
|
||||
}
|
||||
case "url":
|
||||
if len(args) == 1 {
|
||||
sz := l.sizeEstimate(args[0])
|
||||
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor)}
|
||||
}
|
||||
case "lowerAscii", "upperAscii", "substring", "trim":
|
||||
if target != nil {
|
||||
sz := l.sizeEstimate(*target)
|
||||
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor), ResultSize: &sz}
|
||||
}
|
||||
case "replace":
|
||||
if target != nil && len(args) >= 2 {
|
||||
sz := l.sizeEstimate(*target)
|
||||
toReplaceSz := l.sizeEstimate(args[0])
|
||||
replaceWithSz := l.sizeEstimate(args[1])
|
||||
// smallest possible result: smallest input size composed of the largest possible substrings being replaced by smallest possible replacement
|
||||
minSz := uint64(math.Ceil(float64(sz.Min)/float64(toReplaceSz.Max))) * replaceWithSz.Min
|
||||
// largest possible result: largest input size composed of the smallest possible substrings being replaced by largest possible replacement
|
||||
maxSz := uint64(math.Ceil(float64(sz.Max)/float64(toReplaceSz.Min))) * replaceWithSz.Max
|
||||
|
||||
// cost is the traversal plus the construction of the result
|
||||
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(2 * common.StringTraversalCostFactor), ResultSize: &checker.SizeEstimate{Min: minSz, Max: maxSz}}
|
||||
}
|
||||
case "split":
|
||||
if target != nil {
|
||||
sz := l.sizeEstimate(*target)
|
||||
|
||||
// Worst case size is where is that a separator of "" is used, and each char is returned as a list element.
|
||||
max := sz.Max
|
||||
if len(args) > 1 {
|
||||
if c := args[1].Expr().GetConstExpr(); c != nil {
|
||||
max = uint64(c.GetInt64Value())
|
||||
}
|
||||
}
|
||||
// Cost is the traversal plus the construction of the result.
|
||||
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(2 * common.StringTraversalCostFactor), ResultSize: &checker.SizeEstimate{Min: 0, Max: max}}
|
||||
}
|
||||
case "join":
|
||||
if target != nil {
|
||||
var sz checker.SizeEstimate
|
||||
listSize := l.sizeEstimate(*target)
|
||||
if elNode := l.listElementNode(*target); elNode != nil {
|
||||
elemSize := l.sizeEstimate(elNode)
|
||||
sz = listSize.Multiply(elemSize)
|
||||
}
|
||||
|
||||
if len(args) > 0 {
|
||||
sepSize := l.sizeEstimate(args[0])
|
||||
minSeparators := uint64(0)
|
||||
maxSeparators := uint64(0)
|
||||
if listSize.Min > 0 {
|
||||
minSeparators = listSize.Min - 1
|
||||
}
|
||||
if listSize.Max > 0 {
|
||||
maxSeparators = listSize.Max - 1
|
||||
}
|
||||
sz = sz.Add(sepSize.Multiply(checker.SizeEstimate{Min: minSeparators, Max: maxSeparators}))
|
||||
}
|
||||
|
||||
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor), ResultSize: &sz}
|
||||
}
|
||||
case "find", "findAll":
|
||||
if target != nil && len(args) >= 1 {
|
||||
sz := l.sizeEstimate(*target)
|
||||
// Add one to string length for purposes of cost calculation to prevent product of string and regex to be 0
|
||||
// in case where string is empty but regex is still expensive.
|
||||
strCost := sz.Add(checker.SizeEstimate{Min: 1, Max: 1}).MultiplyByCostFactor(common.StringTraversalCostFactor)
|
||||
// We don't know how many expressions are in the regex, just the string length (a huge
|
||||
// improvement here would be to somehow get a count the number of expressions in the regex or
|
||||
// how many states are in the regex state machine and use that to measure regex cost).
|
||||
// For now, we're making a guess that each expression in a regex is typically at least 4 chars
|
||||
// in length.
|
||||
regexCost := l.sizeEstimate(args[0]).MultiplyByCostFactor(common.RegexStringLengthCostFactor)
|
||||
// worst case size of result is that every char is returned as separate find result.
|
||||
return &checker.CallEstimate{CostEstimate: strCost.Multiply(regexCost), ResultSize: &checker.SizeEstimate{Min: 0, Max: sz.Max}}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func actualSize(value ref.Val) uint64 {
|
||||
if sz, ok := value.(traits.Sizer); ok {
|
||||
return uint64(sz.Size().(types.Int))
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func (l *CostEstimator) sizeEstimate(t checker.AstNode) checker.SizeEstimate {
|
||||
if sz := t.ComputedSize(); sz != nil {
|
||||
return *sz
|
||||
}
|
||||
if sz := l.EstimateSize(t); sz != nil {
|
||||
return *sz
|
||||
}
|
||||
return checker.SizeEstimate{Min: 0, Max: math.MaxUint64}
|
||||
}
|
||||
|
||||
func (l *CostEstimator) listElementNode(list checker.AstNode) checker.AstNode {
|
||||
if lt := list.Type().GetListType(); lt != nil {
|
||||
nodePath := list.Path()
|
||||
if nodePath != nil {
|
||||
// Provide path if we have it so that a OpenAPIv3 maxLength validation can be looked up, if it exists
|
||||
// for this node.
|
||||
path := make([]string, len(nodePath)+1)
|
||||
copy(path, nodePath)
|
||||
path[len(nodePath)] = "@items"
|
||||
return &itemsNode{path: path, t: lt.GetElemType(), expr: nil}
|
||||
} else {
|
||||
// Provide just the type if no path is available so that worst case size can be looked up based on type.
|
||||
return &itemsNode{t: lt.GetElemType(), expr: nil}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *CostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate {
|
||||
if l.SizeEstimator != nil {
|
||||
return l.SizeEstimator.EstimateSize(element)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type itemsNode struct {
|
||||
path []string
|
||||
t *exprpb.Type
|
||||
expr *exprpb.Expr
|
||||
}
|
||||
|
||||
func (i *itemsNode) Path() []string {
|
||||
return i.path
|
||||
}
|
||||
|
||||
func (i *itemsNode) Type() *exprpb.Type {
|
||||
return i.t
|
||||
}
|
||||
|
||||
func (i *itemsNode) Expr() *exprpb.Expr {
|
||||
return i.expr
|
||||
}
|
||||
|
||||
func (i *itemsNode) ComputedSize() *checker.SizeEstimate {
|
||||
return nil
|
||||
}
|
||||
|
||||
// traversalCost computes the cost of traversing a ref.Val as a data tree.
|
||||
func traversalCost(v ref.Val) uint64 {
|
||||
// TODO: This could potentially be optimized by sampling maps and lists instead of traversing.
|
||||
switch vt := v.(type) {
|
||||
case types.String:
|
||||
return uint64(float64(len(string(vt))) * common.StringTraversalCostFactor)
|
||||
case types.Bytes:
|
||||
return uint64(float64(len([]byte(vt))) * common.StringTraversalCostFactor)
|
||||
case traits.Lister:
|
||||
cost := uint64(0)
|
||||
for it := vt.Iterator(); it.HasNext() == types.True; {
|
||||
i := it.Next()
|
||||
cost += traversalCost(i)
|
||||
}
|
||||
return cost
|
||||
case traits.Mapper: // maps and objects
|
||||
cost := uint64(0)
|
||||
for it := vt.Iterator(); it.HasNext() == types.True; {
|
||||
k := it.Next()
|
||||
cost += traversalCost(k) + traversalCost(vt.Get(k))
|
||||
}
|
||||
return cost
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,363 @@
|
|||
/*
|
||||
Copyright 2022 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package library
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/google/cel-go/checker"
|
||||
"github.com/google/cel-go/ext"
|
||||
expr "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
)
|
||||
|
||||
const (
|
||||
intListLiteral = "[1, 2, 3, 4, 5]"
|
||||
uintListLiteral = "[uint(1), uint(2), uint(3), uint(4), uint(5)]"
|
||||
doubleListLiteral = "[1.0, 2.0, 3.0, 4.0, 5.0]"
|
||||
boolListLiteral = "[false, true, false, true, false]"
|
||||
stringListLiteral = "['012345678901', '012345678901', '012345678901', '012345678901', '012345678901']"
|
||||
bytesListLiteral = "[bytes('012345678901'), bytes('012345678901'), bytes('012345678901'), bytes('012345678901'), bytes('012345678901')]"
|
||||
durationListLiteral = "[duration('1s'), duration('2s'), duration('3s'), duration('4s'), duration('5s')]"
|
||||
timestampListLiteral = "[timestamp('2011-01-01T00:00:00.000+01:00'), timestamp('2011-01-02T00:00:00.000+01:00'), " +
|
||||
"timestamp('2011-01-03T00:00:00.000+01:00'), timestamp('2011-01-04T00:00:00.000+01:00'), " +
|
||||
"timestamp('2011-01-05T00:00:00.000+01:00')]"
|
||||
stringLiteral = "'01234567890123456789012345678901234567890123456789'"
|
||||
)
|
||||
|
||||
type comparableCost struct {
|
||||
comparableLiteral string
|
||||
expectedEstimatedCost checker.CostEstimate
|
||||
expectedRuntimeCost uint64
|
||||
|
||||
param string
|
||||
}
|
||||
|
||||
func TestListsCost(t *testing.T) {
|
||||
cases := []struct {
|
||||
opts []string
|
||||
costs []comparableCost
|
||||
}{
|
||||
{
|
||||
opts: []string{".sum()"},
|
||||
// 10 cost for the list declaration, the rest is the due to the function call
|
||||
costs: []comparableCost{
|
||||
{
|
||||
comparableLiteral: intListLiteral,
|
||||
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
|
||||
},
|
||||
{
|
||||
comparableLiteral: uintListLiteral,
|
||||
expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for casts
|
||||
},
|
||||
{
|
||||
comparableLiteral: doubleListLiteral,
|
||||
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
|
||||
},
|
||||
{
|
||||
comparableLiteral: durationListLiteral,
|
||||
expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for casts
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
opts: []string{".isSorted()", ".max()", ".min()"},
|
||||
// 10 cost for the list declaration, the rest is the due to the function call
|
||||
costs: []comparableCost{
|
||||
{
|
||||
comparableLiteral: intListLiteral,
|
||||
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
|
||||
},
|
||||
{
|
||||
comparableLiteral: uintListLiteral,
|
||||
expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for numeric casts
|
||||
},
|
||||
{
|
||||
comparableLiteral: doubleListLiteral,
|
||||
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
|
||||
},
|
||||
{
|
||||
comparableLiteral: boolListLiteral,
|
||||
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
|
||||
},
|
||||
{
|
||||
comparableLiteral: stringListLiteral,
|
||||
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 25}, expectedRuntimeCost: 15, // +5 for string comparisons
|
||||
},
|
||||
{
|
||||
comparableLiteral: bytesListLiteral,
|
||||
expectedEstimatedCost: checker.CostEstimate{Min: 25, Max: 35}, expectedRuntimeCost: 25, // +10 for casts from string to byte, +5 for byte comparisons
|
||||
},
|
||||
{
|
||||
comparableLiteral: durationListLiteral,
|
||||
expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for numeric casts
|
||||
},
|
||||
{
|
||||
comparableLiteral: timestampListLiteral,
|
||||
expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for casts
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
for _, op := range tc.opts {
|
||||
for _, typ := range tc.costs {
|
||||
t.Run(typ.comparableLiteral+op, func(t *testing.T) {
|
||||
e := typ.comparableLiteral + op
|
||||
testCost(t, e, typ.expectedEstimatedCost, typ.expectedRuntimeCost)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndexOfCost(t *testing.T) {
|
||||
cases := []struct {
|
||||
opts []string
|
||||
costs []comparableCost
|
||||
}{
|
||||
{
|
||||
opts: []string{".indexOf(%s)", ".lastIndexOf(%s)"},
|
||||
// 10 cost for the list declaration, the rest is the due to the function call
|
||||
costs: []comparableCost{
|
||||
{
|
||||
comparableLiteral: intListLiteral, param: "3",
|
||||
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
|
||||
},
|
||||
{
|
||||
comparableLiteral: uintListLiteral, param: "uint(3)",
|
||||
expectedEstimatedCost: checker.CostEstimate{Min: 21, Max: 21}, expectedRuntimeCost: 21, // +5 for numeric casts
|
||||
},
|
||||
{
|
||||
comparableLiteral: doubleListLiteral, param: "3.0",
|
||||
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
|
||||
},
|
||||
{
|
||||
comparableLiteral: boolListLiteral, param: "true",
|
||||
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
|
||||
},
|
||||
{
|
||||
comparableLiteral: stringListLiteral, param: "'x'",
|
||||
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 25}, expectedRuntimeCost: 15, // +5 for string comparisons
|
||||
},
|
||||
{
|
||||
comparableLiteral: bytesListLiteral, param: "bytes('x')",
|
||||
expectedEstimatedCost: checker.CostEstimate{Min: 26, Max: 36}, expectedRuntimeCost: 26, // +11 for casts from string to byte, +5 for byte comparisons
|
||||
},
|
||||
{
|
||||
comparableLiteral: durationListLiteral, param: "duration('3s')",
|
||||
expectedEstimatedCost: checker.CostEstimate{Min: 21, Max: 21}, expectedRuntimeCost: 21, // +6 for casts from duration to byte
|
||||
},
|
||||
{
|
||||
comparableLiteral: timestampListLiteral, param: "timestamp('2011-01-03T00:00:00.000+01:00')",
|
||||
expectedEstimatedCost: checker.CostEstimate{Min: 21, Max: 21}, expectedRuntimeCost: 21, // +6 for casts from timestamp to byte
|
||||
},
|
||||
|
||||
// index of operations are also defined for strings
|
||||
{
|
||||
comparableLiteral: stringLiteral, param: "'123'",
|
||||
expectedEstimatedCost: checker.CostEstimate{Min: 5, Max: 5}, expectedRuntimeCost: 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
for _, op := range tc.opts {
|
||||
for _, typ := range tc.costs {
|
||||
opWithParam := fmt.Sprintf(op, typ.param)
|
||||
t.Run(typ.comparableLiteral+opWithParam, func(t *testing.T) {
|
||||
e := typ.comparableLiteral + opWithParam
|
||||
testCost(t, e, typ.expectedEstimatedCost, typ.expectedRuntimeCost)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLsCost(t *testing.T) {
|
||||
cases := []struct {
|
||||
ops []string
|
||||
expectEsimatedCost checker.CostEstimate
|
||||
expectRuntimeCost uint64
|
||||
}{
|
||||
{
|
||||
ops: []string{".getScheme()", ".getHostname()", ".getHost()", ".getPort()", ".getEscapedPath()", ".getQuery()"},
|
||||
expectEsimatedCost: checker.CostEstimate{Min: 4, Max: 4},
|
||||
expectRuntimeCost: 4,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
for _, op := range tc.ops {
|
||||
t.Run("url."+op, func(t *testing.T) {
|
||||
testCost(t, "url('https:://kubernetes.io/')"+op, tc.expectEsimatedCost, tc.expectRuntimeCost)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringLibrary(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
expr string
|
||||
expectEsimatedCost checker.CostEstimate
|
||||
expectRuntimeCost uint64
|
||||
}{
|
||||
{
|
||||
name: "lowerAscii",
|
||||
expr: "'ABCDEFGHIJ abcdefghij'.lowerAscii()",
|
||||
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
|
||||
expectRuntimeCost: 3,
|
||||
},
|
||||
{
|
||||
name: "upperAscii",
|
||||
expr: "'ABCDEFGHIJ abcdefghij'.upperAscii()",
|
||||
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
|
||||
expectRuntimeCost: 3,
|
||||
},
|
||||
{
|
||||
name: "replace",
|
||||
expr: "'abc 123 def 123'.replace('123', '456')",
|
||||
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
|
||||
expectRuntimeCost: 3,
|
||||
},
|
||||
{
|
||||
name: "replace with limit",
|
||||
expr: "'abc 123 def 123'.replace('123', '456', 1)",
|
||||
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
|
||||
expectRuntimeCost: 3,
|
||||
},
|
||||
{
|
||||
name: "split",
|
||||
expr: "'abc 123 def 123'.split(' ')",
|
||||
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
|
||||
expectRuntimeCost: 3,
|
||||
},
|
||||
{
|
||||
name: "split with limit",
|
||||
expr: "'abc 123 def 123'.split(' ', 1)",
|
||||
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
|
||||
expectRuntimeCost: 3,
|
||||
},
|
||||
{
|
||||
name: "substring",
|
||||
expr: "'abc 123 def 123'.substring(5)",
|
||||
expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2},
|
||||
expectRuntimeCost: 2,
|
||||
},
|
||||
{
|
||||
name: "substring with end",
|
||||
expr: "'abc 123 def 123'.substring(5, 8)",
|
||||
expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2},
|
||||
expectRuntimeCost: 2,
|
||||
},
|
||||
{
|
||||
name: "trim",
|
||||
expr: "' abc 123 def 123 '.trim()",
|
||||
expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2},
|
||||
expectRuntimeCost: 2,
|
||||
},
|
||||
{
|
||||
name: "join with separator",
|
||||
expr: "['aa', 'bb', 'cc', 'd', 'e', 'f', 'g', 'h', 'i', 'j'].join(' ')",
|
||||
expectEsimatedCost: checker.CostEstimate{Min: 11, Max: 23},
|
||||
expectRuntimeCost: 15,
|
||||
},
|
||||
{
|
||||
name: "join",
|
||||
expr: "['aa', 'bb', 'cc', 'd', 'e', 'f', 'g', 'h', 'i', 'j'].join()",
|
||||
expectEsimatedCost: checker.CostEstimate{Min: 10, Max: 22},
|
||||
expectRuntimeCost: 13,
|
||||
},
|
||||
{
|
||||
name: "find",
|
||||
expr: "'abc 123 def 123'.find('123')",
|
||||
expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2},
|
||||
expectRuntimeCost: 2,
|
||||
},
|
||||
{
|
||||
name: "findAll",
|
||||
expr: "'abc 123 def 123'.findAll('123')",
|
||||
expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2},
|
||||
expectRuntimeCost: 2,
|
||||
},
|
||||
{
|
||||
name: "findAll with limit",
|
||||
expr: "'abc 123 def 123'.findAll('123', 1)",
|
||||
expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2},
|
||||
expectRuntimeCost: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
testCost(t, tc.expr, tc.expectEsimatedCost, tc.expectRuntimeCost)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testCost(t *testing.T, expr string, expectEsimatedCost checker.CostEstimate, expectRuntimeCost uint64) {
|
||||
est := &CostEstimator{SizeEstimator: &testCostEstimator{}}
|
||||
env, err := cel.NewEnv(append(k8sExtensionLibs, ext.Strings())...)
|
||||
if err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
}
|
||||
compiled, issues := env.Compile(expr)
|
||||
if len(issues.Errors()) > 0 {
|
||||
t.Fatalf("%v", issues.Errors())
|
||||
}
|
||||
estCost, err := env.EstimateCost(compiled, est)
|
||||
if err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
}
|
||||
if estCost.Min != expectEsimatedCost.Min || estCost.Max != expectEsimatedCost.Max {
|
||||
t.Errorf("Expected estimated cost of %d..%d but got %d..%d", expectEsimatedCost.Min, expectEsimatedCost.Max, estCost.Min, estCost.Max)
|
||||
}
|
||||
prog, err := env.Program(compiled, cel.CostTracking(est))
|
||||
if err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
}
|
||||
_, details, err := prog.Eval(map[string]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
}
|
||||
cost := details.ActualCost()
|
||||
if *cost != expectRuntimeCost {
|
||||
t.Errorf("Expected cost of %d but got %d", expectRuntimeCost, *cost)
|
||||
}
|
||||
}
|
||||
|
||||
type testCostEstimator struct {
|
||||
}
|
||||
|
||||
func (t *testCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate {
|
||||
switch t := element.Type().TypeKind.(type) {
|
||||
case *expr.Type_Primitive:
|
||||
switch t.Primitive {
|
||||
case expr.Type_STRING:
|
||||
return &checker.SizeEstimate{Min: 0, Max: 12}
|
||||
case expr.Type_BYTES:
|
||||
return &checker.SizeEstimate{Min: 0, Max: 12}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *testCostEstimator) EstimateCallCost(function, overloadId string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
/*
|
||||
Copyright 2022 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package library
|
||||
|
||||
import (
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/google/cel-go/ext"
|
||||
"github.com/google/cel-go/interpreter"
|
||||
)
|
||||
|
||||
// ExtensionLibs declares the set of CEL extension libraries available everywhere CEL is used in Kubernetes.
|
||||
var ExtensionLibs = append(k8sExtensionLibs, ext.Strings())
|
||||
|
||||
var k8sExtensionLibs = []cel.EnvOption{
|
||||
URLs(),
|
||||
Regex(),
|
||||
Lists(),
|
||||
}
|
||||
|
||||
var ExtensionLibRegexOptimizations = []*interpreter.RegexOptimization{FindRegexOptimization, FindAllRegexOptimization}
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
/*
|
||||
Copyright 2022 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package library
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/cel-go/cel"
|
||||
)
|
||||
|
||||
func TestLibraryCompatibility(t *testing.T) {
|
||||
functionNames := map[string]struct{}{}
|
||||
|
||||
decls := map[cel.Library]map[string][]cel.FunctionOpt{
|
||||
urlsLib: urlLibraryDecls,
|
||||
listsLib: listsLibraryDecls,
|
||||
regexLib: regexLibraryDecls,
|
||||
}
|
||||
if len(k8sExtensionLibs) != len(decls) {
|
||||
t.Errorf("Expected the same number of libraries in the ExtensionLibs as are tested for compatibility")
|
||||
}
|
||||
for _, decl := range decls {
|
||||
for name := range decl {
|
||||
functionNames[name] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// WARN: All library changes must follow
|
||||
// https://github.com/kubernetes/enhancements/tree/master/keps/sig-api-machinery/2876-crd-validation-expression-language#function-library-updates
|
||||
// and must track the functions here along with which Kubernetes version introduced them.
|
||||
knownFunctions := []string{
|
||||
// Kubernetes 1.24:
|
||||
"isSorted", "sum", "max", "min", "indexOf", "lastIndexOf", "find", "findAll", "url", "getScheme", "getHost", "getHostname",
|
||||
"getPort", "getEscapedPath", "getQuery", "isURL",
|
||||
// Kubernetes <1.??>:
|
||||
}
|
||||
for _, fn := range knownFunctions {
|
||||
delete(functionNames, fn)
|
||||
}
|
||||
|
||||
if len(functionNames) != 0 {
|
||||
t.Errorf("Expected all functions in the libraries to be assigned to a kubernetes release, but found the unassigned function names: %v", functionNames)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,312 @@
|
|||
/*
|
||||
Copyright 2022 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package library
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/google/cel-go/common/types"
|
||||
"github.com/google/cel-go/common/types/ref"
|
||||
"github.com/google/cel-go/common/types/traits"
|
||||
"github.com/google/cel-go/interpreter/functions"
|
||||
)
|
||||
|
||||
// Lists provides a CEL function library extension of list utility functions.
|
||||
//
|
||||
// isSorted
|
||||
//
|
||||
// Returns true if the provided list of comparable elements is sorted, else returns false.
|
||||
//
|
||||
// <list<T>>.isSorted() <bool>, T must be a comparable type
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// [1, 2, 3].isSorted() // return true
|
||||
// ['a', 'b', 'b', 'c'].isSorted() // return true
|
||||
// [2.0, 1.0].isSorted() // return false
|
||||
// [1].isSorted() // return true
|
||||
// [].isSorted() // return true
|
||||
//
|
||||
// sum
|
||||
//
|
||||
// Returns the sum of the elements of the provided list. Supports CEL number (int, uint, double) and duration types.
|
||||
//
|
||||
// <list<T>>.sum() <T>, T must be a numeric type or a duration
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// [1, 3].sum() // returns 4
|
||||
// [1.0, 3.0].sum() // returns 4.0
|
||||
// ['1m', '1s'].sum() // returns '1m1s'
|
||||
// emptyIntList.sum() // returns 0
|
||||
// emptyDoubleList.sum() // returns 0.0
|
||||
// [].sum() // returns 0
|
||||
//
|
||||
// min / max
|
||||
//
|
||||
// Returns the minimum/maximum valued element of the provided list. Supports all comparable types.
|
||||
// If the list is empty, an error is returned.
|
||||
//
|
||||
// <list<T>>.min() <T>, T must be a comparable type
|
||||
// <list<T>>.max() <T>, T must be a comparable type
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// [1, 3].min() // returns 1
|
||||
// [1, 3].max() // returns 3
|
||||
// [].min() // error
|
||||
// [1].min() // returns 1
|
||||
// ([0] + emptyList).min() // returns 0
|
||||
//
|
||||
// indexOf / lastIndexOf
|
||||
//
|
||||
// Returns either the first or last positional index of the provided element in the list.
|
||||
// If the element is not found, -1 is returned. Supports all equatable types.
|
||||
//
|
||||
// <list<T>>.indexOf(<T>) <int>, T must be an equatable type
|
||||
// <list<T>>.lastIndexOf(<T>) <int>, T must be an equatable type
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// [1, 2, 2, 3].indexOf(2) // returns 1
|
||||
// ['a', 'b', 'b', 'c'].lastIndexOf('b') // returns 2
|
||||
// [1.0].indexOf(1.1) // returns -1
|
||||
// [].indexOf('string') // returns -1
|
||||
func Lists() cel.EnvOption {
|
||||
return cel.Lib(listsLib)
|
||||
}
|
||||
|
||||
var listsLib = &lists{}
|
||||
|
||||
type lists struct{}
|
||||
|
||||
var paramA = cel.TypeParamType("A")
|
||||
|
||||
// CEL typeParams can be used to constraint to a specific trait (e.g. traits.ComparableType) if the 1st operand is the type to constrain.
|
||||
// But the functions we need to constrain are <list<paramType>>, not just <paramType>.
|
||||
// Make sure the order of overload set is deterministic
|
||||
type namedCELType struct {
|
||||
typeName string
|
||||
celType *cel.Type
|
||||
}
|
||||
|
||||
var summableTypes = []namedCELType{
|
||||
{typeName: "int", celType: cel.IntType},
|
||||
{typeName: "uint", celType: cel.UintType},
|
||||
{typeName: "double", celType: cel.DoubleType},
|
||||
{typeName: "duration", celType: cel.DurationType},
|
||||
}
|
||||
|
||||
var zeroValuesOfSummableTypes = map[string]ref.Val{
|
||||
"int": types.Int(0),
|
||||
"uint": types.Uint(0),
|
||||
"double": types.Double(0.0),
|
||||
"duration": types.Duration{Duration: 0},
|
||||
}
|
||||
var comparableTypes = []namedCELType{
|
||||
{typeName: "int", celType: cel.IntType},
|
||||
{typeName: "uint", celType: cel.UintType},
|
||||
{typeName: "double", celType: cel.DoubleType},
|
||||
{typeName: "bool", celType: cel.BoolType},
|
||||
{typeName: "duration", celType: cel.DurationType},
|
||||
{typeName: "timestamp", celType: cel.TimestampType},
|
||||
{typeName: "string", celType: cel.StringType},
|
||||
{typeName: "bytes", celType: cel.BytesType},
|
||||
}
|
||||
|
||||
// WARNING: All library additions or modifications must follow
|
||||
// https://github.com/kubernetes/enhancements/tree/master/keps/sig-api-machinery/2876-crd-validation-expression-language#function-library-updates
|
||||
var listsLibraryDecls = map[string][]cel.FunctionOpt{
|
||||
"isSorted": templatedOverloads(comparableTypes, func(name string, paramType *cel.Type) cel.FunctionOpt {
|
||||
return cel.MemberOverload(fmt.Sprintf("list_%s_is_sorted_bool", name),
|
||||
[]*cel.Type{cel.ListType(paramType)}, cel.BoolType, cel.UnaryBinding(isSorted))
|
||||
}),
|
||||
"sum": templatedOverloads(summableTypes, func(name string, paramType *cel.Type) cel.FunctionOpt {
|
||||
return cel.MemberOverload(fmt.Sprintf("list_%s_sum_%s", name, name),
|
||||
[]*cel.Type{cel.ListType(paramType)}, paramType, cel.UnaryBinding(func(list ref.Val) ref.Val {
|
||||
return sum(
|
||||
func() ref.Val {
|
||||
return zeroValuesOfSummableTypes[name]
|
||||
})(list)
|
||||
}))
|
||||
}),
|
||||
"max": templatedOverloads(comparableTypes, func(name string, paramType *cel.Type) cel.FunctionOpt {
|
||||
return cel.MemberOverload(fmt.Sprintf("list_%s_max_%s", name, name),
|
||||
[]*cel.Type{cel.ListType(paramType)}, paramType, cel.UnaryBinding(max()))
|
||||
}),
|
||||
"min": templatedOverloads(comparableTypes, func(name string, paramType *cel.Type) cel.FunctionOpt {
|
||||
return cel.MemberOverload(fmt.Sprintf("list_%s_min_%s", name, name),
|
||||
[]*cel.Type{cel.ListType(paramType)}, paramType, cel.UnaryBinding(min()))
|
||||
}),
|
||||
"indexOf": {
|
||||
cel.MemberOverload("list_a_index_of_int", []*cel.Type{cel.ListType(paramA), paramA}, cel.IntType,
|
||||
cel.BinaryBinding(indexOf)),
|
||||
},
|
||||
"lastIndexOf": {
|
||||
cel.MemberOverload("list_a_last_index_of_int", []*cel.Type{cel.ListType(paramA), paramA}, cel.IntType,
|
||||
cel.BinaryBinding(lastIndexOf)),
|
||||
},
|
||||
}
|
||||
|
||||
func (*lists) CompileOptions() []cel.EnvOption {
|
||||
options := []cel.EnvOption{}
|
||||
for name, overloads := range listsLibraryDecls {
|
||||
options = append(options, cel.Function(name, overloads...))
|
||||
}
|
||||
return options
|
||||
}
|
||||
|
||||
func (*lists) ProgramOptions() []cel.ProgramOption {
|
||||
return []cel.ProgramOption{}
|
||||
}
|
||||
|
||||
func isSorted(val ref.Val) ref.Val {
|
||||
var prev traits.Comparer
|
||||
iterable, ok := val.(traits.Iterable)
|
||||
if !ok {
|
||||
return types.MaybeNoSuchOverloadErr(val)
|
||||
}
|
||||
for it := iterable.Iterator(); it.HasNext() == types.True; {
|
||||
next := it.Next()
|
||||
nextCmp, ok := next.(traits.Comparer)
|
||||
if !ok {
|
||||
return types.MaybeNoSuchOverloadErr(next)
|
||||
}
|
||||
if prev != nil {
|
||||
cmp := prev.Compare(next)
|
||||
if cmp == types.IntOne {
|
||||
return types.False
|
||||
}
|
||||
}
|
||||
prev = nextCmp
|
||||
}
|
||||
return types.True
|
||||
}
|
||||
|
||||
func sum(init func() ref.Val) functions.UnaryOp {
|
||||
return func(val ref.Val) ref.Val {
|
||||
i := init()
|
||||
acc, ok := i.(traits.Adder)
|
||||
if !ok {
|
||||
// Should never happen since all passed in init values are valid
|
||||
return types.MaybeNoSuchOverloadErr(i)
|
||||
}
|
||||
iterable, ok := val.(traits.Iterable)
|
||||
if !ok {
|
||||
return types.MaybeNoSuchOverloadErr(val)
|
||||
}
|
||||
for it := iterable.Iterator(); it.HasNext() == types.True; {
|
||||
next := it.Next()
|
||||
nextAdder, ok := next.(traits.Adder)
|
||||
if !ok {
|
||||
// Should never happen for type checked CEL programs
|
||||
return types.MaybeNoSuchOverloadErr(next)
|
||||
}
|
||||
if acc != nil {
|
||||
s := acc.Add(next)
|
||||
sum, ok := s.(traits.Adder)
|
||||
if !ok {
|
||||
// Should never happen for type checked CEL programs
|
||||
return types.MaybeNoSuchOverloadErr(s)
|
||||
}
|
||||
acc = sum
|
||||
} else {
|
||||
acc = nextAdder
|
||||
}
|
||||
}
|
||||
return acc.(ref.Val)
|
||||
}
|
||||
}
|
||||
|
||||
func min() functions.UnaryOp {
|
||||
return cmp("min", types.IntOne)
|
||||
}
|
||||
|
||||
func max() functions.UnaryOp {
|
||||
return cmp("max", types.IntNegOne)
|
||||
}
|
||||
|
||||
func cmp(opName string, opPreferCmpResult ref.Val) functions.UnaryOp {
|
||||
return func(val ref.Val) ref.Val {
|
||||
var result traits.Comparer
|
||||
iterable, ok := val.(traits.Iterable)
|
||||
if !ok {
|
||||
return types.MaybeNoSuchOverloadErr(val)
|
||||
}
|
||||
for it := iterable.Iterator(); it.HasNext() == types.True; {
|
||||
next := it.Next()
|
||||
nextCmp, ok := next.(traits.Comparer)
|
||||
if !ok {
|
||||
// Should never happen for type checked CEL programs
|
||||
return types.MaybeNoSuchOverloadErr(next)
|
||||
}
|
||||
if result == nil {
|
||||
result = nextCmp
|
||||
} else {
|
||||
cmp := result.Compare(next)
|
||||
if cmp == opPreferCmpResult {
|
||||
result = nextCmp
|
||||
}
|
||||
}
|
||||
}
|
||||
if result == nil {
|
||||
return types.NewErr("%s called on empty list", opName)
|
||||
}
|
||||
return result.(ref.Val)
|
||||
}
|
||||
}
|
||||
|
||||
func indexOf(list ref.Val, item ref.Val) ref.Val {
|
||||
lister, ok := list.(traits.Lister)
|
||||
if !ok {
|
||||
return types.MaybeNoSuchOverloadErr(list)
|
||||
}
|
||||
sz := lister.Size().(types.Int)
|
||||
for i := types.Int(0); i < sz; i++ {
|
||||
if lister.Get(types.Int(i)).Equal(item) == types.True {
|
||||
return types.Int(i)
|
||||
}
|
||||
}
|
||||
return types.Int(-1)
|
||||
}
|
||||
|
||||
func lastIndexOf(list ref.Val, item ref.Val) ref.Val {
|
||||
lister, ok := list.(traits.Lister)
|
||||
if !ok {
|
||||
return types.MaybeNoSuchOverloadErr(list)
|
||||
}
|
||||
sz := lister.Size().(types.Int)
|
||||
for i := sz - 1; i >= 0; i-- {
|
||||
if lister.Get(types.Int(i)).Equal(item) == types.True {
|
||||
return types.Int(i)
|
||||
}
|
||||
}
|
||||
return types.Int(-1)
|
||||
}
|
||||
|
||||
// templatedOverloads returns overloads for each of the provided types. The template function is called with each type
|
||||
// name (map key) and type to construct the overloads.
|
||||
func templatedOverloads(types []namedCELType, template func(name string, t *cel.Type) cel.FunctionOpt) []cel.FunctionOpt {
|
||||
overloads := make([]cel.FunctionOpt, len(types))
|
||||
i := 0
|
||||
for _, t := range types {
|
||||
overloads[i] = template(t.typeName, t.celType)
|
||||
i++
|
||||
}
|
||||
return overloads
|
||||
}
|
||||
|
|
@ -0,0 +1,187 @@
|
|||
/*
|
||||
Copyright 2022 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package library
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/google/cel-go/common/types"
|
||||
"github.com/google/cel-go/common/types/ref"
|
||||
"github.com/google/cel-go/interpreter"
|
||||
)
|
||||
|
||||
// Regex provides a CEL function library extension of regex utility functions.
|
||||
//
|
||||
// find / findAll
|
||||
//
|
||||
// Returns substrings that match the provided regular expression. find returns the first match. findAll may optionally
|
||||
// be provided a limit. If the limit is set and >= 0, no more than the limit number of matches are returned.
|
||||
//
|
||||
// <string>.find(<string>) <string>
|
||||
// <string>.findAll(<string>) <list <string>>
|
||||
// <string>.findAll(<string>, <int>) <list <string>>
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// "abc 123".find('[0-9]*') // returns '123'
|
||||
// "abc 123".find('xyz') // returns ''
|
||||
// "123 abc 456".findAll('[0-9]*') // returns ['123', '456']
|
||||
// "123 abc 456".findAll('[0-9]*', 1) // returns ['123']
|
||||
// "123 abc 456".findAll('xyz') // returns []
|
||||
func Regex() cel.EnvOption {
|
||||
return cel.Lib(regexLib)
|
||||
}
|
||||
|
||||
var regexLib = ®ex{}
|
||||
|
||||
type regex struct{}
|
||||
|
||||
var regexLibraryDecls = map[string][]cel.FunctionOpt{
|
||||
"find": {
|
||||
cel.MemberOverload("string_find_string", []*cel.Type{cel.StringType, cel.StringType}, cel.StringType,
|
||||
cel.BinaryBinding(find))},
|
||||
"findAll": {
|
||||
cel.MemberOverload("string_find_all_string", []*cel.Type{cel.StringType, cel.StringType},
|
||||
cel.ListType(cel.StringType),
|
||||
cel.BinaryBinding(func(str, regex ref.Val) ref.Val {
|
||||
return findAll(str, regex, types.Int(-1))
|
||||
})),
|
||||
cel.MemberOverload("string_find_all_string_int",
|
||||
[]*cel.Type{cel.StringType, cel.StringType, cel.IntType},
|
||||
cel.ListType(cel.StringType),
|
||||
cel.FunctionBinding(findAll)),
|
||||
},
|
||||
}
|
||||
|
||||
func (*regex) CompileOptions() []cel.EnvOption {
|
||||
options := []cel.EnvOption{}
|
||||
for name, overloads := range regexLibraryDecls {
|
||||
options = append(options, cel.Function(name, overloads...))
|
||||
}
|
||||
return options
|
||||
}
|
||||
|
||||
func (*regex) ProgramOptions() []cel.ProgramOption {
|
||||
return []cel.ProgramOption{}
|
||||
}
|
||||
|
||||
func find(strVal ref.Val, regexVal ref.Val) ref.Val {
|
||||
str, ok := strVal.Value().(string)
|
||||
if !ok {
|
||||
return types.MaybeNoSuchOverloadErr(strVal)
|
||||
}
|
||||
regex, ok := regexVal.Value().(string)
|
||||
if !ok {
|
||||
return types.MaybeNoSuchOverloadErr(regexVal)
|
||||
}
|
||||
re, err := regexp.Compile(regex)
|
||||
if err != nil {
|
||||
return types.NewErr("Illegal regex: %v", err.Error())
|
||||
}
|
||||
result := re.FindString(str)
|
||||
return types.String(result)
|
||||
}
|
||||
|
||||
func findAll(args ...ref.Val) ref.Val {
|
||||
argn := len(args)
|
||||
if argn < 2 || argn > 3 {
|
||||
return types.NoSuchOverloadErr()
|
||||
}
|
||||
str, ok := args[0].Value().(string)
|
||||
if !ok {
|
||||
return types.MaybeNoSuchOverloadErr(args[0])
|
||||
}
|
||||
regex, ok := args[1].Value().(string)
|
||||
if !ok {
|
||||
return types.MaybeNoSuchOverloadErr(args[1])
|
||||
}
|
||||
n := int64(-1)
|
||||
if argn == 3 {
|
||||
n, ok = args[2].Value().(int64)
|
||||
if !ok {
|
||||
return types.MaybeNoSuchOverloadErr(args[2])
|
||||
}
|
||||
}
|
||||
|
||||
re, err := regexp.Compile(regex)
|
||||
if err != nil {
|
||||
return types.NewErr("Illegal regex: %v", err.Error())
|
||||
}
|
||||
|
||||
result := re.FindAllString(str, int(n))
|
||||
|
||||
return types.NewStringList(types.DefaultTypeAdapter, result)
|
||||
}
|
||||
|
||||
// FindRegexOptimization optimizes the 'find' function by compiling the regex pattern and
|
||||
// reporting any compilation errors at program creation time, and using the compiled regex pattern for all function
|
||||
// call invocations.
|
||||
var FindRegexOptimization = &interpreter.RegexOptimization{
|
||||
Function: "find",
|
||||
RegexIndex: 1,
|
||||
Factory: func(call interpreter.InterpretableCall, regexPattern string) (interpreter.InterpretableCall, error) {
|
||||
compiledRegex, err := regexp.Compile(regexPattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return interpreter.NewCall(call.ID(), call.Function(), call.OverloadID(), call.Args(), func(args ...ref.Val) ref.Val {
|
||||
if len(args) != 2 {
|
||||
return types.NoSuchOverloadErr()
|
||||
}
|
||||
in, ok := args[0].Value().(string)
|
||||
if !ok {
|
||||
return types.MaybeNoSuchOverloadErr(args[0])
|
||||
}
|
||||
return types.String(compiledRegex.FindString(in))
|
||||
}), nil
|
||||
},
|
||||
}
|
||||
|
||||
// FindAllRegexOptimization optimizes the 'findAll' function by compiling the regex pattern and
|
||||
// reporting any compilation errors at program creation time, and using the compiled regex pattern for all function
|
||||
// call invocations.
|
||||
var FindAllRegexOptimization = &interpreter.RegexOptimization{
|
||||
Function: "findAll",
|
||||
RegexIndex: 1,
|
||||
Factory: func(call interpreter.InterpretableCall, regexPattern string) (interpreter.InterpretableCall, error) {
|
||||
compiledRegex, err := regexp.Compile(regexPattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return interpreter.NewCall(call.ID(), call.Function(), call.OverloadID(), call.Args(), func(args ...ref.Val) ref.Val {
|
||||
argn := len(args)
|
||||
if argn < 2 || argn > 3 {
|
||||
return types.NoSuchOverloadErr()
|
||||
}
|
||||
str, ok := args[0].Value().(string)
|
||||
if !ok {
|
||||
return types.MaybeNoSuchOverloadErr(args[0])
|
||||
}
|
||||
n := int64(-1)
|
||||
if argn == 3 {
|
||||
n, ok = args[2].Value().(int64)
|
||||
if !ok {
|
||||
return types.MaybeNoSuchOverloadErr(args[2])
|
||||
}
|
||||
}
|
||||
|
||||
result := compiledRegex.FindAllString(str, int(n))
|
||||
return types.NewStringList(types.DefaultTypeAdapter, result)
|
||||
}), nil
|
||||
},
|
||||
}
|
||||
|
|
@ -0,0 +1,236 @@
|
|||
/*
|
||||
Copyright 2022 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package library
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/google/cel-go/common/types"
|
||||
"github.com/google/cel-go/common/types/ref"
|
||||
|
||||
apiservercel "k8s.io/apiserver/pkg/cel"
|
||||
)
|
||||
|
||||
// URLs provides a CEL function library extension of URL parsing functions.
|
||||
//
|
||||
// url
|
||||
//
|
||||
// Converts a string to a URL or results in an error if the string is not a valid URL. The URL must be an absolute URI
|
||||
// or an absolute path.
|
||||
//
|
||||
// url(<string>) <URL>
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// url('https://user:pass@example.com:80/path?query=val#fragment') // returns a URL
|
||||
// url('/absolute-path') // returns a URL
|
||||
// url('https://a:b:c/') // error
|
||||
// url('../relative-path') // error
|
||||
//
|
||||
// isURL
|
||||
//
|
||||
// Returns true if a string is a valid URL. The URL must be an absolute URI or an absolute path.
|
||||
//
|
||||
// isURL( <string>) <bool>
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// isURL('https://user:pass@example.com:80/path?query=val#fragment') // returns true
|
||||
// isURL('/absolute-path') // returns true
|
||||
// isURL('https://a:b:c/') // returns false
|
||||
// isURL('../relative-path') // returns false
|
||||
//
|
||||
// getScheme / getHost / getHostname / getPort / getEscapedPath / getQuery
|
||||
//
|
||||
// Return the parsed components of a URL.
|
||||
//
|
||||
// - getScheme: If absent in the URL, returns an empty string.
|
||||
//
|
||||
// - getHostname: IPv6 addresses are returned with braces, e.g. "[::1]". If absent in the URL, returns an empty string.
|
||||
//
|
||||
// - getHost: IPv6 addresses are returned without braces, e.g. "::1". If absent in the URL, returns an empty string.
|
||||
//
|
||||
// - getEscapedPath: The string returned by getEscapedPath is URL escaped, e.g. "with space" becomes "with%20space".
|
||||
// If absent in the URL, returns an empty string.
|
||||
//
|
||||
// - getPort: If absent in the URL, returns an empty string.
|
||||
//
|
||||
// - getQuery: Returns the query parameters in "matrix" form where a repeated query key is interpreted to
|
||||
// mean that there are multiple values for that key. The keys and values are returned unescaped.
|
||||
// If absent in the URL, returns an empty map.
|
||||
//
|
||||
// <URL>.getScheme() <string>
|
||||
// <URL>.getHost() <string>
|
||||
// <URL>.getHostname() <string>
|
||||
// <URL>.getPort() <string>
|
||||
// <URL>.getEscapedPath() <string>
|
||||
// <URL>.getQuery() <map <string>, <list <string>>
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// url('/path').getScheme() // returns ''
|
||||
// url('https://example.com/').getScheme() // returns 'https'
|
||||
// url('https://example.com:80/').getHost() // returns 'example.com:80'
|
||||
// url('https://example.com/').getHost() // returns 'example.com'
|
||||
// url('https://[::1]:80/').getHost() // returns '[::1]:80'
|
||||
// url('https://[::1]/').getHost() // returns '[::1]'
|
||||
// url('/path').getHost() // returns ''
|
||||
// url('https://example.com:80/').getHostname() // returns 'example.com'
|
||||
// url('https://127.0.0.1:80/').getHostname() // returns '127.0.0.1'
|
||||
// url('https://[::1]:80/').getHostname() // returns '::1'
|
||||
// url('/path').getHostname() // returns ''
|
||||
// url('https://example.com:80/').getPort() // returns '80'
|
||||
// url('https://example.com/').getPort() // returns ''
|
||||
// url('/path').getPort() // returns ''
|
||||
// url('https://example.com/path').getEscapedPath() // returns '/path'
|
||||
// url('https://example.com/path with spaces/').getEscapedPath() // returns '/path%20with%20spaces/'
|
||||
// url('https://example.com').getEscapedPath() // returns ''
|
||||
// url('https://example.com/path?k1=a&k2=b&k2=c').getQuery() // returns { 'k1': ['a'], 'k2': ['b', 'c']}
|
||||
// url('https://example.com/path?key with spaces=value with spaces').getQuery() // returns { 'key with spaces': ['value with spaces']}
|
||||
// url('https://example.com/path?').getQuery() // returns {}
|
||||
// url('https://example.com/path').getQuery() // returns {}
|
||||
func URLs() cel.EnvOption {
|
||||
return cel.Lib(urlsLib)
|
||||
}
|
||||
|
||||
var urlsLib = &urls{}
|
||||
|
||||
type urls struct{}
|
||||
|
||||
var urlLibraryDecls = map[string][]cel.FunctionOpt{
|
||||
"url": {
|
||||
cel.Overload("string_to_url", []*cel.Type{cel.StringType}, apiservercel.URLType,
|
||||
cel.UnaryBinding(stringToUrl))},
|
||||
"getScheme": {
|
||||
cel.MemberOverload("url_get_scheme", []*cel.Type{apiservercel.URLType}, cel.StringType,
|
||||
cel.UnaryBinding(getScheme))},
|
||||
"getHost": {
|
||||
cel.MemberOverload("url_get_host", []*cel.Type{apiservercel.URLType}, cel.StringType,
|
||||
cel.UnaryBinding(getHost))},
|
||||
"getHostname": {
|
||||
cel.MemberOverload("url_get_hostname", []*cel.Type{apiservercel.URLType}, cel.StringType,
|
||||
cel.UnaryBinding(getHostname))},
|
||||
"getPort": {
|
||||
cel.MemberOverload("url_get_port", []*cel.Type{apiservercel.URLType}, cel.StringType,
|
||||
cel.UnaryBinding(getPort))},
|
||||
"getEscapedPath": {
|
||||
cel.MemberOverload("url_get_escaped_path", []*cel.Type{apiservercel.URLType}, cel.StringType,
|
||||
cel.UnaryBinding(getEscapedPath))},
|
||||
"getQuery": {
|
||||
cel.MemberOverload("url_get_query", []*cel.Type{apiservercel.URLType},
|
||||
cel.MapType(cel.StringType, cel.ListType(cel.StringType)),
|
||||
cel.UnaryBinding(getQuery))},
|
||||
"isURL": {
|
||||
cel.Overload("is_url_string", []*cel.Type{cel.StringType}, cel.BoolType,
|
||||
cel.UnaryBinding(isURL))},
|
||||
}
|
||||
|
||||
func (*urls) CompileOptions() []cel.EnvOption {
|
||||
options := []cel.EnvOption{}
|
||||
for name, overloads := range urlLibraryDecls {
|
||||
options = append(options, cel.Function(name, overloads...))
|
||||
}
|
||||
return options
|
||||
}
|
||||
|
||||
func (*urls) ProgramOptions() []cel.ProgramOption {
|
||||
return []cel.ProgramOption{}
|
||||
}
|
||||
|
||||
func stringToUrl(arg ref.Val) ref.Val {
|
||||
s, ok := arg.Value().(string)
|
||||
if !ok {
|
||||
return types.MaybeNoSuchOverloadErr(arg)
|
||||
}
|
||||
// Use ParseRequestURI to check the URL before conversion.
|
||||
// ParseRequestURI requires absolute URLs and is used by the OpenAPIv3 'uri' format.
|
||||
_, err := url.ParseRequestURI(s)
|
||||
if err != nil {
|
||||
return types.NewErr("URL parse error during conversion from string: %v", err)
|
||||
}
|
||||
// We must parse again with Parse since ParseRequestURI incorrectly parses URLs that contain a fragment
|
||||
// part and will incorrectly append the fragment to either the path or the query, depending on which it was adjacent to.
|
||||
u, err := url.Parse(s)
|
||||
if err != nil {
|
||||
// Errors are not expected here since Parse is a more lenient parser than ParseRequestURI.
|
||||
return types.NewErr("URL parse error during conversion from string: %v", err)
|
||||
}
|
||||
return apiservercel.URL{URL: u}
|
||||
}
|
||||
|
||||
func getScheme(arg ref.Val) ref.Val {
|
||||
u, ok := arg.Value().(*url.URL)
|
||||
if !ok {
|
||||
return types.MaybeNoSuchOverloadErr(arg)
|
||||
}
|
||||
return types.String(u.Scheme)
|
||||
}
|
||||
|
||||
func getHost(arg ref.Val) ref.Val {
|
||||
u, ok := arg.Value().(*url.URL)
|
||||
if !ok {
|
||||
return types.MaybeNoSuchOverloadErr(arg)
|
||||
}
|
||||
return types.String(u.Host)
|
||||
}
|
||||
|
||||
func getHostname(arg ref.Val) ref.Val {
|
||||
u, ok := arg.Value().(*url.URL)
|
||||
if !ok {
|
||||
return types.MaybeNoSuchOverloadErr(arg)
|
||||
}
|
||||
return types.String(u.Hostname())
|
||||
}
|
||||
|
||||
func getPort(arg ref.Val) ref.Val {
|
||||
u, ok := arg.Value().(*url.URL)
|
||||
if !ok {
|
||||
return types.MaybeNoSuchOverloadErr(arg)
|
||||
}
|
||||
return types.String(u.Port())
|
||||
}
|
||||
|
||||
func getEscapedPath(arg ref.Val) ref.Val {
|
||||
u, ok := arg.Value().(*url.URL)
|
||||
if !ok {
|
||||
return types.MaybeNoSuchOverloadErr(arg)
|
||||
}
|
||||
return types.String(u.EscapedPath())
|
||||
}
|
||||
|
||||
func getQuery(arg ref.Val) ref.Val {
|
||||
u, ok := arg.Value().(*url.URL)
|
||||
if !ok {
|
||||
return types.MaybeNoSuchOverloadErr(arg)
|
||||
}
|
||||
|
||||
result := map[ref.Val]ref.Val{}
|
||||
for k, v := range u.Query() {
|
||||
result[types.String(k)] = types.NewStringList(types.DefaultTypeAdapter, v)
|
||||
}
|
||||
return types.NewRefValMap(types.DefaultTypeAdapter, result)
|
||||
}
|
||||
|
||||
func isURL(arg ref.Val) ref.Val {
|
||||
s, ok := arg.Value().(string)
|
||||
if !ok {
|
||||
return types.MaybeNoSuchOverloadErr(arg)
|
||||
}
|
||||
_, err := url.ParseRequestURI(s)
|
||||
return types.Bool(err == nil)
|
||||
}
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
/*
|
||||
Copyright 2022 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package cel
|
||||
|
||||
const (
|
||||
// DefaultMaxRequestSizeBytes is the size of the largest request that will be accepted
|
||||
DefaultMaxRequestSizeBytes = int64(3 * 1024 * 1024)
|
||||
|
||||
// MaxDurationSizeJSON
|
||||
// OpenAPI duration strings follow RFC 3339, section 5.6 - see the comment on maxDatetimeSizeJSON
|
||||
MaxDurationSizeJSON = 32
|
||||
// MaxDatetimeSizeJSON
|
||||
// OpenAPI datetime strings follow RFC 3339, section 5.6, and the longest possible
|
||||
// such string is 9999-12-31T23:59:59.999999999Z, which has length 30 - we add 2
|
||||
// to allow for quotation marks
|
||||
MaxDatetimeSizeJSON = 32
|
||||
// MinDurationSizeJSON
|
||||
// Golang allows a string of 0 to be parsed as a duration, so that plus 2 to account for
|
||||
// quotation marks makes 3
|
||||
MinDurationSizeJSON = 3
|
||||
// JSONDateSize is the size of a date serialized as part of a JSON object
|
||||
// RFC 3339 dates require YYYY-MM-DD, and then we add 2 to allow for quotation marks
|
||||
JSONDateSize = 12
|
||||
// MinDatetimeSizeJSON is the minimal length of a datetime formatted as RFC 3339
|
||||
// RFC 3339 datetimes require a full date (YYYY-MM-DD) and full time (HH:MM:SS), and we add 3 for
|
||||
// quotation marks like always in addition to the capital T that separates the date and time
|
||||
MinDatetimeSizeJSON = 21
|
||||
// MinStringSize is the size of literal ""
|
||||
MinStringSize = 2
|
||||
// MinBoolSize is the length of literal true
|
||||
MinBoolSize = 4
|
||||
// MinNumberSize is the length of literal 0
|
||||
MinNumberSize = 1
|
||||
)
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
/*
|
||||
Copyright 2022 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"k8s.io/component-base/metrics"
|
||||
"k8s.io/component-base/metrics/legacyregistry"
|
||||
)
|
||||
|
||||
// TODO(jiahuif) CEL is to be used in multiple components, revise naming when that happens.
|
||||
const (
|
||||
namespace = "apiserver"
|
||||
subsystem = "cel"
|
||||
)
|
||||
|
||||
// Metrics provides access to CEL metrics.
|
||||
var Metrics = newCelMetrics()
|
||||
|
||||
type CelMetrics struct {
|
||||
compilationTime *metrics.Histogram
|
||||
evaluationTime *metrics.Histogram
|
||||
}
|
||||
|
||||
func newCelMetrics() *CelMetrics {
|
||||
m := &CelMetrics{
|
||||
compilationTime: metrics.NewHistogram(&metrics.HistogramOpts{
|
||||
Namespace: namespace,
|
||||
Subsystem: subsystem,
|
||||
Name: "compilation_duration_seconds",
|
||||
StabilityLevel: metrics.ALPHA,
|
||||
}),
|
||||
evaluationTime: metrics.NewHistogram(&metrics.HistogramOpts{
|
||||
Namespace: namespace,
|
||||
Subsystem: subsystem,
|
||||
Name: "evaluation_duration_seconds",
|
||||
StabilityLevel: metrics.ALPHA,
|
||||
}),
|
||||
}
|
||||
|
||||
legacyregistry.MustRegister(m.compilationTime)
|
||||
legacyregistry.MustRegister(m.evaluationTime)
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// ObserveCompilation records a CEL compilation with the time the compilation took.
|
||||
func (m *CelMetrics) ObserveCompilation(elapsed time.Duration) {
|
||||
seconds := elapsed.Seconds()
|
||||
m.compilationTime.Observe(seconds)
|
||||
}
|
||||
|
||||
// ObserveEvaluation records a CEL evaluation with the time the evaluation took.
|
||||
func (m *CelMetrics) ObserveEvaluation(elapsed time.Duration) {
|
||||
seconds := elapsed.Seconds()
|
||||
m.evaluationTime.Observe(seconds)
|
||||
}
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
/*
|
||||
Copyright 2022 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"k8s.io/component-base/metrics/legacyregistry"
|
||||
)
|
||||
|
||||
func TestObserveCompilation(t *testing.T) {
|
||||
defer legacyregistry.Reset()
|
||||
Metrics.ObserveCompilation(2 * time.Second)
|
||||
c, s := gatherHistogram(t, "apiserver_cel_compilation_duration_seconds")
|
||||
if c != 1 {
|
||||
t.Errorf("unexpected count: %v", c)
|
||||
}
|
||||
if math.Abs(s-2.0) > 1e-7 {
|
||||
t.Fatalf("incorrect sum: %v", s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestObserveEvaluation(t *testing.T) {
|
||||
defer legacyregistry.Reset()
|
||||
Metrics.ObserveEvaluation(2 * time.Second)
|
||||
c, s := gatherHistogram(t, "apiserver_cel_evaluation_duration_seconds")
|
||||
if c != 1 {
|
||||
t.Errorf("unexpected count: %v", c)
|
||||
}
|
||||
if math.Abs(s-2.0) > 1e-7 {
|
||||
t.Fatalf("incorrect sum: %v", s)
|
||||
}
|
||||
}
|
||||
|
||||
func gatherHistogram(t *testing.T, name string) (count uint64, sum float64) {
|
||||
metrics, err := legacyregistry.DefaultGatherer.Gather()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to gather metrics: %s", err)
|
||||
}
|
||||
for _, mf := range metrics {
|
||||
if mf.GetName() == name {
|
||||
for _, m := range mf.GetMetric() {
|
||||
h := m.GetHistogram()
|
||||
count += h.GetSampleCount()
|
||||
sum += h.GetSampleSum()
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Fatalf("metric not found: %v", name)
|
||||
return 0, 0
|
||||
}
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
/*
|
||||
Copyright 2022 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package cel
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/google/cel-go/cel"
|
||||
)
|
||||
|
||||
// Resolver declares methods to find policy templates and related configuration objects.
|
||||
type Resolver interface {
|
||||
// FindType returns a DeclType instance corresponding to the given fully-qualified name, if
|
||||
// present.
|
||||
FindType(name string) (*DeclType, bool)
|
||||
}
|
||||
|
||||
// NewRegistry create a registry for keeping track of environments and types
|
||||
// from a base cel.Env expression environment.
|
||||
func NewRegistry(stdExprEnv *cel.Env) *Registry {
|
||||
return &Registry{
|
||||
exprEnvs: map[string]*cel.Env{"": stdExprEnv},
|
||||
types: map[string]*DeclType{
|
||||
BoolType.TypeName(): BoolType,
|
||||
BytesType.TypeName(): BytesType,
|
||||
DoubleType.TypeName(): DoubleType,
|
||||
DurationType.TypeName(): DurationType,
|
||||
IntType.TypeName(): IntType,
|
||||
NullType.TypeName(): NullType,
|
||||
StringType.TypeName(): StringType,
|
||||
TimestampType.TypeName(): TimestampType,
|
||||
UintType.TypeName(): UintType,
|
||||
ListType.TypeName(): ListType,
|
||||
MapType.TypeName(): MapType,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Registry defines a repository of environment, schema, template, and type definitions.
|
||||
//
|
||||
// Registry instances are concurrency-safe.
|
||||
type Registry struct {
|
||||
rwMux sync.RWMutex
|
||||
exprEnvs map[string]*cel.Env
|
||||
types map[string]*DeclType
|
||||
}
|
||||
|
||||
// FindType implements the Resolver interface method.
|
||||
func (r *Registry) FindType(name string) (*DeclType, bool) {
|
||||
r.rwMux.RLock()
|
||||
defer r.rwMux.RUnlock()
|
||||
typ, found := r.types[name]
|
||||
if found {
|
||||
return typ, true
|
||||
}
|
||||
return typ, found
|
||||
}
|
||||
|
||||
// SetType registers a DeclType descriptor by its fully qualified name.
|
||||
func (r *Registry) SetType(name string, declType *DeclType) error {
|
||||
r.rwMux.Lock()
|
||||
defer r.rwMux.Unlock()
|
||||
r.types[name] = declType
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,552 @@
|
|||
/*
|
||||
Copyright 2022 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package cel
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/google/cel-go/common/types"
|
||||
"github.com/google/cel-go/common/types/ref"
|
||||
"github.com/google/cel-go/common/types/traits"
|
||||
|
||||
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
noMaxLength = math.MaxInt
|
||||
)
|
||||
|
||||
// NewListType returns a parameterized list type with a specified element type.
|
||||
func NewListType(elem *DeclType, maxItems int64) *DeclType {
|
||||
return &DeclType{
|
||||
name: "list",
|
||||
ElemType: elem,
|
||||
MaxElements: maxItems,
|
||||
celType: cel.ListType(elem.CelType()),
|
||||
defaultValue: NewListValue(),
|
||||
// a list can always be represented as [] in JSON, so hardcode the min size
|
||||
// to 2
|
||||
MinSerializedSize: 2,
|
||||
}
|
||||
}
|
||||
|
||||
// NewMapType returns a parameterized map type with the given key and element types.
|
||||
func NewMapType(key, elem *DeclType, maxProperties int64) *DeclType {
|
||||
return &DeclType{
|
||||
name: "map",
|
||||
KeyType: key,
|
||||
ElemType: elem,
|
||||
MaxElements: maxProperties,
|
||||
celType: cel.MapType(key.CelType(), elem.CelType()),
|
||||
defaultValue: NewMapValue(),
|
||||
// a map can always be represented as {} in JSON, so hardcode the min size
|
||||
// to 2
|
||||
MinSerializedSize: 2,
|
||||
}
|
||||
}
|
||||
|
||||
// NewObjectType creates an object type with a qualified name and a set of field declarations.
|
||||
func NewObjectType(name string, fields map[string]*DeclField) *DeclType {
|
||||
t := &DeclType{
|
||||
name: name,
|
||||
Fields: fields,
|
||||
celType: cel.ObjectType(name),
|
||||
traitMask: traits.FieldTesterType | traits.IndexerType,
|
||||
// an object could potentially be larger than the min size we default to here ({}),
|
||||
// but we rely upon the caller to change MinSerializedSize accordingly if they add
|
||||
// properties to the object
|
||||
MinSerializedSize: 2,
|
||||
}
|
||||
t.defaultValue = NewObjectValue(t)
|
||||
return t
|
||||
}
|
||||
|
||||
func NewSimpleTypeWithMinSize(name string, celType *cel.Type, zeroVal ref.Val, minSize int64) *DeclType {
|
||||
return &DeclType{
|
||||
name: name,
|
||||
celType: celType,
|
||||
defaultValue: zeroVal,
|
||||
MinSerializedSize: minSize,
|
||||
}
|
||||
}
|
||||
|
||||
// DeclType represents the universal type descriptor for OpenAPIv3 types.
|
||||
type DeclType struct {
|
||||
fmt.Stringer
|
||||
|
||||
name string
|
||||
// Fields contains a map of escaped CEL identifier field names to field declarations.
|
||||
Fields map[string]*DeclField
|
||||
KeyType *DeclType
|
||||
ElemType *DeclType
|
||||
TypeParam bool
|
||||
Metadata map[string]string
|
||||
MaxElements int64
|
||||
// MinSerializedSize represents the smallest possible size in bytes that
|
||||
// the DeclType could be serialized to in JSON.
|
||||
MinSerializedSize int64
|
||||
|
||||
celType *cel.Type
|
||||
traitMask int
|
||||
defaultValue ref.Val
|
||||
}
|
||||
|
||||
// MaybeAssignTypeName attempts to set the DeclType name to a fully qualified name, if the type
|
||||
// is of `object` type.
|
||||
//
|
||||
// The DeclType must return true for `IsObject` or this assignment will error.
|
||||
func (t *DeclType) MaybeAssignTypeName(name string) *DeclType {
|
||||
if t.IsObject() {
|
||||
objUpdated := false
|
||||
if t.name != "object" {
|
||||
name = t.name
|
||||
} else {
|
||||
objUpdated = true
|
||||
}
|
||||
fieldMap := make(map[string]*DeclField, len(t.Fields))
|
||||
for fieldName, field := range t.Fields {
|
||||
fieldType := field.Type
|
||||
fieldTypeName := fmt.Sprintf("%s.%s", name, fieldName)
|
||||
updated := fieldType.MaybeAssignTypeName(fieldTypeName)
|
||||
if updated == fieldType {
|
||||
fieldMap[fieldName] = field
|
||||
continue
|
||||
}
|
||||
objUpdated = true
|
||||
fieldMap[fieldName] = &DeclField{
|
||||
Name: fieldName,
|
||||
Type: updated,
|
||||
Required: field.Required,
|
||||
enumValues: field.enumValues,
|
||||
defaultValue: field.defaultValue,
|
||||
}
|
||||
}
|
||||
if !objUpdated {
|
||||
return t
|
||||
}
|
||||
return &DeclType{
|
||||
name: name,
|
||||
Fields: fieldMap,
|
||||
KeyType: t.KeyType,
|
||||
ElemType: t.ElemType,
|
||||
TypeParam: t.TypeParam,
|
||||
Metadata: t.Metadata,
|
||||
celType: cel.ObjectType(name),
|
||||
traitMask: t.traitMask,
|
||||
defaultValue: t.defaultValue,
|
||||
MinSerializedSize: t.MinSerializedSize,
|
||||
}
|
||||
}
|
||||
if t.IsMap() {
|
||||
elemTypeName := fmt.Sprintf("%s.@elem", name)
|
||||
updated := t.ElemType.MaybeAssignTypeName(elemTypeName)
|
||||
if updated == t.ElemType {
|
||||
return t
|
||||
}
|
||||
return NewMapType(t.KeyType, updated, t.MaxElements)
|
||||
}
|
||||
if t.IsList() {
|
||||
elemTypeName := fmt.Sprintf("%s.@idx", name)
|
||||
updated := t.ElemType.MaybeAssignTypeName(elemTypeName)
|
||||
if updated == t.ElemType {
|
||||
return t
|
||||
}
|
||||
return NewListType(updated, t.MaxElements)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// ExprType returns the CEL expression type of this declaration.
|
||||
func (t *DeclType) ExprType() (*exprpb.Type, error) {
|
||||
return cel.TypeToExprType(t.celType)
|
||||
}
|
||||
|
||||
// CelType returns the CEL type of this declaration.
|
||||
func (t *DeclType) CelType() *cel.Type {
|
||||
return t.celType
|
||||
}
|
||||
|
||||
// FindField returns the DeclField with the given name if present.
|
||||
func (t *DeclType) FindField(name string) (*DeclField, bool) {
|
||||
f, found := t.Fields[name]
|
||||
return f, found
|
||||
}
|
||||
|
||||
// HasTrait implements the CEL ref.Type interface making this type declaration suitable for use
|
||||
// within the CEL evaluator.
|
||||
func (t *DeclType) HasTrait(trait int) bool {
|
||||
if t.traitMask&trait == trait {
|
||||
return true
|
||||
}
|
||||
if t.defaultValue == nil {
|
||||
return false
|
||||
}
|
||||
_, isDecl := t.defaultValue.Type().(*DeclType)
|
||||
if isDecl {
|
||||
return false
|
||||
}
|
||||
return t.defaultValue.Type().HasTrait(trait)
|
||||
}
|
||||
|
||||
// IsList returns whether the declaration is a `list` type which defines a parameterized element
|
||||
// type, but not a parameterized key type or fields.
|
||||
func (t *DeclType) IsList() bool {
|
||||
return t.KeyType == nil && t.ElemType != nil && t.Fields == nil
|
||||
}
|
||||
|
||||
// IsMap returns whether the declaration is a 'map' type which defines parameterized key and
|
||||
// element types, but not fields.
|
||||
func (t *DeclType) IsMap() bool {
|
||||
return t.KeyType != nil && t.ElemType != nil && t.Fields == nil
|
||||
}
|
||||
|
||||
// IsObject returns whether the declartion is an 'object' type which defined a set of typed fields.
|
||||
func (t *DeclType) IsObject() bool {
|
||||
return t.KeyType == nil && t.ElemType == nil && t.Fields != nil
|
||||
}
|
||||
|
||||
// String implements the fmt.Stringer interface method.
|
||||
func (t *DeclType) String() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
// TypeName returns the fully qualified type name for the DeclType.
|
||||
func (t *DeclType) TypeName() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
// DefaultValue returns the CEL ref.Val representing the default value for this object type,
|
||||
// if one exists.
|
||||
func (t *DeclType) DefaultValue() ref.Val {
|
||||
return t.defaultValue
|
||||
}
|
||||
|
||||
// FieldTypeMap constructs a map of the field and object types nested within a given type.
|
||||
func FieldTypeMap(path string, t *DeclType) map[string]*DeclType {
|
||||
if t.IsObject() && t.TypeName() != "object" {
|
||||
path = t.TypeName()
|
||||
}
|
||||
types := make(map[string]*DeclType)
|
||||
buildDeclTypes(path, t, types)
|
||||
return types
|
||||
}
|
||||
|
||||
func buildDeclTypes(path string, t *DeclType, types map[string]*DeclType) {
|
||||
// Ensure object types are properly named according to where they appear in the schema.
|
||||
if t.IsObject() {
|
||||
// Hack to ensure that names are uniquely qualified and work well with the type
|
||||
// resolution steps which require fully qualified type names for field resolution
|
||||
// to function properly.
|
||||
types[t.TypeName()] = t
|
||||
for name, field := range t.Fields {
|
||||
fieldPath := fmt.Sprintf("%s.%s", path, name)
|
||||
buildDeclTypes(fieldPath, field.Type, types)
|
||||
}
|
||||
}
|
||||
// Map element properties to type names if needed.
|
||||
if t.IsMap() {
|
||||
mapElemPath := fmt.Sprintf("%s.@elem", path)
|
||||
buildDeclTypes(mapElemPath, t.ElemType, types)
|
||||
types[path] = t
|
||||
}
|
||||
// List element properties.
|
||||
if t.IsList() {
|
||||
listIdxPath := fmt.Sprintf("%s.@idx", path)
|
||||
buildDeclTypes(listIdxPath, t.ElemType, types)
|
||||
types[path] = t
|
||||
}
|
||||
}
|
||||
|
||||
// DeclField describes the name, ordinal, and optionality of a field declaration within a type.
|
||||
type DeclField struct {
|
||||
Name string
|
||||
Type *DeclType
|
||||
Required bool
|
||||
enumValues []interface{}
|
||||
defaultValue interface{}
|
||||
}
|
||||
|
||||
func NewDeclField(name string, declType *DeclType, required bool, enumValues []interface{}, defaultValue interface{}) *DeclField {
|
||||
return &DeclField{
|
||||
Name: name,
|
||||
Type: declType,
|
||||
Required: required,
|
||||
enumValues: enumValues,
|
||||
defaultValue: defaultValue,
|
||||
}
|
||||
}
|
||||
|
||||
// TypeName returns the string type name of the field.
|
||||
func (f *DeclField) TypeName() string {
|
||||
return f.Type.TypeName()
|
||||
}
|
||||
|
||||
// DefaultValue returns the zero value associated with the field.
|
||||
func (f *DeclField) DefaultValue() ref.Val {
|
||||
if f.defaultValue != nil {
|
||||
return types.DefaultTypeAdapter.NativeToValue(f.defaultValue)
|
||||
}
|
||||
return f.Type.DefaultValue()
|
||||
}
|
||||
|
||||
// EnumValues returns the set of values that this field may take.
|
||||
func (f *DeclField) EnumValues() []ref.Val {
|
||||
if f.enumValues == nil || len(f.enumValues) == 0 {
|
||||
return []ref.Val{}
|
||||
}
|
||||
ev := make([]ref.Val, len(f.enumValues))
|
||||
for i, e := range f.enumValues {
|
||||
ev[i] = types.DefaultTypeAdapter.NativeToValue(e)
|
||||
}
|
||||
return ev
|
||||
}
|
||||
|
||||
// NewRuleTypes returns an Open API Schema-based type-system which is CEL compatible.
|
||||
func NewRuleTypes(kind string,
|
||||
declType *DeclType,
|
||||
res Resolver) (*RuleTypes, error) {
|
||||
// Note, if the schema indicates that it's actually based on another proto
|
||||
// then prefer the proto definition. For expressions in the proto, a new field
|
||||
// annotation will be needed to indicate the expected environment and type of
|
||||
// the expression.
|
||||
schemaTypes, err := newSchemaTypeProvider(kind, declType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if schemaTypes == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return &RuleTypes{
|
||||
ruleSchemaDeclTypes: schemaTypes,
|
||||
resolver: res,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RuleTypes extends the CEL ref.TypeProvider interface and provides an Open API Schema-based
|
||||
// type-system.
|
||||
type RuleTypes struct {
|
||||
ref.TypeProvider
|
||||
ruleSchemaDeclTypes *schemaTypeProvider
|
||||
typeAdapter ref.TypeAdapter
|
||||
resolver Resolver
|
||||
}
|
||||
|
||||
// EnvOptions returns a set of cel.EnvOption values which includes the declaration set
|
||||
// as well as a custom ref.TypeProvider.
|
||||
//
|
||||
// Note, the standard declaration set includes 'rule' which is defined as the top-level rule-schema
|
||||
// type if one is configured.
|
||||
//
|
||||
// If the RuleTypes value is nil, an empty []cel.EnvOption set is returned.
|
||||
func (rt *RuleTypes) EnvOptions(tp ref.TypeProvider) ([]cel.EnvOption, error) {
|
||||
if rt == nil {
|
||||
return []cel.EnvOption{}, nil
|
||||
}
|
||||
var ta ref.TypeAdapter = types.DefaultTypeAdapter
|
||||
tpa, ok := tp.(ref.TypeAdapter)
|
||||
if ok {
|
||||
ta = tpa
|
||||
}
|
||||
rtWithTypes := &RuleTypes{
|
||||
TypeProvider: tp,
|
||||
typeAdapter: ta,
|
||||
ruleSchemaDeclTypes: rt.ruleSchemaDeclTypes,
|
||||
resolver: rt.resolver,
|
||||
}
|
||||
for name, declType := range rt.ruleSchemaDeclTypes.types {
|
||||
tpType, found := tp.FindType(name)
|
||||
expT, err := declType.ExprType()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fail to get cel type: %s", err)
|
||||
}
|
||||
if found && !proto.Equal(tpType, expT) {
|
||||
return nil, fmt.Errorf(
|
||||
"type %s definition differs between CEL environment and rule", name)
|
||||
}
|
||||
}
|
||||
return []cel.EnvOption{
|
||||
cel.CustomTypeProvider(rtWithTypes),
|
||||
cel.CustomTypeAdapter(rtWithTypes),
|
||||
cel.Variable("rule", rt.ruleSchemaDeclTypes.root.CelType()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// FindType attempts to resolve the typeName provided from the rule's rule-schema, or if not
|
||||
// from the embedded ref.TypeProvider.
|
||||
//
|
||||
// FindType overrides the default type-finding behavior of the embedded TypeProvider.
|
||||
//
|
||||
// Note, when the type name is based on the Open API Schema, the name will reflect the object path
|
||||
// where the type definition appears.
|
||||
func (rt *RuleTypes) FindType(typeName string) (*exprpb.Type, bool) {
|
||||
if rt == nil {
|
||||
return nil, false
|
||||
}
|
||||
declType, found := rt.findDeclType(typeName)
|
||||
if found {
|
||||
expT, err := declType.ExprType()
|
||||
if err != nil {
|
||||
return expT, false
|
||||
}
|
||||
return expT, found
|
||||
}
|
||||
return rt.TypeProvider.FindType(typeName)
|
||||
}
|
||||
|
||||
// FindDeclType returns the CPT type description which can be mapped to a CEL type.
|
||||
func (rt *RuleTypes) FindDeclType(typeName string) (*DeclType, bool) {
|
||||
if rt == nil {
|
||||
return nil, false
|
||||
}
|
||||
return rt.findDeclType(typeName)
|
||||
}
|
||||
|
||||
// FindFieldType returns a field type given a type name and field name, if found.
|
||||
//
|
||||
// Note, the type name for an Open API Schema type is likely to be its qualified object path.
|
||||
// If, in the future an object instance rather than a type name were provided, the field
|
||||
// resolution might more accurately reflect the expected type model. However, in this case
|
||||
// concessions were made to align with the existing CEL interfaces.
|
||||
func (rt *RuleTypes) FindFieldType(typeName, fieldName string) (*ref.FieldType, bool) {
|
||||
st, found := rt.findDeclType(typeName)
|
||||
if !found {
|
||||
return rt.TypeProvider.FindFieldType(typeName, fieldName)
|
||||
}
|
||||
|
||||
f, found := st.Fields[fieldName]
|
||||
if found {
|
||||
ft := f.Type
|
||||
expT, err := ft.ExprType()
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return &ref.FieldType{
|
||||
Type: expT,
|
||||
}, true
|
||||
}
|
||||
// This could be a dynamic map.
|
||||
if st.IsMap() {
|
||||
et := st.ElemType
|
||||
expT, err := et.ExprType()
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return &ref.FieldType{
|
||||
Type: expT,
|
||||
}, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// NativeToValue is an implementation of the ref.TypeAdapater interface which supports conversion
|
||||
// of rule values to CEL ref.Val instances.
|
||||
func (rt *RuleTypes) NativeToValue(val interface{}) ref.Val {
|
||||
return rt.typeAdapter.NativeToValue(val)
|
||||
}
|
||||
|
||||
// TypeNames returns the list of type names declared within the RuleTypes object.
|
||||
func (rt *RuleTypes) TypeNames() []string {
|
||||
typeNames := make([]string, len(rt.ruleSchemaDeclTypes.types))
|
||||
i := 0
|
||||
for name := range rt.ruleSchemaDeclTypes.types {
|
||||
typeNames[i] = name
|
||||
i++
|
||||
}
|
||||
return typeNames
|
||||
}
|
||||
|
||||
func (rt *RuleTypes) findDeclType(typeName string) (*DeclType, bool) {
|
||||
declType, found := rt.ruleSchemaDeclTypes.types[typeName]
|
||||
if found {
|
||||
return declType, true
|
||||
}
|
||||
declType, found = rt.resolver.FindType(typeName)
|
||||
if found {
|
||||
return declType, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func newSchemaTypeProvider(kind string, declType *DeclType) (*schemaTypeProvider, error) {
|
||||
if declType == nil {
|
||||
return nil, nil
|
||||
}
|
||||
root := declType.MaybeAssignTypeName(kind)
|
||||
types := FieldTypeMap(kind, root)
|
||||
return &schemaTypeProvider{
|
||||
root: root,
|
||||
types: types,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type schemaTypeProvider struct {
|
||||
root *DeclType
|
||||
types map[string]*DeclType
|
||||
}
|
||||
|
||||
var (
|
||||
// AnyType is equivalent to the CEL 'protobuf.Any' type in that the value may have any of the
|
||||
// types supported.
|
||||
AnyType = NewSimpleTypeWithMinSize("any", cel.AnyType, nil, 1)
|
||||
|
||||
// BoolType is equivalent to the CEL 'bool' type.
|
||||
BoolType = NewSimpleTypeWithMinSize("bool", cel.BoolType, types.False, MinBoolSize)
|
||||
|
||||
// BytesType is equivalent to the CEL 'bytes' type.
|
||||
BytesType = NewSimpleTypeWithMinSize("bytes", cel.BytesType, types.Bytes([]byte{}), MinStringSize)
|
||||
|
||||
// DoubleType is equivalent to the CEL 'double' type which is a 64-bit floating point value.
|
||||
DoubleType = NewSimpleTypeWithMinSize("double", cel.DoubleType, types.Double(0), MinNumberSize)
|
||||
|
||||
// DurationType is equivalent to the CEL 'duration' type.
|
||||
DurationType = NewSimpleTypeWithMinSize("duration", cel.DurationType, types.Duration{Duration: time.Duration(0)}, MinDurationSizeJSON)
|
||||
|
||||
// DateType is equivalent to the CEL 'date' type.
|
||||
DateType = NewSimpleTypeWithMinSize("date", cel.TimestampType, types.Timestamp{Time: time.Time{}}, JSONDateSize)
|
||||
|
||||
// DynType is the equivalent of the CEL 'dyn' concept which indicates that the type will be
|
||||
// determined at runtime rather than compile time.
|
||||
DynType = NewSimpleTypeWithMinSize("dyn", cel.DynType, nil, 1)
|
||||
|
||||
// IntType is equivalent to the CEL 'int' type which is a 64-bit signed int.
|
||||
IntType = NewSimpleTypeWithMinSize("int", cel.IntType, types.IntZero, MinNumberSize)
|
||||
|
||||
// NullType is equivalent to the CEL 'null_type'.
|
||||
NullType = NewSimpleTypeWithMinSize("null_type", cel.NullType, types.NullValue, 4)
|
||||
|
||||
// StringType is equivalent to the CEL 'string' type which is expected to be a UTF-8 string.
|
||||
// StringType values may either be string literals or expression strings.
|
||||
StringType = NewSimpleTypeWithMinSize("string", cel.StringType, types.String(""), MinStringSize)
|
||||
|
||||
// TimestampType corresponds to the well-known protobuf.Timestamp type supported within CEL.
|
||||
// Note that both the OpenAPI date and date-time types map onto TimestampType, so not all types
|
||||
// labeled as Timestamp will necessarily have the same MinSerializedSize.
|
||||
TimestampType = NewSimpleTypeWithMinSize("timestamp", cel.TimestampType, types.Timestamp{Time: time.Time{}}, JSONDateSize)
|
||||
|
||||
// UintType is equivalent to the CEL 'uint' type.
|
||||
UintType = NewSimpleTypeWithMinSize("uint", cel.UintType, types.Uint(0), 1)
|
||||
|
||||
// ListType is equivalent to the CEL 'list' type.
|
||||
ListType = NewListType(AnyType, noMaxLength)
|
||||
|
||||
// MapType is equivalent to the CEL 'map' type.
|
||||
MapType = NewMapType(AnyType, AnyType, noMaxLength)
|
||||
)
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
/*
|
||||
Copyright 2022 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package cel
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTypes_ListType(t *testing.T) {
|
||||
list := NewListType(StringType, -1)
|
||||
if !list.IsList() {
|
||||
t.Error("list type not identifiable as list")
|
||||
}
|
||||
if list.TypeName() != "list" {
|
||||
t.Errorf("got %s, wanted list", list.TypeName())
|
||||
}
|
||||
if list.DefaultValue() == nil {
|
||||
t.Error("got nil zero value for list type")
|
||||
}
|
||||
if list.ElemType.TypeName() != "string" {
|
||||
t.Errorf("got %s, wanted elem type of string", list.ElemType.TypeName())
|
||||
}
|
||||
expT, err := list.ExprType()
|
||||
if err != nil {
|
||||
t.Errorf("fail to get cel type: %s", err)
|
||||
}
|
||||
if expT.GetListType() == nil {
|
||||
t.Errorf("got %v, wanted CEL list type", expT)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypes_MapType(t *testing.T) {
|
||||
mp := NewMapType(StringType, IntType, -1)
|
||||
if !mp.IsMap() {
|
||||
t.Error("map type not identifiable as map")
|
||||
}
|
||||
if mp.TypeName() != "map" {
|
||||
t.Errorf("got %s, wanted map", mp.TypeName())
|
||||
}
|
||||
if mp.DefaultValue() == nil {
|
||||
t.Error("got nil zero value for map type")
|
||||
}
|
||||
if mp.KeyType.TypeName() != "string" {
|
||||
t.Errorf("got %s, wanted key type of string", mp.KeyType.TypeName())
|
||||
}
|
||||
if mp.ElemType.TypeName() != "int" {
|
||||
t.Errorf("got %s, wanted elem type of int", mp.ElemType.TypeName())
|
||||
}
|
||||
expT, err := mp.ExprType()
|
||||
if err != nil {
|
||||
t.Errorf("fail to get cel type: %s", err)
|
||||
}
|
||||
if expT.GetMapType() == nil {
|
||||
t.Errorf("got %v, wanted CEL map type", expT)
|
||||
}
|
||||
}
|
||||
|
||||
func testValue(t *testing.T, id int64, val interface{}) *DynValue {
|
||||
t.Helper()
|
||||
dv, err := NewDynValue(id, val)
|
||||
if err != nil {
|
||||
t.Fatalf("NewDynValue(%d, %v) failed: %v", id, val, err)
|
||||
}
|
||||
return dv
|
||||
}
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
/*
|
||||
Copyright 2022 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package cel
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"reflect"
|
||||
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/google/cel-go/checker/decls"
|
||||
"github.com/google/cel-go/common/types"
|
||||
"github.com/google/cel-go/common/types/ref"
|
||||
)
|
||||
|
||||
// URL provides a CEL representation of a URL.
|
||||
type URL struct {
|
||||
*url.URL
|
||||
}
|
||||
|
||||
var (
|
||||
URLObject = decls.NewObjectType("kubernetes.URL")
|
||||
typeValue = types.NewTypeValue("kubernetes.URL")
|
||||
URLType = cel.ObjectType("kubernetes.URL")
|
||||
)
|
||||
|
||||
// ConvertToNative implements ref.Val.ConvertToNative.
|
||||
func (d URL) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
|
||||
if reflect.TypeOf(d.URL).AssignableTo(typeDesc) {
|
||||
return d.URL, nil
|
||||
}
|
||||
if reflect.TypeOf("").AssignableTo(typeDesc) {
|
||||
return d.URL.String(), nil
|
||||
}
|
||||
return nil, fmt.Errorf("type conversion error from 'URL' to '%v'", typeDesc)
|
||||
}
|
||||
|
||||
// ConvertToType implements ref.Val.ConvertToType.
|
||||
func (d URL) ConvertToType(typeVal ref.Type) ref.Val {
|
||||
switch typeVal {
|
||||
case typeValue:
|
||||
return d
|
||||
case types.TypeType:
|
||||
return typeValue
|
||||
}
|
||||
return types.NewErr("type conversion error from '%s' to '%s'", typeValue, typeVal)
|
||||
}
|
||||
|
||||
// Equal implements ref.Val.Equal.
|
||||
func (d URL) Equal(other ref.Val) ref.Val {
|
||||
otherDur, ok := other.(URL)
|
||||
if !ok {
|
||||
return types.MaybeNoSuchOverloadErr(other)
|
||||
}
|
||||
return types.Bool(d.URL.String() == otherDur.URL.String())
|
||||
}
|
||||
|
||||
// Type implements ref.Val.Type.
|
||||
func (d URL) Type() ref.Type {
|
||||
return typeValue
|
||||
}
|
||||
|
||||
// Value implements ref.Val.Value.
|
||||
func (d URL) Value() interface{} {
|
||||
return d.URL
|
||||
}
|
||||
|
|
@ -0,0 +1,769 @@
|
|||
/*
|
||||
Copyright 2022 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package cel
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/cel-go/common/types"
|
||||
"github.com/google/cel-go/common/types/ref"
|
||||
"github.com/google/cel-go/common/types/traits"
|
||||
)
|
||||
|
||||
// EncodeStyle is a hint for string encoding of parsed values.
|
||||
type EncodeStyle int
|
||||
|
||||
const (
|
||||
// BlockValueStyle is the default string encoding which preserves whitespace and newlines.
|
||||
BlockValueStyle EncodeStyle = iota
|
||||
|
||||
// FlowValueStyle indicates that the string is an inline representation of complex types.
|
||||
FlowValueStyle
|
||||
|
||||
// FoldedValueStyle is a multiline string with whitespace and newlines trimmed to a single
|
||||
// a whitespace. Repeated newlines are replaced with a single newline rather than a single
|
||||
// whitespace.
|
||||
FoldedValueStyle
|
||||
|
||||
// LiteralStyle is a multiline string that preserves newlines, but trims all other whitespace
|
||||
// to a single character.
|
||||
LiteralStyle
|
||||
)
|
||||
|
||||
// NewEmptyDynValue returns the zero-valued DynValue.
|
||||
func NewEmptyDynValue() *DynValue {
|
||||
// note: 0 is not a valid parse node identifier.
|
||||
dv, _ := NewDynValue(0, nil)
|
||||
return dv
|
||||
}
|
||||
|
||||
// NewDynValue returns a DynValue that corresponds to a parse node id and value.
|
||||
func NewDynValue(id int64, val interface{}) (*DynValue, error) {
|
||||
dv := &DynValue{ID: id}
|
||||
err := dv.SetValue(val)
|
||||
return dv, err
|
||||
}
|
||||
|
||||
// DynValue is a dynamically typed value used to describe unstructured content.
|
||||
// Whether the value has the desired type is determined by where it is used within the Instance or
|
||||
// Template, and whether there are schemas which might enforce a more rigid type definition.
|
||||
type DynValue struct {
|
||||
ID int64
|
||||
EncodeStyle EncodeStyle
|
||||
value interface{}
|
||||
exprValue ref.Val
|
||||
declType *DeclType
|
||||
}
|
||||
|
||||
// DeclType returns the policy model type of the dyn value.
|
||||
func (dv *DynValue) DeclType() *DeclType {
|
||||
return dv.declType
|
||||
}
|
||||
|
||||
// ConvertToNative is an implementation of the CEL ref.Val method used to adapt between CEL types
|
||||
// and Go-native types.
|
||||
//
|
||||
// The default behavior of this method is to first convert to a CEL type which has a well-defined
|
||||
// set of conversion behaviors and proxy to the CEL ConvertToNative method for the type.
|
||||
func (dv *DynValue) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
|
||||
ev := dv.ExprValue()
|
||||
if types.IsError(ev) {
|
||||
return nil, ev.(*types.Err)
|
||||
}
|
||||
return ev.ConvertToNative(typeDesc)
|
||||
}
|
||||
|
||||
// Equal returns whether the dyn value is equal to a given CEL value.
|
||||
func (dv *DynValue) Equal(other ref.Val) ref.Val {
|
||||
dvType := dv.Type()
|
||||
otherType := other.Type()
|
||||
// Preserve CEL's homogeneous equality constraint.
|
||||
if dvType.TypeName() != otherType.TypeName() {
|
||||
return types.MaybeNoSuchOverloadErr(other)
|
||||
}
|
||||
switch v := dv.value.(type) {
|
||||
case ref.Val:
|
||||
return v.Equal(other)
|
||||
case PlainTextValue:
|
||||
return celBool(string(v) == other.Value().(string))
|
||||
case *MultilineStringValue:
|
||||
return celBool(v.Value == other.Value().(string))
|
||||
case time.Duration:
|
||||
otherDuration := other.Value().(time.Duration)
|
||||
return celBool(v == otherDuration)
|
||||
case time.Time:
|
||||
otherTimestamp := other.Value().(time.Time)
|
||||
return celBool(v.Equal(otherTimestamp))
|
||||
default:
|
||||
return celBool(reflect.DeepEqual(v, other.Value()))
|
||||
}
|
||||
}
|
||||
|
||||
// ExprValue converts the DynValue into a CEL value.
|
||||
func (dv *DynValue) ExprValue() ref.Val {
|
||||
return dv.exprValue
|
||||
}
|
||||
|
||||
// Value returns the underlying value held by this reference.
|
||||
func (dv *DynValue) Value() interface{} {
|
||||
return dv.value
|
||||
}
|
||||
|
||||
// SetValue updates the underlying value held by this reference.
|
||||
func (dv *DynValue) SetValue(value interface{}) error {
|
||||
dv.value = value
|
||||
var err error
|
||||
dv.exprValue, dv.declType, err = exprValue(value)
|
||||
return err
|
||||
}
|
||||
|
||||
// Type returns the CEL type for the given value.
|
||||
func (dv *DynValue) Type() ref.Type {
|
||||
return dv.ExprValue().Type()
|
||||
}
|
||||
|
||||
func exprValue(value interface{}) (ref.Val, *DeclType, error) {
|
||||
switch v := value.(type) {
|
||||
case bool:
|
||||
return types.Bool(v), BoolType, nil
|
||||
case []byte:
|
||||
return types.Bytes(v), BytesType, nil
|
||||
case float64:
|
||||
return types.Double(v), DoubleType, nil
|
||||
case int64:
|
||||
return types.Int(v), IntType, nil
|
||||
case string:
|
||||
return types.String(v), StringType, nil
|
||||
case uint64:
|
||||
return types.Uint(v), UintType, nil
|
||||
case time.Duration:
|
||||
return types.Duration{Duration: v}, DurationType, nil
|
||||
case time.Time:
|
||||
return types.Timestamp{Time: v}, TimestampType, nil
|
||||
case types.Null:
|
||||
return v, NullType, nil
|
||||
case *ListValue:
|
||||
return v, ListType, nil
|
||||
case *MapValue:
|
||||
return v, MapType, nil
|
||||
case *ObjectValue:
|
||||
return v, v.objectType, nil
|
||||
default:
|
||||
return nil, unknownType, fmt.Errorf("unsupported type: (%T)%v", v, v)
|
||||
}
|
||||
}
|
||||
|
||||
// PlainTextValue is a text string literal which must not be treated as an expression.
|
||||
type PlainTextValue string
|
||||
|
||||
// MultilineStringValue is a multiline string value which has been parsed in a way which omits
|
||||
// whitespace as well as a raw form which preserves whitespace.
|
||||
type MultilineStringValue struct {
|
||||
Value string
|
||||
Raw string
|
||||
}
|
||||
|
||||
func newStructValue() *structValue {
|
||||
return &structValue{
|
||||
Fields: []*Field{},
|
||||
fieldMap: map[string]*Field{},
|
||||
}
|
||||
}
|
||||
|
||||
type structValue struct {
|
||||
Fields []*Field
|
||||
fieldMap map[string]*Field
|
||||
}
|
||||
|
||||
// AddField appends a MapField to the MapValue and indexes the field by name.
|
||||
func (sv *structValue) AddField(field *Field) {
|
||||
sv.Fields = append(sv.Fields, field)
|
||||
sv.fieldMap[field.Name] = field
|
||||
}
|
||||
|
||||
// ConvertToNative converts the MapValue type to a native go types.
|
||||
func (sv *structValue) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
|
||||
if typeDesc.Kind() != reflect.Map &&
|
||||
typeDesc.Kind() != reflect.Struct &&
|
||||
typeDesc.Kind() != reflect.Pointer &&
|
||||
typeDesc.Kind() != reflect.Interface {
|
||||
return nil, fmt.Errorf("type conversion error from object to '%v'", typeDesc)
|
||||
}
|
||||
|
||||
// Unwrap pointers, but track their use.
|
||||
isPtr := false
|
||||
if typeDesc.Kind() == reflect.Pointer {
|
||||
tk := typeDesc
|
||||
typeDesc = typeDesc.Elem()
|
||||
if typeDesc.Kind() == reflect.Pointer {
|
||||
return nil, fmt.Errorf("unsupported type conversion to '%v'", tk)
|
||||
}
|
||||
isPtr = true
|
||||
}
|
||||
|
||||
if typeDesc.Kind() == reflect.Map {
|
||||
keyType := typeDesc.Key()
|
||||
if keyType.Kind() != reflect.String && keyType.Kind() != reflect.Interface {
|
||||
return nil, fmt.Errorf("object fields cannot be converted to type '%v'", keyType)
|
||||
}
|
||||
elemType := typeDesc.Elem()
|
||||
sz := len(sv.fieldMap)
|
||||
ntvMap := reflect.MakeMapWithSize(typeDesc, sz)
|
||||
for name, val := range sv.fieldMap {
|
||||
refVal, err := val.Ref.ConvertToNative(elemType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ntvMap.SetMapIndex(reflect.ValueOf(name), reflect.ValueOf(refVal))
|
||||
}
|
||||
return ntvMap.Interface(), nil
|
||||
}
|
||||
|
||||
if typeDesc.Kind() == reflect.Struct {
|
||||
ntvObjPtr := reflect.New(typeDesc)
|
||||
ntvObj := ntvObjPtr.Elem()
|
||||
for name, val := range sv.fieldMap {
|
||||
f := ntvObj.FieldByName(name)
|
||||
if !f.IsValid() {
|
||||
return nil, fmt.Errorf("type conversion error, no such field %s in type %v",
|
||||
name, typeDesc)
|
||||
}
|
||||
fv, err := val.Ref.ConvertToNative(f.Type())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
f.Set(reflect.ValueOf(fv))
|
||||
}
|
||||
if isPtr {
|
||||
return ntvObjPtr.Interface(), nil
|
||||
}
|
||||
return ntvObj.Interface(), nil
|
||||
}
|
||||
return nil, fmt.Errorf("type conversion error from object to '%v'", typeDesc)
|
||||
}
|
||||
|
||||
// GetField returns a MapField by name if one exists.
|
||||
func (sv *structValue) GetField(name string) (*Field, bool) {
|
||||
field, found := sv.fieldMap[name]
|
||||
return field, found
|
||||
}
|
||||
|
||||
// IsSet returns whether the given field, which is defined, has also been set.
|
||||
func (sv *structValue) IsSet(key ref.Val) ref.Val {
|
||||
k, ok := key.(types.String)
|
||||
if !ok {
|
||||
return types.MaybeNoSuchOverloadErr(key)
|
||||
}
|
||||
name := string(k)
|
||||
_, found := sv.fieldMap[name]
|
||||
return celBool(found)
|
||||
}
|
||||
|
||||
// NewObjectValue creates a struct value with a schema type and returns the empty ObjectValue.
|
||||
func NewObjectValue(sType *DeclType) *ObjectValue {
|
||||
return &ObjectValue{
|
||||
structValue: newStructValue(),
|
||||
objectType: sType,
|
||||
}
|
||||
}
|
||||
|
||||
// ObjectValue is a struct with a custom schema type which indicates the fields and types
|
||||
// associated with the structure.
|
||||
type ObjectValue struct {
|
||||
*structValue
|
||||
objectType *DeclType
|
||||
}
|
||||
|
||||
// ConvertToType is an implementation of the CEL ref.Val interface method.
|
||||
func (o *ObjectValue) ConvertToType(t ref.Type) ref.Val {
|
||||
if t == types.TypeType {
|
||||
return types.NewObjectTypeValue(o.objectType.TypeName())
|
||||
}
|
||||
if t.TypeName() == o.objectType.TypeName() {
|
||||
return o
|
||||
}
|
||||
return types.NewErr("type conversion error from '%s' to '%s'", o.Type(), t)
|
||||
}
|
||||
|
||||
// Equal returns true if the two object types are equal and their field values are equal.
|
||||
func (o *ObjectValue) Equal(other ref.Val) ref.Val {
|
||||
// Preserve CEL's homogeneous equality semantics.
|
||||
if o.objectType.TypeName() != other.Type().TypeName() {
|
||||
return types.MaybeNoSuchOverloadErr(other)
|
||||
}
|
||||
o2 := other.(traits.Indexer)
|
||||
for name := range o.objectType.Fields {
|
||||
k := types.String(name)
|
||||
v := o.Get(k)
|
||||
ov := o2.Get(k)
|
||||
vEq := v.Equal(ov)
|
||||
if vEq != types.True {
|
||||
return vEq
|
||||
}
|
||||
}
|
||||
return types.True
|
||||
}
|
||||
|
||||
// Get returns the value of the specified field.
|
||||
//
|
||||
// If the field is set, its value is returned. If the field is not set, the default value for the
|
||||
// field is returned thus allowing for safe-traversal and preserving proto-like field traversal
|
||||
// semantics for Open API Schema backed types.
|
||||
func (o *ObjectValue) Get(name ref.Val) ref.Val {
|
||||
n, ok := name.(types.String)
|
||||
if !ok {
|
||||
return types.MaybeNoSuchOverloadErr(n)
|
||||
}
|
||||
nameStr := string(n)
|
||||
field, found := o.fieldMap[nameStr]
|
||||
if found {
|
||||
return field.Ref.ExprValue()
|
||||
}
|
||||
fieldDef, found := o.objectType.Fields[nameStr]
|
||||
if !found {
|
||||
return types.NewErr("no such field: %s", nameStr)
|
||||
}
|
||||
defValue := fieldDef.DefaultValue()
|
||||
if defValue != nil {
|
||||
return defValue
|
||||
}
|
||||
return types.NewErr("no default for type: %s", fieldDef.TypeName())
|
||||
}
|
||||
|
||||
// Type returns the CEL type value of the object.
|
||||
func (o *ObjectValue) Type() ref.Type {
|
||||
return o.objectType
|
||||
}
|
||||
|
||||
// Value returns the Go-native representation of the object.
|
||||
func (o *ObjectValue) Value() interface{} {
|
||||
return o
|
||||
}
|
||||
|
||||
// NewMapValue returns an empty MapValue.
|
||||
func NewMapValue() *MapValue {
|
||||
return &MapValue{
|
||||
structValue: newStructValue(),
|
||||
}
|
||||
}
|
||||
|
||||
// MapValue declares an object with a set of named fields whose values are dynamically typed.
|
||||
type MapValue struct {
|
||||
*structValue
|
||||
}
|
||||
|
||||
// ConvertToObject produces an ObjectValue from the MapValue with the associated schema type.
|
||||
//
|
||||
// The conversion is shallow and the memory shared between the Object and Map as all references
|
||||
// to the map are expected to be replaced with the Object reference.
|
||||
func (m *MapValue) ConvertToObject(declType *DeclType) *ObjectValue {
|
||||
return &ObjectValue{
|
||||
structValue: m.structValue,
|
||||
objectType: declType,
|
||||
}
|
||||
}
|
||||
|
||||
// Contains returns whether the given key is contained in the MapValue.
|
||||
func (m *MapValue) Contains(key ref.Val) ref.Val {
|
||||
v, found := m.Find(key)
|
||||
if v != nil && types.IsUnknownOrError(v) {
|
||||
return v
|
||||
}
|
||||
return celBool(found)
|
||||
}
|
||||
|
||||
// ConvertToType converts the MapValue to another CEL type, if possible.
|
||||
func (m *MapValue) ConvertToType(t ref.Type) ref.Val {
|
||||
switch t {
|
||||
case types.MapType:
|
||||
return m
|
||||
case types.TypeType:
|
||||
return types.MapType
|
||||
}
|
||||
return types.NewErr("type conversion error from '%s' to '%s'", m.Type(), t)
|
||||
}
|
||||
|
||||
// Equal returns true if the maps are of the same size, have the same keys, and the key-values
|
||||
// from each map are equal.
|
||||
func (m *MapValue) Equal(other ref.Val) ref.Val {
|
||||
oMap, isMap := other.(traits.Mapper)
|
||||
if !isMap {
|
||||
return types.MaybeNoSuchOverloadErr(other)
|
||||
}
|
||||
if m.Size() != oMap.Size() {
|
||||
return types.False
|
||||
}
|
||||
for name, field := range m.fieldMap {
|
||||
k := types.String(name)
|
||||
ov, found := oMap.Find(k)
|
||||
if !found {
|
||||
return types.False
|
||||
}
|
||||
v := field.Ref.ExprValue()
|
||||
vEq := v.Equal(ov)
|
||||
if vEq != types.True {
|
||||
return vEq
|
||||
}
|
||||
}
|
||||
return types.True
|
||||
}
|
||||
|
||||
// Find returns the value for the key in the map, if found.
|
||||
func (m *MapValue) Find(name ref.Val) (ref.Val, bool) {
|
||||
// Currently only maps with string keys are supported as this is best aligned with JSON,
|
||||
// and also much simpler to support.
|
||||
n, ok := name.(types.String)
|
||||
if !ok {
|
||||
return types.MaybeNoSuchOverloadErr(n), true
|
||||
}
|
||||
nameStr := string(n)
|
||||
field, found := m.fieldMap[nameStr]
|
||||
if found {
|
||||
return field.Ref.ExprValue(), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Get returns the value for the key in the map, or error if not found.
|
||||
func (m *MapValue) Get(key ref.Val) ref.Val {
|
||||
v, found := m.Find(key)
|
||||
if found {
|
||||
return v
|
||||
}
|
||||
return types.ValOrErr(key, "no such key: %v", key)
|
||||
}
|
||||
|
||||
// Iterator produces a traits.Iterator which walks over the map keys.
|
||||
//
|
||||
// The Iterator is frequently used within comprehensions.
|
||||
func (m *MapValue) Iterator() traits.Iterator {
|
||||
keys := make([]ref.Val, len(m.fieldMap))
|
||||
i := 0
|
||||
for k := range m.fieldMap {
|
||||
keys[i] = types.String(k)
|
||||
i++
|
||||
}
|
||||
return &baseMapIterator{
|
||||
baseVal: &baseVal{},
|
||||
keys: keys,
|
||||
}
|
||||
}
|
||||
|
||||
// Size returns the number of keys in the map.
|
||||
func (m *MapValue) Size() ref.Val {
|
||||
return types.Int(len(m.Fields))
|
||||
}
|
||||
|
||||
// Type returns the CEL ref.Type for the map.
|
||||
func (m *MapValue) Type() ref.Type {
|
||||
return types.MapType
|
||||
}
|
||||
|
||||
// Value returns the Go-native representation of the MapValue.
|
||||
func (m *MapValue) Value() interface{} {
|
||||
return m
|
||||
}
|
||||
|
||||
type baseMapIterator struct {
|
||||
*baseVal
|
||||
keys []ref.Val
|
||||
idx int
|
||||
}
|
||||
|
||||
// HasNext implements the traits.Iterator interface method.
|
||||
func (it *baseMapIterator) HasNext() ref.Val {
|
||||
if it.idx < len(it.keys) {
|
||||
return types.True
|
||||
}
|
||||
return types.False
|
||||
}
|
||||
|
||||
// Next implements the traits.Iterator interface method.
|
||||
func (it *baseMapIterator) Next() ref.Val {
|
||||
key := it.keys[it.idx]
|
||||
it.idx++
|
||||
return key
|
||||
}
|
||||
|
||||
// Type implements the CEL ref.Val interface metohd.
|
||||
func (it *baseMapIterator) Type() ref.Type {
|
||||
return types.IteratorType
|
||||
}
|
||||
|
||||
// NewField returns a MapField instance with an empty DynValue that refers to the
|
||||
// specified parse node id and field name.
|
||||
func NewField(id int64, name string) *Field {
|
||||
return &Field{
|
||||
ID: id,
|
||||
Name: name,
|
||||
Ref: NewEmptyDynValue(),
|
||||
}
|
||||
}
|
||||
|
||||
// Field specifies a field name and a reference to a dynamic value.
|
||||
type Field struct {
|
||||
ID int64
|
||||
Name string
|
||||
Ref *DynValue
|
||||
}
|
||||
|
||||
// NewListValue returns an empty ListValue instance.
|
||||
func NewListValue() *ListValue {
|
||||
return &ListValue{
|
||||
Entries: []*DynValue{},
|
||||
}
|
||||
}
|
||||
|
||||
// ListValue contains a list of dynamically typed entries.
|
||||
type ListValue struct {
|
||||
Entries []*DynValue
|
||||
initValueSet sync.Once
|
||||
valueSet map[ref.Val]struct{}
|
||||
}
|
||||
|
||||
// Add concatenates two lists together to produce a new CEL list value.
|
||||
func (lv *ListValue) Add(other ref.Val) ref.Val {
|
||||
oArr, isArr := other.(traits.Lister)
|
||||
if !isArr {
|
||||
return types.MaybeNoSuchOverloadErr(other)
|
||||
}
|
||||
szRight := len(lv.Entries)
|
||||
szLeft := int(oArr.Size().(types.Int))
|
||||
sz := szRight + szLeft
|
||||
combo := make([]ref.Val, sz)
|
||||
for i := 0; i < szRight; i++ {
|
||||
combo[i] = lv.Entries[i].ExprValue()
|
||||
}
|
||||
for i := 0; i < szLeft; i++ {
|
||||
combo[i+szRight] = oArr.Get(types.Int(i))
|
||||
}
|
||||
return types.DefaultTypeAdapter.NativeToValue(combo)
|
||||
}
|
||||
|
||||
// Append adds another entry into the ListValue.
|
||||
func (lv *ListValue) Append(entry *DynValue) {
|
||||
lv.Entries = append(lv.Entries, entry)
|
||||
// The append resets all previously built indices.
|
||||
lv.initValueSet = sync.Once{}
|
||||
}
|
||||
|
||||
// Contains returns whether the input `val` is equal to an element in the list.
|
||||
//
|
||||
// If any pair-wise comparison between the input value and the list element is an error, the
|
||||
// operation will return an error.
|
||||
func (lv *ListValue) Contains(val ref.Val) ref.Val {
|
||||
if types.IsUnknownOrError(val) {
|
||||
return val
|
||||
}
|
||||
lv.initValueSet.Do(lv.finalizeValueSet)
|
||||
if lv.valueSet != nil {
|
||||
_, found := lv.valueSet[val]
|
||||
if found {
|
||||
return types.True
|
||||
}
|
||||
// Instead of returning false, ensure that CEL's heterogeneous equality constraint
|
||||
// is satisfied by allowing pair-wise equality behavior to determine the outcome.
|
||||
}
|
||||
var err ref.Val
|
||||
sz := len(lv.Entries)
|
||||
for i := 0; i < sz; i++ {
|
||||
elem := lv.Entries[i]
|
||||
cmp := elem.Equal(val)
|
||||
b, ok := cmp.(types.Bool)
|
||||
if !ok && err == nil {
|
||||
err = types.MaybeNoSuchOverloadErr(cmp)
|
||||
}
|
||||
if b == types.True {
|
||||
return types.True
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return types.False
|
||||
}
|
||||
|
||||
// ConvertToNative is an implementation of the CEL ref.Val method used to adapt between CEL types
|
||||
// and Go-native array-like types.
|
||||
func (lv *ListValue) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
|
||||
// Non-list conversion.
|
||||
if typeDesc.Kind() != reflect.Slice &&
|
||||
typeDesc.Kind() != reflect.Array &&
|
||||
typeDesc.Kind() != reflect.Interface {
|
||||
return nil, fmt.Errorf("type conversion error from list to '%v'", typeDesc)
|
||||
}
|
||||
|
||||
// If the list is already assignable to the desired type return it.
|
||||
if reflect.TypeOf(lv).AssignableTo(typeDesc) {
|
||||
return lv, nil
|
||||
}
|
||||
|
||||
// List conversion.
|
||||
otherElem := typeDesc.Elem()
|
||||
|
||||
// Allow the element ConvertToNative() function to determine whether conversion is possible.
|
||||
sz := len(lv.Entries)
|
||||
nativeList := reflect.MakeSlice(typeDesc, int(sz), int(sz))
|
||||
for i := 0; i < sz; i++ {
|
||||
elem := lv.Entries[i]
|
||||
nativeElemVal, err := elem.ConvertToNative(otherElem)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nativeList.Index(int(i)).Set(reflect.ValueOf(nativeElemVal))
|
||||
}
|
||||
return nativeList.Interface(), nil
|
||||
}
|
||||
|
||||
// ConvertToType converts the ListValue to another CEL type.
|
||||
func (lv *ListValue) ConvertToType(t ref.Type) ref.Val {
|
||||
switch t {
|
||||
case types.ListType:
|
||||
return lv
|
||||
case types.TypeType:
|
||||
return types.ListType
|
||||
}
|
||||
return types.NewErr("type conversion error from '%s' to '%s'", ListType, t)
|
||||
}
|
||||
|
||||
// Equal returns true if two lists are of the same size, and the values at each index are also
|
||||
// equal.
|
||||
func (lv *ListValue) Equal(other ref.Val) ref.Val {
|
||||
oArr, isArr := other.(traits.Lister)
|
||||
if !isArr {
|
||||
return types.MaybeNoSuchOverloadErr(other)
|
||||
}
|
||||
sz := types.Int(len(lv.Entries))
|
||||
if sz != oArr.Size() {
|
||||
return types.False
|
||||
}
|
||||
for i := types.Int(0); i < sz; i++ {
|
||||
cmp := lv.Get(i).Equal(oArr.Get(i))
|
||||
if cmp != types.True {
|
||||
return cmp
|
||||
}
|
||||
}
|
||||
return types.True
|
||||
}
|
||||
|
||||
// Get returns the value at the given index.
|
||||
//
|
||||
// If the index is negative or greater than the size of the list, an error is returned.
|
||||
func (lv *ListValue) Get(idx ref.Val) ref.Val {
|
||||
iv, isInt := idx.(types.Int)
|
||||
if !isInt {
|
||||
return types.ValOrErr(idx, "unsupported index: %v", idx)
|
||||
}
|
||||
i := int(iv)
|
||||
if i < 0 || i >= len(lv.Entries) {
|
||||
return types.NewErr("index out of bounds: %v", idx)
|
||||
}
|
||||
return lv.Entries[i].ExprValue()
|
||||
}
|
||||
|
||||
// Iterator produces a traits.Iterator suitable for use in CEL comprehension macros.
|
||||
func (lv *ListValue) Iterator() traits.Iterator {
|
||||
return &baseListIterator{
|
||||
getter: lv.Get,
|
||||
sz: len(lv.Entries),
|
||||
}
|
||||
}
|
||||
|
||||
// Size returns the number of elements in the list.
|
||||
func (lv *ListValue) Size() ref.Val {
|
||||
return types.Int(len(lv.Entries))
|
||||
}
|
||||
|
||||
// Type returns the CEL ref.Type for the list.
|
||||
func (lv *ListValue) Type() ref.Type {
|
||||
return types.ListType
|
||||
}
|
||||
|
||||
// Value returns the Go-native value.
|
||||
func (lv *ListValue) Value() interface{} {
|
||||
return lv
|
||||
}
|
||||
|
||||
// finalizeValueSet inspects the ListValue entries in order to make internal optimizations once all list
|
||||
// entries are known.
|
||||
func (lv *ListValue) finalizeValueSet() {
|
||||
valueSet := make(map[ref.Val]struct{})
|
||||
for _, e := range lv.Entries {
|
||||
switch e.value.(type) {
|
||||
case bool, float64, int64, string, uint64, types.Null, PlainTextValue:
|
||||
valueSet[e.ExprValue()] = struct{}{}
|
||||
default:
|
||||
lv.valueSet = nil
|
||||
return
|
||||
}
|
||||
}
|
||||
lv.valueSet = valueSet
|
||||
}
|
||||
|
||||
type baseVal struct{}
|
||||
|
||||
func (*baseVal) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
|
||||
return nil, fmt.Errorf("unsupported native conversion to: %v", typeDesc)
|
||||
}
|
||||
|
||||
func (*baseVal) ConvertToType(t ref.Type) ref.Val {
|
||||
return types.NewErr("unsupported type conversion to: %v", t)
|
||||
}
|
||||
|
||||
func (*baseVal) Equal(other ref.Val) ref.Val {
|
||||
return types.NewErr("unsupported equality test between instances")
|
||||
}
|
||||
|
||||
func (v *baseVal) Value() interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
type baseListIterator struct {
|
||||
*baseVal
|
||||
getter func(idx ref.Val) ref.Val
|
||||
sz int
|
||||
idx int
|
||||
}
|
||||
|
||||
func (it *baseListIterator) HasNext() ref.Val {
|
||||
if it.idx < it.sz {
|
||||
return types.True
|
||||
}
|
||||
return types.False
|
||||
}
|
||||
|
||||
func (it *baseListIterator) Next() ref.Val {
|
||||
v := it.getter(types.Int(it.idx))
|
||||
it.idx++
|
||||
return v
|
||||
}
|
||||
|
||||
func (it *baseListIterator) Type() ref.Type {
|
||||
return types.IteratorType
|
||||
}
|
||||
|
||||
func celBool(pred bool) ref.Val {
|
||||
if pred {
|
||||
return types.True
|
||||
}
|
||||
return types.False
|
||||
}
|
||||
|
||||
var unknownType = &DeclType{name: "unknown", MinSerializedSize: 1}
|
||||
|
|
@ -0,0 +1,362 @@
|
|||
/*
|
||||
Copyright 2022 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package cel
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/cel-go/common/types"
|
||||
"github.com/google/cel-go/common/types/ref"
|
||||
"github.com/google/cel-go/common/types/traits"
|
||||
)
|
||||
|
||||
func TestConvertToType(t *testing.T) {
|
||||
objType := NewObjectType("TestObject", map[string]*DeclField{})
|
||||
tests := []struct {
|
||||
val interface{}
|
||||
typ ref.Type
|
||||
}{
|
||||
{true, types.BoolType},
|
||||
{float64(1.2), types.DoubleType},
|
||||
{int64(-42), types.IntType},
|
||||
{uint64(63), types.UintType},
|
||||
{time.Duration(300), types.DurationType},
|
||||
{time.Now().UTC(), types.TimestampType},
|
||||
{types.NullValue, types.NullType},
|
||||
{NewListValue(), types.ListType},
|
||||
{NewMapValue(), types.MapType},
|
||||
{[]byte("bytes"), types.BytesType},
|
||||
{NewObjectValue(objType), objType},
|
||||
}
|
||||
for i, tc := range tests {
|
||||
idx := i
|
||||
tst := tc
|
||||
t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) {
|
||||
dv := testValue(t, int64(idx), tst.val)
|
||||
ev := dv.ExprValue()
|
||||
if ev.ConvertToType(types.TypeType).(ref.Type).TypeName() != tst.typ.TypeName() {
|
||||
t.Errorf("got %v, wanted %v type", ev.ConvertToType(types.TypeType), tst.typ)
|
||||
}
|
||||
if ev.ConvertToType(tst.typ).Equal(ev) != types.True {
|
||||
t.Errorf("got %v, wanted input value %v", ev.ConvertToType(tst.typ), ev)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEqual(t *testing.T) {
|
||||
vals := []interface{}{
|
||||
true, []byte("bytes"), float64(1.2), int64(-42), uint64(63), time.Duration(300),
|
||||
time.Now().UTC(), types.NullValue, NewListValue(), NewMapValue(),
|
||||
NewObjectValue(NewObjectType("TestObject", map[string]*DeclField{})),
|
||||
}
|
||||
for i, v := range vals {
|
||||
dv := testValue(t, int64(i), v)
|
||||
if dv.Equal(dv.ExprValue()) != types.True {
|
||||
t.Errorf("got %v, wanted dyn value %v equal to itself", dv.Equal(dv.ExprValue()), dv.ExprValue())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestListValueAdd(t *testing.T) {
|
||||
lv := NewListValue()
|
||||
lv.Append(testValue(t, 1, "first"))
|
||||
ov := NewListValue()
|
||||
ov.Append(testValue(t, 2, "second"))
|
||||
ov.Append(testValue(t, 3, "third"))
|
||||
llv := NewListValue()
|
||||
llv.Append(testValue(t, 4, lv))
|
||||
lov := NewListValue()
|
||||
lov.Append(testValue(t, 5, ov))
|
||||
var v traits.Lister = llv.Add(lov).(traits.Lister)
|
||||
if v.Size() != types.Int(2) {
|
||||
t.Errorf("got list size %d, wanted 2", v.Size())
|
||||
}
|
||||
complex, err := v.ConvertToNative(reflect.TypeOf([][]string{}))
|
||||
complexList := complex.([][]string)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(complexList, [][]string{{"first"}, {"second", "third"}}) {
|
||||
t.Errorf("got %v, wanted [['first'], ['second', 'third']]", complexList)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListValueContains(t *testing.T) {
|
||||
lv := NewListValue()
|
||||
lv.Append(testValue(t, 1, "first"))
|
||||
lv.Append(testValue(t, 2, "second"))
|
||||
lv.Append(testValue(t, 3, "third"))
|
||||
for i := types.Int(0); i < lv.Size().(types.Int); i++ {
|
||||
e := lv.Get(i)
|
||||
contained := lv.Contains(e)
|
||||
if contained != types.True {
|
||||
t.Errorf("got %v, wanted list contains elem[%v] %v == true", contained, i, e)
|
||||
}
|
||||
}
|
||||
if lv.Contains(types.String("fourth")) != types.False {
|
||||
t.Errorf("got %v, wanted false 'fourth'", lv.Contains(types.String("fourth")))
|
||||
}
|
||||
if !types.IsError(lv.Contains(types.Int(-1))) {
|
||||
t.Errorf("got %v, wanted error for invalid type", lv.Contains(types.Int(-1)))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListValueContainsNestedList(t *testing.T) {
|
||||
lvA := NewListValue()
|
||||
lvA.Append(testValue(t, 1, int64(1)))
|
||||
lvA.Append(testValue(t, 2, int64(2)))
|
||||
|
||||
lvB := NewListValue()
|
||||
lvB.Append(testValue(t, 3, int64(3)))
|
||||
|
||||
elemA, elemB := testValue(t, 4, lvA), testValue(t, 5, lvB)
|
||||
lv := NewListValue()
|
||||
lv.Append(elemA)
|
||||
lv.Append(elemB)
|
||||
|
||||
contained := lv.Contains(elemA.ExprValue())
|
||||
if contained != types.True {
|
||||
t.Errorf("got %v, wanted elemA contained in list value", contained)
|
||||
}
|
||||
contained = lv.Contains(elemB.ExprValue())
|
||||
if contained != types.True {
|
||||
t.Errorf("got %v, wanted elemB contained in list value", contained)
|
||||
}
|
||||
contained = lv.Contains(types.DefaultTypeAdapter.NativeToValue([]int32{4}))
|
||||
if contained != types.False {
|
||||
t.Errorf("got %v, wanted empty list not contained", contained)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListValueConvertToNative(t *testing.T) {
|
||||
lv := NewListValue()
|
||||
none, err := lv.ConvertToNative(reflect.TypeOf([]interface{}{}))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(none, []interface{}{}) {
|
||||
t.Errorf("got %v, wanted empty list", none)
|
||||
}
|
||||
lv.Append(testValue(t, 1, "first"))
|
||||
one, err := lv.ConvertToNative(reflect.TypeOf([]string{}))
|
||||
oneList := one.([]string)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(oneList) != 1 {
|
||||
t.Errorf("got len(one) == %d, wanted 1", len(oneList))
|
||||
}
|
||||
if !reflect.DeepEqual(oneList, []string{"first"}) {
|
||||
t.Errorf("got %v, wanted string list", oneList)
|
||||
}
|
||||
ov := NewListValue()
|
||||
ov.Append(testValue(t, 2, "second"))
|
||||
ov.Append(testValue(t, 3, "third"))
|
||||
if ov.Size() != types.Int(2) {
|
||||
t.Errorf("got list size %d, wanted 2", ov.Size())
|
||||
}
|
||||
llv := NewListValue()
|
||||
llv.Append(testValue(t, 4, lv))
|
||||
llv.Append(testValue(t, 5, ov))
|
||||
if llv.Size() != types.Int(2) {
|
||||
t.Errorf("got list size %d, wanted 2", llv.Size())
|
||||
}
|
||||
complex, err := llv.ConvertToNative(reflect.TypeOf([][]string{}))
|
||||
complexList := complex.([][]string)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(complexList, [][]string{{"first"}, {"second", "third"}}) {
|
||||
t.Errorf("got %v, wanted [['first'], ['second', 'third']]", complexList)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListValueIterator(t *testing.T) {
|
||||
lv := NewListValue()
|
||||
lv.Append(testValue(t, 1, "first"))
|
||||
lv.Append(testValue(t, 2, "second"))
|
||||
lv.Append(testValue(t, 3, "third"))
|
||||
it := lv.Iterator()
|
||||
if it.Type() != types.IteratorType {
|
||||
t.Errorf("got type %v for iterator, wanted IteratorType", it.Type())
|
||||
}
|
||||
i := types.Int(0)
|
||||
for it.HasNext() == types.True {
|
||||
v := it.Next()
|
||||
if v.Equal(lv.Get(i)) != types.True {
|
||||
t.Errorf("iterator value %v and value %v at index %d not equal", v, lv.Get(i), i)
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapValueConvertToNative(t *testing.T) {
|
||||
mv := NewMapValue()
|
||||
none, err := mv.ConvertToNative(reflect.TypeOf(map[string]interface{}{}))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(none, map[string]interface{}{}) {
|
||||
t.Errorf("got %v, wanted empty map", none)
|
||||
}
|
||||
none, err = mv.ConvertToNative(reflect.TypeOf(map[interface{}]interface{}{}))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(none, map[interface{}]interface{}{}) {
|
||||
t.Errorf("got %v, wanted empty map", none)
|
||||
}
|
||||
mv.AddField(NewField(1, "Test"))
|
||||
tst, _ := mv.GetField("Test")
|
||||
tst.Ref = testValue(t, 2, uint64(12))
|
||||
mv.AddField(NewField(3, "Check"))
|
||||
chk, _ := mv.GetField("Check")
|
||||
chk.Ref = testValue(t, 4, uint64(34))
|
||||
if mv.Size() != types.Int(2) {
|
||||
t.Errorf("got size %d, wanted 2", mv.Size())
|
||||
}
|
||||
if mv.Contains(types.String("Test")) != types.True {
|
||||
t.Error("key 'Test' not found")
|
||||
}
|
||||
if mv.Contains(types.String("Check")) != types.True {
|
||||
t.Error("key 'Check' not found")
|
||||
}
|
||||
if mv.Contains(types.String("Checked")) != types.False {
|
||||
t.Error("key 'Checked' found, wanted not found")
|
||||
}
|
||||
it := mv.Iterator()
|
||||
for it.HasNext() == types.True {
|
||||
k := it.Next()
|
||||
v := mv.Get(k)
|
||||
if k == types.String("Test") && v != types.Uint(12) {
|
||||
t.Errorf("key 'Test' not equal to 12u")
|
||||
}
|
||||
if k == types.String("Check") && v != types.Uint(34) {
|
||||
t.Errorf("key 'Check' not equal to 34u")
|
||||
}
|
||||
}
|
||||
mpStrUint, err := mv.ConvertToNative(reflect.TypeOf(map[string]uint64{}))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(mpStrUint, map[string]uint64{
|
||||
"Test": uint64(12),
|
||||
"Check": uint64(34),
|
||||
}) {
|
||||
t.Errorf("got %v, wanted {'Test': 12u, 'Check': 34u}", mpStrUint)
|
||||
}
|
||||
tstStr, err := mv.ConvertToNative(reflect.TypeOf(&tstStruct{}))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(tstStr, &tstStruct{
|
||||
Test: uint64(12),
|
||||
Check: uint64(34),
|
||||
}) {
|
||||
t.Errorf("got %v, wanted tstStruct{Test: 12u, Check: 34u}", tstStr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapValueEqual(t *testing.T) {
|
||||
mv := NewMapValue()
|
||||
name := NewField(1, "name")
|
||||
name.Ref = testValue(t, 2, "alert")
|
||||
priority := NewField(3, "priority")
|
||||
priority.Ref = testValue(t, 4, int64(4))
|
||||
mv.AddField(name)
|
||||
mv.AddField(priority)
|
||||
if mv.Equal(mv) != types.True {
|
||||
t.Fatalf("map.Equal(map) failed: %v", mv.Equal(mv))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapValueNotEqual(t *testing.T) {
|
||||
mv := NewMapValue()
|
||||
name := NewField(1, "name")
|
||||
name.Ref = testValue(t, 2, "alert")
|
||||
priority := NewField(3, "priority")
|
||||
priority.Ref = testValue(t, 4, int64(4))
|
||||
mv.AddField(name)
|
||||
mv.AddField(priority)
|
||||
|
||||
mv2 := NewMapValue()
|
||||
mv2.AddField(name)
|
||||
if mv.Equal(mv2) != types.False {
|
||||
t.Fatalf("mv.Equal(mv2) failed: %v", mv.Equal(mv2))
|
||||
}
|
||||
|
||||
priority2 := NewField(5, "priority")
|
||||
priority2.Ref = testValue(t, 6, int64(3))
|
||||
mv2.AddField(priority2)
|
||||
if mv.Equal(mv2) != types.False {
|
||||
t.Fatalf("mv.Equal(mv2) failed: %v", mv.Equal(mv2))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapValueIsSet(t *testing.T) {
|
||||
mv := NewMapValue()
|
||||
if mv.IsSet(types.String("name")) != types.False {
|
||||
t.Error("map.IsSet('name') returned true for unset key")
|
||||
}
|
||||
mv.AddField(NewField(1, "name"))
|
||||
if mv.IsSet(types.String("name")) != types.True {
|
||||
t.Error("map.IsSet('name') returned false for a set key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestObjectValueEqual(t *testing.T) {
|
||||
objType := NewObjectType("Notice", map[string]*DeclField{
|
||||
"name": {Name: "name", Type: StringType},
|
||||
"priority": {Name: "priority", Type: IntType},
|
||||
"message": {Name: "message", Type: StringType, defaultValue: "<eom>"},
|
||||
})
|
||||
name := NewField(1, "name")
|
||||
name.Ref = testValue(t, 2, "alert")
|
||||
priority := NewField(3, "priority")
|
||||
priority.Ref = testValue(t, 4, int64(4))
|
||||
message := NewField(5, "message")
|
||||
message.Ref = testValue(t, 6, "call immediately")
|
||||
|
||||
mv1 := NewMapValue()
|
||||
mv1.AddField(name)
|
||||
mv1.AddField(priority)
|
||||
obj1 := mv1.ConvertToObject(objType)
|
||||
if obj1.Equal(obj1) != types.True {
|
||||
t.Errorf("obj1.Equal(obj1) failed, got: %v", obj1.Equal(obj1))
|
||||
}
|
||||
|
||||
mv2 := NewMapValue()
|
||||
mv2.AddField(name)
|
||||
mv2.AddField(priority)
|
||||
mv2.AddField(message)
|
||||
obj2 := mv2.ConvertToObject(objType)
|
||||
if obj1.Equal(obj2) == types.True {
|
||||
t.Error("obj1.Equal(obj2) returned true, wanted false")
|
||||
}
|
||||
if obj2.Equal(obj1) == types.True {
|
||||
t.Error("obj2.Equal(obj1) returned true, wanted false")
|
||||
}
|
||||
}
|
||||
|
||||
type tstStruct struct {
|
||||
Test uint64
|
||||
Check uint64
|
||||
}
|
||||
|
|
@ -366,7 +366,7 @@ func NewConfig(codecs serializer.CodecFactory) *Config {
|
|||
// A request body might be encoded in json, and is converted to
|
||||
// proto when persisted in etcd, so we allow 2x as the largest request
|
||||
// body size to be accepted and decoded in a write request.
|
||||
// If this constant is changed, maxRequestSizeBytes in apiextensions-apiserver/pkg/apiserver/schema/cel/model/schemas.go
|
||||
// If this constant is changed, DefaultMaxRequestSizeBytes in k8s.io/apiserver/pkg/cel/limits.go
|
||||
// should be changed to reflect the new value, if the two haven't
|
||||
// been wired together already somehow.
|
||||
MaxRequestBodyBytes: int64(3 * 1024 * 1024),
|
||||
|
|
|
|||
Loading…
Reference in New Issue