refactor(sdk): remove dead BaseModel code (#8415)

* remove dead basemodel code

* remove unused deserialization logic

* remove more dead code

* remove basemodel entirely
This commit is contained in:
Connor McCarthy 2022-11-03 14:58:15 -07:00 committed by GitHub
parent 6e95e8988b
commit 08a8b1df97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 40 additions and 927 deletions

View File

@ -1,425 +0,0 @@
# Copyright 2022 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from collections import abc
import dataclasses
import inspect
import json
import pprint
from typing import (Any, Dict, ForwardRef, Iterable, Iterator, Mapping,
MutableMapping, MutableSequence, Optional, OrderedDict,
Sequence, Tuple, Type, TypeVar, Union)
PRIMITIVE_TYPES = {int, str, float, bool}
# typing.Optional.__origin__ is typing.Union
UNION_TYPES = {Union}
# do not need typing.List, because __origin__ is list
ITERABLE_TYPES = {
list,
abc.Sequence,
abc.MutableSequence,
Sequence,
MutableSequence,
Iterable,
}
# do not need typing.Dict, because __origin__ is dict
MAPPING_TYPES = {
dict, abc.Mapping, abc.MutableMapping, Mapping, MutableMapping, OrderedDict,
collections.OrderedDict
}
OTHER_SUPPORTED_TYPES = {type(None), Any}
SUPPORTED_TYPES = PRIMITIVE_TYPES | UNION_TYPES | ITERABLE_TYPES | MAPPING_TYPES | OTHER_SUPPORTED_TYPES
BaseModelType = TypeVar('BaseModelType', bound='BaseModel')
class BaseModel:
"""BaseModel for structures. Subclasses are converted to dataclasses at
object construction time, with.
Subclasses are dataclasses with methods to support for converting to
and from dict, type enforcement, and user-defined validation logic.
"""
_aliases: Dict[str, str] = {} # will be used by subclasses
# this quiets the mypy "Unexpected keyword argument..." errors on subclass construction
# TODO: find a way to propogate type info to subclasses
def __init__(self, *args, **kwargs):
pass
def __init_subclass__(cls) -> None:
"""Hook called at subclass definition time and at instance construction
time.
Validates that the field to type mapping provided at subclass
definition time are supported by BaseModel.
"""
cls = dataclasses.dataclass(cls)
# print(inspect.signature(cls.__init__))
for field in dataclasses.fields(cls):
cls._recursively_validate_type_is_supported(field.type)
def to_dict(self, by_alias: bool = False) -> Dict[str, Any]:
"""Recursively converts to a dictionary.
Args:
by_alias (bool, optional): Whether to use attribute name to alias field mapping provided by cls._aliases when converting to dictionary. Defaults to False.
Returns:
Dict[str, Any]: Dictionary representation of the object.
"""
return convert_object_to_dict(self, by_alias=by_alias)
def to_json(self, by_alias: bool = False) -> str:
"""Recursively converts to a JSON string.
Args:
by_alias (bool, optional): Whether to use attribute name to alias field mapping provided by cls._aliases when converting to JSON. Defaults to False.
Returns:
str: JSON representation of the object.
"""
return json.dumps(self.to_dict(by_alias=by_alias))
@classmethod
def from_dict(cls,
data: Dict[str, Any],
by_alias: bool = False) -> BaseModelType:
"""Recursively loads object from a dictionary.
Args:
data (Dict[str, Any]): Dictionary representation of the object.
by_alias (bool, optional): Whether to use attribute name to alias field mapping provided by cls._aliases when reading in dictionary. Defaults to False.
Returns:
BaseModelType: Subclass of BaseModel.
"""
return _load_basemodel_helper(cls, data, by_alias=by_alias)
@classmethod
def from_json(cls, text: str, by_alias: bool = False) -> 'BaseModel':
"""Recursively loads object from a JSON string.
Args:
text (str): JSON representation of the object.
by_alias (bool, optional): Whether to use attribute name to alias field mapping provided by cls._aliases when reading in JSON. Defaults to False.
Returns:
BaseModelType: Subclass of BaseModel.
"""
return _load_basemodel_helper(cls, json.loads(text), by_alias=by_alias)
@property
def types(self) -> Dict[str, type]:
"""Dictionary mapping field names to types."""
return {field.name: field.type for field in dataclasses.fields(self)}
@property
def fields(self) -> Tuple[dataclasses.Field, ...]:
"""The dataclass fields."""
return dataclasses.fields(self)
@classmethod
def _recursively_validate_type_is_supported(cls, type_: type) -> None:
"""Walks the type definition (generics and subtypes) and checks if it
is supported by downstream BaseModel operations.
Args:
type_ (type): Type to check.
Raises:
TypeError: If type is unsupported.
"""
if isinstance(type_, ForwardRef):
return
if type_ in SUPPORTED_TYPES or _is_basemodel(type_):
return
if _get_origin_py37(type_) not in SUPPORTED_TYPES:
raise TypeError(
f'Type {type_} is not a supported type fields on child class of {BaseModel.__name__}: {cls.__name__}.'
)
args = _get_args_py37(type_) or [Any, Any]
for arg in args:
cls._recursively_validate_type_is_supported(arg)
def __post_init__(self) -> None:
"""Hook called after object is instantiated from BaseModel.
Transforms data and validates data using user-defined logic by
calling all methods prefixed with `_transform_`, then all
methods prefixed with `_validate_`.
"""
validate_methods = [
method for method in dir(self)
if method.startswith('_transform_') and
callable(getattr(self, method))
]
for method in validate_methods:
getattr(self, method)()
validate_methods = [
method for method in dir(self) if
method.startswith('_validate_') and callable(getattr(self, method))
]
for method in validate_methods:
getattr(self, method)()
def __str__(self) -> str:
"""Returns a readable representation of the BaseModel subclass."""
return base_model_format(self)
def base_model_format(x: BaseModelType) -> str:
"""Formats a BaseModel object for improved readability.
Args:
x (BaseModelType): The subclass of BaseModel to format.
chars (int, optional): Indentation size. Defaults to 0.
Returns:
str: Readable string representation of the object.
"""
CHARS = 0
def first_level_indent(string: str, chars: int = 1) -> str:
return '\n'.join(' ' * chars + p for p in string.split('\n'))
def body_level_indent(string: str, chars=4) -> str:
a, *b = string.split('\n')
return a + '\n' + first_level_indent(
'\n'.join(b),
chars=chars,
) if b else a
def parts() -> Iterator[str]:
if dataclasses.is_dataclass(x):
yield type(x).__name__ + '('
def fields() -> Iterator[str]:
for field in dataclasses.fields(x):
nindent = CHARS + len(field.name) + 4
value = getattr(x, field.name)
rep_value = base_model_format(value)
yield (' ' * (CHARS + 3) + body_level_indent(
f'{field.name}={rep_value}', chars=nindent))
yield ',\n'.join(fields())
yield ' ' * CHARS + ')'
else:
yield pprint.pformat(x)
return '\n'.join(parts())
def convert_object_to_dict(obj: BaseModelType,
by_alias: bool) -> Dict[str, Any]:
"""Recursion helper function for converting a BaseModel and data structures
therein to a dictionary. Converts all fields that do not start with an
underscore.
Args:
obj (BaseModelType): The object to convert to a dictionary. Initially called with subclass of BaseModel.
by_alias (bool): Whether to use the attribute name to alias field mapping provided by cls._aliases when converting to dictionary.
Raises:
ValueError: If a field is missing a required value. In pracice, this should never be raised, but is included to help with debugging.
Returns:
Dict[str, Any]: The dictionary representation of the object.
"""
signature = inspect.signature(obj.__init__)
result = {}
for attr_name in signature.parameters:
if attr_name.startswith('_'):
continue
field_name = attr_name
value = getattr(obj, attr_name)
param = signature.parameters.get(attr_name, None)
if by_alias and hasattr(obj, '_aliases'):
field_name = obj._aliases.get(attr_name, attr_name)
if hasattr(value, 'to_dict'):
result[field_name] = value.to_dict(by_alias=by_alias)
elif isinstance(value, list):
result[field_name] = [
(x.to_dict(by_alias=by_alias) if hasattr(x, 'to_dict') else x)
for x in value
]
elif isinstance(value, dict):
result[field_name] = {
k:
(v.to_dict(by_alias=by_alias) if hasattr(v, 'to_dict') else v)
for k, v in value.items()
}
elif (param is not None):
result[
field_name] = value if value != param.default else param.default
else:
raise ValueError(
f'Cannot serialize {obj}. No value for {attr_name}.')
return result
def _is_basemodel(obj: Any) -> bool:
"""Checks if object is a subclass of BaseModel.
Args:
obj (Any): Any object
Returns:
bool: Is a subclass of BaseModel.
"""
return inspect.isclass(obj) and issubclass(obj, BaseModel)
def _get_origin_py37(type_: Type) -> Optional[Type]:
"""typing.get_origin is introduced in Python 3.8, but we need a get_origin
that is compatible with 3.7.
Args:
type_ (Type): A type.
Returns:
Type: The origin of `type_`.
"""
# uses typing for types
return type_.__origin__ if hasattr(type_, '__origin__') else None
def _get_args_py37(type_: Type) -> Tuple[Type]:
"""typing.get_args is introduced in Python 3.8, but we need a get_args that
is compatible with 3.7.
Args:
type_ (Type): A type.
Returns:
Tuple[Type]: The type arguments of `type_`.
"""
# uses typing for types
return type_.__args__ if hasattr(type_, '__args__') else tuple()
def _load_basemodel_helper(type_: Any, data: Any, by_alias: bool) -> Any:
"""Helper function for recursively loading a BaseModel.
Args:
type_ (Any): The type of the object to load. Typically an instance of `type`, `BaseModel` or `Any`.
data (Any): The data to load.
Returns:
Any: The loaded object.
"""
if isinstance(type_, str):
raise TypeError(
'Please do not use built-in collection types as generics (e.g., list[int]) and do not include the import line `from __future__ import annotations`. Please use the corresponding generic from typing (e.g., List[int]).'
)
# catch unsupported types early
type_or_generic = _get_origin_py37(type_) or type_
if type_or_generic not in SUPPORTED_TYPES and not _is_basemodel(type_):
raise TypeError(
f'Unsupported type: {type_}. Cannot load data into object.')
# if don't have any helpful type information, return data as is
if type_ is Any:
return data
# if type is NoneType and data is None, return data/None
if type_ is type(None):
if data is None:
return data
else:
raise TypeError(
f'Expected value None for type NoneType. Got: {data}')
# handle primitives, with typecasting
if type_ in PRIMITIVE_TYPES:
return type_(data)
# simple types are handled, now handle for container types
origin = _get_origin_py37(type_)
args = _get_args_py37(type_) or [
Any, Any
] # if there is an inner type in the generic, use it, else use Any
# recursively load iterable objects
if origin in ITERABLE_TYPES:
for arg in args: # TODO handle errors
return [
_load_basemodel_helper(arg, element, by_alias=by_alias)
for element in data
]
# recursively load mapping objects
if origin in MAPPING_TYPES:
if len(args) != 2:
raise TypeError(
f'Expected exactly 2 type arguments for mapping type {type_}.')
return {
_load_basemodel_helper(args[0], k, by_alias=by_alias):
_load_basemodel_helper(
args[1], # type: ignore
# length check a few lines up ensures index 1 exists
v,
by_alias=by_alias) for k, v in data.items()
}
# if the type is a Union, try to load the data into each of the types,
# greedily accepting the first annotation arg that works --> the developer
# can indicate which types are preferred based on the annotation arg order
if origin in UNION_TYPES:
# don't try to cast none if the union type is optional
if type(None) in args and data is None:
return None
for arg in args:
return _load_basemodel_helper(args[0], data, by_alias=by_alias)
# finally, handle the cases where the type is an instance of baseclass
if _is_basemodel(type_):
fields = dataclasses.fields(type_)
basemodel_kwargs = {}
for field in fields:
attr_name = field.name
data_field_name = attr_name
if by_alias and hasattr(type_, '_aliases'):
data_field_name = type_._aliases.get(attr_name, attr_name)
value = data.get(data_field_name)
if value is None:
if field.default is dataclasses.MISSING and field.default_factory is dataclasses.MISSING:
raise ValueError(
f'Missing required field: {data_field_name}')
value = field.default if field.default is not dataclasses.MISSING else field.default_factory(
)
else:
value = _load_basemodel_helper(
field.type, value, by_alias=by_alias)
basemodel_kwargs[attr_name] = value
return type_(**basemodel_kwargs)
raise TypeError(
f'Unknown error when loading data: {data} into type {type_}')

View File

@ -1,461 +0,0 @@
# Copyright 2022 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import functools
from typing import (Any, Dict, List, Mapping, MutableMapping, MutableSequence,
Optional, OrderedDict, Sequence, Set, Tuple, Union)
import unittest
from absl.testing import parameterized
from kfp.components import base_model
class TypeClass(base_model.BaseModel):
a: str
b: List[int]
c: Dict[str, int]
d: Union[int, str]
e: Union[int, str, bool]
f: Optional[int]
class TestBaseModel(unittest.TestCase):
def test_is_dataclass(self):
class Child(base_model.BaseModel):
x: int
child = Child(x=1)
self.assertTrue(dataclasses.is_dataclass(child))
def test_to_dict_simple(self):
class Child(base_model.BaseModel):
i: int
s: str
f: float
l: List[int]
data = {'i': 1, 's': 's', 'f': 1.0, 'l': [1, 2]}
child = Child(**data)
actual = child.to_dict()
self.assertEqual(actual, data)
self.assertEqual(child, Child.from_dict(actual))
def test_to_dict_nested(self):
class InnerChild(base_model.BaseModel):
a: str
class Child(base_model.BaseModel):
b: int
c: InnerChild
data = {'b': 1, 'c': InnerChild(a='a')}
child = Child(**data)
actual = child.to_dict()
expected = {'b': 1, 'c': {'a': 'a'}}
self.assertEqual(actual, expected)
self.assertEqual(child, Child.from_dict(actual))
def test_from_dict_no_defaults(self):
class Child(base_model.BaseModel):
i: int
s: str
f: float
l: List[int]
data = {'i': 1, 's': 's', 'f': 1.0, 'l': [1, 2]}
child = Child.from_dict(data)
self.assertEqual(child.i, 1)
self.assertEqual(child.s, 's')
self.assertEqual(child.f, 1.0)
self.assertEqual(child.l, [1, 2])
self.assertEqual(child.to_dict(), data)
def test_from_dict_with_defaults(self):
class Child(base_model.BaseModel):
s: str
f: float
l: List[int]
i: int = 1
data = {'s': 's', 'f': 1.0, 'l': [1, 2]}
child = Child.from_dict(data)
self.assertEqual(child.i, 1)
self.assertEqual(child.s, 's')
self.assertEqual(child.f, 1.0)
self.assertEqual(child.l, [1, 2])
self.assertEqual(child.to_dict(), {**data, **{'i': 1}})
def test_from_dict_nested(self):
class InnerChild(base_model.BaseModel):
a: str
class Child(base_model.BaseModel):
b: int
c: InnerChild
data = {'b': 1, 'c': {'a': 'a'}}
child = Child.from_dict(data)
self.assertEqual(child.b, 1)
self.assertIsInstance(child.c, InnerChild)
self.assertEqual(child.c.a, 'a')
self.assertEqual(child.to_dict(), data)
def test_from_dict_array_nested(self):
class InnerChild(base_model.BaseModel):
a: str
class Child(base_model.BaseModel):
b: int
c: List[InnerChild]
d: Dict[str, InnerChild]
inner_child_data = {'a': 'a'}
data = {
'b': 1,
'c': [inner_child_data, inner_child_data],
'd': {
'e': inner_child_data
}
}
child = Child.from_dict(data)
self.assertEqual(child.b, 1)
self.assertIsInstance(child.c[0], InnerChild)
self.assertIsInstance(child.c[1], InnerChild)
self.assertIsInstance(child.d['e'], InnerChild)
self.assertEqual(child.c[0].a, 'a')
self.assertEqual(child.to_dict(), data)
def test_from_dict_by_alias(self):
class InnerChild(base_model.BaseModel):
inner_child_field: int
_aliases = {'inner_child_field': 'inner_child_field_alias'}
class Child(base_model.BaseModel):
sub_field: InnerChild
_aliases = {'sub_field': 'sub_field_alias'}
data = {'sub_field_alias': {'inner_child_field_alias': 2}}
child = Child.from_dict(data, by_alias=True)
self.assertIsInstance(child.sub_field, InnerChild)
self.assertEqual(child.sub_field.inner_child_field, 2)
self.assertEqual(child.to_dict(by_alias=True), data)
def test_to_dict_by_alias(self):
class InnerChild(base_model.BaseModel):
inner_child_field: int
_aliases = {'inner_child_field': 'inner_child_field_alias'}
class Child(base_model.BaseModel):
sub_field: InnerChild
_aliases = {'sub_field': 'sub_field_alias'}
inner_child = InnerChild(inner_child_field=2)
child = Child(sub_field=inner_child)
actual = child.to_dict(by_alias=True)
expected = {'sub_field_alias': {'inner_child_field_alias': 2}}
self.assertEqual(actual, expected)
self.assertEqual(Child.from_dict(actual, by_alias=True), child)
def test_to_dict_by_alias2(self):
class MyClass(base_model.BaseModel):
x: int
y: List[int]
z: Dict[str, int]
_aliases = {'x': 'a', 'z': 'b'}
res = MyClass(x=1, y=[2], z={'key': 3}).to_dict(by_alias=True)
self.assertEqual(res, {'a': 1, 'y': [2], 'b': {'key': 3}})
def test_to_dict_by_alias_nested(self):
class InnerClass(base_model.BaseModel):
f: float
_aliases = {'f': 'a'}
class MyClass(base_model.BaseModel):
x: int
y: List[int]
z: InnerClass
_aliases = {'x': 'a', 'z': 'b'}
res = MyClass(x=1, y=[2], z=InnerClass(f=1.0)).to_dict(by_alias=True)
self.assertEqual(res, {'a': 1, 'y': [2], 'b': {'a': 1.0}})
def test_can_create_properties_using_attributes(self):
class Child(base_model.BaseModel):
x: Optional[int]
@property
def prop(self) -> bool:
return self.x is not None
child1 = Child(x=None)
self.assertEqual(child1.prop, False)
child2 = Child(x=1)
self.assertEqual(child2.prop, True)
def test_unsupported_type_success(self):
class OtherClass(base_model.BaseModel):
x: int
class MyClass(base_model.BaseModel):
a: OtherClass
def test_unsupported_type_failures(self):
with self.assertRaisesRegex(TypeError, r'not a supported'):
class MyClass(base_model.BaseModel):
a: tuple
with self.assertRaisesRegex(TypeError, r'not a supported'):
class MyClass(base_model.BaseModel):
a: Tuple
with self.assertRaisesRegex(TypeError, r'not a supported'):
class MyClass(base_model.BaseModel):
a: Set
with self.assertRaisesRegex(TypeError, r'not a supported'):
class OtherClass:
pass
class MyClass(base_model.BaseModel):
a: OtherClass
def test_base_model_validation(self):
# test exception thrown
class MyClass(base_model.BaseModel):
x: int
def _validate_x(self) -> None:
if self.x < 2:
raise ValueError('x must be greater than 2')
with self.assertRaisesRegex(ValueError, 'x must be greater than 2'):
mc = MyClass(x=1)
# test value modified same type
class MyClass(base_model.BaseModel):
x: int
def _validate_x(self) -> None:
self.x = max(self.x, 2)
mc = MyClass(x=1)
self.assertEqual(mc.x, 2)
# test value modified new type
class MyClass(base_model.BaseModel):
x: Optional[List[int]] = None
def _validate_x(self) -> None:
if isinstance(self.x, list) and not self.x:
self.x = None
mc = MyClass(x=[])
self.assertEqual(mc.x, None)
def test_can_set_field(self):
class MyClass(base_model.BaseModel):
x: int
mc = MyClass(x=2)
mc.x = 1
self.assertEqual(mc.x, 1)
def test_can_use_default_factory(self):
class MyClass(base_model.BaseModel):
x: List[int] = dataclasses.field(default_factory=list)
mc = MyClass()
self.assertEqual(mc.x, [])
class TestIsBaseModel(unittest.TestCase):
def test_true(self):
self.assertEqual(base_model._is_basemodel(base_model.BaseModel), True)
class MyClass(base_model.BaseModel):
pass
self.assertEqual(base_model._is_basemodel(MyClass), True)
def test_false(self):
self.assertEqual(base_model._is_basemodel(int), False)
self.assertEqual(base_model._is_basemodel(1), False)
self.assertEqual(base_model._is_basemodel(str), False)
class TestLoadBaseModelHelper(parameterized.TestCase):
def setUp(self):
self.no_alias_load_base_model_helper = functools.partial(
base_model._load_basemodel_helper, by_alias=False)
def test_load_primitive(self):
self.assertEqual(self.no_alias_load_base_model_helper(str, 'a'), 'a')
self.assertEqual(self.no_alias_load_base_model_helper(int, 1), 1)
self.assertEqual(self.no_alias_load_base_model_helper(float, 1.0), 1.0)
self.assertEqual(self.no_alias_load_base_model_helper(bool, True), True)
self.assertEqual(
self.no_alias_load_base_model_helper(type(None), None), None)
def test_load_primitive_with_casting(self):
self.assertEqual(self.no_alias_load_base_model_helper(int, '1'), 1)
self.assertEqual(self.no_alias_load_base_model_helper(str, 1), '1')
self.assertEqual(self.no_alias_load_base_model_helper(float, 1), 1.0)
self.assertEqual(self.no_alias_load_base_model_helper(int, 1.0), 1)
self.assertEqual(self.no_alias_load_base_model_helper(bool, 1), True)
self.assertEqual(self.no_alias_load_base_model_helper(bool, 0), False)
self.assertEqual(self.no_alias_load_base_model_helper(int, True), 1)
self.assertEqual(self.no_alias_load_base_model_helper(int, False), 0)
self.assertEqual(
self.no_alias_load_base_model_helper(bool, None), False)
def test_load_none(self):
self.assertEqual(
self.no_alias_load_base_model_helper(type(None), None), None)
with self.assertRaisesRegex(TypeError, ''):
self.no_alias_load_base_model_helper(type(None), 1)
@parameterized.parameters(['a', 1, 1.0, True, False, None, ['list']])
def test_load_any(self, data: Any):
self.assertEqual(self.no_alias_load_base_model_helper(Any, data),
data) # type: ignore
def test_load_list(self):
self.assertEqual(
self.no_alias_load_base_model_helper(List[str], ['a']), ['a'])
self.assertEqual(
self.no_alias_load_base_model_helper(List[int], [1, 1]), [1, 1])
self.assertEqual(
self.no_alias_load_base_model_helper(List[float], [1.0]), [1.0])
self.assertEqual(
self.no_alias_load_base_model_helper(List[bool], [True]), [True])
self.assertEqual(
self.no_alias_load_base_model_helper(List[type(None)], [None]),
[None])
def test_load_primitive_other_iterables(self):
self.assertEqual(
self.no_alias_load_base_model_helper(Sequence[bool], [True]),
[True])
self.assertEqual(
self.no_alias_load_base_model_helper(MutableSequence[type(None)],
[None]), [None])
self.assertEqual(
self.no_alias_load_base_model_helper(Sequence[str], ['a']), ['a'])
def test_load_dict(self):
self.assertEqual(
self.no_alias_load_base_model_helper(Dict[str, str], {'a': 'a'}),
{'a': 'a'})
self.assertEqual(
self.no_alias_load_base_model_helper(Dict[str, int], {'a': 1}),
{'a': 1})
self.assertEqual(
self.no_alias_load_base_model_helper(Dict[str, float], {'a': 1.0}),
{'a': 1.0})
self.assertEqual(
self.no_alias_load_base_model_helper(Dict[str, bool], {'a': True}),
{'a': True})
self.assertEqual(
self.no_alias_load_base_model_helper(Dict[str, type(None)],
{'a': None}), {'a': None})
def test_load_mapping(self):
self.assertEqual(
self.no_alias_load_base_model_helper(Mapping[str, float],
{'a': 1.0}), {'a': 1.0})
self.assertEqual(
self.no_alias_load_base_model_helper(MutableMapping[str, bool],
{'a': True}), {'a': True})
self.assertEqual(
self.no_alias_load_base_model_helper(OrderedDict[str,
type(None)],
{'a': None}), {'a': None})
def test_load_union_types(self):
self.assertEqual(
self.no_alias_load_base_model_helper(Union[str, int], 'a'), 'a')
self.assertEqual(
self.no_alias_load_base_model_helper(Union[str, int], 1), '1')
self.assertEqual(
self.no_alias_load_base_model_helper(Union[int, str], 1), 1)
self.assertEqual(
self.no_alias_load_base_model_helper(Union[int, str], '1'), 1)
def test_load_optional_types(self):
self.assertEqual(
self.no_alias_load_base_model_helper(Optional[str], 'a'), 'a')
self.assertEqual(
self.no_alias_load_base_model_helper(Optional[str], None), None)
def test_unsupported_type(self):
with self.assertRaisesRegex(TypeError, r'Unsupported type:'):
self.no_alias_load_base_model_helper(Set[int], {1})
class TestGetOriginPy37(parameterized.TestCase):
def test_is_same_as_typing_version(self):
import sys
if sys.version_info.major == 3 and sys.version_info.minor >= 8:
import typing
for field in dataclasses.fields(TypeClass):
self.assertEqual(
base_model._get_origin_py37(field.type),
typing.get_origin(field.type))
class TestGetArgsPy37(parameterized.TestCase):
def test_is_same_as_typing_version(self):
import sys
if sys.version_info.major == 3 and sys.version_info.minor >= 8:
import typing
for field in dataclasses.fields(TypeClass):
self.assertEqual(
base_model._get_args_py37(field.type),
typing.get_args(field.type))
if __name__ == '__main__':
unittest.main()

View File

@ -18,12 +18,11 @@ import abc
import json
from typing import Any, Dict, List, Optional, Union
from kfp.components import base_model
from kfp.components import utils
from kfp.components.types import type_utils
class Placeholder(abc.ABC, base_model.BaseModel):
class Placeholder(abc.ABC):
@abc.abstractmethod
def _to_string(self) -> str:

View File

@ -15,6 +15,7 @@
import ast
import collections
import dataclasses
import functools
import itertools
import re
@ -23,7 +24,6 @@ import uuid
from google.protobuf import json_format
import kfp
from kfp.components import base_model
from kfp.components import placeholders
from kfp.components import utils
from kfp.components import v1_components
@ -37,7 +37,8 @@ from kfp.pipeline_spec import pipeline_spec_pb2
import yaml
class InputSpec_(base_model.BaseModel):
@dataclasses.dataclass
class InputSpec_:
"""Component input definitions. (Inner class).
Attributes:
@ -50,7 +51,7 @@ class InputSpec_(base_model.BaseModel):
# Hack to allow access to __init__ arguments for setting _optional value
class InputSpec(InputSpec_, base_model.BaseModel):
class InputSpec(InputSpec_):
"""Component input definitions.
Attributes:
@ -68,6 +69,9 @@ class InputSpec(InputSpec_, base_model.BaseModel):
super().__init__(*args, **kwargs)
self._optional = 'default' in kwargs
def __post_init__(self) -> None:
self._validate_type()
@classmethod
def from_ir_component_inputs_dict(
cls, ir_component_inputs_dict: Dict[str, Any]) -> 'InputSpec':
@ -132,7 +136,8 @@ class InputSpec(InputSpec_, base_model.BaseModel):
type_utils.validate_bundled_artifact_type(self.type)
class OutputSpec(base_model.BaseModel):
@dataclasses.dataclass
class OutputSpec:
"""Component output definitions.
Attributes:
@ -140,6 +145,9 @@ class OutputSpec(base_model.BaseModel):
"""
type: Union[str, dict]
def __post_init__(self) -> None:
self._validate_type()
@classmethod
def from_ir_component_outputs_dict(
cls, ir_component_outputs_dict: Dict[str, Any]) -> 'OutputSpec':
@ -206,7 +214,8 @@ def spec_type_is_parameter(type_: str) -> bool:
return in_memory_type in type_utils.IN_MEMORY_SPEC_TYPE_TO_IR_TYPE or in_memory_type == 'PipelineTaskFinalStatus'
class ResourceSpec(base_model.BaseModel):
@dataclasses.dataclass
class ResourceSpec:
"""The resource requirements of a container execution.
Attributes:
@ -222,7 +231,8 @@ class ResourceSpec(base_model.BaseModel):
accelerator_count: Optional[int] = None
class ContainerSpec(base_model.BaseModel):
@dataclasses.dataclass
class ContainerSpec:
"""Container definition.
This is only used for pipeline authors when constructing a containerized component
@ -259,7 +269,8 @@ class ContainerSpec(base_model.BaseModel):
"""Arguments to the container entrypoint."""
class ContainerSpecImplementation(base_model.BaseModel):
@dataclasses.dataclass
class ContainerSpecImplementation:
"""Container implementation definition."""
image: str
"""Container image."""
@ -276,6 +287,11 @@ class ContainerSpecImplementation(base_model.BaseModel):
resources: Optional[ResourceSpec] = None
"""Specification on the resource requirements."""
def __post_init__(self) -> None:
self._transform_command()
self._transform_args()
self._transform_env()
def _transform_command(self) -> None:
"""Use None instead of empty list for command."""
self.command = None if self.command == [] else self.command
@ -322,7 +338,8 @@ class ContainerSpecImplementation(base_model.BaseModel):
resources=None) # can only be set on tasks
class RetryPolicy(base_model.BaseModel):
@dataclasses.dataclass
class RetryPolicy:
"""The retry policy of a container execution.
Attributes:
@ -355,16 +372,9 @@ class RetryPolicy(base_model.BaseModel):
'backoff_max_duration': backoff_max_duration_seconds,
}, pipeline_spec_pb2.PipelineTaskSpec.RetryPolicy())
@classmethod
def from_proto(
cls, retry_policy_proto: pipeline_spec_pb2.PipelineTaskSpec.RetryPolicy
) -> 'RetryPolicy':
return cls.from_dict(
json_format.MessageToDict(
retry_policy_proto, preserving_proto_field_name=True))
class TaskSpec(base_model.BaseModel):
@dataclasses.dataclass
class TaskSpec:
"""The spec of a pipeline task.
Attributes:
@ -399,7 +409,8 @@ class TaskSpec(base_model.BaseModel):
retry_policy: Optional[RetryPolicy] = None
class ImporterSpec(base_model.BaseModel):
@dataclasses.dataclass
class ImporterSpec:
"""ImporterSpec definition.
Attributes:
@ -417,7 +428,8 @@ class ImporterSpec(base_model.BaseModel):
metadata: Optional[Mapping[str, Any]] = None
class Implementation(base_model.BaseModel):
@dataclasses.dataclass
class Implementation:
"""Implementation definition.
Attributes:
@ -506,7 +518,8 @@ def check_placeholder_references_valid_io_name(
raise TypeError(f'Unexpected argument "{arg}" of type {type(arg)}.')
class ComponentSpec(base_model.BaseModel):
@dataclasses.dataclass
class ComponentSpec:
"""The definition of a component.
Attributes:
@ -523,6 +536,12 @@ class ComponentSpec(base_model.BaseModel):
inputs: Optional[Dict[str, InputSpec]] = None
outputs: Optional[Dict[str, OutputSpec]] = None
def __post_init__(self) -> None:
self._transform_name()
self._transform_inputs()
self._transform_outputs()
self._validate_placeholders()
def _transform_name(self) -> None:
"""Converts the name to a valid name."""
self.name = utils.maybe_rename_for_k8s(self.name)

View File

@ -19,13 +19,11 @@ import textwrap
import unittest
from absl.testing import parameterized
from google.protobuf import json_format
from kfp import compiler
from kfp import dsl
from kfp.components import component_factory
from kfp.components import placeholders
from kfp.components import structures
from kfp.pipeline_spec import pipeline_spec_pb2
V1_YAML_IF_PLACEHOLDER = textwrap.dedent("""\
implementation:
@ -939,23 +937,6 @@ class TestRetryPolicy(unittest.TestCase):
# tests cap
self.assertEqual(retry_policy_proto.backoff_max_duration.seconds, 3600)
def test_from_proto(self):
retry_policy_proto = json_format.ParseDict(
{
'max_retry_count': 3,
'backoff_duration': '5s',
'backoff_factor': 1.0,
'backoff_max_duration': '1s'
}, pipeline_spec_pb2.PipelineTaskSpec.RetryPolicy())
retry_policy_struct = structures.RetryPolicy.from_proto(
retry_policy_proto)
print(retry_policy_struct)
self.assertEqual(retry_policy_struct.max_retry_count, 3)
self.assertEqual(retry_policy_struct.backoff_duration, '5s')
self.assertEqual(retry_policy_struct.backoff_factor, 1.0)
self.assertEqual(retry_policy_struct.backoff_max_duration, '1s')
if __name__ == '__main__':
unittest.main()