From e907b6343ed8453c1d060182a56ce524752aecf6 Mon Sep 17 00:00:00 2001 From: Connor McCarthy Date: Mon, 10 Apr 2023 12:01:36 -0700 Subject: [PATCH] set _ir_type attribute (#9105) --- .../kfp/components/component_factory.py | 5 +- .../container_component_decorator_test.py | 55 +++++++++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/sdk/python/kfp/components/component_factory.py b/sdk/python/kfp/components/component_factory.py index b914898bd2..bf7ed2763f 100644 --- a/sdk/python/kfp/components/component_factory.py +++ b/sdk/python/kfp/components/component_factory.py @@ -527,7 +527,10 @@ def make_input_for_parameterized_container_component_function( return placeholders.OutputParameterPlaceholder(name) else: - return placeholders.InputValuePlaceholder(name) + placeholder = placeholders.InputValuePlaceholder(name) + # small hack to encode the runtime value's type for a custom json.dumps function + placeholder._ir_type = type_utils.get_parameter_type_name(annotation) + return placeholder def create_container_component_from_func( diff --git a/sdk/python/kfp/components/container_component_decorator_test.py b/sdk/python/kfp/components/container_component_decorator_test.py index 12f8c8a722..5aafa73a8c 100644 --- a/sdk/python/kfp/components/container_component_decorator_test.py +++ b/sdk/python/kfp/components/container_component_decorator_test.py @@ -12,10 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, List import unittest from kfp import dsl from kfp.components import container_component +from kfp.dsl import Artifact +from kfp.dsl import Input +from kfp.dsl import Output class TestContainerComponentDecorator(unittest.TestCase): @@ -75,3 +79,54 @@ class TestContainerComponentDecorator(unittest.TestCase): self.assertIsInstance(container_comp_with_artifacts, container_component.ContainerComponent) + + +class TestInputValuePlaceholderIrTypeHack(unittest.TestCase): + + def test(self): + + @dsl.container_component + def comp( + in_artifact: Input[Artifact], + out_artifact: Output[Artifact], + string: str = 'hello', + integer: int = 1, + floating_pt: float = 0.1, + boolean: bool = True, + dictionary: Dict = {'key': 'value'}, + array: List = [1, 2, 3], + hlist: List = [ + { + 'k': 'v' + }, + 1, + ['a'], + 'a', + ], + ): + self.assertEqual(string._ir_type, 'STRING') + self.assertEqual(integer._ir_type, 'NUMBER_INTEGER') + self.assertEqual(floating_pt._ir_type, 'NUMBER_DOUBLE') + self.assertEqual(boolean._ir_type, 'BOOLEAN') + self.assertEqual(dictionary._ir_type, 'STRUCT') + self.assertEqual(array._ir_type, 'LIST') + self.assertEqual(hlist._ir_type, 'LIST') + self.assertFalse(hasattr(in_artifact, '_ir_type')) + self.assertFalse(hasattr(out_artifact, '_ir_type')) + return dsl.ContainerSpec( + image='alpine', + command=[ + 'echo', + ], + args=[ + string, + integer, + floating_pt, + boolean, + dictionary, + array, + hlist, + in_artifact.path, + out_artifact.path, + ], + )