From 08b2408d3da7912d78f64fb9a78bc8416586fce0 Mon Sep 17 00:00:00 2001 From: Connor McCarthy Date: Thu, 15 Dec 2022 13:34:21 -0800 Subject: [PATCH] feat(sdk): add support for lists of artifacts in Python components [support lists of artifacts pt. 3] (#8465) * remove unused codepath * rename to avoid confusing variable names * implement support in executor * add comment * make executor code safer --- sdk/python/kfp/components/executor.py | 42 ++++++---- sdk/python/kfp/components/executor_test.py | 90 ++++++++++++++++++++++ 2 files changed, 117 insertions(+), 15 deletions(-) diff --git a/sdk/python/kfp/components/executor.py b/sdk/python/kfp/components/executor.py index d472dcbea7..3736c98118 100644 --- a/sdk/python/kfp/components/executor.py +++ b/sdk/python/kfp/components/executor.py @@ -25,31 +25,43 @@ class Executor(): """Executor executes v2-based Python function components.""" def __init__(self, executor_input: Dict, function_to_execute: Callable): - if hasattr(function_to_execute, 'python_func'): - self._func = function_to_execute.python_func - else: - self._func = function_to_execute + self._func = function_to_execute self._input = executor_input - self._input_artifacts: Dict[str, artifact_types.Artifact] = {} + self._input_artifacts: Dict[str, + Union[artifact_types.Artifact, + List[artifact_types.Artifact]]] = {} self._output_artifacts: Dict[str, artifact_types.Artifact] = {} for name, artifacts in self._input.get('inputs', {}).get('artifacts', {}).items(): - artifacts_list = artifacts.get('artifacts') - if artifacts_list: - self._input_artifacts[name] = self.make_artifact( - artifacts_list[0], - name, - self._func, - ) + list_of_artifact_proto_structs = artifacts.get('artifacts') + if list_of_artifact_proto_structs: + annotation = self._func.__annotations__[name] + # InputPath has no attribute __origin__ and also should be handled as a single artifact + if type_annotations.is_Input_Output_artifact_annotation( + annotation) and type_annotations.is_list_of_artifacts( + annotation.__origin__): + self._input_artifacts[name] = [ + self.make_artifact( + msg, + name, + self._func, + ) for msg in list_of_artifact_proto_structs + ] + else: + self._input_artifacts[name] = self.make_artifact( + list_of_artifact_proto_structs[0], + name, + self._func, + ) for name, artifacts in self._input.get('outputs', {}).get('artifacts', {}).items(): - artifacts_list = artifacts.get('artifacts') - if artifacts_list: + list_of_artifact_proto_structs = artifacts.get('artifacts') + if list_of_artifact_proto_structs: output_artifact = self.make_artifact( - artifacts_list[0], + list_of_artifact_proto_structs[0], name, self._func, ) diff --git a/sdk/python/kfp/components/executor_test.py b/sdk/python/kfp/components/executor_test.py index 835d18d56c..3762644cf9 100644 --- a/sdk/python/kfp/components/executor_test.py +++ b/sdk/python/kfp/components/executor_test.py @@ -1078,6 +1078,96 @@ class ExecutorTest(unittest.TestCase): os.path.exists( os.path.join(self._test_dir, 'output_metadata.json'))) + def test_single_artifact_input(self): + executor_input = """\ + { + "inputs": { + "artifacts": { + "input_artifact": { + "artifacts": [ + { + "metadata": {}, + "name": "projects/123/locations/us-central1/metadataStores/default/artifacts/input_artifact", + "type": { + "schemaTitle": "system.Artifact" + }, + "uri": "gs://some-bucket/output/input_artifact" + } + ] + } + } + }, + "outputs": { + "outputFile": "%(test_dir)s/output_metadata.json" + } + } + """ + + def test_func(input_artifact: Input[Artifact]): + self.assertIsInstance(input_artifact, Artifact) + self.assertEqual( + input_artifact.name, + 'projects/123/locations/us-central1/metadataStores/default/artifacts/input_artifact' + ) + self.assertEqual( + input_artifact.name, + 'projects/123/locations/us-central1/metadataStores/default/artifacts/input_artifact' + ) + + output_metadata = self.execute_and_load_output_metadata( + test_func, executor_input) + + self.assertDictEqual(output_metadata, {}) + + def test_list_of_artifacts_input(self): + executor_input = """\ + { + "inputs": { + "artifacts": { + "input_list": { + "artifacts": [ + { + "metadata": {}, + "name": "projects/123/locations/us-central1/metadataStores/default/artifacts/input_list/0", + "type": { + "schemaTitle": "system.Artifact" + }, + "uri": "gs://some-bucket/output/input_list/0" + }, + { + "metadata": {}, + "name": "projects/123/locations/us-central1/metadataStores/default/artifacts/input_list/1", + "type": { + "schemaTitle": "system.Artifact" + }, + "uri": "gs://some-bucket/output/input_list/1" + } + ] + } + } + }, + "outputs": { + "outputFile": "%(test_dir)s/output_metadata.json" + } + } + """ + + def test_func(input_list: Input[List[Artifact]]): + self.assertEqual(len(input_list), 2) + self.assertEqual( + input_list[0].name, + 'projects/123/locations/us-central1/metadataStores/default/artifacts/input_list/0' + ) + self.assertEqual( + input_list[1].name, + 'projects/123/locations/us-central1/metadataStores/default/artifacts/input_list/1' + ) + + output_metadata = self.execute_and_load_output_metadata( + test_func, executor_input) + + self.assertDictEqual(output_metadata, {}) + class VertexDataset: schema_title = 'google.VertexDataset'