Merge pull request #112926 from jiahuif-forks/refactor/cel-out-of-apiextensions

split and move CEL package

Kubernetes-commit: 61ca612cbb85efa13444a6d8ae517cd5e9c151a4
This commit is contained in:
Kubernetes Publisher 2022-10-12 15:03:03 -07:00
commit db8c02bd35
22 changed files with 4009 additions and 11 deletions

15
go.mod
View File

@ -12,6 +12,7 @@ require (
github.com/evanphx/json-patch v4.12.0+incompatible
github.com/fsnotify/fsnotify v1.5.4
github.com/gogo/protobuf v1.3.2
github.com/google/cel-go v0.12.5
github.com/google/gnostic v0.5.7-v3refs
github.com/google/go-cmp v0.5.9
github.com/google/gofuzz v1.1.0
@ -36,13 +37,15 @@ require (
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8
google.golang.org/genproto v0.0.0-20220502173005-c8bf987b8c21
google.golang.org/grpc v1.49.0
google.golang.org/protobuf v1.28.1
gopkg.in/natefinch/lumberjack.v2 v2.0.0
gopkg.in/square/go-jose.v2 v2.2.2
k8s.io/api v0.0.0-20221012035047-0f8110492ea0
k8s.io/api v0.0.0-20221012115127-0184bd884c5e
k8s.io/apimachinery v0.0.0-20221012034848-78d003cc9419
k8s.io/client-go v0.0.0-20221012035333-e6d958c7a853
k8s.io/component-base v0.0.0-20221012040034-5d2a88c65282
k8s.io/component-base v0.0.0-20221012235520-6ecca3322b4e
k8s.io/klog/v2 v2.80.1
k8s.io/kms v0.0.0-20221012040222-bf322548c086
k8s.io/kube-openapi v0.0.0-20220803162953-67bda5d908f1
@ -56,6 +59,7 @@ require (
require (
cloud.google.com/go v0.97.0 // indirect
github.com/NYTimes/gziphandler v1.1.1 // indirect
github.com/antlr/antlr4/runtime/Go/antlr v1.4.10 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/blang/semver/v4 v4.0.0 // indirect
github.com/cenkalti/backoff/v4 v4.1.3 // indirect
@ -95,6 +99,7 @@ require (
github.com/sirupsen/logrus v1.8.1 // indirect
github.com/soheilhy/cmux v0.1.5 // indirect
github.com/spf13/cobra v1.5.0 // indirect
github.com/stoewer/go-strcase v1.2.0 // indirect
github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802 // indirect
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 // indirect
go.etcd.io/bbolt v1.3.6 // indirect
@ -111,17 +116,15 @@ require (
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect
golang.org/x/text v0.3.8 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20220502173005-c8bf987b8c21 // indirect
google.golang.org/protobuf v1.28.1 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
replace (
k8s.io/api => k8s.io/api v0.0.0-20221012035047-0f8110492ea0
k8s.io/api => k8s.io/api v0.0.0-20221012115127-0184bd884c5e
k8s.io/apimachinery => k8s.io/apimachinery v0.0.0-20221012034848-78d003cc9419
k8s.io/client-go => k8s.io/client-go v0.0.0-20221012035333-e6d958c7a853
k8s.io/component-base => k8s.io/component-base v0.0.0-20221012040034-5d2a88c65282
k8s.io/component-base => k8s.io/component-base v0.0.0-20221012235520-6ecca3322b4e
k8s.io/kms => k8s.io/kms v0.0.0-20221012040222-bf322548c086
)

13
go.sum
View File

@ -57,6 +57,8 @@ github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRF
github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho=
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
github.com/antlr/antlr4/runtime/Go/antlr v1.4.10 h1:yL7+Jz0jTC6yykIK/Wh74gnTJnrGr5AyrNMXuA0gves=
github.com/antlr/antlr4/runtime/Go/antlr v1.4.10/go.mod h1:F7bn7fEU90QkQ3tnmaTx3LTKLEDqnwWODIYppRQ5hnY=
github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o=
github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY=
github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
@ -220,6 +222,8 @@ github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Z
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
github.com/google/cel-go v0.12.5 h1:DmzaiSgoaqGCjtpPQWl26/gND+yRpim56H1jCVev6d8=
github.com/google/cel-go v0.12.5/go.mod h1:Jk7ljRzLBhkmiAwBoUxB1sZSCVBAzkqPF25olK/iRDw=
github.com/google/gnostic v0.5.7-v3refs h1:FhTMOKj2VhjpouxvWJAV1TL304uMlb9zcDqkl6cEI54=
github.com/google/gnostic v0.5.7-v3refs/go.mod h1:73MKFl6jIHelAJNaBGFzt3SPtZULs9dYrGFt8OiIsHQ=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
@ -439,6 +443,7 @@ github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnIn
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.7.0/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg=
github.com/stoewer/go-strcase v1.2.0 h1:Z2iHWqGXH00XYgqDmNgQbIBxf3wrNq0F3feEy0ainaU=
github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@ -977,14 +982,14 @@ honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWh
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
k8s.io/api v0.0.0-20221012035047-0f8110492ea0 h1:zsKtQXVrOo62IwClRKcF/RbgeFI2FQFZ99YhdfypWn8=
k8s.io/api v0.0.0-20221012035047-0f8110492ea0/go.mod h1:Q2jmki3lhHeeAyYtbIiEe/2JoGg2Ge1cJoMgpFkjtzg=
k8s.io/api v0.0.0-20221012115127-0184bd884c5e h1:N2+121lkGMNfLeos4/AlwuujPJd5xJfFkyi15BUzob8=
k8s.io/api v0.0.0-20221012115127-0184bd884c5e/go.mod h1:Q2jmki3lhHeeAyYtbIiEe/2JoGg2Ge1cJoMgpFkjtzg=
k8s.io/apimachinery v0.0.0-20221012034848-78d003cc9419 h1:6islJrEgy0CM8YsBcMFdlxm1lXtVC9X5A2AVRK7JEpc=
k8s.io/apimachinery v0.0.0-20221012034848-78d003cc9419/go.mod h1:1b4APDhID8eky0PLpgeoWtdLUptwwJD8Jk5uFOd22CE=
k8s.io/client-go v0.0.0-20221012035333-e6d958c7a853 h1:MA0fOvlQt7fzsEIujzUH4ssakSCKB6+iWRclNvRYAnk=
k8s.io/client-go v0.0.0-20221012035333-e6d958c7a853/go.mod h1:UVW05h29JMuginXCTvlNPTFVS/AyKlU3Q4XoA6UidtU=
k8s.io/component-base v0.0.0-20221012040034-5d2a88c65282 h1:JqSl7JGZ8FwUFz9Isq0YIJNLx5oK5x/g0Vl2VtVntI4=
k8s.io/component-base v0.0.0-20221012040034-5d2a88c65282/go.mod h1:t3mMmVABnDZIpNgn7xDM2AdNUr2jvWA8pleZlSDjAug=
k8s.io/component-base v0.0.0-20221012235520-6ecca3322b4e h1:Gg6uxVzmcw4M5KKG4jy6mE+CahWF6VvA39YUsXTaJcI=
k8s.io/component-base v0.0.0-20221012235520-6ecca3322b4e/go.mod h1:nOngF+6Y18fCD8T0L9mxX2FDBo3vzC1t6hsFv6DE3lI=
k8s.io/klog/v2 v2.80.1 h1:atnLQ121W371wYYFawwYx1aEY2eUfs4l3J72wtgAwV4=
k8s.io/klog/v2 v2.80.1/go.mod h1:y1WjHnz7Dj687irZUWR/WLkLc5N1YHtjLdmgWjndZn0=
k8s.io/kms v0.0.0-20221012040222-bf322548c086 h1:BCrccAO6WPPwlzfgds9ERms+pPGvNJ4QpakEuqEQ5rk=

47
pkg/cel/errors.go Normal file
View File

@ -0,0 +1,47 @@
/*
Copyright 2021 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cel
// Error is an implementation of the 'error' interface, which represents a
// XValidation error.
type Error struct {
Type ErrorType
Detail string
}
var _ error = &Error{}
// Error implements the error interface.
func (v *Error) Error() string {
return v.Detail
}
// ErrorType is a machine readable value providing more detail about why
// a XValidation is invalid.
type ErrorType string
const (
// ErrorTypeRequired is used to report withNullable values that are not
// provided (e.g. empty strings, null values, or empty arrays). See
// Required().
ErrorTypeRequired ErrorType = "RuleRequired"
// ErrorTypeInvalid is used to report malformed values
ErrorTypeInvalid ErrorType = "RuleInvalid"
// ErrorTypeInternal is used to report other errors that are not related
// to user input. See InternalError().
ErrorTypeInternal ErrorType = "InternalError"
)

170
pkg/cel/escaping.go Normal file
View File

@ -0,0 +1,170 @@
/*
Copyright 2021 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cel
import (
"regexp"
"k8s.io/apimachinery/pkg/util/sets"
)
// celReservedSymbols is a list of RESERVED symbols defined in the CEL lexer.
// No identifiers are allowed to collide with these symbols.
// https://github.com/google/cel-spec/blob/master/doc/langdef.md#syntax
var celReservedSymbols = sets.NewString(
"true", "false", "null", "in",
"as", "break", "const", "continue", "else",
"for", "function", "if", "import", "let",
"loop", "package", "namespace", "return", // !! 'namespace' is used heavily in Kubernetes
"var", "void", "while",
)
// expandMatcher matches the escape sequence, characters that are escaped, and characters that are unsupported
var expandMatcher = regexp.MustCompile(`(__|[-./]|[^a-zA-Z0-9-./_])`)
// newCharacterFilter returns a boolean array to indicate the allowed characters
func newCharacterFilter(characters string) []bool {
maxChar := 0
for _, c := range characters {
if maxChar < int(c) {
maxChar = int(c)
}
}
filter := make([]bool, maxChar+1)
for _, c := range characters {
filter[int(c)] = true
}
return filter
}
type escapeCheck struct {
canSkipRegex bool
invalidCharFound bool
}
// skipRegexCheck checks if escape would be skipped.
// if invalidCharFound is true, it must have invalid character; if invalidCharFound is false, not sure if it has invalid character or not
func skipRegexCheck(ident string) escapeCheck {
escapeCheck := escapeCheck{canSkipRegex: true, invalidCharFound: false}
// skip escape if possible
previous_underscore := false
for _, c := range ident {
if c == '/' || c == '-' || c == '.' {
escapeCheck.canSkipRegex = false
return escapeCheck
}
intc := int(c)
if intc < 0 || intc >= len(validCharacterFilter) || !validCharacterFilter[intc] {
escapeCheck.invalidCharFound = true
return escapeCheck
}
if c == '_' && previous_underscore {
escapeCheck.canSkipRegex = false
return escapeCheck
}
previous_underscore = c == '_'
}
return escapeCheck
}
// validCharacterFilter indicates the allowed characters.
var validCharacterFilter = newCharacterFilter("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_")
// Escape escapes ident and returns a CEL identifier (of the form '[a-zA-Z_][a-zA-Z0-9_]*'), or returns
// false if the ident does not match the supported input format of `[a-zA-Z_.-/][a-zA-Z0-9_.-/]*`.
// Escaping Rules:
// - '__' escapes to '__underscores__'
// - '.' escapes to '__dot__'
// - '-' escapes to '__dash__'
// - '/' escapes to '__slash__'
// - Identifiers that exactly match a CEL RESERVED keyword escape to '__{keyword}__'. The keywords are: "true", "false",
// "null", "in", "as", "break", "const", "continue", "else", "for", "function", "if", "import", "let", loop", "package",
// "namespace", "return".
func Escape(ident string) (string, bool) {
if len(ident) == 0 || ('0' <= ident[0] && ident[0] <= '9') {
return "", false
}
if celReservedSymbols.Has(ident) {
return "__" + ident + "__", true
}
escapeCheck := skipRegexCheck(ident)
if escapeCheck.invalidCharFound {
return "", false
}
if escapeCheck.canSkipRegex {
return ident, true
}
ok := true
ident = expandMatcher.ReplaceAllStringFunc(ident, func(s string) string {
switch s {
case "__":
return "__underscores__"
case ".":
return "__dot__"
case "-":
return "__dash__"
case "/":
return "__slash__"
default: // matched a unsupported supported
ok = false
return ""
}
})
if !ok {
return "", false
}
return ident, true
}
var unexpandMatcher = regexp.MustCompile(`(_{2}[^_]+_{2})`)
// Unescape unescapes an CEL identifier containing the escape sequences described in Escape, or return false if the
// string contains invalid escape sequences. The escaped input is expected to be a valid CEL identifier, but is
// not checked.
func Unescape(escaped string) (string, bool) {
ok := true
escaped = unexpandMatcher.ReplaceAllStringFunc(escaped, func(s string) string {
contents := s[2 : len(s)-2]
switch contents {
case "underscores":
return "__"
case "dot":
return "."
case "dash":
return "-"
case "slash":
return "/"
}
if celReservedSymbols.Has(contents) {
if len(s) != len(escaped) {
ok = false
}
return contents
}
ok = false
return ""
})
if !ok {
return "", false
}
return escaped, true
}

206
pkg/cel/escaping_test.go Normal file
View File

@ -0,0 +1,206 @@
/*
Copyright 2021 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cel
import (
"fmt"
"regexp"
"testing"
fuzz "github.com/google/gofuzz"
)
// TestEscaping tests that property names are escaped as expected.
func TestEscaping(t *testing.T) {
cases := []struct {
unescaped string
escaped string
unescapable bool
}{
// '.', '-', '/' and '__' are escaped since
// CEL only allows identifiers of the form: [a-zA-Z_][a-zA-Z0-9_]*
{unescaped: "a.a", escaped: "a__dot__a"},
{unescaped: "a-a", escaped: "a__dash__a"},
{unescaped: "a__a", escaped: "a__underscores__a"},
{unescaped: "a.-/__a", escaped: "a__dot____dash____slash____underscores__a"},
{unescaped: "a._a", escaped: "a__dot___a"},
{unescaped: "a__.__a", escaped: "a__underscores____dot____underscores__a"},
{unescaped: "a___a", escaped: "a__underscores___a"},
{unescaped: "a____a", escaped: "a__underscores____underscores__a"},
{unescaped: "a__dot__a", escaped: "a__underscores__dot__underscores__a"},
{unescaped: "a__underscores__a", escaped: "a__underscores__underscores__underscores__a"},
// CEL lexer RESERVED keywords must be escaped
{unescaped: "true", escaped: "__true__"},
{unescaped: "false", escaped: "__false__"},
{unescaped: "null", escaped: "__null__"},
{unescaped: "in", escaped: "__in__"},
{unescaped: "as", escaped: "__as__"},
{unescaped: "break", escaped: "__break__"},
{unescaped: "const", escaped: "__const__"},
{unescaped: "continue", escaped: "__continue__"},
{unescaped: "else", escaped: "__else__"},
{unescaped: "for", escaped: "__for__"},
{unescaped: "function", escaped: "__function__"},
{unescaped: "if", escaped: "__if__"},
{unescaped: "import", escaped: "__import__"},
{unescaped: "let", escaped: "__let__"},
{unescaped: "loop", escaped: "__loop__"},
{unescaped: "package", escaped: "__package__"},
{unescaped: "namespace", escaped: "__namespace__"},
{unescaped: "return", escaped: "__return__"},
{unescaped: "var", escaped: "__var__"},
{unescaped: "void", escaped: "__void__"},
{unescaped: "while", escaped: "__while__"},
// Not all property names are escapable
{unescaped: "@", unescapable: true},
{unescaped: "1up", unescapable: true},
{unescaped: "👑", unescapable: true},
// CEL macro and function names do not need to be escaped because the parser keeps identifiers in a
// different namespace than function and macro names.
{unescaped: "has", escaped: "has"},
{unescaped: "all", escaped: "all"},
{unescaped: "exists", escaped: "exists"},
{unescaped: "exists_one", escaped: "exists_one"},
{unescaped: "filter", escaped: "filter"},
{unescaped: "size", escaped: "size"},
{unescaped: "contains", escaped: "contains"},
{unescaped: "startsWith", escaped: "startsWith"},
{unescaped: "endsWith", escaped: "endsWith"},
{unescaped: "matches", escaped: "matches"},
{unescaped: "duration", escaped: "duration"},
{unescaped: "timestamp", escaped: "timestamp"},
{unescaped: "getDate", escaped: "getDate"},
{unescaped: "getDayOfMonth", escaped: "getDayOfMonth"},
{unescaped: "getDayOfWeek", escaped: "getDayOfWeek"},
{unescaped: "getFullYear", escaped: "getFullYear"},
{unescaped: "getHours", escaped: "getHours"},
{unescaped: "getMilliseconds", escaped: "getMilliseconds"},
{unescaped: "getMinutes", escaped: "getMinutes"},
{unescaped: "getMonth", escaped: "getMonth"},
{unescaped: "getSeconds", escaped: "getSeconds"},
// we don't escape a single _
{unescaped: "_if", escaped: "_if"},
{unescaped: "_has", escaped: "_has"},
{unescaped: "_int", escaped: "_int"},
{unescaped: "_anything", escaped: "_anything"},
}
for _, tc := range cases {
t.Run(tc.unescaped, func(t *testing.T) {
e, escapable := Escape(tc.unescaped)
if tc.unescapable {
if escapable {
t.Errorf("Expected escapable=false, but got %t", escapable)
}
return
}
if !escapable {
t.Fatalf("Expected escapable=true, but got %t", escapable)
}
if tc.escaped != e {
t.Errorf("Expected %s to escape to %s, but got %s", tc.unescaped, tc.escaped, e)
}
if !validCelIdent.MatchString(e) {
t.Errorf("Expected %s to escape to a valid CEL identifier, but got %s", tc.unescaped, e)
}
u, ok := Unescape(tc.escaped)
if !ok {
t.Fatalf("Expected %s to be escapable, but it was not", tc.escaped)
}
if tc.unescaped != u {
t.Errorf("Expected %s to unescape to %s, but got %s", tc.escaped, tc.unescaped, u)
}
})
}
}
func TestUnescapeMalformed(t *testing.T) {
for _, s := range []string{"__int__extra", "__illegal__"} {
t.Run(s, func(t *testing.T) {
e, ok := Unescape(s)
if ok {
t.Fatalf("Expected %s to be unescapable, but it escaped to: %s", s, e)
}
})
}
}
func TestEscapingFuzz(t *testing.T) {
fuzzer := fuzz.New()
for i := 0; i < 1000; i++ {
var unescaped string
fuzzer.Fuzz(&unescaped)
t.Run(fmt.Sprintf("%d - '%s'", i, unescaped), func(t *testing.T) {
if len(unescaped) == 0 {
return
}
escaped, ok := Escape(unescaped)
if !ok {
return
}
if !validCelIdent.MatchString(escaped) {
t.Errorf("Expected %s to escape to a valid CEL identifier, but got %s", unescaped, escaped)
}
u, ok := Unescape(escaped)
if !ok {
t.Fatalf("Expected %s to be unescapable, but it was not", escaped)
}
if unescaped != u {
t.Errorf("Expected %s to unescape to %s, but got %s", escaped, unescaped, u)
}
})
}
}
var validCelIdent = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
func TestCanSkipRegex(t *testing.T) {
cases := []struct {
unescaped string
canSkip bool
invalidCharFound bool
}{
{unescaped: "a.a", canSkip: false},
{unescaped: "a-a", canSkip: false},
{unescaped: "a__a", canSkip: false},
{unescaped: "a.-/__a", canSkip: false},
{unescaped: "a_a", canSkip: true},
{unescaped: "a_a_a", canSkip: true},
{unescaped: "@", invalidCharFound: true},
{unescaped: "👑", invalidCharFound: true},
// if escape keyword is detected before invalid character, invalidCharFound
{unescaped: "/👑", canSkip: false},
}
for _, tc := range cases {
t.Run(tc.unescaped, func(t *testing.T) {
escapeCheck := skipRegexCheck(tc.unescaped)
if escapeCheck.invalidCharFound {
if escapeCheck.invalidCharFound != tc.invalidCharFound {
t.Errorf("Expected input validation to be %v, but got %t", tc.invalidCharFound, escapeCheck.invalidCharFound)
}
} else {
if escapeCheck.canSkipRegex != tc.canSkip {
t.Errorf("Expected %v, but got %t", tc.canSkip, escapeCheck.canSkipRegex)
}
}
})
}
}

268
pkg/cel/library/cost.go Normal file
View File

@ -0,0 +1,268 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package library
import (
"math"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// CostEstimator implements CEL's interpretable.ActualCostEstimator and checker.CostEstimator.
type CostEstimator struct {
// SizeEstimator provides a CostEstimator.EstimateSize that this CostEstimator will delegate size estimation
// calculations to if the size is not well known (i.e. a constant).
SizeEstimator checker.CostEstimator
}
func (l *CostEstimator) CallCost(function, overloadId string, args []ref.Val, result ref.Val) *uint64 {
switch function {
case "isSorted", "sum", "max", "min", "indexOf", "lastIndexOf":
var cost uint64
if len(args) > 0 {
cost += traversalCost(args[0]) // these O(n) operations all cost roughly the cost of a single traversal
}
return &cost
case "url", "lowerAscii", "upperAscii", "substring", "trim":
if len(args) >= 1 {
cost := uint64(math.Ceil(float64(actualSize(args[0])) * common.StringTraversalCostFactor))
return &cost
}
case "replace", "split":
if len(args) >= 1 {
// cost is the traversal plus the construction of the result
cost := uint64(math.Ceil(float64(actualSize(args[0])) * 2 * common.StringTraversalCostFactor))
return &cost
}
case "join":
if len(args) >= 1 {
cost := uint64(math.Ceil(float64(actualSize(result)) * 2 * common.StringTraversalCostFactor))
return &cost
}
case "find", "findAll":
if len(args) >= 2 {
strCost := uint64(math.Ceil((1.0 + float64(actualSize(args[0]))) * common.StringTraversalCostFactor))
// We don't know how many expressions are in the regex, just the string length (a huge
// improvement here would be to somehow get a count the number of expressions in the regex or
// how many states are in the regex state machine and use that to measure regex cost).
// For now, we're making a guess that each expression in a regex is typically at least 4 chars
// in length.
regexCost := uint64(math.Ceil(float64(actualSize(args[1])) * common.RegexStringLengthCostFactor))
cost := strCost * regexCost
return &cost
}
}
return nil
}
func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
// WARNING: Any changes to this code impact API compatibility! The estimated cost is used to determine which CEL rules may be written to a
// CRD and any change (cost increases and cost decreases) are breaking.
switch function {
case "isSorted", "sum", "max", "min", "indexOf", "lastIndexOf":
if target != nil {
// Charge 1 cost for comparing each element in the list
elCost := checker.CostEstimate{Min: 1, Max: 1}
// If the list contains strings or bytes, add the cost of traversing all the strings/bytes as a way
// of estimating the additional comparison cost.
if elNode := l.listElementNode(*target); elNode != nil {
t := elNode.Type().GetPrimitive()
if t == exprpb.Type_STRING || t == exprpb.Type_BYTES {
sz := l.sizeEstimate(elNode)
elCost = elCost.Add(sz.MultiplyByCostFactor(common.StringTraversalCostFactor))
}
return &checker.CallEstimate{CostEstimate: l.sizeEstimate(*target).MultiplyByCost(elCost)}
} else { // the target is a string, which is supported by indexOf and lastIndexOf
return &checker.CallEstimate{CostEstimate: l.sizeEstimate(*target).MultiplyByCostFactor(common.StringTraversalCostFactor)}
}
}
case "url":
if len(args) == 1 {
sz := l.sizeEstimate(args[0])
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor)}
}
case "lowerAscii", "upperAscii", "substring", "trim":
if target != nil {
sz := l.sizeEstimate(*target)
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor), ResultSize: &sz}
}
case "replace":
if target != nil && len(args) >= 2 {
sz := l.sizeEstimate(*target)
toReplaceSz := l.sizeEstimate(args[0])
replaceWithSz := l.sizeEstimate(args[1])
// smallest possible result: smallest input size composed of the largest possible substrings being replaced by smallest possible replacement
minSz := uint64(math.Ceil(float64(sz.Min)/float64(toReplaceSz.Max))) * replaceWithSz.Min
// largest possible result: largest input size composed of the smallest possible substrings being replaced by largest possible replacement
maxSz := uint64(math.Ceil(float64(sz.Max)/float64(toReplaceSz.Min))) * replaceWithSz.Max
// cost is the traversal plus the construction of the result
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(2 * common.StringTraversalCostFactor), ResultSize: &checker.SizeEstimate{Min: minSz, Max: maxSz}}
}
case "split":
if target != nil {
sz := l.sizeEstimate(*target)
// Worst case size is where is that a separator of "" is used, and each char is returned as a list element.
max := sz.Max
if len(args) > 1 {
if c := args[1].Expr().GetConstExpr(); c != nil {
max = uint64(c.GetInt64Value())
}
}
// Cost is the traversal plus the construction of the result.
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(2 * common.StringTraversalCostFactor), ResultSize: &checker.SizeEstimate{Min: 0, Max: max}}
}
case "join":
if target != nil {
var sz checker.SizeEstimate
listSize := l.sizeEstimate(*target)
if elNode := l.listElementNode(*target); elNode != nil {
elemSize := l.sizeEstimate(elNode)
sz = listSize.Multiply(elemSize)
}
if len(args) > 0 {
sepSize := l.sizeEstimate(args[0])
minSeparators := uint64(0)
maxSeparators := uint64(0)
if listSize.Min > 0 {
minSeparators = listSize.Min - 1
}
if listSize.Max > 0 {
maxSeparators = listSize.Max - 1
}
sz = sz.Add(sepSize.Multiply(checker.SizeEstimate{Min: minSeparators, Max: maxSeparators}))
}
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor), ResultSize: &sz}
}
case "find", "findAll":
if target != nil && len(args) >= 1 {
sz := l.sizeEstimate(*target)
// Add one to string length for purposes of cost calculation to prevent product of string and regex to be 0
// in case where string is empty but regex is still expensive.
strCost := sz.Add(checker.SizeEstimate{Min: 1, Max: 1}).MultiplyByCostFactor(common.StringTraversalCostFactor)
// We don't know how many expressions are in the regex, just the string length (a huge
// improvement here would be to somehow get a count the number of expressions in the regex or
// how many states are in the regex state machine and use that to measure regex cost).
// For now, we're making a guess that each expression in a regex is typically at least 4 chars
// in length.
regexCost := l.sizeEstimate(args[0]).MultiplyByCostFactor(common.RegexStringLengthCostFactor)
// worst case size of result is that every char is returned as separate find result.
return &checker.CallEstimate{CostEstimate: strCost.Multiply(regexCost), ResultSize: &checker.SizeEstimate{Min: 0, Max: sz.Max}}
}
}
return nil
}
func actualSize(value ref.Val) uint64 {
if sz, ok := value.(traits.Sizer); ok {
return uint64(sz.Size().(types.Int))
}
return 1
}
func (l *CostEstimator) sizeEstimate(t checker.AstNode) checker.SizeEstimate {
if sz := t.ComputedSize(); sz != nil {
return *sz
}
if sz := l.EstimateSize(t); sz != nil {
return *sz
}
return checker.SizeEstimate{Min: 0, Max: math.MaxUint64}
}
func (l *CostEstimator) listElementNode(list checker.AstNode) checker.AstNode {
if lt := list.Type().GetListType(); lt != nil {
nodePath := list.Path()
if nodePath != nil {
// Provide path if we have it so that a OpenAPIv3 maxLength validation can be looked up, if it exists
// for this node.
path := make([]string, len(nodePath)+1)
copy(path, nodePath)
path[len(nodePath)] = "@items"
return &itemsNode{path: path, t: lt.GetElemType(), expr: nil}
} else {
// Provide just the type if no path is available so that worst case size can be looked up based on type.
return &itemsNode{t: lt.GetElemType(), expr: nil}
}
}
return nil
}
func (l *CostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate {
if l.SizeEstimator != nil {
return l.SizeEstimator.EstimateSize(element)
}
return nil
}
type itemsNode struct {
path []string
t *exprpb.Type
expr *exprpb.Expr
}
func (i *itemsNode) Path() []string {
return i.path
}
func (i *itemsNode) Type() *exprpb.Type {
return i.t
}
func (i *itemsNode) Expr() *exprpb.Expr {
return i.expr
}
func (i *itemsNode) ComputedSize() *checker.SizeEstimate {
return nil
}
// traversalCost computes the cost of traversing a ref.Val as a data tree.
func traversalCost(v ref.Val) uint64 {
// TODO: This could potentially be optimized by sampling maps and lists instead of traversing.
switch vt := v.(type) {
case types.String:
return uint64(float64(len(string(vt))) * common.StringTraversalCostFactor)
case types.Bytes:
return uint64(float64(len([]byte(vt))) * common.StringTraversalCostFactor)
case traits.Lister:
cost := uint64(0)
for it := vt.Iterator(); it.HasNext() == types.True; {
i := it.Next()
cost += traversalCost(i)
}
return cost
case traits.Mapper: // maps and objects
cost := uint64(0)
for it := vt.Iterator(); it.HasNext() == types.True; {
k := it.Next()
cost += traversalCost(k) + traversalCost(vt.Get(k))
}
return cost
default:
return 1
}
}

View File

@ -0,0 +1,363 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package library
import (
"fmt"
"testing"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/ext"
expr "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
const (
intListLiteral = "[1, 2, 3, 4, 5]"
uintListLiteral = "[uint(1), uint(2), uint(3), uint(4), uint(5)]"
doubleListLiteral = "[1.0, 2.0, 3.0, 4.0, 5.0]"
boolListLiteral = "[false, true, false, true, false]"
stringListLiteral = "['012345678901', '012345678901', '012345678901', '012345678901', '012345678901']"
bytesListLiteral = "[bytes('012345678901'), bytes('012345678901'), bytes('012345678901'), bytes('012345678901'), bytes('012345678901')]"
durationListLiteral = "[duration('1s'), duration('2s'), duration('3s'), duration('4s'), duration('5s')]"
timestampListLiteral = "[timestamp('2011-01-01T00:00:00.000+01:00'), timestamp('2011-01-02T00:00:00.000+01:00'), " +
"timestamp('2011-01-03T00:00:00.000+01:00'), timestamp('2011-01-04T00:00:00.000+01:00'), " +
"timestamp('2011-01-05T00:00:00.000+01:00')]"
stringLiteral = "'01234567890123456789012345678901234567890123456789'"
)
type comparableCost struct {
comparableLiteral string
expectedEstimatedCost checker.CostEstimate
expectedRuntimeCost uint64
param string
}
func TestListsCost(t *testing.T) {
cases := []struct {
opts []string
costs []comparableCost
}{
{
opts: []string{".sum()"},
// 10 cost for the list declaration, the rest is the due to the function call
costs: []comparableCost{
{
comparableLiteral: intListLiteral,
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
},
{
comparableLiteral: uintListLiteral,
expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for casts
},
{
comparableLiteral: doubleListLiteral,
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
},
{
comparableLiteral: durationListLiteral,
expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for casts
},
},
},
{
opts: []string{".isSorted()", ".max()", ".min()"},
// 10 cost for the list declaration, the rest is the due to the function call
costs: []comparableCost{
{
comparableLiteral: intListLiteral,
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
},
{
comparableLiteral: uintListLiteral,
expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for numeric casts
},
{
comparableLiteral: doubleListLiteral,
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
},
{
comparableLiteral: boolListLiteral,
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
},
{
comparableLiteral: stringListLiteral,
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 25}, expectedRuntimeCost: 15, // +5 for string comparisons
},
{
comparableLiteral: bytesListLiteral,
expectedEstimatedCost: checker.CostEstimate{Min: 25, Max: 35}, expectedRuntimeCost: 25, // +10 for casts from string to byte, +5 for byte comparisons
},
{
comparableLiteral: durationListLiteral,
expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for numeric casts
},
{
comparableLiteral: timestampListLiteral,
expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for casts
},
},
},
}
for _, tc := range cases {
for _, op := range tc.opts {
for _, typ := range tc.costs {
t.Run(typ.comparableLiteral+op, func(t *testing.T) {
e := typ.comparableLiteral + op
testCost(t, e, typ.expectedEstimatedCost, typ.expectedRuntimeCost)
})
}
}
}
}
func TestIndexOfCost(t *testing.T) {
cases := []struct {
opts []string
costs []comparableCost
}{
{
opts: []string{".indexOf(%s)", ".lastIndexOf(%s)"},
// 10 cost for the list declaration, the rest is the due to the function call
costs: []comparableCost{
{
comparableLiteral: intListLiteral, param: "3",
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
},
{
comparableLiteral: uintListLiteral, param: "uint(3)",
expectedEstimatedCost: checker.CostEstimate{Min: 21, Max: 21}, expectedRuntimeCost: 21, // +5 for numeric casts
},
{
comparableLiteral: doubleListLiteral, param: "3.0",
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
},
{
comparableLiteral: boolListLiteral, param: "true",
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
},
{
comparableLiteral: stringListLiteral, param: "'x'",
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 25}, expectedRuntimeCost: 15, // +5 for string comparisons
},
{
comparableLiteral: bytesListLiteral, param: "bytes('x')",
expectedEstimatedCost: checker.CostEstimate{Min: 26, Max: 36}, expectedRuntimeCost: 26, // +11 for casts from string to byte, +5 for byte comparisons
},
{
comparableLiteral: durationListLiteral, param: "duration('3s')",
expectedEstimatedCost: checker.CostEstimate{Min: 21, Max: 21}, expectedRuntimeCost: 21, // +6 for casts from duration to byte
},
{
comparableLiteral: timestampListLiteral, param: "timestamp('2011-01-03T00:00:00.000+01:00')",
expectedEstimatedCost: checker.CostEstimate{Min: 21, Max: 21}, expectedRuntimeCost: 21, // +6 for casts from timestamp to byte
},
// index of operations are also defined for strings
{
comparableLiteral: stringLiteral, param: "'123'",
expectedEstimatedCost: checker.CostEstimate{Min: 5, Max: 5}, expectedRuntimeCost: 5,
},
},
},
}
for _, tc := range cases {
for _, op := range tc.opts {
for _, typ := range tc.costs {
opWithParam := fmt.Sprintf(op, typ.param)
t.Run(typ.comparableLiteral+opWithParam, func(t *testing.T) {
e := typ.comparableLiteral + opWithParam
testCost(t, e, typ.expectedEstimatedCost, typ.expectedRuntimeCost)
})
}
}
}
}
func TestURLsCost(t *testing.T) {
cases := []struct {
ops []string
expectEsimatedCost checker.CostEstimate
expectRuntimeCost uint64
}{
{
ops: []string{".getScheme()", ".getHostname()", ".getHost()", ".getPort()", ".getEscapedPath()", ".getQuery()"},
expectEsimatedCost: checker.CostEstimate{Min: 4, Max: 4},
expectRuntimeCost: 4,
},
}
for _, tc := range cases {
for _, op := range tc.ops {
t.Run("url."+op, func(t *testing.T) {
testCost(t, "url('https:://kubernetes.io/')"+op, tc.expectEsimatedCost, tc.expectRuntimeCost)
})
}
}
}
func TestStringLibrary(t *testing.T) {
cases := []struct {
name string
expr string
expectEsimatedCost checker.CostEstimate
expectRuntimeCost uint64
}{
{
name: "lowerAscii",
expr: "'ABCDEFGHIJ abcdefghij'.lowerAscii()",
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
expectRuntimeCost: 3,
},
{
name: "upperAscii",
expr: "'ABCDEFGHIJ abcdefghij'.upperAscii()",
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
expectRuntimeCost: 3,
},
{
name: "replace",
expr: "'abc 123 def 123'.replace('123', '456')",
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
expectRuntimeCost: 3,
},
{
name: "replace with limit",
expr: "'abc 123 def 123'.replace('123', '456', 1)",
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
expectRuntimeCost: 3,
},
{
name: "split",
expr: "'abc 123 def 123'.split(' ')",
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
expectRuntimeCost: 3,
},
{
name: "split with limit",
expr: "'abc 123 def 123'.split(' ', 1)",
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
expectRuntimeCost: 3,
},
{
name: "substring",
expr: "'abc 123 def 123'.substring(5)",
expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2},
expectRuntimeCost: 2,
},
{
name: "substring with end",
expr: "'abc 123 def 123'.substring(5, 8)",
expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2},
expectRuntimeCost: 2,
},
{
name: "trim",
expr: "' abc 123 def 123 '.trim()",
expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2},
expectRuntimeCost: 2,
},
{
name: "join with separator",
expr: "['aa', 'bb', 'cc', 'd', 'e', 'f', 'g', 'h', 'i', 'j'].join(' ')",
expectEsimatedCost: checker.CostEstimate{Min: 11, Max: 23},
expectRuntimeCost: 15,
},
{
name: "join",
expr: "['aa', 'bb', 'cc', 'd', 'e', 'f', 'g', 'h', 'i', 'j'].join()",
expectEsimatedCost: checker.CostEstimate{Min: 10, Max: 22},
expectRuntimeCost: 13,
},
{
name: "find",
expr: "'abc 123 def 123'.find('123')",
expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2},
expectRuntimeCost: 2,
},
{
name: "findAll",
expr: "'abc 123 def 123'.findAll('123')",
expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2},
expectRuntimeCost: 2,
},
{
name: "findAll with limit",
expr: "'abc 123 def 123'.findAll('123', 1)",
expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2},
expectRuntimeCost: 2,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
testCost(t, tc.expr, tc.expectEsimatedCost, tc.expectRuntimeCost)
})
}
}
func testCost(t *testing.T, expr string, expectEsimatedCost checker.CostEstimate, expectRuntimeCost uint64) {
est := &CostEstimator{SizeEstimator: &testCostEstimator{}}
env, err := cel.NewEnv(append(k8sExtensionLibs, ext.Strings())...)
if err != nil {
t.Fatalf("%v", err)
}
compiled, issues := env.Compile(expr)
if len(issues.Errors()) > 0 {
t.Fatalf("%v", issues.Errors())
}
estCost, err := env.EstimateCost(compiled, est)
if err != nil {
t.Fatalf("%v", err)
}
if estCost.Min != expectEsimatedCost.Min || estCost.Max != expectEsimatedCost.Max {
t.Errorf("Expected estimated cost of %d..%d but got %d..%d", expectEsimatedCost.Min, expectEsimatedCost.Max, estCost.Min, estCost.Max)
}
prog, err := env.Program(compiled, cel.CostTracking(est))
if err != nil {
t.Fatalf("%v", err)
}
_, details, err := prog.Eval(map[string]interface{}{})
if err != nil {
t.Fatalf("%v", err)
}
cost := details.ActualCost()
if *cost != expectRuntimeCost {
t.Errorf("Expected cost of %d but got %d", expectRuntimeCost, *cost)
}
}
type testCostEstimator struct {
}
func (t *testCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate {
switch t := element.Type().TypeKind.(type) {
case *expr.Type_Primitive:
switch t.Primitive {
case expr.Type_STRING:
return &checker.SizeEstimate{Min: 0, Max: 12}
case expr.Type_BYTES:
return &checker.SizeEstimate{Min: 0, Max: 12}
}
}
return nil
}
func (t *testCostEstimator) EstimateCallCost(function, overloadId string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
return nil
}

View File

@ -0,0 +1,34 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package library
import (
"github.com/google/cel-go/cel"
"github.com/google/cel-go/ext"
"github.com/google/cel-go/interpreter"
)
// ExtensionLibs declares the set of CEL extension libraries available everywhere CEL is used in Kubernetes.
var ExtensionLibs = append(k8sExtensionLibs, ext.Strings())
var k8sExtensionLibs = []cel.EnvOption{
URLs(),
Regex(),
Lists(),
}
var ExtensionLibRegexOptimizations = []*interpreter.RegexOptimization{FindRegexOptimization, FindAllRegexOptimization}

View File

@ -0,0 +1,58 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package library
import (
"testing"
"github.com/google/cel-go/cel"
)
func TestLibraryCompatibility(t *testing.T) {
functionNames := map[string]struct{}{}
decls := map[cel.Library]map[string][]cel.FunctionOpt{
urlsLib: urlLibraryDecls,
listsLib: listsLibraryDecls,
regexLib: regexLibraryDecls,
}
if len(k8sExtensionLibs) != len(decls) {
t.Errorf("Expected the same number of libraries in the ExtensionLibs as are tested for compatibility")
}
for _, decl := range decls {
for name := range decl {
functionNames[name] = struct{}{}
}
}
// WARN: All library changes must follow
// https://github.com/kubernetes/enhancements/tree/master/keps/sig-api-machinery/2876-crd-validation-expression-language#function-library-updates
// and must track the functions here along with which Kubernetes version introduced them.
knownFunctions := []string{
// Kubernetes 1.24:
"isSorted", "sum", "max", "min", "indexOf", "lastIndexOf", "find", "findAll", "url", "getScheme", "getHost", "getHostname",
"getPort", "getEscapedPath", "getQuery", "isURL",
// Kubernetes <1.??>:
}
for _, fn := range knownFunctions {
delete(functionNames, fn)
}
if len(functionNames) != 0 {
t.Errorf("Expected all functions in the libraries to be assigned to a kubernetes release, but found the unassigned function names: %v", functionNames)
}
}

312
pkg/cel/library/lists.go Normal file
View File

@ -0,0 +1,312 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package library
import (
"fmt"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/interpreter/functions"
)
// Lists provides a CEL function library extension of list utility functions.
//
// isSorted
//
// Returns true if the provided list of comparable elements is sorted, else returns false.
//
// <list<T>>.isSorted() <bool>, T must be a comparable type
//
// Examples:
//
// [1, 2, 3].isSorted() // return true
// ['a', 'b', 'b', 'c'].isSorted() // return true
// [2.0, 1.0].isSorted() // return false
// [1].isSorted() // return true
// [].isSorted() // return true
//
// sum
//
// Returns the sum of the elements of the provided list. Supports CEL number (int, uint, double) and duration types.
//
// <list<T>>.sum() <T>, T must be a numeric type or a duration
//
// Examples:
//
// [1, 3].sum() // returns 4
// [1.0, 3.0].sum() // returns 4.0
// ['1m', '1s'].sum() // returns '1m1s'
// emptyIntList.sum() // returns 0
// emptyDoubleList.sum() // returns 0.0
// [].sum() // returns 0
//
// min / max
//
// Returns the minimum/maximum valued element of the provided list. Supports all comparable types.
// If the list is empty, an error is returned.
//
// <list<T>>.min() <T>, T must be a comparable type
// <list<T>>.max() <T>, T must be a comparable type
//
// Examples:
//
// [1, 3].min() // returns 1
// [1, 3].max() // returns 3
// [].min() // error
// [1].min() // returns 1
// ([0] + emptyList).min() // returns 0
//
// indexOf / lastIndexOf
//
// Returns either the first or last positional index of the provided element in the list.
// If the element is not found, -1 is returned. Supports all equatable types.
//
// <list<T>>.indexOf(<T>) <int>, T must be an equatable type
// <list<T>>.lastIndexOf(<T>) <int>, T must be an equatable type
//
// Examples:
//
// [1, 2, 2, 3].indexOf(2) // returns 1
// ['a', 'b', 'b', 'c'].lastIndexOf('b') // returns 2
// [1.0].indexOf(1.1) // returns -1
// [].indexOf('string') // returns -1
func Lists() cel.EnvOption {
return cel.Lib(listsLib)
}
var listsLib = &lists{}
type lists struct{}
var paramA = cel.TypeParamType("A")
// CEL typeParams can be used to constraint to a specific trait (e.g. traits.ComparableType) if the 1st operand is the type to constrain.
// But the functions we need to constrain are <list<paramType>>, not just <paramType>.
// Make sure the order of overload set is deterministic
type namedCELType struct {
typeName string
celType *cel.Type
}
var summableTypes = []namedCELType{
{typeName: "int", celType: cel.IntType},
{typeName: "uint", celType: cel.UintType},
{typeName: "double", celType: cel.DoubleType},
{typeName: "duration", celType: cel.DurationType},
}
var zeroValuesOfSummableTypes = map[string]ref.Val{
"int": types.Int(0),
"uint": types.Uint(0),
"double": types.Double(0.0),
"duration": types.Duration{Duration: 0},
}
var comparableTypes = []namedCELType{
{typeName: "int", celType: cel.IntType},
{typeName: "uint", celType: cel.UintType},
{typeName: "double", celType: cel.DoubleType},
{typeName: "bool", celType: cel.BoolType},
{typeName: "duration", celType: cel.DurationType},
{typeName: "timestamp", celType: cel.TimestampType},
{typeName: "string", celType: cel.StringType},
{typeName: "bytes", celType: cel.BytesType},
}
// WARNING: All library additions or modifications must follow
// https://github.com/kubernetes/enhancements/tree/master/keps/sig-api-machinery/2876-crd-validation-expression-language#function-library-updates
var listsLibraryDecls = map[string][]cel.FunctionOpt{
"isSorted": templatedOverloads(comparableTypes, func(name string, paramType *cel.Type) cel.FunctionOpt {
return cel.MemberOverload(fmt.Sprintf("list_%s_is_sorted_bool", name),
[]*cel.Type{cel.ListType(paramType)}, cel.BoolType, cel.UnaryBinding(isSorted))
}),
"sum": templatedOverloads(summableTypes, func(name string, paramType *cel.Type) cel.FunctionOpt {
return cel.MemberOverload(fmt.Sprintf("list_%s_sum_%s", name, name),
[]*cel.Type{cel.ListType(paramType)}, paramType, cel.UnaryBinding(func(list ref.Val) ref.Val {
return sum(
func() ref.Val {
return zeroValuesOfSummableTypes[name]
})(list)
}))
}),
"max": templatedOverloads(comparableTypes, func(name string, paramType *cel.Type) cel.FunctionOpt {
return cel.MemberOverload(fmt.Sprintf("list_%s_max_%s", name, name),
[]*cel.Type{cel.ListType(paramType)}, paramType, cel.UnaryBinding(max()))
}),
"min": templatedOverloads(comparableTypes, func(name string, paramType *cel.Type) cel.FunctionOpt {
return cel.MemberOverload(fmt.Sprintf("list_%s_min_%s", name, name),
[]*cel.Type{cel.ListType(paramType)}, paramType, cel.UnaryBinding(min()))
}),
"indexOf": {
cel.MemberOverload("list_a_index_of_int", []*cel.Type{cel.ListType(paramA), paramA}, cel.IntType,
cel.BinaryBinding(indexOf)),
},
"lastIndexOf": {
cel.MemberOverload("list_a_last_index_of_int", []*cel.Type{cel.ListType(paramA), paramA}, cel.IntType,
cel.BinaryBinding(lastIndexOf)),
},
}
func (*lists) CompileOptions() []cel.EnvOption {
options := []cel.EnvOption{}
for name, overloads := range listsLibraryDecls {
options = append(options, cel.Function(name, overloads...))
}
return options
}
func (*lists) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
func isSorted(val ref.Val) ref.Val {
var prev traits.Comparer
iterable, ok := val.(traits.Iterable)
if !ok {
return types.MaybeNoSuchOverloadErr(val)
}
for it := iterable.Iterator(); it.HasNext() == types.True; {
next := it.Next()
nextCmp, ok := next.(traits.Comparer)
if !ok {
return types.MaybeNoSuchOverloadErr(next)
}
if prev != nil {
cmp := prev.Compare(next)
if cmp == types.IntOne {
return types.False
}
}
prev = nextCmp
}
return types.True
}
func sum(init func() ref.Val) functions.UnaryOp {
return func(val ref.Val) ref.Val {
i := init()
acc, ok := i.(traits.Adder)
if !ok {
// Should never happen since all passed in init values are valid
return types.MaybeNoSuchOverloadErr(i)
}
iterable, ok := val.(traits.Iterable)
if !ok {
return types.MaybeNoSuchOverloadErr(val)
}
for it := iterable.Iterator(); it.HasNext() == types.True; {
next := it.Next()
nextAdder, ok := next.(traits.Adder)
if !ok {
// Should never happen for type checked CEL programs
return types.MaybeNoSuchOverloadErr(next)
}
if acc != nil {
s := acc.Add(next)
sum, ok := s.(traits.Adder)
if !ok {
// Should never happen for type checked CEL programs
return types.MaybeNoSuchOverloadErr(s)
}
acc = sum
} else {
acc = nextAdder
}
}
return acc.(ref.Val)
}
}
func min() functions.UnaryOp {
return cmp("min", types.IntOne)
}
func max() functions.UnaryOp {
return cmp("max", types.IntNegOne)
}
func cmp(opName string, opPreferCmpResult ref.Val) functions.UnaryOp {
return func(val ref.Val) ref.Val {
var result traits.Comparer
iterable, ok := val.(traits.Iterable)
if !ok {
return types.MaybeNoSuchOverloadErr(val)
}
for it := iterable.Iterator(); it.HasNext() == types.True; {
next := it.Next()
nextCmp, ok := next.(traits.Comparer)
if !ok {
// Should never happen for type checked CEL programs
return types.MaybeNoSuchOverloadErr(next)
}
if result == nil {
result = nextCmp
} else {
cmp := result.Compare(next)
if cmp == opPreferCmpResult {
result = nextCmp
}
}
}
if result == nil {
return types.NewErr("%s called on empty list", opName)
}
return result.(ref.Val)
}
}
func indexOf(list ref.Val, item ref.Val) ref.Val {
lister, ok := list.(traits.Lister)
if !ok {
return types.MaybeNoSuchOverloadErr(list)
}
sz := lister.Size().(types.Int)
for i := types.Int(0); i < sz; i++ {
if lister.Get(types.Int(i)).Equal(item) == types.True {
return types.Int(i)
}
}
return types.Int(-1)
}
func lastIndexOf(list ref.Val, item ref.Val) ref.Val {
lister, ok := list.(traits.Lister)
if !ok {
return types.MaybeNoSuchOverloadErr(list)
}
sz := lister.Size().(types.Int)
for i := sz - 1; i >= 0; i-- {
if lister.Get(types.Int(i)).Equal(item) == types.True {
return types.Int(i)
}
}
return types.Int(-1)
}
// templatedOverloads returns overloads for each of the provided types. The template function is called with each type
// name (map key) and type to construct the overloads.
func templatedOverloads(types []namedCELType, template func(name string, t *cel.Type) cel.FunctionOpt) []cel.FunctionOpt {
overloads := make([]cel.FunctionOpt, len(types))
i := 0
for _, t := range types {
overloads[i] = template(t.typeName, t.celType)
i++
}
return overloads
}

187
pkg/cel/library/regex.go Normal file
View File

@ -0,0 +1,187 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package library
import (
"regexp"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter"
)
// Regex provides a CEL function library extension of regex utility functions.
//
// find / findAll
//
// Returns substrings that match the provided regular expression. find returns the first match. findAll may optionally
// be provided a limit. If the limit is set and >= 0, no more than the limit number of matches are returned.
//
// <string>.find(<string>) <string>
// <string>.findAll(<string>) <list <string>>
// <string>.findAll(<string>, <int>) <list <string>>
//
// Examples:
//
// "abc 123".find('[0-9]*') // returns '123'
// "abc 123".find('xyz') // returns ''
// "123 abc 456".findAll('[0-9]*') // returns ['123', '456']
// "123 abc 456".findAll('[0-9]*', 1) // returns ['123']
// "123 abc 456".findAll('xyz') // returns []
func Regex() cel.EnvOption {
return cel.Lib(regexLib)
}
var regexLib = &regex{}
type regex struct{}
var regexLibraryDecls = map[string][]cel.FunctionOpt{
"find": {
cel.MemberOverload("string_find_string", []*cel.Type{cel.StringType, cel.StringType}, cel.StringType,
cel.BinaryBinding(find))},
"findAll": {
cel.MemberOverload("string_find_all_string", []*cel.Type{cel.StringType, cel.StringType},
cel.ListType(cel.StringType),
cel.BinaryBinding(func(str, regex ref.Val) ref.Val {
return findAll(str, regex, types.Int(-1))
})),
cel.MemberOverload("string_find_all_string_int",
[]*cel.Type{cel.StringType, cel.StringType, cel.IntType},
cel.ListType(cel.StringType),
cel.FunctionBinding(findAll)),
},
}
func (*regex) CompileOptions() []cel.EnvOption {
options := []cel.EnvOption{}
for name, overloads := range regexLibraryDecls {
options = append(options, cel.Function(name, overloads...))
}
return options
}
func (*regex) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
func find(strVal ref.Val, regexVal ref.Val) ref.Val {
str, ok := strVal.Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(strVal)
}
regex, ok := regexVal.Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(regexVal)
}
re, err := regexp.Compile(regex)
if err != nil {
return types.NewErr("Illegal regex: %v", err.Error())
}
result := re.FindString(str)
return types.String(result)
}
func findAll(args ...ref.Val) ref.Val {
argn := len(args)
if argn < 2 || argn > 3 {
return types.NoSuchOverloadErr()
}
str, ok := args[0].Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(args[0])
}
regex, ok := args[1].Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(args[1])
}
n := int64(-1)
if argn == 3 {
n, ok = args[2].Value().(int64)
if !ok {
return types.MaybeNoSuchOverloadErr(args[2])
}
}
re, err := regexp.Compile(regex)
if err != nil {
return types.NewErr("Illegal regex: %v", err.Error())
}
result := re.FindAllString(str, int(n))
return types.NewStringList(types.DefaultTypeAdapter, result)
}
// FindRegexOptimization optimizes the 'find' function by compiling the regex pattern and
// reporting any compilation errors at program creation time, and using the compiled regex pattern for all function
// call invocations.
var FindRegexOptimization = &interpreter.RegexOptimization{
Function: "find",
RegexIndex: 1,
Factory: func(call interpreter.InterpretableCall, regexPattern string) (interpreter.InterpretableCall, error) {
compiledRegex, err := regexp.Compile(regexPattern)
if err != nil {
return nil, err
}
return interpreter.NewCall(call.ID(), call.Function(), call.OverloadID(), call.Args(), func(args ...ref.Val) ref.Val {
if len(args) != 2 {
return types.NoSuchOverloadErr()
}
in, ok := args[0].Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(args[0])
}
return types.String(compiledRegex.FindString(in))
}), nil
},
}
// FindAllRegexOptimization optimizes the 'findAll' function by compiling the regex pattern and
// reporting any compilation errors at program creation time, and using the compiled regex pattern for all function
// call invocations.
var FindAllRegexOptimization = &interpreter.RegexOptimization{
Function: "findAll",
RegexIndex: 1,
Factory: func(call interpreter.InterpretableCall, regexPattern string) (interpreter.InterpretableCall, error) {
compiledRegex, err := regexp.Compile(regexPattern)
if err != nil {
return nil, err
}
return interpreter.NewCall(call.ID(), call.Function(), call.OverloadID(), call.Args(), func(args ...ref.Val) ref.Val {
argn := len(args)
if argn < 2 || argn > 3 {
return types.NoSuchOverloadErr()
}
str, ok := args[0].Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(args[0])
}
n := int64(-1)
if argn == 3 {
n, ok = args[2].Value().(int64)
if !ok {
return types.MaybeNoSuchOverloadErr(args[2])
}
}
result := compiledRegex.FindAllString(str, int(n))
return types.NewStringList(types.DefaultTypeAdapter, result)
}), nil
},
}

236
pkg/cel/library/urls.go Normal file
View File

@ -0,0 +1,236 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package library
import (
"net/url"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
apiservercel "k8s.io/apiserver/pkg/cel"
)
// URLs provides a CEL function library extension of URL parsing functions.
//
// url
//
// Converts a string to a URL or results in an error if the string is not a valid URL. The URL must be an absolute URI
// or an absolute path.
//
// url(<string>) <URL>
//
// Examples:
//
// url('https://user:pass@example.com:80/path?query=val#fragment') // returns a URL
// url('/absolute-path') // returns a URL
// url('https://a:b:c/') // error
// url('../relative-path') // error
//
// isURL
//
// Returns true if a string is a valid URL. The URL must be an absolute URI or an absolute path.
//
// isURL( <string>) <bool>
//
// Examples:
//
// isURL('https://user:pass@example.com:80/path?query=val#fragment') // returns true
// isURL('/absolute-path') // returns true
// isURL('https://a:b:c/') // returns false
// isURL('../relative-path') // returns false
//
// getScheme / getHost / getHostname / getPort / getEscapedPath / getQuery
//
// Return the parsed components of a URL.
//
// - getScheme: If absent in the URL, returns an empty string.
//
// - getHostname: IPv6 addresses are returned with braces, e.g. "[::1]". If absent in the URL, returns an empty string.
//
// - getHost: IPv6 addresses are returned without braces, e.g. "::1". If absent in the URL, returns an empty string.
//
// - getEscapedPath: The string returned by getEscapedPath is URL escaped, e.g. "with space" becomes "with%20space".
// If absent in the URL, returns an empty string.
//
// - getPort: If absent in the URL, returns an empty string.
//
// - getQuery: Returns the query parameters in "matrix" form where a repeated query key is interpreted to
// mean that there are multiple values for that key. The keys and values are returned unescaped.
// If absent in the URL, returns an empty map.
//
// <URL>.getScheme() <string>
// <URL>.getHost() <string>
// <URL>.getHostname() <string>
// <URL>.getPort() <string>
// <URL>.getEscapedPath() <string>
// <URL>.getQuery() <map <string>, <list <string>>
//
// Examples:
//
// url('/path').getScheme() // returns ''
// url('https://example.com/').getScheme() // returns 'https'
// url('https://example.com:80/').getHost() // returns 'example.com:80'
// url('https://example.com/').getHost() // returns 'example.com'
// url('https://[::1]:80/').getHost() // returns '[::1]:80'
// url('https://[::1]/').getHost() // returns '[::1]'
// url('/path').getHost() // returns ''
// url('https://example.com:80/').getHostname() // returns 'example.com'
// url('https://127.0.0.1:80/').getHostname() // returns '127.0.0.1'
// url('https://[::1]:80/').getHostname() // returns '::1'
// url('/path').getHostname() // returns ''
// url('https://example.com:80/').getPort() // returns '80'
// url('https://example.com/').getPort() // returns ''
// url('/path').getPort() // returns ''
// url('https://example.com/path').getEscapedPath() // returns '/path'
// url('https://example.com/path with spaces/').getEscapedPath() // returns '/path%20with%20spaces/'
// url('https://example.com').getEscapedPath() // returns ''
// url('https://example.com/path?k1=a&k2=b&k2=c').getQuery() // returns { 'k1': ['a'], 'k2': ['b', 'c']}
// url('https://example.com/path?key with spaces=value with spaces').getQuery() // returns { 'key with spaces': ['value with spaces']}
// url('https://example.com/path?').getQuery() // returns {}
// url('https://example.com/path').getQuery() // returns {}
func URLs() cel.EnvOption {
return cel.Lib(urlsLib)
}
var urlsLib = &urls{}
type urls struct{}
var urlLibraryDecls = map[string][]cel.FunctionOpt{
"url": {
cel.Overload("string_to_url", []*cel.Type{cel.StringType}, apiservercel.URLType,
cel.UnaryBinding(stringToUrl))},
"getScheme": {
cel.MemberOverload("url_get_scheme", []*cel.Type{apiservercel.URLType}, cel.StringType,
cel.UnaryBinding(getScheme))},
"getHost": {
cel.MemberOverload("url_get_host", []*cel.Type{apiservercel.URLType}, cel.StringType,
cel.UnaryBinding(getHost))},
"getHostname": {
cel.MemberOverload("url_get_hostname", []*cel.Type{apiservercel.URLType}, cel.StringType,
cel.UnaryBinding(getHostname))},
"getPort": {
cel.MemberOverload("url_get_port", []*cel.Type{apiservercel.URLType}, cel.StringType,
cel.UnaryBinding(getPort))},
"getEscapedPath": {
cel.MemberOverload("url_get_escaped_path", []*cel.Type{apiservercel.URLType}, cel.StringType,
cel.UnaryBinding(getEscapedPath))},
"getQuery": {
cel.MemberOverload("url_get_query", []*cel.Type{apiservercel.URLType},
cel.MapType(cel.StringType, cel.ListType(cel.StringType)),
cel.UnaryBinding(getQuery))},
"isURL": {
cel.Overload("is_url_string", []*cel.Type{cel.StringType}, cel.BoolType,
cel.UnaryBinding(isURL))},
}
func (*urls) CompileOptions() []cel.EnvOption {
options := []cel.EnvOption{}
for name, overloads := range urlLibraryDecls {
options = append(options, cel.Function(name, overloads...))
}
return options
}
func (*urls) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
func stringToUrl(arg ref.Val) ref.Val {
s, ok := arg.Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
// Use ParseRequestURI to check the URL before conversion.
// ParseRequestURI requires absolute URLs and is used by the OpenAPIv3 'uri' format.
_, err := url.ParseRequestURI(s)
if err != nil {
return types.NewErr("URL parse error during conversion from string: %v", err)
}
// We must parse again with Parse since ParseRequestURI incorrectly parses URLs that contain a fragment
// part and will incorrectly append the fragment to either the path or the query, depending on which it was adjacent to.
u, err := url.Parse(s)
if err != nil {
// Errors are not expected here since Parse is a more lenient parser than ParseRequestURI.
return types.NewErr("URL parse error during conversion from string: %v", err)
}
return apiservercel.URL{URL: u}
}
func getScheme(arg ref.Val) ref.Val {
u, ok := arg.Value().(*url.URL)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.String(u.Scheme)
}
func getHost(arg ref.Val) ref.Val {
u, ok := arg.Value().(*url.URL)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.String(u.Host)
}
func getHostname(arg ref.Val) ref.Val {
u, ok := arg.Value().(*url.URL)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.String(u.Hostname())
}
func getPort(arg ref.Val) ref.Val {
u, ok := arg.Value().(*url.URL)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.String(u.Port())
}
func getEscapedPath(arg ref.Val) ref.Val {
u, ok := arg.Value().(*url.URL)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.String(u.EscapedPath())
}
func getQuery(arg ref.Val) ref.Val {
u, ok := arg.Value().(*url.URL)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
result := map[ref.Val]ref.Val{}
for k, v := range u.Query() {
result[types.String(k)] = types.NewStringList(types.DefaultTypeAdapter, v)
}
return types.NewRefValMap(types.DefaultTypeAdapter, result)
}
func isURL(arg ref.Val) ref.Val {
s, ok := arg.Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
_, err := url.ParseRequestURI(s)
return types.Bool(err == nil)
}

48
pkg/cel/limits.go Normal file
View File

@ -0,0 +1,48 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cel
const (
// DefaultMaxRequestSizeBytes is the size of the largest request that will be accepted
DefaultMaxRequestSizeBytes = int64(3 * 1024 * 1024)
// MaxDurationSizeJSON
// OpenAPI duration strings follow RFC 3339, section 5.6 - see the comment on maxDatetimeSizeJSON
MaxDurationSizeJSON = 32
// MaxDatetimeSizeJSON
// OpenAPI datetime strings follow RFC 3339, section 5.6, and the longest possible
// such string is 9999-12-31T23:59:59.999999999Z, which has length 30 - we add 2
// to allow for quotation marks
MaxDatetimeSizeJSON = 32
// MinDurationSizeJSON
// Golang allows a string of 0 to be parsed as a duration, so that plus 2 to account for
// quotation marks makes 3
MinDurationSizeJSON = 3
// JSONDateSize is the size of a date serialized as part of a JSON object
// RFC 3339 dates require YYYY-MM-DD, and then we add 2 to allow for quotation marks
JSONDateSize = 12
// MinDatetimeSizeJSON is the minimal length of a datetime formatted as RFC 3339
// RFC 3339 datetimes require a full date (YYYY-MM-DD) and full time (HH:MM:SS), and we add 3 for
// quotation marks like always in addition to the capital T that separates the date and time
MinDatetimeSizeJSON = 21
// MinStringSize is the size of literal ""
MinStringSize = 2
// MinBoolSize is the length of literal true
MinBoolSize = 4
// MinNumberSize is the length of literal 0
MinNumberSize = 1
)

View File

@ -0,0 +1,72 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package metrics
import (
"time"
"k8s.io/component-base/metrics"
"k8s.io/component-base/metrics/legacyregistry"
)
// TODO(jiahuif) CEL is to be used in multiple components, revise naming when that happens.
const (
namespace = "apiserver"
subsystem = "cel"
)
// Metrics provides access to CEL metrics.
var Metrics = newCelMetrics()
type CelMetrics struct {
compilationTime *metrics.Histogram
evaluationTime *metrics.Histogram
}
func newCelMetrics() *CelMetrics {
m := &CelMetrics{
compilationTime: metrics.NewHistogram(&metrics.HistogramOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "compilation_duration_seconds",
StabilityLevel: metrics.ALPHA,
}),
evaluationTime: metrics.NewHistogram(&metrics.HistogramOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "evaluation_duration_seconds",
StabilityLevel: metrics.ALPHA,
}),
}
legacyregistry.MustRegister(m.compilationTime)
legacyregistry.MustRegister(m.evaluationTime)
return m
}
// ObserveCompilation records a CEL compilation with the time the compilation took.
func (m *CelMetrics) ObserveCompilation(elapsed time.Duration) {
seconds := elapsed.Seconds()
m.compilationTime.Observe(seconds)
}
// ObserveEvaluation records a CEL evaluation with the time the evaluation took.
func (m *CelMetrics) ObserveEvaluation(elapsed time.Duration) {
seconds := elapsed.Seconds()
m.evaluationTime.Observe(seconds)
}

View File

@ -0,0 +1,68 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package metrics
import (
"math"
"testing"
"time"
"k8s.io/component-base/metrics/legacyregistry"
)
func TestObserveCompilation(t *testing.T) {
defer legacyregistry.Reset()
Metrics.ObserveCompilation(2 * time.Second)
c, s := gatherHistogram(t, "apiserver_cel_compilation_duration_seconds")
if c != 1 {
t.Errorf("unexpected count: %v", c)
}
if math.Abs(s-2.0) > 1e-7 {
t.Fatalf("incorrect sum: %v", s)
}
}
func TestObserveEvaluation(t *testing.T) {
defer legacyregistry.Reset()
Metrics.ObserveEvaluation(2 * time.Second)
c, s := gatherHistogram(t, "apiserver_cel_evaluation_duration_seconds")
if c != 1 {
t.Errorf("unexpected count: %v", c)
}
if math.Abs(s-2.0) > 1e-7 {
t.Fatalf("incorrect sum: %v", s)
}
}
func gatherHistogram(t *testing.T, name string) (count uint64, sum float64) {
metrics, err := legacyregistry.DefaultGatherer.Gather()
if err != nil {
t.Fatalf("Failed to gather metrics: %s", err)
}
for _, mf := range metrics {
if mf.GetName() == name {
for _, m := range mf.GetMetric() {
h := m.GetHistogram()
count += h.GetSampleCount()
sum += h.GetSampleSum()
}
return
}
}
t.Fatalf("metric not found: %v", name)
return 0, 0
}

79
pkg/cel/registry.go Normal file
View File

@ -0,0 +1,79 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cel
import (
"sync"
"github.com/google/cel-go/cel"
)
// Resolver declares methods to find policy templates and related configuration objects.
type Resolver interface {
// FindType returns a DeclType instance corresponding to the given fully-qualified name, if
// present.
FindType(name string) (*DeclType, bool)
}
// NewRegistry create a registry for keeping track of environments and types
// from a base cel.Env expression environment.
func NewRegistry(stdExprEnv *cel.Env) *Registry {
return &Registry{
exprEnvs: map[string]*cel.Env{"": stdExprEnv},
types: map[string]*DeclType{
BoolType.TypeName(): BoolType,
BytesType.TypeName(): BytesType,
DoubleType.TypeName(): DoubleType,
DurationType.TypeName(): DurationType,
IntType.TypeName(): IntType,
NullType.TypeName(): NullType,
StringType.TypeName(): StringType,
TimestampType.TypeName(): TimestampType,
UintType.TypeName(): UintType,
ListType.TypeName(): ListType,
MapType.TypeName(): MapType,
},
}
}
// Registry defines a repository of environment, schema, template, and type definitions.
//
// Registry instances are concurrency-safe.
type Registry struct {
rwMux sync.RWMutex
exprEnvs map[string]*cel.Env
types map[string]*DeclType
}
// FindType implements the Resolver interface method.
func (r *Registry) FindType(name string) (*DeclType, bool) {
r.rwMux.RLock()
defer r.rwMux.RUnlock()
typ, found := r.types[name]
if found {
return typ, true
}
return typ, found
}
// SetType registers a DeclType descriptor by its fully qualified name.
func (r *Registry) SetType(name string, declType *DeclType) error {
r.rwMux.Lock()
defer r.rwMux.Unlock()
r.types[name] = declType
return nil
}

552
pkg/cel/types.go Normal file
View File

@ -0,0 +1,552 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cel
import (
"fmt"
"math"
"time"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"google.golang.org/protobuf/proto"
)
const (
noMaxLength = math.MaxInt
)
// NewListType returns a parameterized list type with a specified element type.
func NewListType(elem *DeclType, maxItems int64) *DeclType {
return &DeclType{
name: "list",
ElemType: elem,
MaxElements: maxItems,
celType: cel.ListType(elem.CelType()),
defaultValue: NewListValue(),
// a list can always be represented as [] in JSON, so hardcode the min size
// to 2
MinSerializedSize: 2,
}
}
// NewMapType returns a parameterized map type with the given key and element types.
func NewMapType(key, elem *DeclType, maxProperties int64) *DeclType {
return &DeclType{
name: "map",
KeyType: key,
ElemType: elem,
MaxElements: maxProperties,
celType: cel.MapType(key.CelType(), elem.CelType()),
defaultValue: NewMapValue(),
// a map can always be represented as {} in JSON, so hardcode the min size
// to 2
MinSerializedSize: 2,
}
}
// NewObjectType creates an object type with a qualified name and a set of field declarations.
func NewObjectType(name string, fields map[string]*DeclField) *DeclType {
t := &DeclType{
name: name,
Fields: fields,
celType: cel.ObjectType(name),
traitMask: traits.FieldTesterType | traits.IndexerType,
// an object could potentially be larger than the min size we default to here ({}),
// but we rely upon the caller to change MinSerializedSize accordingly if they add
// properties to the object
MinSerializedSize: 2,
}
t.defaultValue = NewObjectValue(t)
return t
}
func NewSimpleTypeWithMinSize(name string, celType *cel.Type, zeroVal ref.Val, minSize int64) *DeclType {
return &DeclType{
name: name,
celType: celType,
defaultValue: zeroVal,
MinSerializedSize: minSize,
}
}
// DeclType represents the universal type descriptor for OpenAPIv3 types.
type DeclType struct {
fmt.Stringer
name string
// Fields contains a map of escaped CEL identifier field names to field declarations.
Fields map[string]*DeclField
KeyType *DeclType
ElemType *DeclType
TypeParam bool
Metadata map[string]string
MaxElements int64
// MinSerializedSize represents the smallest possible size in bytes that
// the DeclType could be serialized to in JSON.
MinSerializedSize int64
celType *cel.Type
traitMask int
defaultValue ref.Val
}
// MaybeAssignTypeName attempts to set the DeclType name to a fully qualified name, if the type
// is of `object` type.
//
// The DeclType must return true for `IsObject` or this assignment will error.
func (t *DeclType) MaybeAssignTypeName(name string) *DeclType {
if t.IsObject() {
objUpdated := false
if t.name != "object" {
name = t.name
} else {
objUpdated = true
}
fieldMap := make(map[string]*DeclField, len(t.Fields))
for fieldName, field := range t.Fields {
fieldType := field.Type
fieldTypeName := fmt.Sprintf("%s.%s", name, fieldName)
updated := fieldType.MaybeAssignTypeName(fieldTypeName)
if updated == fieldType {
fieldMap[fieldName] = field
continue
}
objUpdated = true
fieldMap[fieldName] = &DeclField{
Name: fieldName,
Type: updated,
Required: field.Required,
enumValues: field.enumValues,
defaultValue: field.defaultValue,
}
}
if !objUpdated {
return t
}
return &DeclType{
name: name,
Fields: fieldMap,
KeyType: t.KeyType,
ElemType: t.ElemType,
TypeParam: t.TypeParam,
Metadata: t.Metadata,
celType: cel.ObjectType(name),
traitMask: t.traitMask,
defaultValue: t.defaultValue,
MinSerializedSize: t.MinSerializedSize,
}
}
if t.IsMap() {
elemTypeName := fmt.Sprintf("%s.@elem", name)
updated := t.ElemType.MaybeAssignTypeName(elemTypeName)
if updated == t.ElemType {
return t
}
return NewMapType(t.KeyType, updated, t.MaxElements)
}
if t.IsList() {
elemTypeName := fmt.Sprintf("%s.@idx", name)
updated := t.ElemType.MaybeAssignTypeName(elemTypeName)
if updated == t.ElemType {
return t
}
return NewListType(updated, t.MaxElements)
}
return t
}
// ExprType returns the CEL expression type of this declaration.
func (t *DeclType) ExprType() (*exprpb.Type, error) {
return cel.TypeToExprType(t.celType)
}
// CelType returns the CEL type of this declaration.
func (t *DeclType) CelType() *cel.Type {
return t.celType
}
// FindField returns the DeclField with the given name if present.
func (t *DeclType) FindField(name string) (*DeclField, bool) {
f, found := t.Fields[name]
return f, found
}
// HasTrait implements the CEL ref.Type interface making this type declaration suitable for use
// within the CEL evaluator.
func (t *DeclType) HasTrait(trait int) bool {
if t.traitMask&trait == trait {
return true
}
if t.defaultValue == nil {
return false
}
_, isDecl := t.defaultValue.Type().(*DeclType)
if isDecl {
return false
}
return t.defaultValue.Type().HasTrait(trait)
}
// IsList returns whether the declaration is a `list` type which defines a parameterized element
// type, but not a parameterized key type or fields.
func (t *DeclType) IsList() bool {
return t.KeyType == nil && t.ElemType != nil && t.Fields == nil
}
// IsMap returns whether the declaration is a 'map' type which defines parameterized key and
// element types, but not fields.
func (t *DeclType) IsMap() bool {
return t.KeyType != nil && t.ElemType != nil && t.Fields == nil
}
// IsObject returns whether the declartion is an 'object' type which defined a set of typed fields.
func (t *DeclType) IsObject() bool {
return t.KeyType == nil && t.ElemType == nil && t.Fields != nil
}
// String implements the fmt.Stringer interface method.
func (t *DeclType) String() string {
return t.name
}
// TypeName returns the fully qualified type name for the DeclType.
func (t *DeclType) TypeName() string {
return t.name
}
// DefaultValue returns the CEL ref.Val representing the default value for this object type,
// if one exists.
func (t *DeclType) DefaultValue() ref.Val {
return t.defaultValue
}
// FieldTypeMap constructs a map of the field and object types nested within a given type.
func FieldTypeMap(path string, t *DeclType) map[string]*DeclType {
if t.IsObject() && t.TypeName() != "object" {
path = t.TypeName()
}
types := make(map[string]*DeclType)
buildDeclTypes(path, t, types)
return types
}
func buildDeclTypes(path string, t *DeclType, types map[string]*DeclType) {
// Ensure object types are properly named according to where they appear in the schema.
if t.IsObject() {
// Hack to ensure that names are uniquely qualified and work well with the type
// resolution steps which require fully qualified type names for field resolution
// to function properly.
types[t.TypeName()] = t
for name, field := range t.Fields {
fieldPath := fmt.Sprintf("%s.%s", path, name)
buildDeclTypes(fieldPath, field.Type, types)
}
}
// Map element properties to type names if needed.
if t.IsMap() {
mapElemPath := fmt.Sprintf("%s.@elem", path)
buildDeclTypes(mapElemPath, t.ElemType, types)
types[path] = t
}
// List element properties.
if t.IsList() {
listIdxPath := fmt.Sprintf("%s.@idx", path)
buildDeclTypes(listIdxPath, t.ElemType, types)
types[path] = t
}
}
// DeclField describes the name, ordinal, and optionality of a field declaration within a type.
type DeclField struct {
Name string
Type *DeclType
Required bool
enumValues []interface{}
defaultValue interface{}
}
func NewDeclField(name string, declType *DeclType, required bool, enumValues []interface{}, defaultValue interface{}) *DeclField {
return &DeclField{
Name: name,
Type: declType,
Required: required,
enumValues: enumValues,
defaultValue: defaultValue,
}
}
// TypeName returns the string type name of the field.
func (f *DeclField) TypeName() string {
return f.Type.TypeName()
}
// DefaultValue returns the zero value associated with the field.
func (f *DeclField) DefaultValue() ref.Val {
if f.defaultValue != nil {
return types.DefaultTypeAdapter.NativeToValue(f.defaultValue)
}
return f.Type.DefaultValue()
}
// EnumValues returns the set of values that this field may take.
func (f *DeclField) EnumValues() []ref.Val {
if f.enumValues == nil || len(f.enumValues) == 0 {
return []ref.Val{}
}
ev := make([]ref.Val, len(f.enumValues))
for i, e := range f.enumValues {
ev[i] = types.DefaultTypeAdapter.NativeToValue(e)
}
return ev
}
// NewRuleTypes returns an Open API Schema-based type-system which is CEL compatible.
func NewRuleTypes(kind string,
declType *DeclType,
res Resolver) (*RuleTypes, error) {
// Note, if the schema indicates that it's actually based on another proto
// then prefer the proto definition. For expressions in the proto, a new field
// annotation will be needed to indicate the expected environment and type of
// the expression.
schemaTypes, err := newSchemaTypeProvider(kind, declType)
if err != nil {
return nil, err
}
if schemaTypes == nil {
return nil, nil
}
return &RuleTypes{
ruleSchemaDeclTypes: schemaTypes,
resolver: res,
}, nil
}
// RuleTypes extends the CEL ref.TypeProvider interface and provides an Open API Schema-based
// type-system.
type RuleTypes struct {
ref.TypeProvider
ruleSchemaDeclTypes *schemaTypeProvider
typeAdapter ref.TypeAdapter
resolver Resolver
}
// EnvOptions returns a set of cel.EnvOption values which includes the declaration set
// as well as a custom ref.TypeProvider.
//
// Note, the standard declaration set includes 'rule' which is defined as the top-level rule-schema
// type if one is configured.
//
// If the RuleTypes value is nil, an empty []cel.EnvOption set is returned.
func (rt *RuleTypes) EnvOptions(tp ref.TypeProvider) ([]cel.EnvOption, error) {
if rt == nil {
return []cel.EnvOption{}, nil
}
var ta ref.TypeAdapter = types.DefaultTypeAdapter
tpa, ok := tp.(ref.TypeAdapter)
if ok {
ta = tpa
}
rtWithTypes := &RuleTypes{
TypeProvider: tp,
typeAdapter: ta,
ruleSchemaDeclTypes: rt.ruleSchemaDeclTypes,
resolver: rt.resolver,
}
for name, declType := range rt.ruleSchemaDeclTypes.types {
tpType, found := tp.FindType(name)
expT, err := declType.ExprType()
if err != nil {
return nil, fmt.Errorf("fail to get cel type: %s", err)
}
if found && !proto.Equal(tpType, expT) {
return nil, fmt.Errorf(
"type %s definition differs between CEL environment and rule", name)
}
}
return []cel.EnvOption{
cel.CustomTypeProvider(rtWithTypes),
cel.CustomTypeAdapter(rtWithTypes),
cel.Variable("rule", rt.ruleSchemaDeclTypes.root.CelType()),
}, nil
}
// FindType attempts to resolve the typeName provided from the rule's rule-schema, or if not
// from the embedded ref.TypeProvider.
//
// FindType overrides the default type-finding behavior of the embedded TypeProvider.
//
// Note, when the type name is based on the Open API Schema, the name will reflect the object path
// where the type definition appears.
func (rt *RuleTypes) FindType(typeName string) (*exprpb.Type, bool) {
if rt == nil {
return nil, false
}
declType, found := rt.findDeclType(typeName)
if found {
expT, err := declType.ExprType()
if err != nil {
return expT, false
}
return expT, found
}
return rt.TypeProvider.FindType(typeName)
}
// FindDeclType returns the CPT type description which can be mapped to a CEL type.
func (rt *RuleTypes) FindDeclType(typeName string) (*DeclType, bool) {
if rt == nil {
return nil, false
}
return rt.findDeclType(typeName)
}
// FindFieldType returns a field type given a type name and field name, if found.
//
// Note, the type name for an Open API Schema type is likely to be its qualified object path.
// If, in the future an object instance rather than a type name were provided, the field
// resolution might more accurately reflect the expected type model. However, in this case
// concessions were made to align with the existing CEL interfaces.
func (rt *RuleTypes) FindFieldType(typeName, fieldName string) (*ref.FieldType, bool) {
st, found := rt.findDeclType(typeName)
if !found {
return rt.TypeProvider.FindFieldType(typeName, fieldName)
}
f, found := st.Fields[fieldName]
if found {
ft := f.Type
expT, err := ft.ExprType()
if err != nil {
return nil, false
}
return &ref.FieldType{
Type: expT,
}, true
}
// This could be a dynamic map.
if st.IsMap() {
et := st.ElemType
expT, err := et.ExprType()
if err != nil {
return nil, false
}
return &ref.FieldType{
Type: expT,
}, true
}
return nil, false
}
// NativeToValue is an implementation of the ref.TypeAdapater interface which supports conversion
// of rule values to CEL ref.Val instances.
func (rt *RuleTypes) NativeToValue(val interface{}) ref.Val {
return rt.typeAdapter.NativeToValue(val)
}
// TypeNames returns the list of type names declared within the RuleTypes object.
func (rt *RuleTypes) TypeNames() []string {
typeNames := make([]string, len(rt.ruleSchemaDeclTypes.types))
i := 0
for name := range rt.ruleSchemaDeclTypes.types {
typeNames[i] = name
i++
}
return typeNames
}
func (rt *RuleTypes) findDeclType(typeName string) (*DeclType, bool) {
declType, found := rt.ruleSchemaDeclTypes.types[typeName]
if found {
return declType, true
}
declType, found = rt.resolver.FindType(typeName)
if found {
return declType, true
}
return nil, false
}
func newSchemaTypeProvider(kind string, declType *DeclType) (*schemaTypeProvider, error) {
if declType == nil {
return nil, nil
}
root := declType.MaybeAssignTypeName(kind)
types := FieldTypeMap(kind, root)
return &schemaTypeProvider{
root: root,
types: types,
}, nil
}
type schemaTypeProvider struct {
root *DeclType
types map[string]*DeclType
}
var (
// AnyType is equivalent to the CEL 'protobuf.Any' type in that the value may have any of the
// types supported.
AnyType = NewSimpleTypeWithMinSize("any", cel.AnyType, nil, 1)
// BoolType is equivalent to the CEL 'bool' type.
BoolType = NewSimpleTypeWithMinSize("bool", cel.BoolType, types.False, MinBoolSize)
// BytesType is equivalent to the CEL 'bytes' type.
BytesType = NewSimpleTypeWithMinSize("bytes", cel.BytesType, types.Bytes([]byte{}), MinStringSize)
// DoubleType is equivalent to the CEL 'double' type which is a 64-bit floating point value.
DoubleType = NewSimpleTypeWithMinSize("double", cel.DoubleType, types.Double(0), MinNumberSize)
// DurationType is equivalent to the CEL 'duration' type.
DurationType = NewSimpleTypeWithMinSize("duration", cel.DurationType, types.Duration{Duration: time.Duration(0)}, MinDurationSizeJSON)
// DateType is equivalent to the CEL 'date' type.
DateType = NewSimpleTypeWithMinSize("date", cel.TimestampType, types.Timestamp{Time: time.Time{}}, JSONDateSize)
// DynType is the equivalent of the CEL 'dyn' concept which indicates that the type will be
// determined at runtime rather than compile time.
DynType = NewSimpleTypeWithMinSize("dyn", cel.DynType, nil, 1)
// IntType is equivalent to the CEL 'int' type which is a 64-bit signed int.
IntType = NewSimpleTypeWithMinSize("int", cel.IntType, types.IntZero, MinNumberSize)
// NullType is equivalent to the CEL 'null_type'.
NullType = NewSimpleTypeWithMinSize("null_type", cel.NullType, types.NullValue, 4)
// StringType is equivalent to the CEL 'string' type which is expected to be a UTF-8 string.
// StringType values may either be string literals or expression strings.
StringType = NewSimpleTypeWithMinSize("string", cel.StringType, types.String(""), MinStringSize)
// TimestampType corresponds to the well-known protobuf.Timestamp type supported within CEL.
// Note that both the OpenAPI date and date-time types map onto TimestampType, so not all types
// labeled as Timestamp will necessarily have the same MinSerializedSize.
TimestampType = NewSimpleTypeWithMinSize("timestamp", cel.TimestampType, types.Timestamp{Time: time.Time{}}, JSONDateSize)
// UintType is equivalent to the CEL 'uint' type.
UintType = NewSimpleTypeWithMinSize("uint", cel.UintType, types.Uint(0), 1)
// ListType is equivalent to the CEL 'list' type.
ListType = NewListType(AnyType, noMaxLength)
// MapType is equivalent to the CEL 'map' type.
MapType = NewMapType(AnyType, AnyType, noMaxLength)
)

79
pkg/cel/types_test.go Normal file
View File

@ -0,0 +1,79 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cel
import (
"testing"
)
func TestTypes_ListType(t *testing.T) {
list := NewListType(StringType, -1)
if !list.IsList() {
t.Error("list type not identifiable as list")
}
if list.TypeName() != "list" {
t.Errorf("got %s, wanted list", list.TypeName())
}
if list.DefaultValue() == nil {
t.Error("got nil zero value for list type")
}
if list.ElemType.TypeName() != "string" {
t.Errorf("got %s, wanted elem type of string", list.ElemType.TypeName())
}
expT, err := list.ExprType()
if err != nil {
t.Errorf("fail to get cel type: %s", err)
}
if expT.GetListType() == nil {
t.Errorf("got %v, wanted CEL list type", expT)
}
}
func TestTypes_MapType(t *testing.T) {
mp := NewMapType(StringType, IntType, -1)
if !mp.IsMap() {
t.Error("map type not identifiable as map")
}
if mp.TypeName() != "map" {
t.Errorf("got %s, wanted map", mp.TypeName())
}
if mp.DefaultValue() == nil {
t.Error("got nil zero value for map type")
}
if mp.KeyType.TypeName() != "string" {
t.Errorf("got %s, wanted key type of string", mp.KeyType.TypeName())
}
if mp.ElemType.TypeName() != "int" {
t.Errorf("got %s, wanted elem type of int", mp.ElemType.TypeName())
}
expT, err := mp.ExprType()
if err != nil {
t.Errorf("fail to get cel type: %s", err)
}
if expT.GetMapType() == nil {
t.Errorf("got %v, wanted CEL map type", expT)
}
}
func testValue(t *testing.T, id int64, val interface{}) *DynValue {
t.Helper()
dv, err := NewDynValue(id, val)
if err != nil {
t.Fatalf("NewDynValue(%d, %v) failed: %v", id, val, err)
}
return dv
}

80
pkg/cel/url.go Normal file
View File

@ -0,0 +1,80 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cel
import (
"fmt"
"net/url"
"reflect"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
// URL provides a CEL representation of a URL.
type URL struct {
*url.URL
}
var (
URLObject = decls.NewObjectType("kubernetes.URL")
typeValue = types.NewTypeValue("kubernetes.URL")
URLType = cel.ObjectType("kubernetes.URL")
)
// ConvertToNative implements ref.Val.ConvertToNative.
func (d URL) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
if reflect.TypeOf(d.URL).AssignableTo(typeDesc) {
return d.URL, nil
}
if reflect.TypeOf("").AssignableTo(typeDesc) {
return d.URL.String(), nil
}
return nil, fmt.Errorf("type conversion error from 'URL' to '%v'", typeDesc)
}
// ConvertToType implements ref.Val.ConvertToType.
func (d URL) ConvertToType(typeVal ref.Type) ref.Val {
switch typeVal {
case typeValue:
return d
case types.TypeType:
return typeValue
}
return types.NewErr("type conversion error from '%s' to '%s'", typeValue, typeVal)
}
// Equal implements ref.Val.Equal.
func (d URL) Equal(other ref.Val) ref.Val {
otherDur, ok := other.(URL)
if !ok {
return types.MaybeNoSuchOverloadErr(other)
}
return types.Bool(d.URL.String() == otherDur.URL.String())
}
// Type implements ref.Val.Type.
func (d URL) Type() ref.Type {
return typeValue
}
// Value implements ref.Val.Value.
func (d URL) Value() interface{} {
return d.URL
}

769
pkg/cel/value.go Normal file
View File

@ -0,0 +1,769 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cel
import (
"fmt"
"reflect"
"sync"
"time"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
)
// EncodeStyle is a hint for string encoding of parsed values.
type EncodeStyle int
const (
// BlockValueStyle is the default string encoding which preserves whitespace and newlines.
BlockValueStyle EncodeStyle = iota
// FlowValueStyle indicates that the string is an inline representation of complex types.
FlowValueStyle
// FoldedValueStyle is a multiline string with whitespace and newlines trimmed to a single
// a whitespace. Repeated newlines are replaced with a single newline rather than a single
// whitespace.
FoldedValueStyle
// LiteralStyle is a multiline string that preserves newlines, but trims all other whitespace
// to a single character.
LiteralStyle
)
// NewEmptyDynValue returns the zero-valued DynValue.
func NewEmptyDynValue() *DynValue {
// note: 0 is not a valid parse node identifier.
dv, _ := NewDynValue(0, nil)
return dv
}
// NewDynValue returns a DynValue that corresponds to a parse node id and value.
func NewDynValue(id int64, val interface{}) (*DynValue, error) {
dv := &DynValue{ID: id}
err := dv.SetValue(val)
return dv, err
}
// DynValue is a dynamically typed value used to describe unstructured content.
// Whether the value has the desired type is determined by where it is used within the Instance or
// Template, and whether there are schemas which might enforce a more rigid type definition.
type DynValue struct {
ID int64
EncodeStyle EncodeStyle
value interface{}
exprValue ref.Val
declType *DeclType
}
// DeclType returns the policy model type of the dyn value.
func (dv *DynValue) DeclType() *DeclType {
return dv.declType
}
// ConvertToNative is an implementation of the CEL ref.Val method used to adapt between CEL types
// and Go-native types.
//
// The default behavior of this method is to first convert to a CEL type which has a well-defined
// set of conversion behaviors and proxy to the CEL ConvertToNative method for the type.
func (dv *DynValue) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
ev := dv.ExprValue()
if types.IsError(ev) {
return nil, ev.(*types.Err)
}
return ev.ConvertToNative(typeDesc)
}
// Equal returns whether the dyn value is equal to a given CEL value.
func (dv *DynValue) Equal(other ref.Val) ref.Val {
dvType := dv.Type()
otherType := other.Type()
// Preserve CEL's homogeneous equality constraint.
if dvType.TypeName() != otherType.TypeName() {
return types.MaybeNoSuchOverloadErr(other)
}
switch v := dv.value.(type) {
case ref.Val:
return v.Equal(other)
case PlainTextValue:
return celBool(string(v) == other.Value().(string))
case *MultilineStringValue:
return celBool(v.Value == other.Value().(string))
case time.Duration:
otherDuration := other.Value().(time.Duration)
return celBool(v == otherDuration)
case time.Time:
otherTimestamp := other.Value().(time.Time)
return celBool(v.Equal(otherTimestamp))
default:
return celBool(reflect.DeepEqual(v, other.Value()))
}
}
// ExprValue converts the DynValue into a CEL value.
func (dv *DynValue) ExprValue() ref.Val {
return dv.exprValue
}
// Value returns the underlying value held by this reference.
func (dv *DynValue) Value() interface{} {
return dv.value
}
// SetValue updates the underlying value held by this reference.
func (dv *DynValue) SetValue(value interface{}) error {
dv.value = value
var err error
dv.exprValue, dv.declType, err = exprValue(value)
return err
}
// Type returns the CEL type for the given value.
func (dv *DynValue) Type() ref.Type {
return dv.ExprValue().Type()
}
func exprValue(value interface{}) (ref.Val, *DeclType, error) {
switch v := value.(type) {
case bool:
return types.Bool(v), BoolType, nil
case []byte:
return types.Bytes(v), BytesType, nil
case float64:
return types.Double(v), DoubleType, nil
case int64:
return types.Int(v), IntType, nil
case string:
return types.String(v), StringType, nil
case uint64:
return types.Uint(v), UintType, nil
case time.Duration:
return types.Duration{Duration: v}, DurationType, nil
case time.Time:
return types.Timestamp{Time: v}, TimestampType, nil
case types.Null:
return v, NullType, nil
case *ListValue:
return v, ListType, nil
case *MapValue:
return v, MapType, nil
case *ObjectValue:
return v, v.objectType, nil
default:
return nil, unknownType, fmt.Errorf("unsupported type: (%T)%v", v, v)
}
}
// PlainTextValue is a text string literal which must not be treated as an expression.
type PlainTextValue string
// MultilineStringValue is a multiline string value which has been parsed in a way which omits
// whitespace as well as a raw form which preserves whitespace.
type MultilineStringValue struct {
Value string
Raw string
}
func newStructValue() *structValue {
return &structValue{
Fields: []*Field{},
fieldMap: map[string]*Field{},
}
}
type structValue struct {
Fields []*Field
fieldMap map[string]*Field
}
// AddField appends a MapField to the MapValue and indexes the field by name.
func (sv *structValue) AddField(field *Field) {
sv.Fields = append(sv.Fields, field)
sv.fieldMap[field.Name] = field
}
// ConvertToNative converts the MapValue type to a native go types.
func (sv *structValue) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
if typeDesc.Kind() != reflect.Map &&
typeDesc.Kind() != reflect.Struct &&
typeDesc.Kind() != reflect.Pointer &&
typeDesc.Kind() != reflect.Interface {
return nil, fmt.Errorf("type conversion error from object to '%v'", typeDesc)
}
// Unwrap pointers, but track their use.
isPtr := false
if typeDesc.Kind() == reflect.Pointer {
tk := typeDesc
typeDesc = typeDesc.Elem()
if typeDesc.Kind() == reflect.Pointer {
return nil, fmt.Errorf("unsupported type conversion to '%v'", tk)
}
isPtr = true
}
if typeDesc.Kind() == reflect.Map {
keyType := typeDesc.Key()
if keyType.Kind() != reflect.String && keyType.Kind() != reflect.Interface {
return nil, fmt.Errorf("object fields cannot be converted to type '%v'", keyType)
}
elemType := typeDesc.Elem()
sz := len(sv.fieldMap)
ntvMap := reflect.MakeMapWithSize(typeDesc, sz)
for name, val := range sv.fieldMap {
refVal, err := val.Ref.ConvertToNative(elemType)
if err != nil {
return nil, err
}
ntvMap.SetMapIndex(reflect.ValueOf(name), reflect.ValueOf(refVal))
}
return ntvMap.Interface(), nil
}
if typeDesc.Kind() == reflect.Struct {
ntvObjPtr := reflect.New(typeDesc)
ntvObj := ntvObjPtr.Elem()
for name, val := range sv.fieldMap {
f := ntvObj.FieldByName(name)
if !f.IsValid() {
return nil, fmt.Errorf("type conversion error, no such field %s in type %v",
name, typeDesc)
}
fv, err := val.Ref.ConvertToNative(f.Type())
if err != nil {
return nil, err
}
f.Set(reflect.ValueOf(fv))
}
if isPtr {
return ntvObjPtr.Interface(), nil
}
return ntvObj.Interface(), nil
}
return nil, fmt.Errorf("type conversion error from object to '%v'", typeDesc)
}
// GetField returns a MapField by name if one exists.
func (sv *structValue) GetField(name string) (*Field, bool) {
field, found := sv.fieldMap[name]
return field, found
}
// IsSet returns whether the given field, which is defined, has also been set.
func (sv *structValue) IsSet(key ref.Val) ref.Val {
k, ok := key.(types.String)
if !ok {
return types.MaybeNoSuchOverloadErr(key)
}
name := string(k)
_, found := sv.fieldMap[name]
return celBool(found)
}
// NewObjectValue creates a struct value with a schema type and returns the empty ObjectValue.
func NewObjectValue(sType *DeclType) *ObjectValue {
return &ObjectValue{
structValue: newStructValue(),
objectType: sType,
}
}
// ObjectValue is a struct with a custom schema type which indicates the fields and types
// associated with the structure.
type ObjectValue struct {
*structValue
objectType *DeclType
}
// ConvertToType is an implementation of the CEL ref.Val interface method.
func (o *ObjectValue) ConvertToType(t ref.Type) ref.Val {
if t == types.TypeType {
return types.NewObjectTypeValue(o.objectType.TypeName())
}
if t.TypeName() == o.objectType.TypeName() {
return o
}
return types.NewErr("type conversion error from '%s' to '%s'", o.Type(), t)
}
// Equal returns true if the two object types are equal and their field values are equal.
func (o *ObjectValue) Equal(other ref.Val) ref.Val {
// Preserve CEL's homogeneous equality semantics.
if o.objectType.TypeName() != other.Type().TypeName() {
return types.MaybeNoSuchOverloadErr(other)
}
o2 := other.(traits.Indexer)
for name := range o.objectType.Fields {
k := types.String(name)
v := o.Get(k)
ov := o2.Get(k)
vEq := v.Equal(ov)
if vEq != types.True {
return vEq
}
}
return types.True
}
// Get returns the value of the specified field.
//
// If the field is set, its value is returned. If the field is not set, the default value for the
// field is returned thus allowing for safe-traversal and preserving proto-like field traversal
// semantics for Open API Schema backed types.
func (o *ObjectValue) Get(name ref.Val) ref.Val {
n, ok := name.(types.String)
if !ok {
return types.MaybeNoSuchOverloadErr(n)
}
nameStr := string(n)
field, found := o.fieldMap[nameStr]
if found {
return field.Ref.ExprValue()
}
fieldDef, found := o.objectType.Fields[nameStr]
if !found {
return types.NewErr("no such field: %s", nameStr)
}
defValue := fieldDef.DefaultValue()
if defValue != nil {
return defValue
}
return types.NewErr("no default for type: %s", fieldDef.TypeName())
}
// Type returns the CEL type value of the object.
func (o *ObjectValue) Type() ref.Type {
return o.objectType
}
// Value returns the Go-native representation of the object.
func (o *ObjectValue) Value() interface{} {
return o
}
// NewMapValue returns an empty MapValue.
func NewMapValue() *MapValue {
return &MapValue{
structValue: newStructValue(),
}
}
// MapValue declares an object with a set of named fields whose values are dynamically typed.
type MapValue struct {
*structValue
}
// ConvertToObject produces an ObjectValue from the MapValue with the associated schema type.
//
// The conversion is shallow and the memory shared between the Object and Map as all references
// to the map are expected to be replaced with the Object reference.
func (m *MapValue) ConvertToObject(declType *DeclType) *ObjectValue {
return &ObjectValue{
structValue: m.structValue,
objectType: declType,
}
}
// Contains returns whether the given key is contained in the MapValue.
func (m *MapValue) Contains(key ref.Val) ref.Val {
v, found := m.Find(key)
if v != nil && types.IsUnknownOrError(v) {
return v
}
return celBool(found)
}
// ConvertToType converts the MapValue to another CEL type, if possible.
func (m *MapValue) ConvertToType(t ref.Type) ref.Val {
switch t {
case types.MapType:
return m
case types.TypeType:
return types.MapType
}
return types.NewErr("type conversion error from '%s' to '%s'", m.Type(), t)
}
// Equal returns true if the maps are of the same size, have the same keys, and the key-values
// from each map are equal.
func (m *MapValue) Equal(other ref.Val) ref.Val {
oMap, isMap := other.(traits.Mapper)
if !isMap {
return types.MaybeNoSuchOverloadErr(other)
}
if m.Size() != oMap.Size() {
return types.False
}
for name, field := range m.fieldMap {
k := types.String(name)
ov, found := oMap.Find(k)
if !found {
return types.False
}
v := field.Ref.ExprValue()
vEq := v.Equal(ov)
if vEq != types.True {
return vEq
}
}
return types.True
}
// Find returns the value for the key in the map, if found.
func (m *MapValue) Find(name ref.Val) (ref.Val, bool) {
// Currently only maps with string keys are supported as this is best aligned with JSON,
// and also much simpler to support.
n, ok := name.(types.String)
if !ok {
return types.MaybeNoSuchOverloadErr(n), true
}
nameStr := string(n)
field, found := m.fieldMap[nameStr]
if found {
return field.Ref.ExprValue(), true
}
return nil, false
}
// Get returns the value for the key in the map, or error if not found.
func (m *MapValue) Get(key ref.Val) ref.Val {
v, found := m.Find(key)
if found {
return v
}
return types.ValOrErr(key, "no such key: %v", key)
}
// Iterator produces a traits.Iterator which walks over the map keys.
//
// The Iterator is frequently used within comprehensions.
func (m *MapValue) Iterator() traits.Iterator {
keys := make([]ref.Val, len(m.fieldMap))
i := 0
for k := range m.fieldMap {
keys[i] = types.String(k)
i++
}
return &baseMapIterator{
baseVal: &baseVal{},
keys: keys,
}
}
// Size returns the number of keys in the map.
func (m *MapValue) Size() ref.Val {
return types.Int(len(m.Fields))
}
// Type returns the CEL ref.Type for the map.
func (m *MapValue) Type() ref.Type {
return types.MapType
}
// Value returns the Go-native representation of the MapValue.
func (m *MapValue) Value() interface{} {
return m
}
type baseMapIterator struct {
*baseVal
keys []ref.Val
idx int
}
// HasNext implements the traits.Iterator interface method.
func (it *baseMapIterator) HasNext() ref.Val {
if it.idx < len(it.keys) {
return types.True
}
return types.False
}
// Next implements the traits.Iterator interface method.
func (it *baseMapIterator) Next() ref.Val {
key := it.keys[it.idx]
it.idx++
return key
}
// Type implements the CEL ref.Val interface metohd.
func (it *baseMapIterator) Type() ref.Type {
return types.IteratorType
}
// NewField returns a MapField instance with an empty DynValue that refers to the
// specified parse node id and field name.
func NewField(id int64, name string) *Field {
return &Field{
ID: id,
Name: name,
Ref: NewEmptyDynValue(),
}
}
// Field specifies a field name and a reference to a dynamic value.
type Field struct {
ID int64
Name string
Ref *DynValue
}
// NewListValue returns an empty ListValue instance.
func NewListValue() *ListValue {
return &ListValue{
Entries: []*DynValue{},
}
}
// ListValue contains a list of dynamically typed entries.
type ListValue struct {
Entries []*DynValue
initValueSet sync.Once
valueSet map[ref.Val]struct{}
}
// Add concatenates two lists together to produce a new CEL list value.
func (lv *ListValue) Add(other ref.Val) ref.Val {
oArr, isArr := other.(traits.Lister)
if !isArr {
return types.MaybeNoSuchOverloadErr(other)
}
szRight := len(lv.Entries)
szLeft := int(oArr.Size().(types.Int))
sz := szRight + szLeft
combo := make([]ref.Val, sz)
for i := 0; i < szRight; i++ {
combo[i] = lv.Entries[i].ExprValue()
}
for i := 0; i < szLeft; i++ {
combo[i+szRight] = oArr.Get(types.Int(i))
}
return types.DefaultTypeAdapter.NativeToValue(combo)
}
// Append adds another entry into the ListValue.
func (lv *ListValue) Append(entry *DynValue) {
lv.Entries = append(lv.Entries, entry)
// The append resets all previously built indices.
lv.initValueSet = sync.Once{}
}
// Contains returns whether the input `val` is equal to an element in the list.
//
// If any pair-wise comparison between the input value and the list element is an error, the
// operation will return an error.
func (lv *ListValue) Contains(val ref.Val) ref.Val {
if types.IsUnknownOrError(val) {
return val
}
lv.initValueSet.Do(lv.finalizeValueSet)
if lv.valueSet != nil {
_, found := lv.valueSet[val]
if found {
return types.True
}
// Instead of returning false, ensure that CEL's heterogeneous equality constraint
// is satisfied by allowing pair-wise equality behavior to determine the outcome.
}
var err ref.Val
sz := len(lv.Entries)
for i := 0; i < sz; i++ {
elem := lv.Entries[i]
cmp := elem.Equal(val)
b, ok := cmp.(types.Bool)
if !ok && err == nil {
err = types.MaybeNoSuchOverloadErr(cmp)
}
if b == types.True {
return types.True
}
}
if err != nil {
return err
}
return types.False
}
// ConvertToNative is an implementation of the CEL ref.Val method used to adapt between CEL types
// and Go-native array-like types.
func (lv *ListValue) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
// Non-list conversion.
if typeDesc.Kind() != reflect.Slice &&
typeDesc.Kind() != reflect.Array &&
typeDesc.Kind() != reflect.Interface {
return nil, fmt.Errorf("type conversion error from list to '%v'", typeDesc)
}
// If the list is already assignable to the desired type return it.
if reflect.TypeOf(lv).AssignableTo(typeDesc) {
return lv, nil
}
// List conversion.
otherElem := typeDesc.Elem()
// Allow the element ConvertToNative() function to determine whether conversion is possible.
sz := len(lv.Entries)
nativeList := reflect.MakeSlice(typeDesc, int(sz), int(sz))
for i := 0; i < sz; i++ {
elem := lv.Entries[i]
nativeElemVal, err := elem.ConvertToNative(otherElem)
if err != nil {
return nil, err
}
nativeList.Index(int(i)).Set(reflect.ValueOf(nativeElemVal))
}
return nativeList.Interface(), nil
}
// ConvertToType converts the ListValue to another CEL type.
func (lv *ListValue) ConvertToType(t ref.Type) ref.Val {
switch t {
case types.ListType:
return lv
case types.TypeType:
return types.ListType
}
return types.NewErr("type conversion error from '%s' to '%s'", ListType, t)
}
// Equal returns true if two lists are of the same size, and the values at each index are also
// equal.
func (lv *ListValue) Equal(other ref.Val) ref.Val {
oArr, isArr := other.(traits.Lister)
if !isArr {
return types.MaybeNoSuchOverloadErr(other)
}
sz := types.Int(len(lv.Entries))
if sz != oArr.Size() {
return types.False
}
for i := types.Int(0); i < sz; i++ {
cmp := lv.Get(i).Equal(oArr.Get(i))
if cmp != types.True {
return cmp
}
}
return types.True
}
// Get returns the value at the given index.
//
// If the index is negative or greater than the size of the list, an error is returned.
func (lv *ListValue) Get(idx ref.Val) ref.Val {
iv, isInt := idx.(types.Int)
if !isInt {
return types.ValOrErr(idx, "unsupported index: %v", idx)
}
i := int(iv)
if i < 0 || i >= len(lv.Entries) {
return types.NewErr("index out of bounds: %v", idx)
}
return lv.Entries[i].ExprValue()
}
// Iterator produces a traits.Iterator suitable for use in CEL comprehension macros.
func (lv *ListValue) Iterator() traits.Iterator {
return &baseListIterator{
getter: lv.Get,
sz: len(lv.Entries),
}
}
// Size returns the number of elements in the list.
func (lv *ListValue) Size() ref.Val {
return types.Int(len(lv.Entries))
}
// Type returns the CEL ref.Type for the list.
func (lv *ListValue) Type() ref.Type {
return types.ListType
}
// Value returns the Go-native value.
func (lv *ListValue) Value() interface{} {
return lv
}
// finalizeValueSet inspects the ListValue entries in order to make internal optimizations once all list
// entries are known.
func (lv *ListValue) finalizeValueSet() {
valueSet := make(map[ref.Val]struct{})
for _, e := range lv.Entries {
switch e.value.(type) {
case bool, float64, int64, string, uint64, types.Null, PlainTextValue:
valueSet[e.ExprValue()] = struct{}{}
default:
lv.valueSet = nil
return
}
}
lv.valueSet = valueSet
}
type baseVal struct{}
func (*baseVal) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
return nil, fmt.Errorf("unsupported native conversion to: %v", typeDesc)
}
func (*baseVal) ConvertToType(t ref.Type) ref.Val {
return types.NewErr("unsupported type conversion to: %v", t)
}
func (*baseVal) Equal(other ref.Val) ref.Val {
return types.NewErr("unsupported equality test between instances")
}
func (v *baseVal) Value() interface{} {
return nil
}
type baseListIterator struct {
*baseVal
getter func(idx ref.Val) ref.Val
sz int
idx int
}
func (it *baseListIterator) HasNext() ref.Val {
if it.idx < it.sz {
return types.True
}
return types.False
}
func (it *baseListIterator) Next() ref.Val {
v := it.getter(types.Int(it.idx))
it.idx++
return v
}
func (it *baseListIterator) Type() ref.Type {
return types.IteratorType
}
func celBool(pred bool) ref.Val {
if pred {
return types.True
}
return types.False
}
var unknownType = &DeclType{name: "unknown", MinSerializedSize: 1}

362
pkg/cel/value_test.go Normal file
View File

@ -0,0 +1,362 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cel
import (
"fmt"
"reflect"
"testing"
"time"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
)
func TestConvertToType(t *testing.T) {
objType := NewObjectType("TestObject", map[string]*DeclField{})
tests := []struct {
val interface{}
typ ref.Type
}{
{true, types.BoolType},
{float64(1.2), types.DoubleType},
{int64(-42), types.IntType},
{uint64(63), types.UintType},
{time.Duration(300), types.DurationType},
{time.Now().UTC(), types.TimestampType},
{types.NullValue, types.NullType},
{NewListValue(), types.ListType},
{NewMapValue(), types.MapType},
{[]byte("bytes"), types.BytesType},
{NewObjectValue(objType), objType},
}
for i, tc := range tests {
idx := i
tst := tc
t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) {
dv := testValue(t, int64(idx), tst.val)
ev := dv.ExprValue()
if ev.ConvertToType(types.TypeType).(ref.Type).TypeName() != tst.typ.TypeName() {
t.Errorf("got %v, wanted %v type", ev.ConvertToType(types.TypeType), tst.typ)
}
if ev.ConvertToType(tst.typ).Equal(ev) != types.True {
t.Errorf("got %v, wanted input value %v", ev.ConvertToType(tst.typ), ev)
}
})
}
}
func TestEqual(t *testing.T) {
vals := []interface{}{
true, []byte("bytes"), float64(1.2), int64(-42), uint64(63), time.Duration(300),
time.Now().UTC(), types.NullValue, NewListValue(), NewMapValue(),
NewObjectValue(NewObjectType("TestObject", map[string]*DeclField{})),
}
for i, v := range vals {
dv := testValue(t, int64(i), v)
if dv.Equal(dv.ExprValue()) != types.True {
t.Errorf("got %v, wanted dyn value %v equal to itself", dv.Equal(dv.ExprValue()), dv.ExprValue())
}
}
}
func TestListValueAdd(t *testing.T) {
lv := NewListValue()
lv.Append(testValue(t, 1, "first"))
ov := NewListValue()
ov.Append(testValue(t, 2, "second"))
ov.Append(testValue(t, 3, "third"))
llv := NewListValue()
llv.Append(testValue(t, 4, lv))
lov := NewListValue()
lov.Append(testValue(t, 5, ov))
var v traits.Lister = llv.Add(lov).(traits.Lister)
if v.Size() != types.Int(2) {
t.Errorf("got list size %d, wanted 2", v.Size())
}
complex, err := v.ConvertToNative(reflect.TypeOf([][]string{}))
complexList := complex.([][]string)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(complexList, [][]string{{"first"}, {"second", "third"}}) {
t.Errorf("got %v, wanted [['first'], ['second', 'third']]", complexList)
}
}
func TestListValueContains(t *testing.T) {
lv := NewListValue()
lv.Append(testValue(t, 1, "first"))
lv.Append(testValue(t, 2, "second"))
lv.Append(testValue(t, 3, "third"))
for i := types.Int(0); i < lv.Size().(types.Int); i++ {
e := lv.Get(i)
contained := lv.Contains(e)
if contained != types.True {
t.Errorf("got %v, wanted list contains elem[%v] %v == true", contained, i, e)
}
}
if lv.Contains(types.String("fourth")) != types.False {
t.Errorf("got %v, wanted false 'fourth'", lv.Contains(types.String("fourth")))
}
if !types.IsError(lv.Contains(types.Int(-1))) {
t.Errorf("got %v, wanted error for invalid type", lv.Contains(types.Int(-1)))
}
}
func TestListValueContainsNestedList(t *testing.T) {
lvA := NewListValue()
lvA.Append(testValue(t, 1, int64(1)))
lvA.Append(testValue(t, 2, int64(2)))
lvB := NewListValue()
lvB.Append(testValue(t, 3, int64(3)))
elemA, elemB := testValue(t, 4, lvA), testValue(t, 5, lvB)
lv := NewListValue()
lv.Append(elemA)
lv.Append(elemB)
contained := lv.Contains(elemA.ExprValue())
if contained != types.True {
t.Errorf("got %v, wanted elemA contained in list value", contained)
}
contained = lv.Contains(elemB.ExprValue())
if contained != types.True {
t.Errorf("got %v, wanted elemB contained in list value", contained)
}
contained = lv.Contains(types.DefaultTypeAdapter.NativeToValue([]int32{4}))
if contained != types.False {
t.Errorf("got %v, wanted empty list not contained", contained)
}
}
func TestListValueConvertToNative(t *testing.T) {
lv := NewListValue()
none, err := lv.ConvertToNative(reflect.TypeOf([]interface{}{}))
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(none, []interface{}{}) {
t.Errorf("got %v, wanted empty list", none)
}
lv.Append(testValue(t, 1, "first"))
one, err := lv.ConvertToNative(reflect.TypeOf([]string{}))
oneList := one.([]string)
if err != nil {
t.Fatal(err)
}
if len(oneList) != 1 {
t.Errorf("got len(one) == %d, wanted 1", len(oneList))
}
if !reflect.DeepEqual(oneList, []string{"first"}) {
t.Errorf("got %v, wanted string list", oneList)
}
ov := NewListValue()
ov.Append(testValue(t, 2, "second"))
ov.Append(testValue(t, 3, "third"))
if ov.Size() != types.Int(2) {
t.Errorf("got list size %d, wanted 2", ov.Size())
}
llv := NewListValue()
llv.Append(testValue(t, 4, lv))
llv.Append(testValue(t, 5, ov))
if llv.Size() != types.Int(2) {
t.Errorf("got list size %d, wanted 2", llv.Size())
}
complex, err := llv.ConvertToNative(reflect.TypeOf([][]string{}))
complexList := complex.([][]string)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(complexList, [][]string{{"first"}, {"second", "third"}}) {
t.Errorf("got %v, wanted [['first'], ['second', 'third']]", complexList)
}
}
func TestListValueIterator(t *testing.T) {
lv := NewListValue()
lv.Append(testValue(t, 1, "first"))
lv.Append(testValue(t, 2, "second"))
lv.Append(testValue(t, 3, "third"))
it := lv.Iterator()
if it.Type() != types.IteratorType {
t.Errorf("got type %v for iterator, wanted IteratorType", it.Type())
}
i := types.Int(0)
for it.HasNext() == types.True {
v := it.Next()
if v.Equal(lv.Get(i)) != types.True {
t.Errorf("iterator value %v and value %v at index %d not equal", v, lv.Get(i), i)
}
i++
}
}
func TestMapValueConvertToNative(t *testing.T) {
mv := NewMapValue()
none, err := mv.ConvertToNative(reflect.TypeOf(map[string]interface{}{}))
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(none, map[string]interface{}{}) {
t.Errorf("got %v, wanted empty map", none)
}
none, err = mv.ConvertToNative(reflect.TypeOf(map[interface{}]interface{}{}))
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(none, map[interface{}]interface{}{}) {
t.Errorf("got %v, wanted empty map", none)
}
mv.AddField(NewField(1, "Test"))
tst, _ := mv.GetField("Test")
tst.Ref = testValue(t, 2, uint64(12))
mv.AddField(NewField(3, "Check"))
chk, _ := mv.GetField("Check")
chk.Ref = testValue(t, 4, uint64(34))
if mv.Size() != types.Int(2) {
t.Errorf("got size %d, wanted 2", mv.Size())
}
if mv.Contains(types.String("Test")) != types.True {
t.Error("key 'Test' not found")
}
if mv.Contains(types.String("Check")) != types.True {
t.Error("key 'Check' not found")
}
if mv.Contains(types.String("Checked")) != types.False {
t.Error("key 'Checked' found, wanted not found")
}
it := mv.Iterator()
for it.HasNext() == types.True {
k := it.Next()
v := mv.Get(k)
if k == types.String("Test") && v != types.Uint(12) {
t.Errorf("key 'Test' not equal to 12u")
}
if k == types.String("Check") && v != types.Uint(34) {
t.Errorf("key 'Check' not equal to 34u")
}
}
mpStrUint, err := mv.ConvertToNative(reflect.TypeOf(map[string]uint64{}))
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(mpStrUint, map[string]uint64{
"Test": uint64(12),
"Check": uint64(34),
}) {
t.Errorf("got %v, wanted {'Test': 12u, 'Check': 34u}", mpStrUint)
}
tstStr, err := mv.ConvertToNative(reflect.TypeOf(&tstStruct{}))
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(tstStr, &tstStruct{
Test: uint64(12),
Check: uint64(34),
}) {
t.Errorf("got %v, wanted tstStruct{Test: 12u, Check: 34u}", tstStr)
}
}
func TestMapValueEqual(t *testing.T) {
mv := NewMapValue()
name := NewField(1, "name")
name.Ref = testValue(t, 2, "alert")
priority := NewField(3, "priority")
priority.Ref = testValue(t, 4, int64(4))
mv.AddField(name)
mv.AddField(priority)
if mv.Equal(mv) != types.True {
t.Fatalf("map.Equal(map) failed: %v", mv.Equal(mv))
}
}
func TestMapValueNotEqual(t *testing.T) {
mv := NewMapValue()
name := NewField(1, "name")
name.Ref = testValue(t, 2, "alert")
priority := NewField(3, "priority")
priority.Ref = testValue(t, 4, int64(4))
mv.AddField(name)
mv.AddField(priority)
mv2 := NewMapValue()
mv2.AddField(name)
if mv.Equal(mv2) != types.False {
t.Fatalf("mv.Equal(mv2) failed: %v", mv.Equal(mv2))
}
priority2 := NewField(5, "priority")
priority2.Ref = testValue(t, 6, int64(3))
mv2.AddField(priority2)
if mv.Equal(mv2) != types.False {
t.Fatalf("mv.Equal(mv2) failed: %v", mv.Equal(mv2))
}
}
func TestMapValueIsSet(t *testing.T) {
mv := NewMapValue()
if mv.IsSet(types.String("name")) != types.False {
t.Error("map.IsSet('name') returned true for unset key")
}
mv.AddField(NewField(1, "name"))
if mv.IsSet(types.String("name")) != types.True {
t.Error("map.IsSet('name') returned false for a set key")
}
}
func TestObjectValueEqual(t *testing.T) {
objType := NewObjectType("Notice", map[string]*DeclField{
"name": {Name: "name", Type: StringType},
"priority": {Name: "priority", Type: IntType},
"message": {Name: "message", Type: StringType, defaultValue: "<eom>"},
})
name := NewField(1, "name")
name.Ref = testValue(t, 2, "alert")
priority := NewField(3, "priority")
priority.Ref = testValue(t, 4, int64(4))
message := NewField(5, "message")
message.Ref = testValue(t, 6, "call immediately")
mv1 := NewMapValue()
mv1.AddField(name)
mv1.AddField(priority)
obj1 := mv1.ConvertToObject(objType)
if obj1.Equal(obj1) != types.True {
t.Errorf("obj1.Equal(obj1) failed, got: %v", obj1.Equal(obj1))
}
mv2 := NewMapValue()
mv2.AddField(name)
mv2.AddField(priority)
mv2.AddField(message)
obj2 := mv2.ConvertToObject(objType)
if obj1.Equal(obj2) == types.True {
t.Error("obj1.Equal(obj2) returned true, wanted false")
}
if obj2.Equal(obj1) == types.True {
t.Error("obj2.Equal(obj1) returned true, wanted false")
}
}
type tstStruct struct {
Test uint64
Check uint64
}

View File

@ -366,7 +366,7 @@ func NewConfig(codecs serializer.CodecFactory) *Config {
// A request body might be encoded in json, and is converted to
// proto when persisted in etcd, so we allow 2x as the largest request
// body size to be accepted and decoded in a write request.
// If this constant is changed, maxRequestSizeBytes in apiextensions-apiserver/pkg/apiserver/schema/cel/model/schemas.go
// If this constant is changed, DefaultMaxRequestSizeBytes in k8s.io/apiserver/pkg/cel/limits.go
// should be changed to reflect the new value, if the two haven't
// been wired together already somehow.
MaxRequestBodyBytes: int64(3 * 1024 * 1024),