refactor(sdk): remove dead BaseModel code (#8415)
* remove dead basemodel code * remove unused deserialization logic * remove more dead code * remove basemodel entirely
This commit is contained in:
parent
6e95e8988b
commit
08a8b1df97
|
@ -1,425 +0,0 @@
|
|||
# Copyright 2022 The Kubeflow Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
from collections import abc
|
||||
import dataclasses
|
||||
import inspect
|
||||
import json
|
||||
import pprint
|
||||
from typing import (Any, Dict, ForwardRef, Iterable, Iterator, Mapping,
|
||||
MutableMapping, MutableSequence, Optional, OrderedDict,
|
||||
Sequence, Tuple, Type, TypeVar, Union)
|
||||
|
||||
PRIMITIVE_TYPES = {int, str, float, bool}
|
||||
|
||||
# typing.Optional.__origin__ is typing.Union
|
||||
UNION_TYPES = {Union}
|
||||
|
||||
# do not need typing.List, because __origin__ is list
|
||||
ITERABLE_TYPES = {
|
||||
list,
|
||||
abc.Sequence,
|
||||
abc.MutableSequence,
|
||||
Sequence,
|
||||
MutableSequence,
|
||||
Iterable,
|
||||
}
|
||||
|
||||
# do not need typing.Dict, because __origin__ is dict
|
||||
MAPPING_TYPES = {
|
||||
dict, abc.Mapping, abc.MutableMapping, Mapping, MutableMapping, OrderedDict,
|
||||
collections.OrderedDict
|
||||
}
|
||||
|
||||
OTHER_SUPPORTED_TYPES = {type(None), Any}
|
||||
SUPPORTED_TYPES = PRIMITIVE_TYPES | UNION_TYPES | ITERABLE_TYPES | MAPPING_TYPES | OTHER_SUPPORTED_TYPES
|
||||
|
||||
BaseModelType = TypeVar('BaseModelType', bound='BaseModel')
|
||||
|
||||
|
||||
class BaseModel:
|
||||
"""BaseModel for structures. Subclasses are converted to dataclasses at
|
||||
object construction time, with.
|
||||
|
||||
Subclasses are dataclasses with methods to support for converting to
|
||||
and from dict, type enforcement, and user-defined validation logic.
|
||||
"""
|
||||
|
||||
_aliases: Dict[str, str] = {} # will be used by subclasses
|
||||
|
||||
# this quiets the mypy "Unexpected keyword argument..." errors on subclass construction
|
||||
# TODO: find a way to propogate type info to subclasses
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
"""Hook called at subclass definition time and at instance construction
|
||||
time.
|
||||
|
||||
Validates that the field to type mapping provided at subclass
|
||||
definition time are supported by BaseModel.
|
||||
"""
|
||||
cls = dataclasses.dataclass(cls)
|
||||
# print(inspect.signature(cls.__init__))
|
||||
for field in dataclasses.fields(cls):
|
||||
cls._recursively_validate_type_is_supported(field.type)
|
||||
|
||||
def to_dict(self, by_alias: bool = False) -> Dict[str, Any]:
|
||||
"""Recursively converts to a dictionary.
|
||||
|
||||
Args:
|
||||
by_alias (bool, optional): Whether to use attribute name to alias field mapping provided by cls._aliases when converting to dictionary. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary representation of the object.
|
||||
"""
|
||||
return convert_object_to_dict(self, by_alias=by_alias)
|
||||
|
||||
def to_json(self, by_alias: bool = False) -> str:
|
||||
"""Recursively converts to a JSON string.
|
||||
|
||||
Args:
|
||||
by_alias (bool, optional): Whether to use attribute name to alias field mapping provided by cls._aliases when converting to JSON. Defaults to False.
|
||||
|
||||
Returns:
|
||||
str: JSON representation of the object.
|
||||
"""
|
||||
return json.dumps(self.to_dict(by_alias=by_alias))
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls,
|
||||
data: Dict[str, Any],
|
||||
by_alias: bool = False) -> BaseModelType:
|
||||
"""Recursively loads object from a dictionary.
|
||||
|
||||
Args:
|
||||
data (Dict[str, Any]): Dictionary representation of the object.
|
||||
by_alias (bool, optional): Whether to use attribute name to alias field mapping provided by cls._aliases when reading in dictionary. Defaults to False.
|
||||
|
||||
Returns:
|
||||
BaseModelType: Subclass of BaseModel.
|
||||
"""
|
||||
return _load_basemodel_helper(cls, data, by_alias=by_alias)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, text: str, by_alias: bool = False) -> 'BaseModel':
|
||||
"""Recursively loads object from a JSON string.
|
||||
|
||||
Args:
|
||||
text (str): JSON representation of the object.
|
||||
by_alias (bool, optional): Whether to use attribute name to alias field mapping provided by cls._aliases when reading in JSON. Defaults to False.
|
||||
|
||||
Returns:
|
||||
BaseModelType: Subclass of BaseModel.
|
||||
"""
|
||||
return _load_basemodel_helper(cls, json.loads(text), by_alias=by_alias)
|
||||
|
||||
@property
|
||||
def types(self) -> Dict[str, type]:
|
||||
"""Dictionary mapping field names to types."""
|
||||
return {field.name: field.type for field in dataclasses.fields(self)}
|
||||
|
||||
@property
|
||||
def fields(self) -> Tuple[dataclasses.Field, ...]:
|
||||
"""The dataclass fields."""
|
||||
return dataclasses.fields(self)
|
||||
|
||||
@classmethod
|
||||
def _recursively_validate_type_is_supported(cls, type_: type) -> None:
|
||||
"""Walks the type definition (generics and subtypes) and checks if it
|
||||
is supported by downstream BaseModel operations.
|
||||
|
||||
Args:
|
||||
type_ (type): Type to check.
|
||||
|
||||
Raises:
|
||||
TypeError: If type is unsupported.
|
||||
"""
|
||||
if isinstance(type_, ForwardRef):
|
||||
return
|
||||
|
||||
if type_ in SUPPORTED_TYPES or _is_basemodel(type_):
|
||||
return
|
||||
|
||||
if _get_origin_py37(type_) not in SUPPORTED_TYPES:
|
||||
raise TypeError(
|
||||
f'Type {type_} is not a supported type fields on child class of {BaseModel.__name__}: {cls.__name__}.'
|
||||
)
|
||||
|
||||
args = _get_args_py37(type_) or [Any, Any]
|
||||
for arg in args:
|
||||
cls._recursively_validate_type_is_supported(arg)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Hook called after object is instantiated from BaseModel.
|
||||
|
||||
Transforms data and validates data using user-defined logic by
|
||||
calling all methods prefixed with `_transform_`, then all
|
||||
methods prefixed with `_validate_`.
|
||||
"""
|
||||
validate_methods = [
|
||||
method for method in dir(self)
|
||||
if method.startswith('_transform_') and
|
||||
callable(getattr(self, method))
|
||||
]
|
||||
for method in validate_methods:
|
||||
getattr(self, method)()
|
||||
|
||||
validate_methods = [
|
||||
method for method in dir(self) if
|
||||
method.startswith('_validate_') and callable(getattr(self, method))
|
||||
]
|
||||
for method in validate_methods:
|
||||
getattr(self, method)()
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Returns a readable representation of the BaseModel subclass."""
|
||||
return base_model_format(self)
|
||||
|
||||
|
||||
def base_model_format(x: BaseModelType) -> str:
|
||||
"""Formats a BaseModel object for improved readability.
|
||||
|
||||
Args:
|
||||
x (BaseModelType): The subclass of BaseModel to format.
|
||||
chars (int, optional): Indentation size. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
str: Readable string representation of the object.
|
||||
"""
|
||||
|
||||
CHARS = 0
|
||||
|
||||
def first_level_indent(string: str, chars: int = 1) -> str:
|
||||
return '\n'.join(' ' * chars + p for p in string.split('\n'))
|
||||
|
||||
def body_level_indent(string: str, chars=4) -> str:
|
||||
a, *b = string.split('\n')
|
||||
return a + '\n' + first_level_indent(
|
||||
'\n'.join(b),
|
||||
chars=chars,
|
||||
) if b else a
|
||||
|
||||
def parts() -> Iterator[str]:
|
||||
if dataclasses.is_dataclass(x):
|
||||
yield type(x).__name__ + '('
|
||||
|
||||
def fields() -> Iterator[str]:
|
||||
for field in dataclasses.fields(x):
|
||||
nindent = CHARS + len(field.name) + 4
|
||||
value = getattr(x, field.name)
|
||||
rep_value = base_model_format(value)
|
||||
yield (' ' * (CHARS + 3) + body_level_indent(
|
||||
f'{field.name}={rep_value}', chars=nindent))
|
||||
|
||||
yield ',\n'.join(fields())
|
||||
yield ' ' * CHARS + ')'
|
||||
else:
|
||||
yield pprint.pformat(x)
|
||||
|
||||
return '\n'.join(parts())
|
||||
|
||||
|
||||
def convert_object_to_dict(obj: BaseModelType,
|
||||
by_alias: bool) -> Dict[str, Any]:
|
||||
"""Recursion helper function for converting a BaseModel and data structures
|
||||
therein to a dictionary. Converts all fields that do not start with an
|
||||
underscore.
|
||||
|
||||
Args:
|
||||
obj (BaseModelType): The object to convert to a dictionary. Initially called with subclass of BaseModel.
|
||||
by_alias (bool): Whether to use the attribute name to alias field mapping provided by cls._aliases when converting to dictionary.
|
||||
|
||||
Raises:
|
||||
ValueError: If a field is missing a required value. In pracice, this should never be raised, but is included to help with debugging.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The dictionary representation of the object.
|
||||
"""
|
||||
signature = inspect.signature(obj.__init__)
|
||||
|
||||
result = {}
|
||||
for attr_name in signature.parameters:
|
||||
if attr_name.startswith('_'):
|
||||
continue
|
||||
|
||||
field_name = attr_name
|
||||
value = getattr(obj, attr_name)
|
||||
param = signature.parameters.get(attr_name, None)
|
||||
|
||||
if by_alias and hasattr(obj, '_aliases'):
|
||||
field_name = obj._aliases.get(attr_name, attr_name)
|
||||
|
||||
if hasattr(value, 'to_dict'):
|
||||
result[field_name] = value.to_dict(by_alias=by_alias)
|
||||
|
||||
elif isinstance(value, list):
|
||||
result[field_name] = [
|
||||
(x.to_dict(by_alias=by_alias) if hasattr(x, 'to_dict') else x)
|
||||
for x in value
|
||||
]
|
||||
elif isinstance(value, dict):
|
||||
result[field_name] = {
|
||||
k:
|
||||
(v.to_dict(by_alias=by_alias) if hasattr(v, 'to_dict') else v)
|
||||
for k, v in value.items()
|
||||
}
|
||||
elif (param is not None):
|
||||
result[
|
||||
field_name] = value if value != param.default else param.default
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Cannot serialize {obj}. No value for {attr_name}.')
|
||||
return result
|
||||
|
||||
|
||||
def _is_basemodel(obj: Any) -> bool:
|
||||
"""Checks if object is a subclass of BaseModel.
|
||||
|
||||
Args:
|
||||
obj (Any): Any object
|
||||
|
||||
Returns:
|
||||
bool: Is a subclass of BaseModel.
|
||||
"""
|
||||
return inspect.isclass(obj) and issubclass(obj, BaseModel)
|
||||
|
||||
|
||||
def _get_origin_py37(type_: Type) -> Optional[Type]:
|
||||
"""typing.get_origin is introduced in Python 3.8, but we need a get_origin
|
||||
that is compatible with 3.7.
|
||||
|
||||
Args:
|
||||
type_ (Type): A type.
|
||||
|
||||
Returns:
|
||||
Type: The origin of `type_`.
|
||||
"""
|
||||
# uses typing for types
|
||||
return type_.__origin__ if hasattr(type_, '__origin__') else None
|
||||
|
||||
|
||||
def _get_args_py37(type_: Type) -> Tuple[Type]:
|
||||
"""typing.get_args is introduced in Python 3.8, but we need a get_args that
|
||||
is compatible with 3.7.
|
||||
|
||||
Args:
|
||||
type_ (Type): A type.
|
||||
|
||||
Returns:
|
||||
Tuple[Type]: The type arguments of `type_`.
|
||||
"""
|
||||
# uses typing for types
|
||||
return type_.__args__ if hasattr(type_, '__args__') else tuple()
|
||||
|
||||
|
||||
def _load_basemodel_helper(type_: Any, data: Any, by_alias: bool) -> Any:
|
||||
"""Helper function for recursively loading a BaseModel.
|
||||
|
||||
Args:
|
||||
type_ (Any): The type of the object to load. Typically an instance of `type`, `BaseModel` or `Any`.
|
||||
data (Any): The data to load.
|
||||
|
||||
Returns:
|
||||
Any: The loaded object.
|
||||
"""
|
||||
if isinstance(type_, str):
|
||||
raise TypeError(
|
||||
'Please do not use built-in collection types as generics (e.g., list[int]) and do not include the import line `from __future__ import annotations`. Please use the corresponding generic from typing (e.g., List[int]).'
|
||||
)
|
||||
|
||||
# catch unsupported types early
|
||||
type_or_generic = _get_origin_py37(type_) or type_
|
||||
if type_or_generic not in SUPPORTED_TYPES and not _is_basemodel(type_):
|
||||
raise TypeError(
|
||||
f'Unsupported type: {type_}. Cannot load data into object.')
|
||||
|
||||
# if don't have any helpful type information, return data as is
|
||||
if type_ is Any:
|
||||
return data
|
||||
|
||||
# if type is NoneType and data is None, return data/None
|
||||
if type_ is type(None):
|
||||
if data is None:
|
||||
return data
|
||||
else:
|
||||
raise TypeError(
|
||||
f'Expected value None for type NoneType. Got: {data}')
|
||||
|
||||
# handle primitives, with typecasting
|
||||
if type_ in PRIMITIVE_TYPES:
|
||||
return type_(data)
|
||||
|
||||
# simple types are handled, now handle for container types
|
||||
origin = _get_origin_py37(type_)
|
||||
args = _get_args_py37(type_) or [
|
||||
Any, Any
|
||||
] # if there is an inner type in the generic, use it, else use Any
|
||||
# recursively load iterable objects
|
||||
if origin in ITERABLE_TYPES:
|
||||
for arg in args: # TODO handle errors
|
||||
return [
|
||||
_load_basemodel_helper(arg, element, by_alias=by_alias)
|
||||
for element in data
|
||||
]
|
||||
|
||||
# recursively load mapping objects
|
||||
if origin in MAPPING_TYPES:
|
||||
if len(args) != 2:
|
||||
raise TypeError(
|
||||
f'Expected exactly 2 type arguments for mapping type {type_}.')
|
||||
return {
|
||||
_load_basemodel_helper(args[0], k, by_alias=by_alias):
|
||||
_load_basemodel_helper(
|
||||
args[1], # type: ignore
|
||||
# length check a few lines up ensures index 1 exists
|
||||
v,
|
||||
by_alias=by_alias) for k, v in data.items()
|
||||
}
|
||||
|
||||
# if the type is a Union, try to load the data into each of the types,
|
||||
# greedily accepting the first annotation arg that works --> the developer
|
||||
# can indicate which types are preferred based on the annotation arg order
|
||||
if origin in UNION_TYPES:
|
||||
# don't try to cast none if the union type is optional
|
||||
if type(None) in args and data is None:
|
||||
return None
|
||||
for arg in args:
|
||||
return _load_basemodel_helper(args[0], data, by_alias=by_alias)
|
||||
|
||||
# finally, handle the cases where the type is an instance of baseclass
|
||||
if _is_basemodel(type_):
|
||||
fields = dataclasses.fields(type_)
|
||||
basemodel_kwargs = {}
|
||||
for field in fields:
|
||||
attr_name = field.name
|
||||
data_field_name = attr_name
|
||||
if by_alias and hasattr(type_, '_aliases'):
|
||||
data_field_name = type_._aliases.get(attr_name, attr_name)
|
||||
value = data.get(data_field_name)
|
||||
if value is None:
|
||||
if field.default is dataclasses.MISSING and field.default_factory is dataclasses.MISSING:
|
||||
raise ValueError(
|
||||
f'Missing required field: {data_field_name}')
|
||||
value = field.default if field.default is not dataclasses.MISSING else field.default_factory(
|
||||
)
|
||||
else:
|
||||
value = _load_basemodel_helper(
|
||||
field.type, value, by_alias=by_alias)
|
||||
basemodel_kwargs[attr_name] = value
|
||||
return type_(**basemodel_kwargs)
|
||||
|
||||
raise TypeError(
|
||||
f'Unknown error when loading data: {data} into type {type_}')
|
|
@ -1,461 +0,0 @@
|
|||
# Copyright 2022 The Kubeflow Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
from typing import (Any, Dict, List, Mapping, MutableMapping, MutableSequence,
|
||||
Optional, OrderedDict, Sequence, Set, Tuple, Union)
|
||||
import unittest
|
||||
|
||||
from absl.testing import parameterized
|
||||
from kfp.components import base_model
|
||||
|
||||
|
||||
class TypeClass(base_model.BaseModel):
|
||||
a: str
|
||||
b: List[int]
|
||||
c: Dict[str, int]
|
||||
d: Union[int, str]
|
||||
e: Union[int, str, bool]
|
||||
f: Optional[int]
|
||||
|
||||
|
||||
class TestBaseModel(unittest.TestCase):
|
||||
|
||||
def test_is_dataclass(self):
|
||||
|
||||
class Child(base_model.BaseModel):
|
||||
x: int
|
||||
|
||||
child = Child(x=1)
|
||||
self.assertTrue(dataclasses.is_dataclass(child))
|
||||
|
||||
def test_to_dict_simple(self):
|
||||
|
||||
class Child(base_model.BaseModel):
|
||||
i: int
|
||||
s: str
|
||||
f: float
|
||||
l: List[int]
|
||||
|
||||
data = {'i': 1, 's': 's', 'f': 1.0, 'l': [1, 2]}
|
||||
child = Child(**data)
|
||||
actual = child.to_dict()
|
||||
self.assertEqual(actual, data)
|
||||
self.assertEqual(child, Child.from_dict(actual))
|
||||
|
||||
def test_to_dict_nested(self):
|
||||
|
||||
class InnerChild(base_model.BaseModel):
|
||||
a: str
|
||||
|
||||
class Child(base_model.BaseModel):
|
||||
b: int
|
||||
c: InnerChild
|
||||
|
||||
data = {'b': 1, 'c': InnerChild(a='a')}
|
||||
child = Child(**data)
|
||||
actual = child.to_dict()
|
||||
expected = {'b': 1, 'c': {'a': 'a'}}
|
||||
self.assertEqual(actual, expected)
|
||||
self.assertEqual(child, Child.from_dict(actual))
|
||||
|
||||
def test_from_dict_no_defaults(self):
|
||||
|
||||
class Child(base_model.BaseModel):
|
||||
i: int
|
||||
s: str
|
||||
f: float
|
||||
l: List[int]
|
||||
|
||||
data = {'i': 1, 's': 's', 'f': 1.0, 'l': [1, 2]}
|
||||
child = Child.from_dict(data)
|
||||
self.assertEqual(child.i, 1)
|
||||
self.assertEqual(child.s, 's')
|
||||
self.assertEqual(child.f, 1.0)
|
||||
self.assertEqual(child.l, [1, 2])
|
||||
self.assertEqual(child.to_dict(), data)
|
||||
|
||||
def test_from_dict_with_defaults(self):
|
||||
|
||||
class Child(base_model.BaseModel):
|
||||
s: str
|
||||
f: float
|
||||
l: List[int]
|
||||
i: int = 1
|
||||
|
||||
data = {'s': 's', 'f': 1.0, 'l': [1, 2]}
|
||||
child = Child.from_dict(data)
|
||||
self.assertEqual(child.i, 1)
|
||||
self.assertEqual(child.s, 's')
|
||||
self.assertEqual(child.f, 1.0)
|
||||
self.assertEqual(child.l, [1, 2])
|
||||
self.assertEqual(child.to_dict(), {**data, **{'i': 1}})
|
||||
|
||||
def test_from_dict_nested(self):
|
||||
|
||||
class InnerChild(base_model.BaseModel):
|
||||
a: str
|
||||
|
||||
class Child(base_model.BaseModel):
|
||||
b: int
|
||||
c: InnerChild
|
||||
|
||||
data = {'b': 1, 'c': {'a': 'a'}}
|
||||
|
||||
child = Child.from_dict(data)
|
||||
self.assertEqual(child.b, 1)
|
||||
self.assertIsInstance(child.c, InnerChild)
|
||||
self.assertEqual(child.c.a, 'a')
|
||||
self.assertEqual(child.to_dict(), data)
|
||||
|
||||
def test_from_dict_array_nested(self):
|
||||
|
||||
class InnerChild(base_model.BaseModel):
|
||||
a: str
|
||||
|
||||
class Child(base_model.BaseModel):
|
||||
b: int
|
||||
c: List[InnerChild]
|
||||
d: Dict[str, InnerChild]
|
||||
|
||||
inner_child_data = {'a': 'a'}
|
||||
data = {
|
||||
'b': 1,
|
||||
'c': [inner_child_data, inner_child_data],
|
||||
'd': {
|
||||
'e': inner_child_data
|
||||
}
|
||||
}
|
||||
|
||||
child = Child.from_dict(data)
|
||||
self.assertEqual(child.b, 1)
|
||||
self.assertIsInstance(child.c[0], InnerChild)
|
||||
self.assertIsInstance(child.c[1], InnerChild)
|
||||
self.assertIsInstance(child.d['e'], InnerChild)
|
||||
self.assertEqual(child.c[0].a, 'a')
|
||||
self.assertEqual(child.to_dict(), data)
|
||||
|
||||
def test_from_dict_by_alias(self):
|
||||
|
||||
class InnerChild(base_model.BaseModel):
|
||||
inner_child_field: int
|
||||
_aliases = {'inner_child_field': 'inner_child_field_alias'}
|
||||
|
||||
class Child(base_model.BaseModel):
|
||||
sub_field: InnerChild
|
||||
_aliases = {'sub_field': 'sub_field_alias'}
|
||||
|
||||
data = {'sub_field_alias': {'inner_child_field_alias': 2}}
|
||||
|
||||
child = Child.from_dict(data, by_alias=True)
|
||||
self.assertIsInstance(child.sub_field, InnerChild)
|
||||
self.assertEqual(child.sub_field.inner_child_field, 2)
|
||||
self.assertEqual(child.to_dict(by_alias=True), data)
|
||||
|
||||
def test_to_dict_by_alias(self):
|
||||
|
||||
class InnerChild(base_model.BaseModel):
|
||||
inner_child_field: int
|
||||
_aliases = {'inner_child_field': 'inner_child_field_alias'}
|
||||
|
||||
class Child(base_model.BaseModel):
|
||||
sub_field: InnerChild
|
||||
_aliases = {'sub_field': 'sub_field_alias'}
|
||||
|
||||
inner_child = InnerChild(inner_child_field=2)
|
||||
child = Child(sub_field=inner_child)
|
||||
actual = child.to_dict(by_alias=True)
|
||||
expected = {'sub_field_alias': {'inner_child_field_alias': 2}}
|
||||
self.assertEqual(actual, expected)
|
||||
self.assertEqual(Child.from_dict(actual, by_alias=True), child)
|
||||
|
||||
def test_to_dict_by_alias2(self):
|
||||
|
||||
class MyClass(base_model.BaseModel):
|
||||
x: int
|
||||
y: List[int]
|
||||
z: Dict[str, int]
|
||||
_aliases = {'x': 'a', 'z': 'b'}
|
||||
|
||||
res = MyClass(x=1, y=[2], z={'key': 3}).to_dict(by_alias=True)
|
||||
self.assertEqual(res, {'a': 1, 'y': [2], 'b': {'key': 3}})
|
||||
|
||||
def test_to_dict_by_alias_nested(self):
|
||||
|
||||
class InnerClass(base_model.BaseModel):
|
||||
f: float
|
||||
_aliases = {'f': 'a'}
|
||||
|
||||
class MyClass(base_model.BaseModel):
|
||||
x: int
|
||||
y: List[int]
|
||||
z: InnerClass
|
||||
_aliases = {'x': 'a', 'z': 'b'}
|
||||
|
||||
res = MyClass(x=1, y=[2], z=InnerClass(f=1.0)).to_dict(by_alias=True)
|
||||
self.assertEqual(res, {'a': 1, 'y': [2], 'b': {'a': 1.0}})
|
||||
|
||||
def test_can_create_properties_using_attributes(self):
|
||||
|
||||
class Child(base_model.BaseModel):
|
||||
x: Optional[int]
|
||||
|
||||
@property
|
||||
def prop(self) -> bool:
|
||||
return self.x is not None
|
||||
|
||||
child1 = Child(x=None)
|
||||
self.assertEqual(child1.prop, False)
|
||||
|
||||
child2 = Child(x=1)
|
||||
self.assertEqual(child2.prop, True)
|
||||
|
||||
def test_unsupported_type_success(self):
|
||||
|
||||
class OtherClass(base_model.BaseModel):
|
||||
x: int
|
||||
|
||||
class MyClass(base_model.BaseModel):
|
||||
a: OtherClass
|
||||
|
||||
def test_unsupported_type_failures(self):
|
||||
|
||||
with self.assertRaisesRegex(TypeError, r'not a supported'):
|
||||
|
||||
class MyClass(base_model.BaseModel):
|
||||
a: tuple
|
||||
|
||||
with self.assertRaisesRegex(TypeError, r'not a supported'):
|
||||
|
||||
class MyClass(base_model.BaseModel):
|
||||
a: Tuple
|
||||
|
||||
with self.assertRaisesRegex(TypeError, r'not a supported'):
|
||||
|
||||
class MyClass(base_model.BaseModel):
|
||||
a: Set
|
||||
|
||||
with self.assertRaisesRegex(TypeError, r'not a supported'):
|
||||
|
||||
class OtherClass:
|
||||
pass
|
||||
|
||||
class MyClass(base_model.BaseModel):
|
||||
a: OtherClass
|
||||
|
||||
def test_base_model_validation(self):
|
||||
|
||||
# test exception thrown
|
||||
class MyClass(base_model.BaseModel):
|
||||
x: int
|
||||
|
||||
def _validate_x(self) -> None:
|
||||
if self.x < 2:
|
||||
raise ValueError('x must be greater than 2')
|
||||
|
||||
with self.assertRaisesRegex(ValueError, 'x must be greater than 2'):
|
||||
mc = MyClass(x=1)
|
||||
|
||||
# test value modified same type
|
||||
class MyClass(base_model.BaseModel):
|
||||
x: int
|
||||
|
||||
def _validate_x(self) -> None:
|
||||
self.x = max(self.x, 2)
|
||||
|
||||
mc = MyClass(x=1)
|
||||
self.assertEqual(mc.x, 2)
|
||||
|
||||
# test value modified new type
|
||||
class MyClass(base_model.BaseModel):
|
||||
x: Optional[List[int]] = None
|
||||
|
||||
def _validate_x(self) -> None:
|
||||
if isinstance(self.x, list) and not self.x:
|
||||
self.x = None
|
||||
|
||||
mc = MyClass(x=[])
|
||||
self.assertEqual(mc.x, None)
|
||||
|
||||
def test_can_set_field(self):
|
||||
|
||||
class MyClass(base_model.BaseModel):
|
||||
x: int
|
||||
|
||||
mc = MyClass(x=2)
|
||||
mc.x = 1
|
||||
self.assertEqual(mc.x, 1)
|
||||
|
||||
def test_can_use_default_factory(self):
|
||||
|
||||
class MyClass(base_model.BaseModel):
|
||||
x: List[int] = dataclasses.field(default_factory=list)
|
||||
|
||||
mc = MyClass()
|
||||
self.assertEqual(mc.x, [])
|
||||
|
||||
|
||||
class TestIsBaseModel(unittest.TestCase):
|
||||
|
||||
def test_true(self):
|
||||
self.assertEqual(base_model._is_basemodel(base_model.BaseModel), True)
|
||||
|
||||
class MyClass(base_model.BaseModel):
|
||||
pass
|
||||
|
||||
self.assertEqual(base_model._is_basemodel(MyClass), True)
|
||||
|
||||
def test_false(self):
|
||||
self.assertEqual(base_model._is_basemodel(int), False)
|
||||
self.assertEqual(base_model._is_basemodel(1), False)
|
||||
self.assertEqual(base_model._is_basemodel(str), False)
|
||||
|
||||
|
||||
class TestLoadBaseModelHelper(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.no_alias_load_base_model_helper = functools.partial(
|
||||
base_model._load_basemodel_helper, by_alias=False)
|
||||
|
||||
def test_load_primitive(self):
|
||||
self.assertEqual(self.no_alias_load_base_model_helper(str, 'a'), 'a')
|
||||
self.assertEqual(self.no_alias_load_base_model_helper(int, 1), 1)
|
||||
self.assertEqual(self.no_alias_load_base_model_helper(float, 1.0), 1.0)
|
||||
self.assertEqual(self.no_alias_load_base_model_helper(bool, True), True)
|
||||
self.assertEqual(
|
||||
self.no_alias_load_base_model_helper(type(None), None), None)
|
||||
|
||||
def test_load_primitive_with_casting(self):
|
||||
self.assertEqual(self.no_alias_load_base_model_helper(int, '1'), 1)
|
||||
self.assertEqual(self.no_alias_load_base_model_helper(str, 1), '1')
|
||||
self.assertEqual(self.no_alias_load_base_model_helper(float, 1), 1.0)
|
||||
self.assertEqual(self.no_alias_load_base_model_helper(int, 1.0), 1)
|
||||
self.assertEqual(self.no_alias_load_base_model_helper(bool, 1), True)
|
||||
self.assertEqual(self.no_alias_load_base_model_helper(bool, 0), False)
|
||||
self.assertEqual(self.no_alias_load_base_model_helper(int, True), 1)
|
||||
self.assertEqual(self.no_alias_load_base_model_helper(int, False), 0)
|
||||
self.assertEqual(
|
||||
self.no_alias_load_base_model_helper(bool, None), False)
|
||||
|
||||
def test_load_none(self):
|
||||
self.assertEqual(
|
||||
self.no_alias_load_base_model_helper(type(None), None), None)
|
||||
with self.assertRaisesRegex(TypeError, ''):
|
||||
self.no_alias_load_base_model_helper(type(None), 1)
|
||||
|
||||
@parameterized.parameters(['a', 1, 1.0, True, False, None, ['list']])
|
||||
def test_load_any(self, data: Any):
|
||||
self.assertEqual(self.no_alias_load_base_model_helper(Any, data),
|
||||
data) # type: ignore
|
||||
|
||||
def test_load_list(self):
|
||||
self.assertEqual(
|
||||
self.no_alias_load_base_model_helper(List[str], ['a']), ['a'])
|
||||
self.assertEqual(
|
||||
self.no_alias_load_base_model_helper(List[int], [1, 1]), [1, 1])
|
||||
self.assertEqual(
|
||||
self.no_alias_load_base_model_helper(List[float], [1.0]), [1.0])
|
||||
self.assertEqual(
|
||||
self.no_alias_load_base_model_helper(List[bool], [True]), [True])
|
||||
self.assertEqual(
|
||||
self.no_alias_load_base_model_helper(List[type(None)], [None]),
|
||||
[None])
|
||||
|
||||
def test_load_primitive_other_iterables(self):
|
||||
self.assertEqual(
|
||||
self.no_alias_load_base_model_helper(Sequence[bool], [True]),
|
||||
[True])
|
||||
self.assertEqual(
|
||||
self.no_alias_load_base_model_helper(MutableSequence[type(None)],
|
||||
[None]), [None])
|
||||
self.assertEqual(
|
||||
self.no_alias_load_base_model_helper(Sequence[str], ['a']), ['a'])
|
||||
|
||||
def test_load_dict(self):
|
||||
self.assertEqual(
|
||||
self.no_alias_load_base_model_helper(Dict[str, str], {'a': 'a'}),
|
||||
{'a': 'a'})
|
||||
self.assertEqual(
|
||||
self.no_alias_load_base_model_helper(Dict[str, int], {'a': 1}),
|
||||
{'a': 1})
|
||||
self.assertEqual(
|
||||
self.no_alias_load_base_model_helper(Dict[str, float], {'a': 1.0}),
|
||||
{'a': 1.0})
|
||||
self.assertEqual(
|
||||
self.no_alias_load_base_model_helper(Dict[str, bool], {'a': True}),
|
||||
{'a': True})
|
||||
self.assertEqual(
|
||||
self.no_alias_load_base_model_helper(Dict[str, type(None)],
|
||||
{'a': None}), {'a': None})
|
||||
|
||||
def test_load_mapping(self):
|
||||
self.assertEqual(
|
||||
self.no_alias_load_base_model_helper(Mapping[str, float],
|
||||
{'a': 1.0}), {'a': 1.0})
|
||||
self.assertEqual(
|
||||
self.no_alias_load_base_model_helper(MutableMapping[str, bool],
|
||||
{'a': True}), {'a': True})
|
||||
self.assertEqual(
|
||||
self.no_alias_load_base_model_helper(OrderedDict[str,
|
||||
type(None)],
|
||||
{'a': None}), {'a': None})
|
||||
|
||||
def test_load_union_types(self):
|
||||
self.assertEqual(
|
||||
self.no_alias_load_base_model_helper(Union[str, int], 'a'), 'a')
|
||||
self.assertEqual(
|
||||
self.no_alias_load_base_model_helper(Union[str, int], 1), '1')
|
||||
self.assertEqual(
|
||||
self.no_alias_load_base_model_helper(Union[int, str], 1), 1)
|
||||
self.assertEqual(
|
||||
self.no_alias_load_base_model_helper(Union[int, str], '1'), 1)
|
||||
|
||||
def test_load_optional_types(self):
|
||||
self.assertEqual(
|
||||
self.no_alias_load_base_model_helper(Optional[str], 'a'), 'a')
|
||||
self.assertEqual(
|
||||
self.no_alias_load_base_model_helper(Optional[str], None), None)
|
||||
|
||||
def test_unsupported_type(self):
|
||||
with self.assertRaisesRegex(TypeError, r'Unsupported type:'):
|
||||
self.no_alias_load_base_model_helper(Set[int], {1})
|
||||
|
||||
|
||||
class TestGetOriginPy37(parameterized.TestCase):
|
||||
|
||||
def test_is_same_as_typing_version(self):
|
||||
import sys
|
||||
if sys.version_info.major == 3 and sys.version_info.minor >= 8:
|
||||
import typing
|
||||
for field in dataclasses.fields(TypeClass):
|
||||
self.assertEqual(
|
||||
base_model._get_origin_py37(field.type),
|
||||
typing.get_origin(field.type))
|
||||
|
||||
|
||||
class TestGetArgsPy37(parameterized.TestCase):
|
||||
|
||||
def test_is_same_as_typing_version(self):
|
||||
import sys
|
||||
if sys.version_info.major == 3 and sys.version_info.minor >= 8:
|
||||
import typing
|
||||
for field in dataclasses.fields(TypeClass):
|
||||
self.assertEqual(
|
||||
base_model._get_args_py37(field.type),
|
||||
typing.get_args(field.type))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -18,12 +18,11 @@ import abc
|
|||
import json
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from kfp.components import base_model
|
||||
from kfp.components import utils
|
||||
from kfp.components.types import type_utils
|
||||
|
||||
|
||||
class Placeholder(abc.ABC, base_model.BaseModel):
|
||||
class Placeholder(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def _to_string(self) -> str:
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
import ast
|
||||
import collections
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools
|
||||
import re
|
||||
|
@ -23,7 +24,6 @@ import uuid
|
|||
|
||||
from google.protobuf import json_format
|
||||
import kfp
|
||||
from kfp.components import base_model
|
||||
from kfp.components import placeholders
|
||||
from kfp.components import utils
|
||||
from kfp.components import v1_components
|
||||
|
@ -37,7 +37,8 @@ from kfp.pipeline_spec import pipeline_spec_pb2
|
|||
import yaml
|
||||
|
||||
|
||||
class InputSpec_(base_model.BaseModel):
|
||||
@dataclasses.dataclass
|
||||
class InputSpec_:
|
||||
"""Component input definitions. (Inner class).
|
||||
|
||||
Attributes:
|
||||
|
@ -50,7 +51,7 @@ class InputSpec_(base_model.BaseModel):
|
|||
|
||||
|
||||
# Hack to allow access to __init__ arguments for setting _optional value
|
||||
class InputSpec(InputSpec_, base_model.BaseModel):
|
||||
class InputSpec(InputSpec_):
|
||||
"""Component input definitions.
|
||||
|
||||
Attributes:
|
||||
|
@ -68,6 +69,9 @@ class InputSpec(InputSpec_, base_model.BaseModel):
|
|||
super().__init__(*args, **kwargs)
|
||||
self._optional = 'default' in kwargs
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._validate_type()
|
||||
|
||||
@classmethod
|
||||
def from_ir_component_inputs_dict(
|
||||
cls, ir_component_inputs_dict: Dict[str, Any]) -> 'InputSpec':
|
||||
|
@ -132,7 +136,8 @@ class InputSpec(InputSpec_, base_model.BaseModel):
|
|||
type_utils.validate_bundled_artifact_type(self.type)
|
||||
|
||||
|
||||
class OutputSpec(base_model.BaseModel):
|
||||
@dataclasses.dataclass
|
||||
class OutputSpec:
|
||||
"""Component output definitions.
|
||||
|
||||
Attributes:
|
||||
|
@ -140,6 +145,9 @@ class OutputSpec(base_model.BaseModel):
|
|||
"""
|
||||
type: Union[str, dict]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._validate_type()
|
||||
|
||||
@classmethod
|
||||
def from_ir_component_outputs_dict(
|
||||
cls, ir_component_outputs_dict: Dict[str, Any]) -> 'OutputSpec':
|
||||
|
@ -206,7 +214,8 @@ def spec_type_is_parameter(type_: str) -> bool:
|
|||
return in_memory_type in type_utils.IN_MEMORY_SPEC_TYPE_TO_IR_TYPE or in_memory_type == 'PipelineTaskFinalStatus'
|
||||
|
||||
|
||||
class ResourceSpec(base_model.BaseModel):
|
||||
@dataclasses.dataclass
|
||||
class ResourceSpec:
|
||||
"""The resource requirements of a container execution.
|
||||
|
||||
Attributes:
|
||||
|
@ -222,7 +231,8 @@ class ResourceSpec(base_model.BaseModel):
|
|||
accelerator_count: Optional[int] = None
|
||||
|
||||
|
||||
class ContainerSpec(base_model.BaseModel):
|
||||
@dataclasses.dataclass
|
||||
class ContainerSpec:
|
||||
"""Container definition.
|
||||
|
||||
This is only used for pipeline authors when constructing a containerized component
|
||||
|
@ -259,7 +269,8 @@ class ContainerSpec(base_model.BaseModel):
|
|||
"""Arguments to the container entrypoint."""
|
||||
|
||||
|
||||
class ContainerSpecImplementation(base_model.BaseModel):
|
||||
@dataclasses.dataclass
|
||||
class ContainerSpecImplementation:
|
||||
"""Container implementation definition."""
|
||||
image: str
|
||||
"""Container image."""
|
||||
|
@ -276,6 +287,11 @@ class ContainerSpecImplementation(base_model.BaseModel):
|
|||
resources: Optional[ResourceSpec] = None
|
||||
"""Specification on the resource requirements."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._transform_command()
|
||||
self._transform_args()
|
||||
self._transform_env()
|
||||
|
||||
def _transform_command(self) -> None:
|
||||
"""Use None instead of empty list for command."""
|
||||
self.command = None if self.command == [] else self.command
|
||||
|
@ -322,7 +338,8 @@ class ContainerSpecImplementation(base_model.BaseModel):
|
|||
resources=None) # can only be set on tasks
|
||||
|
||||
|
||||
class RetryPolicy(base_model.BaseModel):
|
||||
@dataclasses.dataclass
|
||||
class RetryPolicy:
|
||||
"""The retry policy of a container execution.
|
||||
|
||||
Attributes:
|
||||
|
@ -355,16 +372,9 @@ class RetryPolicy(base_model.BaseModel):
|
|||
'backoff_max_duration': backoff_max_duration_seconds,
|
||||
}, pipeline_spec_pb2.PipelineTaskSpec.RetryPolicy())
|
||||
|
||||
@classmethod
|
||||
def from_proto(
|
||||
cls, retry_policy_proto: pipeline_spec_pb2.PipelineTaskSpec.RetryPolicy
|
||||
) -> 'RetryPolicy':
|
||||
return cls.from_dict(
|
||||
json_format.MessageToDict(
|
||||
retry_policy_proto, preserving_proto_field_name=True))
|
||||
|
||||
|
||||
class TaskSpec(base_model.BaseModel):
|
||||
@dataclasses.dataclass
|
||||
class TaskSpec:
|
||||
"""The spec of a pipeline task.
|
||||
|
||||
Attributes:
|
||||
|
@ -399,7 +409,8 @@ class TaskSpec(base_model.BaseModel):
|
|||
retry_policy: Optional[RetryPolicy] = None
|
||||
|
||||
|
||||
class ImporterSpec(base_model.BaseModel):
|
||||
@dataclasses.dataclass
|
||||
class ImporterSpec:
|
||||
"""ImporterSpec definition.
|
||||
|
||||
Attributes:
|
||||
|
@ -417,7 +428,8 @@ class ImporterSpec(base_model.BaseModel):
|
|||
metadata: Optional[Mapping[str, Any]] = None
|
||||
|
||||
|
||||
class Implementation(base_model.BaseModel):
|
||||
@dataclasses.dataclass
|
||||
class Implementation:
|
||||
"""Implementation definition.
|
||||
|
||||
Attributes:
|
||||
|
@ -506,7 +518,8 @@ def check_placeholder_references_valid_io_name(
|
|||
raise TypeError(f'Unexpected argument "{arg}" of type {type(arg)}.')
|
||||
|
||||
|
||||
class ComponentSpec(base_model.BaseModel):
|
||||
@dataclasses.dataclass
|
||||
class ComponentSpec:
|
||||
"""The definition of a component.
|
||||
|
||||
Attributes:
|
||||
|
@ -523,6 +536,12 @@ class ComponentSpec(base_model.BaseModel):
|
|||
inputs: Optional[Dict[str, InputSpec]] = None
|
||||
outputs: Optional[Dict[str, OutputSpec]] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._transform_name()
|
||||
self._transform_inputs()
|
||||
self._transform_outputs()
|
||||
self._validate_placeholders()
|
||||
|
||||
def _transform_name(self) -> None:
|
||||
"""Converts the name to a valid name."""
|
||||
self.name = utils.maybe_rename_for_k8s(self.name)
|
||||
|
|
|
@ -19,13 +19,11 @@ import textwrap
|
|||
import unittest
|
||||
|
||||
from absl.testing import parameterized
|
||||
from google.protobuf import json_format
|
||||
from kfp import compiler
|
||||
from kfp import dsl
|
||||
from kfp.components import component_factory
|
||||
from kfp.components import placeholders
|
||||
from kfp.components import structures
|
||||
from kfp.pipeline_spec import pipeline_spec_pb2
|
||||
|
||||
V1_YAML_IF_PLACEHOLDER = textwrap.dedent("""\
|
||||
implementation:
|
||||
|
@ -939,23 +937,6 @@ class TestRetryPolicy(unittest.TestCase):
|
|||
# tests cap
|
||||
self.assertEqual(retry_policy_proto.backoff_max_duration.seconds, 3600)
|
||||
|
||||
def test_from_proto(self):
|
||||
retry_policy_proto = json_format.ParseDict(
|
||||
{
|
||||
'max_retry_count': 3,
|
||||
'backoff_duration': '5s',
|
||||
'backoff_factor': 1.0,
|
||||
'backoff_max_duration': '1s'
|
||||
}, pipeline_spec_pb2.PipelineTaskSpec.RetryPolicy())
|
||||
retry_policy_struct = structures.RetryPolicy.from_proto(
|
||||
retry_policy_proto)
|
||||
print(retry_policy_struct)
|
||||
|
||||
self.assertEqual(retry_policy_struct.max_retry_count, 3)
|
||||
self.assertEqual(retry_policy_struct.backoff_duration, '5s')
|
||||
self.assertEqual(retry_policy_struct.backoff_factor, 1.0)
|
||||
self.assertEqual(retry_policy_struct.backoff_max_duration, '1s')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue