feat(sdk): add support for lists of artifacts in Python components [support lists of artifacts pt. 3] (#8465)

* remove unused codepath

* rename to avoid confusing variable names

* implement support in executor

* add comment

* make executor code safer
This commit is contained in:
Connor McCarthy 2022-12-15 13:34:21 -08:00 committed by GitHub
parent 750f05cb7d
commit 08b2408d3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 117 additions and 15 deletions

View File

@ -25,31 +25,43 @@ class Executor():
"""Executor executes v2-based Python function components."""
def __init__(self, executor_input: Dict, function_to_execute: Callable):
if hasattr(function_to_execute, 'python_func'):
self._func = function_to_execute.python_func
else:
self._func = function_to_execute
self._input = executor_input
self._input_artifacts: Dict[str, artifact_types.Artifact] = {}
self._input_artifacts: Dict[str,
Union[artifact_types.Artifact,
List[artifact_types.Artifact]]] = {}
self._output_artifacts: Dict[str, artifact_types.Artifact] = {}
for name, artifacts in self._input.get('inputs',
{}).get('artifacts', {}).items():
artifacts_list = artifacts.get('artifacts')
if artifacts_list:
list_of_artifact_proto_structs = artifacts.get('artifacts')
if list_of_artifact_proto_structs:
annotation = self._func.__annotations__[name]
# InputPath has no attribute __origin__ and also should be handled as a single artifact
if type_annotations.is_Input_Output_artifact_annotation(
annotation) and type_annotations.is_list_of_artifacts(
annotation.__origin__):
self._input_artifacts[name] = [
self.make_artifact(
msg,
name,
self._func,
) for msg in list_of_artifact_proto_structs
]
else:
self._input_artifacts[name] = self.make_artifact(
artifacts_list[0],
list_of_artifact_proto_structs[0],
name,
self._func,
)
for name, artifacts in self._input.get('outputs',
{}).get('artifacts', {}).items():
artifacts_list = artifacts.get('artifacts')
if artifacts_list:
list_of_artifact_proto_structs = artifacts.get('artifacts')
if list_of_artifact_proto_structs:
output_artifact = self.make_artifact(
artifacts_list[0],
list_of_artifact_proto_structs[0],
name,
self._func,
)

View File

@ -1078,6 +1078,96 @@ class ExecutorTest(unittest.TestCase):
os.path.exists(
os.path.join(self._test_dir, 'output_metadata.json')))
def test_single_artifact_input(self):
executor_input = """\
{
"inputs": {
"artifacts": {
"input_artifact": {
"artifacts": [
{
"metadata": {},
"name": "projects/123/locations/us-central1/metadataStores/default/artifacts/input_artifact",
"type": {
"schemaTitle": "system.Artifact"
},
"uri": "gs://some-bucket/output/input_artifact"
}
]
}
}
},
"outputs": {
"outputFile": "%(test_dir)s/output_metadata.json"
}
}
"""
def test_func(input_artifact: Input[Artifact]):
self.assertIsInstance(input_artifact, Artifact)
self.assertEqual(
input_artifact.name,
'projects/123/locations/us-central1/metadataStores/default/artifacts/input_artifact'
)
self.assertEqual(
input_artifact.name,
'projects/123/locations/us-central1/metadataStores/default/artifacts/input_artifact'
)
output_metadata = self.execute_and_load_output_metadata(
test_func, executor_input)
self.assertDictEqual(output_metadata, {})
def test_list_of_artifacts_input(self):
executor_input = """\
{
"inputs": {
"artifacts": {
"input_list": {
"artifacts": [
{
"metadata": {},
"name": "projects/123/locations/us-central1/metadataStores/default/artifacts/input_list/0",
"type": {
"schemaTitle": "system.Artifact"
},
"uri": "gs://some-bucket/output/input_list/0"
},
{
"metadata": {},
"name": "projects/123/locations/us-central1/metadataStores/default/artifacts/input_list/1",
"type": {
"schemaTitle": "system.Artifact"
},
"uri": "gs://some-bucket/output/input_list/1"
}
]
}
}
},
"outputs": {
"outputFile": "%(test_dir)s/output_metadata.json"
}
}
"""
def test_func(input_list: Input[List[Artifact]]):
self.assertEqual(len(input_list), 2)
self.assertEqual(
input_list[0].name,
'projects/123/locations/us-central1/metadataStores/default/artifacts/input_list/0'
)
self.assertEqual(
input_list[1].name,
'projects/123/locations/us-central1/metadataStores/default/artifacts/input_list/1'
)
output_metadata = self.execute_and_load_output_metadata(
test_func, executor_input)
self.assertDictEqual(output_metadata, {})
class VertexDataset:
schema_title = 'google.VertexDataset'