feat(sdk): support list of artifacts annotations [list of artifacts support pt. 1] (#8464)

* update get_io_artifact_class function

* update various type annotation tests

* support list of artifacts in component interfaces

* move helper function

* clarify function name

* respond to review feedback
This commit is contained in:
Connor McCarthy 2022-11-29 21:24:15 -08:00 committed by GitHub
parent 249e7af384
commit e7c82c0593
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 229 additions and 41 deletions

View File

@ -164,20 +164,23 @@ def extract_component_interface(
passing_style = None
io_name = parameter.name
if type_annotations.is_artifact_annotation(parameter_type):
if type_annotations.is_Input_Output_artifact_annotation(parameter_type):
# passing_style is either type_annotations.InputAnnotation or
# type_annotations.OutputAnnotation.
passing_style = type_annotations.get_io_artifact_annotation(
parameter_type)
# parameter_type is type_annotations.Artifact or one of its subclasses.
# parameter_type is a type like typing_extensions.Annotated[kfp.components.types.artifact_types.Artifact, <class 'kfp.components.types.type_annotations.OutputAnnotation'>] OR typing_extensions.Annotated[typing.List[kfp.components.types.artifact_types.Artifact], <class 'kfp.components.types.type_annotations.OutputAnnotation'>]
is_artifact_list = type_annotations.is_list_of_artifacts(
parameter_type.__origin__)
parameter_type = type_annotations.get_io_artifact_class(
parameter_type)
if not type_annotations.is_artifact_class(parameter_type):
raise ValueError(
'Input[T] and Output[T] are only supported when T is a '
'subclass of Artifact. Found `{} with type {}`'.format(
io_name, parameter_type))
'Input[T] and Output[T] are only supported when T is an artifact or list of artifacts. Found `{} with type {}`'
.format(io_name, parameter_type))
if parameter.default is not inspect.Parameter.empty:
raise ValueError(
@ -212,7 +215,8 @@ def extract_component_interface(
schema_version = parameter_type.schema_version
output_spec = structures.OutputSpec(
type=type_utils.create_bundled_artifact_type(
type_struct, schema_version))
type_struct, schema_version),
is_artifact_list=is_artifact_list)
else:
output_spec = structures.OutputSpec(type=type_struct)
outputs[io_name] = output_spec
@ -223,7 +227,9 @@ def extract_component_interface(
schema_version = parameter_type.schema_version
input_spec = structures.InputSpec(
type=type_utils.create_bundled_artifact_type(
type_struct, schema_version))
type_struct, schema_version),
is_artifact_list=is_artifact_list,
)
else:
if parameter.default is not inspect.Parameter.empty:
input_spec = structures.InputSpec(
@ -248,14 +254,18 @@ def extract_component_interface(
for field_name in return_ann._fields:
output_name = _maybe_make_unique(field_name, output_names)
output_names.add(output_name)
annotation = field_annotations.get(field_name)
if type_annotations.is_artifact_class(annotation):
type_var = field_annotations.get(field_name)
if type_annotations.is_list_of_artifacts(type_var):
raise ValueError(
f'Cannot use output lists of artifacts in NamedTuple return annotations. Got output list of artifacts annotation for NamedTuple field `{field_name}`.'
)
elif type_annotations.is_artifact_class(type_var):
output_spec = structures.OutputSpec(
type=type_utils.create_bundled_artifact_type(
annotation.schema_title, annotation.schema_version))
type_var.schema_title, type_var.schema_version))
else:
type_struct = type_utils._annotation_to_type_struct(
annotation)
type_var)
output_spec = structures.OutputSpec(type=type_struct)
outputs[output_name] = output_spec
# Deprecated dict-based way of declaring multiple outputs. Was only used by
@ -278,10 +288,17 @@ def extract_component_interface(
# `def func(output_path: OutputPath()) -> str: ...`
output_names.add(output_name)
return_ann = signature.return_annotation
if type_annotations.is_artifact_class(signature.return_annotation):
if type_annotations.is_list_of_artifacts(return_ann):
artifact_cls = return_ann.__args__[0]
output_spec = structures.OutputSpec(
type=type_utils.create_bundled_artifact_type(
return_ann.schema_title, return_ann.schema_version))
artifact_cls.schema_title, artifact_cls.schema_version),
is_artifact_list=True)
elif type_annotations.is_artifact_class(return_ann):
output_spec = structures.OutputSpec(
type=type_utils.create_bundled_artifact_type(
return_ann.schema_title, return_ann.schema_version),
is_artifact_list=False)
else:
type_struct = type_utils._annotation_to_type_struct(return_ann)
output_spec = structures.OutputSpec(type=type_struct)

