From 90b7a4b0963091a6ca4d991735c667ad5ddbbd54 Mon Sep 17 00:00:00 2001 From: yingjinhui Date: Mon, 14 Nov 2022 19:56:05 +0800 Subject: [PATCH] improve decodeValue for Lua Signed-off-by: yingjinhui --- .../configurableinterpreter/luavm/lua.go | 129 +++++++++---- .../configurableinterpreter/luavm/lua_test.go | 175 +++++++++++++++++- 2 files changed, 263 insertions(+), 41 deletions(-) diff --git a/pkg/resourceinterpreter/configurableinterpreter/luavm/lua.go b/pkg/resourceinterpreter/configurableinterpreter/luavm/lua.go index 1d38e2c57..d2f09a085 100644 --- a/pkg/resourceinterpreter/configurableinterpreter/luavm/lua.go +++ b/pkg/resourceinterpreter/configurableinterpreter/luavm/lua.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "reflect" "time" lua "github.com/yuin/gopher-lua" @@ -50,7 +51,10 @@ func (vm VM) GetReplicas(obj *unstructured.Unstructured, script string) (replica } args := make([]lua.LValue, 1) - args[0] = decodeValue(l, obj.Object) + args[0], err = decodeValue(l, obj.Object) + if err != nil { + return + } err = l.CallByParam(lua.P{Fn: f, NRet: 2, Protect: true}, args...) if err != nil { return 0, nil, err @@ -108,8 +112,14 @@ func (vm VM) ReviseReplica(object *unstructured.Unstructured, replica int64, scr } args := make([]lua.LValue, 2) - args[0] = decodeValue(l, object.Object) - args[1] = decodeValue(l, replica) + args[0], err = decodeValue(l, object.Object) + if err != nil { + return nil, err + } + args[1], err = decodeValue(l, replica) + if err != nil { + return nil, err + } err = l.CallByParam(lua.P{Fn: reviseReplicaLuaFunc, NRet: 1, Protect: true}, args...) if err != nil { return nil, err @@ -178,8 +188,14 @@ func (vm VM) Retain(desired *unstructured.Unstructured, observed *unstructured.U } args := make([]lua.LValue, 2) - args[0] = decodeValue(l, desired.Object) - args[1] = decodeValue(l, observed.Object) + args[0], err = decodeValue(l, desired.Object) + if err != nil { + return + } + args[1], err = decodeValue(l, observed.Object) + if err != nil { + return + } err = l.CallByParam(lua.P{Fn: retainLuaFunc, NRet: 1, Protect: true}, args...) if err != nil { return nil, err @@ -225,8 +241,14 @@ func (vm VM) AggregateStatus(object *unstructured.Unstructured, item []map[strin return nil, fmt.Errorf("can't get function AggregateStatus pleace check the function ") } args := make([]lua.LValue, 2) - args[0] = decodeValue(l, object.Object) - args[1] = decodeValue(l, item) + args[0], err = decodeValue(l, object.Object) + if err != nil { + return nil, err + } + args[1], err = decodeValue(l, item) + if err != nil { + return nil, err + } err = l.CallByParam(lua.P{Fn: f, NRet: 1, Protect: true}, args...) if err != nil { return nil, err @@ -272,7 +294,10 @@ func (vm VM) InterpretHealth(object *unstructured.Unstructured, script string) ( } args := make([]lua.LValue, 1) - args[0] = decodeValue(l, object.Object) + args[0], err = decodeValue(l, object.Object) + if err != nil { + return false, err + } err = l.CallByParam(lua.P{Fn: f, NRet: 1, Protect: true}, args...) if err != nil { return false, err @@ -315,7 +340,10 @@ func (vm VM) ReflectStatus(object *unstructured.Unstructured, script string) (st } args := make([]lua.LValue, 1) - args[0] = decodeValue(l, object.Object) + args[0], err = decodeValue(l, object.Object) + if err != nil { + return + } err = l.CallByParam(lua.P{Fn: f, NRet: 2, Protect: true}, args...) if err != nil { return nil, err @@ -376,7 +404,10 @@ func (vm VM) GetDependencies(object *unstructured.Unstructured, script string) ( } args := make([]lua.LValue, 1) - args[0] = decodeValue(l, object.Object) + args[0], err = decodeValue(l, object.Object) + if err != nil { + return + } err = l.CallByParam(lua.P{Fn: f, NRet: 1, Protect: true}, args...) if err != nil { return nil, err @@ -398,46 +429,64 @@ func (vm VM) GetDependencies(object *unstructured.Unstructured, script string) ( return } -// Took logic from the link below and added the int, int32, and int64 types since the value would have type int64 -// while actually running in the controller and it was not reproducible through testing. -// https://github.com/layeh/gopher-json/blob/97fed8db84274c421dbfffbb28ec859901556b97/json.go#L154 -func decodeValue(L *lua.LState, value interface{}) lua.LValue { +// nolint:gocyclo +func decodeValue(L *lua.LState, value interface{}) (lua.LValue, error) { + // We handle simple type without json for better performance. switch converted := value.(type) { - case bool: - return lua.LBool(converted) - case float64: - return lua.LNumber(converted) - case string: - return lua.LString(converted) - case json.Number: - return lua.LString(converted) - case int: - return lua.LNumber(converted) - case int32: - return lua.LNumber(converted) - case int64: - return lua.LNumber(converted) case []interface{}: arr := L.CreateTable(len(converted), 0) for _, item := range converted { - arr.Append(decodeValue(L, item)) + v, err := decodeValue(L, item) + if err != nil { + return nil, err + } + arr.Append(v) } - return arr - case []map[string]interface{}: - arr := L.CreateTable(len(converted), 0) - for _, item := range converted { - arr.Append(decodeValue(L, item)) - } - return arr + return arr, nil case map[string]interface{}: tbl := L.CreateTable(0, len(converted)) for key, item := range converted { - tbl.RawSetH(lua.LString(key), decodeValue(L, item)) + v, err := decodeValue(L, item) + if err != nil { + return nil, err + } + tbl.RawSetString(key, v) } - return tbl + return tbl, nil case nil: - return lua.LNil + return lua.LNil, nil } - return lua.LNil + v := reflect.ValueOf(value) + switch { + case v.CanInt(): + return lua.LNumber(v.Int()), nil + case v.CanUint(): + return lua.LNumber(v.Uint()), nil + case v.CanFloat(): + return lua.LNumber(v.Float()), nil + } + + switch t := v.Type(); t.Kind() { + case reflect.String: + return lua.LString(v.String()), nil + case reflect.Bool: + return lua.LBool(v.Bool()), nil + case reflect.Pointer: + if v.IsNil() { + return lua.LNil, nil + } + } + + // Other types can't be handled, ask for help from json + data, err := json.Marshal(value) + if err != nil { + return nil, fmt.Errorf("json Marshal obj %#v error: %v", value, err) + } + + lv, err := luajson.Decode(L, data) + if err != nil { + return nil, fmt.Errorf("lua Decode obj %#v error: %v", value, err) + } + return lv, nil } diff --git a/pkg/resourceinterpreter/configurableinterpreter/luavm/lua_test.go b/pkg/resourceinterpreter/configurableinterpreter/luavm/lua_test.go index ade2458a5..54a2e86ff 100644 --- a/pkg/resourceinterpreter/configurableinterpreter/luavm/lua_test.go +++ b/pkg/resourceinterpreter/configurableinterpreter/luavm/lua_test.go @@ -5,12 +5,14 @@ import ( "reflect" "testing" + lua "github.com/yuin/gopher-lua" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/runtime" "k8s.io/klog/v2" + "k8s.io/utils/pointer" workv1alpha2 "github.com/karmada-io/karmada/pkg/apis/work/v1alpha2" "github.com/karmada-io/karmada/pkg/util/helper" @@ -18,7 +20,7 @@ import ( func TestGetReplicas(t *testing.T) { var replicas int32 = 1 - //quantity := *resource.NewQuantity(1000, resource.BinarySI) + // quantity := *resource.NewQuantity(1000, resource.BinarySI) vm := VM{UseOpenLibs: false} tests := []struct { name string @@ -474,3 +476,174 @@ func TestGetDeployPodDependencies(t *testing.T) { t.Logf("res %v", res) } } + +func Test_decodeValue(t *testing.T) { + L := lua.NewState() + + type args struct { + value interface{} + } + tests := []struct { + name string + args args + want lua.LValue + wantErr bool + }{ + { + name: "nil", + args: args{ + value: nil, + }, + want: lua.LNil, + }, + { + name: "nil pointer", + args: args{ + value: (*struct{})(nil), + }, + want: lua.LNil, + }, + { + name: "int pointer", + args: args{ + value: pointer.Int(1), + }, + want: lua.LNumber(1), + }, + { + name: "int", + args: args{ + value: 1, + }, + want: lua.LNumber(1), + }, + { + name: "uint", + args: args{ + value: uint(1), + }, + want: lua.LNumber(1), + }, + { + name: "float", + args: args{ + value: 1.0, + }, + want: lua.LNumber(1), + }, + { + name: "bool", + args: args{ + value: true, + }, + want: lua.LBool(true), + }, + { + name: "string", + args: args{ + value: "foo", + }, + want: lua.LString("foo"), + }, + { + name: "json number", + args: args{ + value: json.Number("1"), + }, + want: lua.LString("1"), + }, + { + name: "slice", + args: args{ + value: []string{"foo", "bar"}, + }, + want: func() lua.LValue { + v := L.CreateTable(2, 0) + v.Append(lua.LString("foo")) + v.Append(lua.LString("bar")) + return v + }(), + }, + { + name: "slice pointer", + args: args{ + value: &[]string{"foo", "bar"}, + }, + want: func() lua.LValue { + v := L.CreateTable(2, 0) + v.Append(lua.LString("foo")) + v.Append(lua.LString("bar")) + return v + }(), + }, + { + name: "struct", + args: args{ + value: struct { + Foo string + }{ + Foo: "foo", + }, + }, + want: func() lua.LValue { + v := L.CreateTable(0, 1) + v.RawSetString("Foo", lua.LString("foo")) + return v + }(), + }, + { + name: "struct pointer", + args: args{ + value: &struct { + Foo string + }{ + Foo: "foo", + }, + }, + want: func() lua.LValue { + v := L.CreateTable(0, 1) + v.RawSetString("Foo", lua.LString("foo")) + return v + }(), + }, + { + name: "[]interface{}", + args: args{ + value: []interface{}{1, 2}, + }, + want: func() lua.LValue { + v := L.CreateTable(2, 0) + v.Append(lua.LNumber(1)) + v.Append(lua.LNumber(2)) + return v + }(), + }, + { + name: "map[string]interface{}", + args: args{ + value: map[string]interface{}{ + "foo": "foo1", + "bar": "bar1", + }, + }, + want: func() lua.LValue { + v := L.CreateTable(0, 2) + v.RawSetString("foo", lua.LString("foo1")) + v.RawSetString("bar", lua.LString("bar1")) + return v + }(), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := decodeValue(L, tt.args.value) + if (err != nil) != tt.wantErr { + t.Errorf("decodeValue() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("decodeValue() got = %v, want %v", got, tt.want) + } + }) + } +}