feat(sdk): support list of artifacts annotations [list of artifacts support pt. 1] (#8464)
* update get_io_artifact_class function * update various type annotation tests * support list of artifacts in component interfaces * move helper function * clarify function name * respond to review feedback
This commit is contained in:
parent
249e7af384
commit
e7c82c0593
|
@ -164,20 +164,23 @@ def extract_component_interface(
|
|||
passing_style = None
|
||||
io_name = parameter.name
|
||||
|
||||
if type_annotations.is_artifact_annotation(parameter_type):
|
||||
if type_annotations.is_Input_Output_artifact_annotation(parameter_type):
|
||||
# passing_style is either type_annotations.InputAnnotation or
|
||||
# type_annotations.OutputAnnotation.
|
||||
passing_style = type_annotations.get_io_artifact_annotation(
|
||||
parameter_type)
|
||||
|
||||
# parameter_type is type_annotations.Artifact or one of its subclasses.
|
||||
# parameter_type is a type like typing_extensions.Annotated[kfp.components.types.artifact_types.Artifact, <class 'kfp.components.types.type_annotations.OutputAnnotation'>] OR typing_extensions.Annotated[typing.List[kfp.components.types.artifact_types.Artifact], <class 'kfp.components.types.type_annotations.OutputAnnotation'>]
|
||||
|
||||
is_artifact_list = type_annotations.is_list_of_artifacts(
|
||||
parameter_type.__origin__)
|
||||
|
||||
parameter_type = type_annotations.get_io_artifact_class(
|
||||
parameter_type)
|
||||
if not type_annotations.is_artifact_class(parameter_type):
|
||||
raise ValueError(
|
||||
'Input[T] and Output[T] are only supported when T is a '
|
||||
'subclass of Artifact. Found `{} with type {}`'.format(
|
||||
io_name, parameter_type))
|
||||
'Input[T] and Output[T] are only supported when T is an artifact or list of artifacts. Found `{} with type {}`'
|
||||
.format(io_name, parameter_type))
|
||||
|
||||
if parameter.default is not inspect.Parameter.empty:
|
||||
raise ValueError(
|
||||
|
@ -212,7 +215,8 @@ def extract_component_interface(
|
|||
schema_version = parameter_type.schema_version
|
||||
output_spec = structures.OutputSpec(
|
||||
type=type_utils.create_bundled_artifact_type(
|
||||
type_struct, schema_version))
|
||||
type_struct, schema_version),
|
||||
is_artifact_list=is_artifact_list)
|
||||
else:
|
||||
output_spec = structures.OutputSpec(type=type_struct)
|
||||
outputs[io_name] = output_spec
|
||||
|
@ -223,7 +227,9 @@ def extract_component_interface(
|
|||
schema_version = parameter_type.schema_version
|
||||
input_spec = structures.InputSpec(
|
||||
type=type_utils.create_bundled_artifact_type(
|
||||
type_struct, schema_version))
|
||||
type_struct, schema_version),
|
||||
is_artifact_list=is_artifact_list,
|
||||
)
|
||||
else:
|
||||
if parameter.default is not inspect.Parameter.empty:
|
||||
input_spec = structures.InputSpec(
|
||||
|
@ -248,14 +254,18 @@ def extract_component_interface(
|
|||
for field_name in return_ann._fields:
|
||||
output_name = _maybe_make_unique(field_name, output_names)
|
||||
output_names.add(output_name)
|
||||
annotation = field_annotations.get(field_name)
|
||||
if type_annotations.is_artifact_class(annotation):
|
||||
type_var = field_annotations.get(field_name)
|
||||
if type_annotations.is_list_of_artifacts(type_var):
|
||||
raise ValueError(
|
||||
f'Cannot use output lists of artifacts in NamedTuple return annotations. Got output list of artifacts annotation for NamedTuple field `{field_name}`.'
|
||||
)
|
||||
elif type_annotations.is_artifact_class(type_var):
|
||||
output_spec = structures.OutputSpec(
|
||||
type=type_utils.create_bundled_artifact_type(
|
||||
annotation.schema_title, annotation.schema_version))
|
||||
type_var.schema_title, type_var.schema_version))
|
||||
else:
|
||||
type_struct = type_utils._annotation_to_type_struct(
|
||||
annotation)
|
||||
type_var)
|
||||
output_spec = structures.OutputSpec(type=type_struct)
|
||||
outputs[output_name] = output_spec
|
||||
# Deprecated dict-based way of declaring multiple outputs. Was only used by
|
||||
|
@ -278,10 +288,17 @@ def extract_component_interface(
|
|||
# `def func(output_path: OutputPath()) -> str: ...`
|
||||
output_names.add(output_name)
|
||||
return_ann = signature.return_annotation
|
||||
if type_annotations.is_artifact_class(signature.return_annotation):
|
||||
if type_annotations.is_list_of_artifacts(return_ann):
|
||||
artifact_cls = return_ann.__args__[0]
|
||||
output_spec = structures.OutputSpec(
|
||||
type=type_utils.create_bundled_artifact_type(
|
||||
return_ann.schema_title, return_ann.schema_version))
|
||||
artifact_cls.schema_title, artifact_cls.schema_version),
|
||||
is_artifact_list=True)
|
||||
elif type_annotations.is_artifact_class(return_ann):
|
||||
output_spec = structures.OutputSpec(
|
||||
type=type_utils.create_bundled_artifact_type(
|
||||
return_ann.schema_title, return_ann.schema_version),
|
||||
is_artifact_list=False)
|
||||
else:
|
||||
type_struct = type_utils._annotation_to_type_struct(return_ann)
|
||||
output_spec = structures.OutputSpec(type=type_struct)
|
||||
|
|
|
@ -12,9 +12,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
import unittest
|
||||
|
||||
from kfp import dsl
|
||||
from kfp.components import component_factory
|
||||
from kfp.components import structures
|
||||
from kfp.components.component_decorator import component
|
||||
from kfp.components.types.type_annotations import OutputPath
|
||||
|
||||
|
@ -69,5 +72,110 @@ class TestInvalidParameterName(unittest.TestCase):
|
|||
pass
|
||||
|
||||
|
||||
from kfp.components.types.artifact_types import Artifact
|
||||
from kfp.components.types.artifact_types import Model
|
||||
from kfp.dsl import Input
|
||||
from kfp.dsl import Output
|
||||
|
||||
|
||||
class TestExtractComponentInterfaceListofArtifacts(unittest.TestCase):
|
||||
|
||||
def test_python_component_input(self):
|
||||
|
||||
def comp(i: Input[List[Model]]):
|
||||
...
|
||||
|
||||
component_spec = component_factory.extract_component_interface(comp)
|
||||
self.assertEqual(component_spec.name, 'comp')
|
||||
self.assertEqual(component_spec.description, None)
|
||||
self.assertEqual(
|
||||
component_spec.inputs, {
|
||||
'i':
|
||||
structures.InputSpec(
|
||||
type='system.Model@0.0.1',
|
||||
default=None,
|
||||
is_artifact_list=True)
|
||||
})
|
||||
|
||||
def test_custom_container_component_input(self):
|
||||
|
||||
def comp(i: Input[List[Artifact]]):
|
||||
...
|
||||
|
||||
component_spec = component_factory.extract_component_interface(
|
||||
comp, containerized=True)
|
||||
self.assertEqual(component_spec.name, 'comp')
|
||||
self.assertEqual(component_spec.description, None)
|
||||
self.assertEqual(
|
||||
component_spec.inputs, {
|
||||
'i':
|
||||
structures.InputSpec(
|
||||
type='system.Artifact@0.0.1',
|
||||
default=None,
|
||||
is_artifact_list=True)
|
||||
})
|
||||
|
||||
def test_pipeline_input(self):
|
||||
|
||||
def comp(i: Input[List[Model]]):
|
||||
...
|
||||
|
||||
component_spec = component_factory.extract_component_interface(comp)
|
||||
self.assertEqual(component_spec.name, 'comp')
|
||||
self.assertEqual(component_spec.description, None)
|
||||
self.assertEqual(
|
||||
component_spec.inputs, {
|
||||
'i':
|
||||
structures.InputSpec(
|
||||
type='system.Model@0.0.1',
|
||||
default=None,
|
||||
is_artifact_list=True)
|
||||
})
|
||||
|
||||
def test_pipeline_with_named_tuple_fn(self):
|
||||
from typing import NamedTuple
|
||||
|
||||
def comp(
|
||||
i: Input[List[Model]]
|
||||
) -> NamedTuple('outputs', [('output_list', List[Artifact])]):
|
||||
...
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r'Cannot use output lists of artifacts in NamedTuple return annotations. Got output list of artifacts annotation for NamedTuple field `output_list`\.'
|
||||
):
|
||||
component_factory.extract_component_interface(comp)
|
||||
|
||||
|
||||
class TestOutputListsOfArtifactsTemporarilyBlocked(unittest.TestCase):
|
||||
|
||||
def test_python_component(self):
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
r'Output lists of artifacts are not yet supported\.'):
|
||||
|
||||
@dsl.component
|
||||
def comp(output_list: Output[List[Artifact]]):
|
||||
...
|
||||
|
||||
def test_container_component(self):
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
r'Output lists of artifacts are not yet supported\.'):
|
||||
|
||||
@dsl.container_component
|
||||
def comp(output_list: Output[List[Artifact]]):
|
||||
return dsl.ContainerSpec(image='alpine')
|
||||
|
||||
def test_pipeline(self):
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
r'Output lists of artifacts are not yet supported\.'):
|
||||
|
||||
@dsl.pipeline
|
||||
def comp() -> List[Artifact]:
|
||||
...
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -302,7 +302,7 @@ class Executor():
|
|||
if value is not None:
|
||||
func_kwargs[k] = value
|
||||
|
||||
elif type_annotations.is_artifact_annotation(v):
|
||||
elif type_annotations.is_Input_Output_artifact_annotation(v):
|
||||
if type_annotations.is_input_artifact(v):
|
||||
func_kwargs[k] = self._get_input_artifact(k)
|
||||
if type_annotations.is_output_artifact(v):
|
||||
|
|
|
@ -44,10 +44,13 @@ class InputSpec:
|
|||
type: The type of the input.
|
||||
default (optional): the default value for the input.
|
||||
optional: Wether the input is optional. An input is optional when it has an explicit default value.
|
||||
is_artifact_list: True if `type` represents a list of the artifact type. Only applies when `type` is an artifact.
|
||||
"""
|
||||
type: Union[str, dict]
|
||||
default: Optional[Any] = None
|
||||
optional: bool = False
|
||||
# This special flag for lists of artifacts allows type to be used the same way for list of artifacts and single artifacts. This is aligned with how IR represents lists of artifacts (same as for single artifacts), as well as simplifies downstream type handling/checking operations in the SDK since we don't need to parse the string `type` to determine if single artifact or list.
|
||||
is_artifact_list: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._validate_type()
|
||||
|
@ -132,11 +135,16 @@ class OutputSpec:
|
|||
|
||||
Attributes:
|
||||
type: The type of the output.
|
||||
is_artifact_list: True if `type` represents a list of the artifact type. Only applies when `type` is an artifact.
|
||||
"""
|
||||
type: Union[str, dict]
|
||||
# This special flag for lists of artifacts allows type to be used the same way for list of artifacts and single artifacts. This is aligned with how IR represents lists of artifacts (same as for single artifacts), as well as simplifies downstream type handling/checking operations in the SDK since we don't need to parse the string `type` to determine if single artifact or list.
|
||||
is_artifact_list: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._validate_type()
|
||||
# TODO: remove this method when we support output lists of artifacts
|
||||
self._prevent_using_output_lists_of_artifacts()
|
||||
|
||||
@classmethod
|
||||
def from_ir_component_outputs_dict(
|
||||
|
@ -196,6 +204,11 @@ class OutputSpec:
|
|||
if not spec_type_is_parameter(self.type):
|
||||
type_utils.validate_bundled_artifact_type(self.type)
|
||||
|
||||
def _prevent_using_output_lists_of_artifacts(self):
|
||||
if self.is_artifact_list:
|
||||
raise NotImplementedError(
|
||||
'Output lists of artifacts are not yet supported.')
|
||||
|
||||
|
||||
def spec_type_is_parameter(type_: str) -> bool:
|
||||
in_memory_type = type_annotations.maybe_strip_optional_from_annotation_string(
|
||||
|
|
|
@ -49,7 +49,7 @@ def get_param_to_custom_artifact_class(func: Callable) -> Dict[str, type]:
|
|||
signature = inspect.signature(func)
|
||||
for name, param in signature.parameters.items():
|
||||
annotation = param.annotation
|
||||
if type_annotations.is_artifact_annotation(annotation):
|
||||
if type_annotations.is_Input_Output_artifact_annotation(annotation):
|
||||
artifact_class = type_annotations.get_io_artifact_class(annotation)
|
||||
if artifact_class not in kfp_artifact_classes:
|
||||
param_to_artifact_cls[name] = artifact_class
|
||||
|
|
|
@ -17,8 +17,9 @@ These are only compatible with v2 Pipelines.
|
|||
"""
|
||||
|
||||
import re
|
||||
from typing import Type, TypeVar, Union
|
||||
from typing import List, Type, TypeVar, Union
|
||||
|
||||
from kfp.components.types import artifact_types
|
||||
from kfp.components.types import type_annotations
|
||||
from kfp.components.types import type_utils
|
||||
|
||||
|
@ -120,7 +121,7 @@ class OutputAnnotation:
|
|||
"""Marker type for output artifacts."""
|
||||
|
||||
|
||||
def is_artifact_annotation(typ) -> bool:
|
||||
def is_Input_Output_artifact_annotation(typ) -> bool:
|
||||
if not hasattr(typ, '__metadata__'):
|
||||
return False
|
||||
|
||||
|
@ -132,7 +133,7 @@ def is_artifact_annotation(typ) -> bool:
|
|||
|
||||
def is_input_artifact(typ) -> bool:
|
||||
"""Returns True if typ is of type Input[T]."""
|
||||
if not is_artifact_annotation(typ):
|
||||
if not is_Input_Output_artifact_annotation(typ):
|
||||
return False
|
||||
|
||||
return typ.__metadata__[0] == InputAnnotation
|
||||
|
@ -140,7 +141,7 @@ def is_input_artifact(typ) -> bool:
|
|||
|
||||
def is_output_artifact(typ) -> bool:
|
||||
"""Returns True if typ is of type Output[T]."""
|
||||
if not is_artifact_annotation(typ):
|
||||
if not is_Input_Output_artifact_annotation(typ):
|
||||
return False
|
||||
|
||||
return typ.__metadata__[0] == OutputAnnotation
|
||||
|
@ -149,16 +150,21 @@ def is_output_artifact(typ) -> bool:
|
|||
def get_io_artifact_class(typ):
|
||||
from kfp.dsl import Input
|
||||
from kfp.dsl import Output
|
||||
if not is_artifact_annotation(typ):
|
||||
if not is_Input_Output_artifact_annotation(typ):
|
||||
return None
|
||||
if typ == Input or typ == Output:
|
||||
return None
|
||||
|
||||
return typ.__args__[0]
|
||||
# extract inner type from list of artifacts
|
||||
inner = typ.__args__[0]
|
||||
if hasattr(inner, '__origin__') and inner.__origin__ == list:
|
||||
return inner.__args__[0]
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def get_io_artifact_annotation(typ):
|
||||
if not is_artifact_annotation(typ):
|
||||
if not is_Input_Output_artifact_annotation(typ):
|
||||
return None
|
||||
|
||||
return typ.__metadata__[0]
|
||||
|
@ -223,3 +229,13 @@ def is_artifact_class(artifact_class_or_instance: Type) -> bool:
|
|||
# we do not yet support non-pre-registered custom artifact types with instance_schema attribute
|
||||
return hasattr(artifact_class_or_instance, 'schema_title') and hasattr(
|
||||
artifact_class_or_instance, 'schema_version')
|
||||
|
||||
|
||||
def is_list_of_artifacts(
|
||||
type_var: Union[Type[List[artifact_types.Artifact]],
|
||||
Type[artifact_types.Artifact]]
|
||||
) -> bool:
|
||||
# the type annotation for this function's `type_var` parameter may not actually be a subclass of the KFP SDK's Artifact class for custom artifact types
|
||||
return getattr(type_var, '__origin__',
|
||||
None) == list and type_annotations.is_artifact_class(
|
||||
type_var.__args__[0])
|
||||
|
|
|
@ -30,35 +30,63 @@ from kfp.dsl import Output
|
|||
|
||||
class AnnotationsTest(parameterized.TestCase):
|
||||
|
||||
def test_is_artifact_annotation(self):
|
||||
self.assertTrue(type_annotations.is_artifact_annotation(Input[Model]))
|
||||
self.assertTrue(type_annotations.is_artifact_annotation(Output[Model]))
|
||||
@parameterized.parameters([
|
||||
Input[Model],
|
||||
Output[Model],
|
||||
Output[List[Model]],
|
||||
Output['MyArtifact'],
|
||||
])
|
||||
def test_is_artifact_annotation(self, annotation):
|
||||
self.assertTrue(
|
||||
type_annotations.is_artifact_annotation(Output['MyArtifact']))
|
||||
type_annotations.is_Input_Output_artifact_annotation(annotation))
|
||||
|
||||
self.assertFalse(type_annotations.is_artifact_annotation(Model))
|
||||
self.assertFalse(type_annotations.is_artifact_annotation(int))
|
||||
self.assertFalse(type_annotations.is_artifact_annotation('Dataset'))
|
||||
self.assertFalse(type_annotations.is_artifact_annotation(List[str]))
|
||||
self.assertFalse(type_annotations.is_artifact_annotation(Optional[str]))
|
||||
@parameterized.parameters([
|
||||
Model,
|
||||
int,
|
||||
'Dataset',
|
||||
List[str],
|
||||
Optional[str],
|
||||
])
|
||||
def test_is_not_artifact_annotation(self, annotation):
|
||||
self.assertFalse(
|
||||
type_annotations.is_Input_Output_artifact_annotation(annotation))
|
||||
|
||||
def test_is_input_artifact(self):
|
||||
self.assertTrue(type_annotations.is_input_artifact(Input[Model]))
|
||||
self.assertTrue(type_annotations.is_input_artifact(Input))
|
||||
@parameterized.parameters([
|
||||
Input[Model],
|
||||
Input,
|
||||
])
|
||||
def test_is_input_artifact(self, annotation):
|
||||
self.assertTrue(type_annotations.is_input_artifact(annotation))
|
||||
|
||||
self.assertFalse(type_annotations.is_input_artifact(Output[Model]))
|
||||
self.assertFalse(type_annotations.is_input_artifact(Output))
|
||||
@parameterized.parameters([
|
||||
Output[Model],
|
||||
Output,
|
||||
])
|
||||
def test_is_not_input_artifact(self, annotation):
|
||||
self.assertFalse(type_annotations.is_input_artifact(annotation))
|
||||
|
||||
def test_is_output_artifact(self):
|
||||
self.assertTrue(type_annotations.is_output_artifact(Output[Model]))
|
||||
self.assertTrue(type_annotations.is_output_artifact(Output))
|
||||
@parameterized.parameters([
|
||||
Output[Model],
|
||||
Output[List[Model]],
|
||||
])
|
||||
def test_is_output_artifact(self, annotation):
|
||||
self.assertTrue(type_annotations.is_output_artifact(annotation))
|
||||
|
||||
self.assertFalse(type_annotations.is_output_artifact(Input[Model]))
|
||||
self.assertFalse(type_annotations.is_output_artifact(Input))
|
||||
@parameterized.parameters([
|
||||
Input[Model],
|
||||
Input[List[Model]],
|
||||
Input,
|
||||
])
|
||||
def test_is_not_output_artifact(self, annotation):
|
||||
self.assertFalse(type_annotations.is_output_artifact(annotation))
|
||||
|
||||
def test_get_io_artifact_class(self):
|
||||
self.assertEqual(
|
||||
type_annotations.get_io_artifact_class(Output[Model]), Model)
|
||||
self.assertEqual(
|
||||
type_annotations.get_io_artifact_class(Output[List[Model]]), Model)
|
||||
self.assertEqual(
|
||||
type_annotations.get_io_artifact_class(Input[List[Model]]), Model)
|
||||
|
||||
self.assertEqual(type_annotations.get_io_artifact_class(Input), None)
|
||||
self.assertEqual(type_annotations.get_io_artifact_class(Output), None)
|
||||
|
@ -69,9 +97,15 @@ class AnnotationsTest(parameterized.TestCase):
|
|||
self.assertEqual(
|
||||
type_annotations.get_io_artifact_annotation(Output[Model]),
|
||||
OutputAnnotation)
|
||||
self.assertEqual(
|
||||
type_annotations.get_io_artifact_annotation(Output[List[Model]]),
|
||||
OutputAnnotation)
|
||||
self.assertEqual(
|
||||
type_annotations.get_io_artifact_annotation(Input[Model]),
|
||||
InputAnnotation)
|
||||
self.assertEqual(
|
||||
type_annotations.get_io_artifact_annotation(Input[List[Model]]),
|
||||
InputAnnotation)
|
||||
self.assertEqual(
|
||||
type_annotations.get_io_artifact_annotation(Input), InputAnnotation)
|
||||
self.assertEqual(
|
||||
|
|
Loading…
Reference in New Issue