View File

@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import unittest
from kfp import dsl
from kfp.components import component_factory
from kfp.components import structures
from kfp.components.component_decorator import component
from kfp.components.types.type_annotations import OutputPath
@ -69,5 +72,110 @@ class TestInvalidParameterName(unittest.TestCase):
pass
from kfp.components.types.artifact_types import Artifact
from kfp.components.types.artifact_types import Model
from kfp.dsl import Input
from kfp.dsl import Output
class TestExtractComponentInterfaceListofArtifacts(unittest.TestCase):
def test_python_component_input(self):
def comp(i: Input[List[Model]]):
...
component_spec = component_factory.extract_component_interface(comp)
self.assertEqual(component_spec.name, 'comp')
self.assertEqual(component_spec.description, None)
self.assertEqual(
component_spec.inputs, {
'i':
structures.InputSpec(
type='system.Model@0.0.1',
default=None,
is_artifact_list=True)
})
def test_custom_container_component_input(self):
def comp(i: Input[List[Artifact]]):
...
component_spec = component_factory.extract_component_interface(
comp, containerized=True)
self.assertEqual(component_spec.name, 'comp')
self.assertEqual(component_spec.description, None)
self.assertEqual(
component_spec.inputs, {
'i':
structures.InputSpec(
type='system.Artifact@0.0.1',
default=None,
is_artifact_list=True)
})
def test_pipeline_input(self):
def comp(i: Input[List[Model]]):
...
component_spec = component_factory.extract_component_interface(comp)
self.assertEqual(component_spec.name, 'comp')
self.assertEqual(component_spec.description, None)
self.assertEqual(
component_spec.inputs, {
'i':
structures.InputSpec(
type='system.Model@0.0.1',
default=None,
is_artifact_list=True)
})
def test_pipeline_with_named_tuple_fn(self):
from typing import NamedTuple
def comp(
i: Input[List[Model]]
) -> NamedTuple('outputs', [('output_list', List[Artifact])]):
...
with self.assertRaisesRegex(
ValueError,
r'Cannot use output lists of artifacts in NamedTuple return annotations. Got output list of artifacts annotation for NamedTuple field `output_list`\.'
):
component_factory.extract_component_interface(comp)
class TestOutputListsOfArtifactsTemporarilyBlocked(unittest.TestCase):
def test_python_component(self):
with self.assertRaisesRegex(
NotImplementedError,
r'Output lists of artifacts are not yet supported\.'):
@dsl.component
def comp(output_list: Output[List[Artifact]]):
...
def test_container_component(self):
with self.assertRaisesRegex(
NotImplementedError,
r'Output lists of artifacts are not yet supported\.'):
@dsl.container_component
def comp(output_list: Output[List[Artifact]]):
return dsl.ContainerSpec(image='alpine')
def test_pipeline(self):
with self.assertRaisesRegex(
NotImplementedError,
r'Output lists of artifacts are not yet supported\.'):
@dsl.pipeline
def comp() -> List[Artifact]:
...
if __name__ == '__main__':
unittest.main()

View File

@ -302,7 +302,7 @@ class Executor():
if value is not None:
func_kwargs[k] = value
elif type_annotations.is_artifact_annotation(v):
elif type_annotations.is_Input_Output_artifact_annotation(v):
if type_annotations.is_input_artifact(v):
func_kwargs[k] = self._get_input_artifact(k)
if type_annotations.is_output_artifact(v):

View File

