set _ir_type attribute (#9105)

This commit is contained in:
Connor McCarthy 2023-04-10 12:01:36 -07:00 committed by GitHub
parent 3428ff9a5f
commit e907b6343e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 59 additions and 1 deletions

View File

@ -527,7 +527,10 @@ def make_input_for_parameterized_container_component_function(
return placeholders.OutputParameterPlaceholder(name)
else:
return placeholders.InputValuePlaceholder(name)
placeholder = placeholders.InputValuePlaceholder(name)
# small hack to encode the runtime value's type for a custom json.dumps function
placeholder._ir_type = type_utils.get_parameter_type_name(annotation)
return placeholder
def create_container_component_from_func(

View File

@ -12,10 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List
import unittest
from kfp import dsl
from kfp.components import container_component
from kfp.dsl import Artifact
from kfp.dsl import Input
from kfp.dsl import Output
class TestContainerComponentDecorator(unittest.TestCase):
@ -75,3 +79,54 @@ class TestContainerComponentDecorator(unittest.TestCase):
self.assertIsInstance(container_comp_with_artifacts,
container_component.ContainerComponent)
class TestInputValuePlaceholderIrTypeHack(unittest.TestCase):
def test(self):
@dsl.container_component
def comp(
in_artifact: Input[Artifact],
out_artifact: Output[Artifact],
string: str = 'hello',
integer: int = 1,
floating_pt: float = 0.1,
boolean: bool = True,
dictionary: Dict = {'key': 'value'},
array: List = [1, 2, 3],
hlist: List = [
{
'k': 'v'
},
1,
['a'],
'a',
],
):
self.assertEqual(string._ir_type, 'STRING')
self.assertEqual(integer._ir_type, 'NUMBER_INTEGER')
self.assertEqual(floating_pt._ir_type, 'NUMBER_DOUBLE')
self.assertEqual(boolean._ir_type, 'BOOLEAN')
self.assertEqual(dictionary._ir_type, 'STRUCT')
self.assertEqual(array._ir_type, 'LIST')
self.assertEqual(hlist._ir_type, 'LIST')
self.assertFalse(hasattr(in_artifact, '_ir_type'))
self.assertFalse(hasattr(out_artifact, '_ir_type'))
return dsl.ContainerSpec(
image='alpine',
command=[
'echo',
],
args=[
string,
integer,
floating_pt,
boolean,
dictionary,
array,
hlist,
in_artifact.path,
out_artifact.path,
],
)