feat(sdk): add support for lists of artifacts in Python components [support lists of artifacts pt. 3] (#8465)
* remove unused codepath * rename to avoid confusing variable names * implement support in executor * add comment * make executor code safer
This commit is contained in:
parent
750f05cb7d
commit
08b2408d3d
|
@ -25,31 +25,43 @@ class Executor():
|
|||
"""Executor executes v2-based Python function components."""
|
||||
|
||||
def __init__(self, executor_input: Dict, function_to_execute: Callable):
|
||||
if hasattr(function_to_execute, 'python_func'):
|
||||
self._func = function_to_execute.python_func
|
||||
else:
|
||||
self._func = function_to_execute
|
||||
|
||||
self._input = executor_input
|
||||
self._input_artifacts: Dict[str, artifact_types.Artifact] = {}
|
||||
self._input_artifacts: Dict[str,
|
||||
Union[artifact_types.Artifact,
|
||||
List[artifact_types.Artifact]]] = {}
|
||||
self._output_artifacts: Dict[str, artifact_types.Artifact] = {}
|
||||
|
||||
for name, artifacts in self._input.get('inputs',
|
||||
{}).get('artifacts', {}).items():
|
||||
artifacts_list = artifacts.get('artifacts')
|
||||
if artifacts_list:
|
||||
list_of_artifact_proto_structs = artifacts.get('artifacts')
|
||||
if list_of_artifact_proto_structs:
|
||||
annotation = self._func.__annotations__[name]
|
||||
# InputPath has no attribute __origin__ and also should be handled as a single artifact
|
||||
if type_annotations.is_Input_Output_artifact_annotation(
|
||||
annotation) and type_annotations.is_list_of_artifacts(
|
||||
annotation.__origin__):
|
||||
self._input_artifacts[name] = [
|
||||
self.make_artifact(
|
||||
msg,
|
||||
name,
|
||||
self._func,
|
||||
) for msg in list_of_artifact_proto_structs
|
||||
]
|
||||
else:
|
||||
self._input_artifacts[name] = self.make_artifact(
|
||||
artifacts_list[0],
|
||||
list_of_artifact_proto_structs[0],
|
||||
name,
|
||||
self._func,
|
||||
)
|
||||
|
||||
for name, artifacts in self._input.get('outputs',
|
||||
{}).get('artifacts', {}).items():
|
||||
artifacts_list = artifacts.get('artifacts')
|
||||
if artifacts_list:
|
||||
list_of_artifact_proto_structs = artifacts.get('artifacts')
|
||||
if list_of_artifact_proto_structs:
|
||||
output_artifact = self.make_artifact(
|
||||
artifacts_list[0],
|
||||
list_of_artifact_proto_structs[0],
|
||||
name,
|
||||
self._func,
|
||||
)
|
||||
|
|
|
@ -1078,6 +1078,96 @@ class ExecutorTest(unittest.TestCase):
|
|||
os.path.exists(
|
||||
os.path.join(self._test_dir, 'output_metadata.json')))
|
||||
|
||||
def test_single_artifact_input(self):
|
||||
executor_input = """\
|
||||
{
|
||||
"inputs": {
|
||||
"artifacts": {
|
||||
"input_artifact": {
|
||||
"artifacts": [
|
||||
{
|
||||
"metadata": {},
|
||||
"name": "projects/123/locations/us-central1/metadataStores/default/artifacts/input_artifact",
|
||||
"type": {
|
||||
"schemaTitle": "system.Artifact"
|
||||
},
|
||||
"uri": "gs://some-bucket/output/input_artifact"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"outputFile": "%(test_dir)s/output_metadata.json"
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
def test_func(input_artifact: Input[Artifact]):
|
||||
self.assertIsInstance(input_artifact, Artifact)
|
||||
self.assertEqual(
|
||||
input_artifact.name,
|
||||
'projects/123/locations/us-central1/metadataStores/default/artifacts/input_artifact'
|
||||
)
|
||||
self.assertEqual(
|
||||
input_artifact.name,
|
||||
'projects/123/locations/us-central1/metadataStores/default/artifacts/input_artifact'
|
||||
)
|
||||
|
||||
output_metadata = self.execute_and_load_output_metadata(
|
||||
test_func, executor_input)
|
||||
|
||||
self.assertDictEqual(output_metadata, {})
|
||||
|
||||
def test_list_of_artifacts_input(self):
|
||||
executor_input = """\
|
||||
{
|
||||
"inputs": {
|
||||
"artifacts": {
|
||||
"input_list": {
|
||||
"artifacts": [
|
||||
{
|
||||
"metadata": {},
|
||||
"name": "projects/123/locations/us-central1/metadataStores/default/artifacts/input_list/0",
|
||||
"type": {
|
||||
"schemaTitle": "system.Artifact"
|
||||
},
|
||||
"uri": "gs://some-bucket/output/input_list/0"
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"name": "projects/123/locations/us-central1/metadataStores/default/artifacts/input_list/1",
|
||||
"type": {
|
||||
"schemaTitle": "system.Artifact"
|
||||
},
|
||||
"uri": "gs://some-bucket/output/input_list/1"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"outputFile": "%(test_dir)s/output_metadata.json"
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
def test_func(input_list: Input[List[Artifact]]):
|
||||
self.assertEqual(len(input_list), 2)
|
||||
self.assertEqual(
|
||||
input_list[0].name,
|
||||
'projects/123/locations/us-central1/metadataStores/default/artifacts/input_list/0'
|
||||
)
|
||||
self.assertEqual(
|
||||
input_list[1].name,
|
||||
'projects/123/locations/us-central1/metadataStores/default/artifacts/input_list/1'
|
||||
)
|
||||
|
||||
output_metadata = self.execute_and_load_output_metadata(
|
||||
test_func, executor_input)
|
||||
|
||||
self.assertDictEqual(output_metadata, {})
|
||||
|
||||
|
||||
class VertexDataset:
|
||||
schema_title = 'google.VertexDataset'
|
||||
|
|
Loading…
Reference in New Issue