@ -44,10 +44,13 @@ class InputSpec:
type: The type of the input.
default (optional): the default value for the input.
optional: Wether the input is optional. An input is optional when it has an explicit default value.
is_artifact_list: True if `type` represents a list of the artifact type. Only applies when `type` is an artifact.
"""
type: Union[str, dict]
default: Optional[Any] = None
optional: bool = False
# This special flag for lists of artifacts allows type to be used the same way for list of artifacts and single artifacts. This is aligned with how IR represents lists of artifacts (same as for single artifacts), as well as simplifies downstream type handling/checking operations in the SDK since we don't need to parse the string `type` to determine if single artifact or list.
is_artifact_list: bool = False
def __post_init__(self) -> None:
self._validate_type()
@ -132,11 +135,16 @@ class OutputSpec:
Attributes:
type: The type of the output.
is_artifact_list: True if `type` represents a list of the artifact type. Only applies when `type` is an artifact.
"""
type: Union[str, dict]
# This special flag for lists of artifacts allows type to be used the same way for list of artifacts and single artifacts. This is aligned with how IR represents lists of artifacts (same as for single artifacts), as well as simplifies downstream type handling/checking operations in the SDK since we don't need to parse the string `type` to determine if single artifact or list.
is_artifact_list: bool = False
def __post_init__(self) -> None:
self._validate_type()
# TODO: remove this method when we support output lists of artifacts
self._prevent_using_output_lists_of_artifacts()
@classmethod
def from_ir_component_outputs_dict(
@ -196,6 +204,11 @@ class OutputSpec:
if not spec_type_is_parameter(self.type):
type_utils.validate_bundled_artifact_type(self.type)
def _prevent_using_output_lists_of_artifacts(self):
if self.is_artifact_list:
raise NotImplementedError(
'Output lists of artifacts are not yet supported.')
def spec_type_is_parameter(type_: str) -> bool:
in_memory_type = type_annotations.maybe_strip_optional_from_annotation_string(

View File

@ -49,7 +49,7 @@ def get_param_to_custom_artifact_class(func: Callable) -> Dict[str, type]:
signature = inspect.signature(func)
for name, param in signature.parameters.items():
annotation = param.annotation
if type_annotations.is_artifact_annotation(annotation):
if type_annotations.is_Input_Output_artifact_annotation(annotation):
artifact_class = type_annotations.get_io_artifact_class(annotation)
if artifact_class not in kfp_artifact_classes:
param_to_artifact_cls[name] = artifact_class

View File

@ -17,8 +17,9 @@ These are only compatible with v2 Pipelines.
"""
import re
from typing import Type, TypeVar, Union
from typing import List, Type, TypeVar, Union
from kfp.components.types import artifact_types
from kfp.components.types import type_annotations
from kfp.components.types import type_utils
@ -120,7 +121,7 @@ class OutputAnnotation:
"""Marker type for output artifacts."""
def is_artifact_annotation(typ) -> bool:
def is_Input_Output_artifact_annotation(typ) -> bool:
if not hasattr(typ, '__metadata__'):
return False
@ -132,7 +133,7 @@ def is_artifact_annotation(typ) -> bool:
def is_input_artifact(typ) -> bool:
"""Returns True if typ is of type Input[T]."""
if not is_artifact_annotation(typ):
if not is_Input_Output_artifact_annotation(typ):
return False
return typ.__metadata__[0] == InputAnnotation
@ -140,7 +141,7 @@ def is_input_artifact(typ) -> bool:
def is_output_artifact(typ) -> bool:
"""Returns True if typ is of type Output[T]."""
if not is_artifact_annotation(typ):
if not is_Input_Output_artifact_annotation(typ):
return False
return typ.__metadata__[0] == OutputAnnotation
@ -149,16 +150,21 @@ def is_output_artifact(typ) -> bool:
def get_io_artifact_class(typ):
from kfp.dsl import Input
from kfp.dsl import Output
if not is_artifact_annotation(typ):
if not is_Input_Output_artifact_annotation(typ):
return None
if typ == Input or typ == Output:
return None
return typ.__args__[0]
# extract inner type from list of artifacts
inner = typ.__args__[0]
if hasattr(inner, '__origin__') and inner.__origin__ == list:
return inner.__args__[0]
return inner
def get_io_artifact_annotation(typ):
if not is_artifact_annotation(typ):
if not is_Input_Output_artifact_annotation(typ):
return None
return typ.__metadata__[0]
@ -223,3 +229,13 @@ def is_artifact_class(artifact_class_or_instance: Type) -> bool:
# we do not yet support non-pre-registered custom artifact types with instance_schema attribute
return hasattr(artifact_class_or_instance, 'schema_title') and hasattr(
artifact_class_or_instance, 'schema_version')
def is_list_of_artifacts(
type_var: Union[Type[List[artifact_types.Artifact]],
Type[artifact_types.Artifact]]
) -> bool:
# the type annotation for this function's `type_var` parameter may not actually be a subclass of the KFP SDK's Artifact class for custom artifact types
return getattr(type_var, '__origin__',
None) == list and type_annotations.is_artifact_class(
type_var.__args__[0])

View File

@ -30,35 +30,63 @@ from kfp.dsl import Output
class AnnotationsTest(parameterized.TestCase):
def test_is_artifact_annotation(self):
self.assertTrue(type_annotations.is_artifact_annotation(Input[Model]))
self.assertTrue(type_annotations.is_artifact_annotation(Output[Model]))
@parameterized.parameters([
Input[Model],
Output[Model],
Output[List[Model]],
Output['MyArtifact'],
])
def test_is_artifact_annotation(self, annotation):
self.assertTrue(
type_annotations.is_artifact_annotation(Output['MyArtifact']))
type_annotations.is_Input_Output_artifact_annotation(annotation))
self.assertFalse(type_annotations.is_artifact_annotation(Model))
self.assertFalse(type_annotations.is_artifact_annotation(int))
self.assertFalse(type_annotations.is_artifact_annotation('Dataset'))
self.assertFalse(type_annotations.is_artifact_annotation(List[str]))
self.assertFalse(type_annotations.is_artifact_annotation(Optional[str]))
@parameterized.parameters([
Model,
int,
'Dataset',
List[str],
Optional[str],
])
def test_is_not_artifact_annotation(self, annotation):
self.assertFalse(
type_annotations.is_Input_Output_artifact_annotation(annotation))
def test_is_input_artifact(self):
self.assertTrue(type_annotations.is_input_artifact(Input[Model]))
self.assertTrue(type_annotations.is_input_artifact(Input))
@parameterized.parameters([
Input[Model],
Input,
])
def test_is_input_artifact(self, annotation):
self.assertTrue(type_annotations.is_input_artifact(annotation))
self.assertFalse(type_annotations.is_input_artifact(Output[Model]))
self.assertFalse(type_annotations.is_input_artifact(Output))
@parameterized.parameters([
Output[Model],
Output,
])
def test_is_not_input_artifact(self, annotation):
self.assertFalse(type_annotations.is_input_artifact(annotation))
def test_is_output_artifact(self):
self.assertTrue(type_annotations.is_output_artifact(Output[Model]))
self.assertTrue(type_annotations.is_output_artifact(Output))
@parameterized.parameters([
Output[Model],
Output[List[Model]],
])
def test_is_output_artifact(self, annotation):
self.assertTrue(type_annotations.is_output_artifact(annotation))
self.assertFalse(type_annotations.is_output_artifact(Input[Model]))
self.assertFalse(type_annotations.is_output_artifact(Input))
@parameterized.parameters([
Input[Model],
Input[List[Model]],
Input,
])
def test_is_not_output_artifact(self, annotation):
self.assertFalse(type_annotations.is_output_artifact(annotation))
def test_get_io_artifact_class(self):
self.assertEqual(
type_annotations.get_io_artifact_class(Output[Model]), Model)
self.assertEqual(
type_annotations.get_io_artifact_class(Output[List[Model]]), Model)
self.assertEqual(
type_annotations.get_io_artifact_class(Input[List[Model]]), Model)
self.assertEqual(type_annotations.get_io_artifact_class(Input), None)
self.assertEqual(type_annotations.get_io_artifact_class(Output), None)
@ -69,9 +97,15 @@ class AnnotationsTest(parameterized.TestCase):
self.assertEqual(
type_annotations.get_io_artifact_annotation(Output[Model]),
OutputAnnotation)
self.assertEqual(
type_annotations.get_io_artifact_annotation(Output[List[Model]]),
OutputAnnotation)
self.assertEqual(
type_annotations.get_io_artifact_annotation(Input[Model]),
InputAnnotation)
self.assertEqual(
type_annotations.get_io_artifact_annotation(Input[List[Model]]),
InputAnnotation)
self.assertEqual(
type_annotations.get_io_artifact_annotation(Input), InputAnnotation)
self.assertEqual(