feat(sdk): use custom basemodel and remove pydantic (#7639)

* fix discovered bug

* update tests

* implement custom BaseModel

* use basemodel for structures

* remove pydantic dependency

* assorted code cleanup
This commit is contained in:
Connor McCarthy 2022-05-04 12:56:32 -06:00 committed by GitHub
parent 913277a3aa
commit 5da3826bb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1338 additions and 434 deletions

View File

@ -601,7 +601,7 @@ class V2NamespaceAliasTest(unittest.TestCase):
v2.compiler.Compiler().compile(
pipeline_func=pipeline_hello_world, package_path=temp_filepath)
with open(temp_filepath, "r") as f:
with open(temp_filepath, 'r') as f:
yaml.load(f)
def test_import_modules(self): # pylint: disable=no-self-use
@ -625,7 +625,7 @@ class V2NamespaceAliasTest(unittest.TestCase):
compiler.Compiler().compile(
pipeline_func=pipeline_hello_world, package_path=temp_filepath)
with open(temp_filepath, "r") as f:
with open(temp_filepath, 'r') as f:
yaml.load(f)
def test_import_object(self): # pylint: disable=no-self-use
@ -650,7 +650,7 @@ class V2NamespaceAliasTest(unittest.TestCase):
Compiler().compile(
pipeline_func=pipeline_hello_world, package_path=temp_filepath)
with open(temp_filepath, "r") as f:
with open(temp_filepath, 'r') as f:
yaml.load(f)
@ -670,8 +670,8 @@ class TestWriteToFileTypes(parameterized.TestCase):
return my_pipeline
@parameterized.parameters(
{"extension": ".yaml"},
{"extension": ".yml"},
{'extension': '.yaml'},
{'extension': '.yml'},
)
def test_can_write_to_yaml(self, extension):
@ -701,7 +701,7 @@ class TestWriteToFileTypes(parameterized.TestCase):
target_file = os.path.join(tmpdir, 'result.json')
with self.assertWarnsRegex(DeprecationWarning,
r"Compiling to JSON is deprecated"):
r'Compiling to JSON is deprecated'):
compiler.Compiler().compile(
pipeline_func=pipeline_spec, package_path=target_file)
with open(target_file) as f:
@ -735,7 +735,7 @@ class TestWriteToFileTypes(parameterized.TestCase):
inputs:
- {name: location, type: String, default: 'us-central1'}
- {name: name, type: Integer, default: 1}
- {name: noDefault, type: String}
- {name: nodefault, type: String}
implementation:
container:
image: gcr.io/my-project/my-image:tag
@ -745,7 +745,7 @@ class TestWriteToFileTypes(parameterized.TestCase):
@dsl.pipeline(name='test-pipeline')
def simple_pipeline():
producer = producer_op(location="1")
producer = producer_op(location='1', nodefault='string')
target_json_file = os.path.join(tmpdir, 'result.json')
compiler.Compiler().compile(

View File

@ -40,11 +40,11 @@ class BaseComponent(metaclass=abc.ABCMeta):
# Arguments typed as PipelineTaskFinalStatus are special arguments that
# do not count as user inputs. Instead, they are reserved to for the
# (backend) system to pass a value.
self._component_inputs = set([
self._component_inputs = {
input_name for input_name, input_spec in (
self.component_spec.inputs or {}).items()
if not type_utils.is_task_final_status_type(input_spec.type)
])
}
def __call__(self, *args, **kwargs) -> pipeline_task.PipelineTask:
"""Creates a PipelineTask object.
@ -74,7 +74,7 @@ class BaseComponent(metaclass=abc.ABCMeta):
missing_arguments = [
input_name for input_name, input_spec in (
self.component_spec.inputs or {}).items()
if input_name not in task_inputs and not input_spec.optional and
if input_name not in task_inputs and not input_spec._optional and
not type_utils.is_task_final_status_type(input_spec.type)
]
if missing_arguments:

View File

@ -1,4 +1,4 @@
# Copyright 2021 The Kubeflow Authors
# Copyright 2021-2022 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -17,8 +17,8 @@ import unittest
from unittest.mock import patch
from kfp.components import base_component
from kfp.components import structures
from kfp.components import pipeline_task
from kfp.components import structures
class TestComponent(base_component.BaseComponent):
@ -30,18 +30,19 @@ class TestComponent(base_component.BaseComponent):
component_op = TestComponent(
component_spec=structures.ComponentSpec(
name='component_1',
implementation=structures.ContainerSpec(
image='alpine',
command=[
'sh',
'-c',
'set -ex\necho "$0" "$1" "$2" > "$3"',
structures.InputValuePlaceholder(input_name='input1'),
structures.InputValuePlaceholder(input_name='input2'),
structures.InputValuePlaceholder(input_name='input3'),
structures.OutputPathPlaceholder(output_name='output1'),
],
),
implementation=structures.Implementation(
container=structures.ContainerSpec(
image='alpine',
command=[
'sh',
'-c',
'set -ex\necho "$0" "$1" "$2" > "$3"',
structures.InputValuePlaceholder(input_name='input1'),
structures.InputValuePlaceholder(input_name='input2'),
structures.InputValuePlaceholder(input_name='input3'),
structures.OutputPathPlaceholder(output_name='output1'),
],
)),
inputs={
'input1':
structures.InputSpec(type='String'),
@ -53,7 +54,7 @@ component_op = TestComponent(
structures.InputSpec(type='Optional[Float]', default=None),
},
outputs={
'output1': structures.OutputSpec(name='output1', type='String'),
'output1': structures.OutputSpec(type='String'),
},
))
@ -92,25 +93,26 @@ class BaseComponentTest(unittest.TestCase):
with self.assertRaisesRegex(
TypeError,
'Components must be instantiated using keyword arguments.'
' Positional parameters are not allowed \(found 3 such'
' parameters for component "component_1"\).'):
r' Positional parameters are not allowed \(found 3 such'
r' parameters for component "component_1"\).'):
component_op('abc', 1, 2.3)
def test_instantiate_component_with_unexpected_keyword_arugment(self):
with self.assertRaisesRegex(
TypeError,
'component_1\(\) got an unexpected keyword argument "input0".'):
r'component_1\(\) got an unexpected keyword argument "input0".'
):
component_op(input1='abc', input2=1, input3=2.3, input0='extra')
def test_instantiate_component_with_missing_arugments(self):
with self.assertRaisesRegex(
TypeError,
'component_1\(\) missing 1 required argument: input1.'):
r'component_1\(\) missing 1 required argument: input1.'):
component_op(input2=1)
with self.assertRaisesRegex(
TypeError,
'component_1\(\) missing 2 required arguments: input1, input2.'
r'component_1\(\) missing 2 required arguments: input1, input2.'
):
component_op()

View File

@ -0,0 +1,424 @@
# Copyright 2022 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import dataclasses
import inspect
import json
import pprint
from collections import abc
from typing import (Any, Dict, ForwardRef, Iterable, Iterator, Mapping,
MutableMapping, MutableSequence, Optional, OrderedDict,
Sequence, Tuple, Type, TypeVar, Union)
PRIMITIVE_TYPES = {int, str, float, bool}
# typing.Optional.__origin__ is typing.Union
UNION_TYPES = {Union}
# do not need typing.List, because __origin__ is list
ITERABLE_TYPES = {
list,
abc.Sequence,
abc.MutableSequence,
Sequence,
MutableSequence,
Iterable,
}
# do not need typing.Dict, because __origin__ is dict
MAPPING_TYPES = {
dict, abc.Mapping, abc.MutableMapping, Mapping, MutableMapping, OrderedDict,
collections.OrderedDict
}
OTHER_SUPPORTED_TYPES = {type(None), Any}
SUPPORTED_TYPES = PRIMITIVE_TYPES | UNION_TYPES | ITERABLE_TYPES | MAPPING_TYPES | OTHER_SUPPORTED_TYPES
BaseModelType = TypeVar('BaseModelType', bound='BaseModel')
class BaseModel:
"""BaseModel for structures. Subclasses are converted to dataclasses at
object construction time, with.
Subclasses are dataclasses with methods to support for converting to
and from dict, type enforcement, and user-defined validation logic.
"""
_aliases: Dict[str, str] = {} # will be used by subclasses
# this quiets the mypy "Unexpected keyword argument..." errors on subclass construction
# TODO: find a way to propogate type info to subclasses
def __init__(self, *args, **kwargs):
pass
def __init_subclass__(cls) -> None:
"""Hook called at subclass definition time and at instance construction
time.
Validates that the field to type mapping provided at subclass
definition time are supported by BaseModel.
"""
cls = dataclasses.dataclass(cls)
# print(inspect.signature(cls.__init__))
for field in dataclasses.fields(cls):
cls._recursively_validate_type_is_supported(field.type)
def to_dict(self, by_alias: bool = False) -> Dict[str, Any]:
"""Recursively converts to a dictionary.
Args:
by_alias (bool, optional): Whether to use attribute name to alias field mapping provided by cls._aliases when converting to dictionary. Defaults to False.
Returns:
Dict[str, Any]: Dictionary representation of the object.
"""
return convert_object_to_dict(self, by_alias=by_alias)
def to_json(self, by_alias: bool = False) -> str:
"""Recursively converts to a JSON string.
Args:
by_alias (bool, optional): Whether to use attribute name to alias field mapping provided by cls._aliases when converting to JSON. Defaults to False.
Returns:
str: JSON representation of the object.
"""
return json.dumps(self.to_dict(by_alias=by_alias))
@classmethod
def from_dict(cls,
data: Dict[str, Any],
by_alias: bool = False) -> BaseModelType:
"""Recursively loads object from a dictionary.
Args:
data (Dict[str, Any]): Dictionary representation of the object.
by_alias (bool, optional): Whether to use attribute name to alias field mapping provided by cls._aliases when reading in dictionary. Defaults to False.
Returns:
BaseModelType: Subclass of BaseModel.
"""
return _load_basemodel_helper(cls, data, by_alias=by_alias)
@classmethod
def from_json(cls, text: str, by_alias: bool = False) -> 'BaseModel':
"""Recursively loads object from a JSON string.
Args:
text (str): JSON representation of the object.
by_alias (bool, optional): Whether to use attribute name to alias field mapping provided by cls._aliases when reading in JSON. Defaults to False.
Returns:
BaseModelType: Subclass of BaseModel.
"""
return _load_basemodel_helper(cls, json.loads(text), by_alias=by_alias)
@property
def types(self) -> Dict[str, type]:
"""Dictionary mapping field names to types."""
return {field.name: field.type for field in dataclasses.fields(self)}
@property
def fields(self) -> Tuple[dataclasses.Field, ...]:
"""The dataclass fields."""
return dataclasses.fields(self)
@classmethod
def _recursively_validate_type_is_supported(cls, type_: type) -> None:
"""Walks the type definition (generics and subtypes) and checks if it
is supported by downstream BaseModel operations.
Args:
type_ (type): Type to check.
Raises:
TypeError: If type is unsupported.
"""
if isinstance(type_, ForwardRef):
return
if type_ in SUPPORTED_TYPES or _is_basemodel(type_):
return
if _get_origin_py37(type_) not in SUPPORTED_TYPES:
raise TypeError(
f'Type {type_} is not a supported type fields on child class of {BaseModel.__name__}: {cls.__name__}.'
)
args = _get_args_py37(type_) or [Any, Any]
for arg in args:
cls._recursively_validate_type_is_supported(arg)
def __post_init__(self) -> None:
"""Hook called after object is instantiated from BaseModel.
Transforms data and validates data using user-defined logic by
calling all methods prefixed with `transform_`, then all methods
prefixed with `validate_`.
"""
validate_methods = [
method for method in dir(self) if
method.startswith('transform_') and callable(getattr(self, method))
]
for method in validate_methods:
getattr(self, method)()
validate_methods = [
method for method in dir(self) if method.startswith('validate_') and
callable(getattr(self, method))
]
for method in validate_methods:
getattr(self, method)()
def __str__(self) -> str:
"""Returns a readable representation of the BaseModel subclass."""
return base_model_format(self)
def base_model_format(x: BaseModelType) -> str:
"""Formats a BaseModel object for improved readability.
Args:
x (BaseModelType): The subclass of BaseModel to format.
chars (int, optional): Indentation size. Defaults to 0.
Returns:
str: Readable string representation of the object.
"""
CHARS = 0
def first_level_indent(string: str, chars: int = 1) -> str:
return '\n'.join(' ' * chars + p for p in string.split('\n'))
def body_level_indent(string: str, chars=4) -> str:
a, *b = string.split('\n')
return a + '\n' + first_level_indent(
'\n'.join(b),
chars=chars,
) if b else a
def parts() -> Iterator[str]:
if dataclasses.is_dataclass(x):
yield type(x).__name__ + '('
def fields() -> Iterator[str]:
for field in dataclasses.fields(x):
nindent = CHARS + len(field.name) + 4
value = getattr(x, field.name)
rep_value = base_model_format(value)
yield (' ' * (CHARS + 3) + body_level_indent(
f'{field.name}={rep_value}', chars=nindent))
yield ',\n'.join(fields())
yield ' ' * CHARS + ')'
else:
yield pprint.pformat(x)
return '\n'.join(parts())
def convert_object_to_dict(obj: BaseModelType,
by_alias: bool) -> Dict[str, Any]:
"""Recursion helper function for converting a BaseModel and data structures
therein to a dictionary. Converts all fields that do not start with an
underscore.
Args:
obj (BaseModelType): The object to convert to a dictionary. Initially called with subclass of BaseModel.
by_alias (bool): Whether to use the attribute name to alias field mapping provided by cls._aliases when converting to dictionary.
Raises:
ValueError: If a field is missing a required value. In pracice, this should never be raised, but is included to help with debugging.
Returns:
Dict[str, Any]: The dictionary representation of the object.
"""
signature = inspect.signature(obj.__init__)
result = {}
for attr_name in signature.parameters:
if attr_name.startswith('_'):
continue
field_name = attr_name
value = getattr(obj, attr_name)
param = signature.parameters.get(attr_name, None)
if by_alias and hasattr(obj, '_aliases'):
field_name = obj._aliases.get(attr_name, attr_name)
if hasattr(value, 'to_dict'):
result[field_name] = value.to_dict(by_alias=by_alias)
elif isinstance(value, list):
result[field_name] = [
(x.to_dict(by_alias=by_alias) if hasattr(x, 'to_dict') else x)
for x in value
]
elif isinstance(value, dict):
result[field_name] = {
k:
(v.to_dict(by_alias=by_alias) if hasattr(v, 'to_dict') else v)
for k, v in value.items()
}
elif (param is not None):
result[
field_name] = value if value != param.default else param.default
else:
raise ValueError(
f'Cannot serialize {obj}. No value for {attr_name}.')
return result
def _is_basemodel(obj: Any) -> bool:
"""Checks if object is a subclass of BaseModel.
Args:
obj (Any): Any object
Returns:
bool: Is a subclass of BaseModel.
"""
return inspect.isclass(obj) and issubclass(obj, BaseModel)
def _get_origin_py37(type_: Type) -> Optional[Type]:
"""typing.get_origin is introduced in Python 3.8, but we need a get_origin
that is compatible with 3.7.
Args:
type_ (Type): A type.
Returns:
Type: The origin of `type_`.
"""
# uses typing for types
return type_.__origin__ if hasattr(type_, '__origin__') else None
def _get_args_py37(type_: Type) -> Tuple[Type]:
"""typing.get_args is introduced in Python 3.8, but we need a get_args that
is compatible with 3.7.
Args:
type_ (Type): A type.
Returns:
Tuple[Type]: The type arguments of `type_`.
"""
# uses typing for types
return type_.__args__ if hasattr(type_, '__args__') else tuple()
def _load_basemodel_helper(type_: Any, data: Any, by_alias: bool) -> Any:
"""Helper function for recursively loading a BaseModel.
Args:
type_ (Any): The type of the object to load. Typically an instance of `type`, `BaseModel` or `Any`.
data (Any): The data to load.
Returns:
Any: The loaded object.
"""
if isinstance(type_, str):
raise TypeError(
'Please do not use built-in collection types as generics (e.g., list[int]) and do not include the import line `from __future__ import annotations`. Please use the corresponding generic from typing (e.g., List[int]).'
)
# catch unsupported types early
type_or_generic = _get_origin_py37(type_) or type_
if type_or_generic not in SUPPORTED_TYPES and not _is_basemodel(type_):
raise TypeError(
f'Unsupported type: {type_}. Cannot load data into object.')
# if don't have any helpful type information, return data as is
if type_ is Any:
return data
# if type is NoneType and data is None, return data/None
if type_ is type(None):
if data is None:
return data
else:
raise TypeError(
f'Expected value None for type NoneType. Got: {data}')
# handle primitives, with typecasting
if type_ in PRIMITIVE_TYPES:
return type_(data)
# simple types are handled, now handle for container types
origin = _get_origin_py37(type_)
args = _get_args_py37(type_) or [
Any, Any
] # if there is an inner type in the generic, use it, else use Any
# recursively load iterable objects
if origin in ITERABLE_TYPES:
for arg in args: # TODO handle errors
return [
_load_basemodel_helper(arg, element, by_alias=by_alias)
for element in data
]
# recursively load mapping objects
if origin in MAPPING_TYPES:
if len(args) != 2:
raise TypeError(
f'Expected exactly 2 type arguments for mapping type {type_}.')
return {
_load_basemodel_helper(args[0], k, by_alias=by_alias):
_load_basemodel_helper(
args[1], # type: ignore
# length check a few lines up ensures index 1 exists
v,
by_alias=by_alias) for k, v in data.items()
}
# if the type is a Union, try to load the data into each of the types,
# greedily accepting the first annotation arg that works --> the developer
# can indicate which types are preferred based on the annotation arg order
if origin in UNION_TYPES:
# don't try to cast none if the union type is optional
if type(None) in args and data is None:
return None
for arg in args:
return _load_basemodel_helper(args[0], data, by_alias=by_alias)
# finally, handle the cases where the type is an instance of baseclass
if _is_basemodel(type_):
fields = dataclasses.fields(type_)
basemodel_kwargs = {}
for field in fields:
attr_name = field.name
data_field_name = attr_name
if by_alias and hasattr(type_, '_aliases'):
data_field_name = type_._aliases.get(attr_name, attr_name)
value = data.get(data_field_name)
if value is None:
if field.default is dataclasses.MISSING and field.default_factory is dataclasses.MISSING:
raise ValueError(
f'Missing required field: {data_field_name}')
value = field.default if field.default is not dataclasses.MISSING else field.default_factory(
)
else:
value = _load_basemodel_helper(
field.type, value, by_alias=by_alias)
basemodel_kwargs[attr_name] = value
return type_(**basemodel_kwargs)
raise TypeError(
f'Unknown error when loading data: {data} into type {type_}')

View File

@ -0,0 +1,462 @@
# Copyright 2022 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import functools
import unittest
from collections import abc
from typing import (Any, Dict, List, Mapping, MutableMapping, MutableSequence,
Optional, OrderedDict, Sequence, Set, Tuple, Union)
from absl.testing import parameterized
from kfp.components import base_model
class TypeClass(base_model.BaseModel):
a: str
b: List[int]
c: Dict[str, int]
d: Union[int, str]
e: Union[int, str, bool]
f: Optional[int]
class TestBaseModel(unittest.TestCase):
def test_is_dataclass(self):
class Child(base_model.BaseModel):
x: int
child = Child(x=1)
self.assertTrue(dataclasses.is_dataclass(child))
def test_to_dict_simple(self):
class Child(base_model.BaseModel):
i: int
s: str
f: float
l: List[int]
data = {'i': 1, 's': 's', 'f': 1.0, 'l': [1, 2]}
child = Child(**data)
actual = child.to_dict()
self.assertEqual(actual, data)
self.assertEqual(child, Child.from_dict(actual))
def test_to_dict_nested(self):
class InnerChild(base_model.BaseModel):
a: str
class Child(base_model.BaseModel):
b: int
c: InnerChild
data = {'b': 1, 'c': InnerChild(a='a')}
child = Child(**data)
actual = child.to_dict()
expected = {'b': 1, 'c': {'a': 'a'}}
self.assertEqual(actual, expected)
self.assertEqual(child, Child.from_dict(actual))
def test_from_dict_no_defaults(self):
class Child(base_model.BaseModel):
i: int
s: str
f: float
l: List[int]
data = {'i': 1, 's': 's', 'f': 1.0, 'l': [1, 2]}
child = Child.from_dict(data)
self.assertEqual(child.i, 1)
self.assertEqual(child.s, 's')
self.assertEqual(child.f, 1.0)
self.assertEqual(child.l, [1, 2])
self.assertEqual(child.to_dict(), data)
def test_from_dict_with_defaults(self):
class Child(base_model.BaseModel):
s: str
f: float
l: List[int]
i: int = 1
data = {'s': 's', 'f': 1.0, 'l': [1, 2]}
child = Child.from_dict(data)
self.assertEqual(child.i, 1)
self.assertEqual(child.s, 's')
self.assertEqual(child.f, 1.0)
self.assertEqual(child.l, [1, 2])
self.assertEqual(child.to_dict(), {**data, **{'i': 1}})
def test_from_dict_nested(self):
class InnerChild(base_model.BaseModel):
a: str
class Child(base_model.BaseModel):
b: int
c: InnerChild
data = {'b': 1, 'c': {'a': 'a'}}
child = Child.from_dict(data)
self.assertEqual(child.b, 1)
self.assertIsInstance(child.c, InnerChild)
self.assertEqual(child.c.a, 'a')
self.assertEqual(child.to_dict(), data)
def test_from_dict_array_nested(self):
class InnerChild(base_model.BaseModel):
a: str
class Child(base_model.BaseModel):
b: int
c: List[InnerChild]
d: Dict[str, InnerChild]
inner_child_data = {'a': 'a'}
data = {
'b': 1,
'c': [inner_child_data, inner_child_data],
'd': {
'e': inner_child_data
}
}
child = Child.from_dict(data)
self.assertEqual(child.b, 1)
self.assertIsInstance(child.c[0], InnerChild)
self.assertIsInstance(child.c[1], InnerChild)
self.assertIsInstance(child.d['e'], InnerChild)
self.assertEqual(child.c[0].a, 'a')
self.assertEqual(child.to_dict(), data)
def test_from_dict_by_alias(self):
class InnerChild(base_model.BaseModel):
inner_child_field: int
_aliases = {'inner_child_field': 'inner_child_field_alias'}
class Child(base_model.BaseModel):
sub_field: InnerChild
_aliases = {'sub_field': 'sub_field_alias'}
data = {'sub_field_alias': {'inner_child_field_alias': 2}}
child = Child.from_dict(data, by_alias=True)
self.assertIsInstance(child.sub_field, InnerChild)
self.assertEqual(child.sub_field.inner_child_field, 2)
self.assertEqual(child.to_dict(by_alias=True), data)
def test_to_dict_by_alias(self):
class InnerChild(base_model.BaseModel):
inner_child_field: int
_aliases = {'inner_child_field': 'inner_child_field_alias'}
class Child(base_model.BaseModel):
sub_field: InnerChild
_aliases = {'sub_field': 'sub_field_alias'}
inner_child = InnerChild(inner_child_field=2)
child = Child(sub_field=inner_child)
actual = child.to_dict(by_alias=True)
expected = {'sub_field_alias': {'inner_child_field_alias': 2}}
self.assertEqual(actual, expected)
self.assertEqual(Child.from_dict(actual, by_alias=True), child)
def test_to_dict_by_alias2(self):
class MyClass(base_model.BaseModel):
x: int
y: List[int]
z: Dict[str, int]
_aliases = {'x': 'a', 'z': 'b'}
res = MyClass(x=1, y=[2], z={'key': 3}).to_dict(by_alias=True)
self.assertEqual(res, {'a': 1, 'y': [2], 'b': {'key': 3}})
def test_to_dict_by_alias_nested(self):
class InnerClass(base_model.BaseModel):
f: float
_aliases = {'f': 'a'}
class MyClass(base_model.BaseModel):
x: int
y: List[int]
z: InnerClass
_aliases = {'x': 'a', 'z': 'b'}
res = MyClass(x=1, y=[2], z=InnerClass(f=1.0)).to_dict(by_alias=True)
self.assertEqual(res, {'a': 1, 'y': [2], 'b': {'a': 1.0}})
def test_can_create_properties_using_attributes(self):
class Child(base_model.BaseModel):
x: Optional[int]
@property
def prop(self) -> bool:
return self.x is not None
child1 = Child(x=None)
self.assertEqual(child1.prop, False)
child2 = Child(x=1)
self.assertEqual(child2.prop, True)
def test_unsupported_type_success(self):
class OtherClass(base_model.BaseModel):
x: int
class MyClass(base_model.BaseModel):
a: OtherClass
def test_unsupported_type_failures(self):
with self.assertRaisesRegex(TypeError, r'not a supported'):
class MyClass(base_model.BaseModel):
a: tuple
with self.assertRaisesRegex(TypeError, r'not a supported'):
class MyClass(base_model.BaseModel):
a: Tuple
with self.assertRaisesRegex(TypeError, r'not a supported'):
class MyClass(base_model.BaseModel):
a: Set
with self.assertRaisesRegex(TypeError, r'not a supported'):
class OtherClass:
pass
class MyClass(base_model.BaseModel):
a: OtherClass
def test_base_model_validation(self):
# test exception thrown
class MyClass(base_model.BaseModel):
x: int
def validate_x(self) -> None:
if self.x < 2:
raise ValueError('x must be greater than 2')
with self.assertRaisesRegex(ValueError, 'x must be greater than 2'):
mc = MyClass(x=1)
# test value modified same type
class MyClass(base_model.BaseModel):
x: int
def validate_x(self) -> None:
self.x = max(self.x, 2)
mc = MyClass(x=1)
self.assertEqual(mc.x, 2)
# test value modified new type
class MyClass(base_model.BaseModel):
x: Optional[List[int]] = None
def validate_x(self) -> None:
if isinstance(self.x, list) and not self.x:
self.x = None
mc = MyClass(x=[])
self.assertEqual(mc.x, None)
def test_can_set_field(self):
class MyClass(base_model.BaseModel):
x: int
mc = MyClass(x=2)
mc.x = 1
self.assertEqual(mc.x, 1)
def test_can_use_default_factory(self):
class MyClass(base_model.BaseModel):
x: List[int] = dataclasses.field(default_factory=list)
mc = MyClass()
self.assertEqual(mc.x, [])
class TestIsBaseModel(unittest.TestCase):
def test_true(self):
self.assertEqual(base_model._is_basemodel(base_model.BaseModel), True)
class MyClass(base_model.BaseModel):
pass
self.assertEqual(base_model._is_basemodel(MyClass), True)
def test_false(self):
self.assertEqual(base_model._is_basemodel(int), False)
self.assertEqual(base_model._is_basemodel(1), False)
self.assertEqual(base_model._is_basemodel(str), False)
class TestLoadBaseModelHelper(parameterized.TestCase):
def setUp(self):
self.no_alias_load_base_model_helper = functools.partial(
base_model._load_basemodel_helper, by_alias=False)
def test_load_primitive(self):
self.assertEqual(self.no_alias_load_base_model_helper(str, 'a'), 'a')
self.assertEqual(self.no_alias_load_base_model_helper(int, 1), 1)
self.assertEqual(self.no_alias_load_base_model_helper(float, 1.0), 1.0)
self.assertEqual(self.no_alias_load_base_model_helper(bool, True), True)
self.assertEqual(
self.no_alias_load_base_model_helper(type(None), None), None)
def test_load_primitive_with_casting(self):
self.assertEqual(self.no_alias_load_base_model_helper(int, '1'), 1)
self.assertEqual(self.no_alias_load_base_model_helper(str, 1), '1')
self.assertEqual(self.no_alias_load_base_model_helper(float, 1), 1.0)
self.assertEqual(self.no_alias_load_base_model_helper(int, 1.0), 1)
self.assertEqual(self.no_alias_load_base_model_helper(bool, 1), True)
self.assertEqual(self.no_alias_load_base_model_helper(bool, 0), False)
self.assertEqual(self.no_alias_load_base_model_helper(int, True), 1)
self.assertEqual(self.no_alias_load_base_model_helper(int, False), 0)
self.assertEqual(
self.no_alias_load_base_model_helper(bool, None), False)
def test_load_none(self):
self.assertEqual(
self.no_alias_load_base_model_helper(type(None), None), None)
with self.assertRaisesRegex(TypeError, ''):
self.no_alias_load_base_model_helper(type(None), 1)
@parameterized.parameters(['a', 1, 1.0, True, False, None, ['list']])
def test_load_any(self, data: Any):
self.assertEqual(self.no_alias_load_base_model_helper(Any, data),
data) # type: ignore
def test_load_list(self):
self.assertEqual(
self.no_alias_load_base_model_helper(List[str], ['a']), ['a'])
self.assertEqual(
self.no_alias_load_base_model_helper(List[int], [1, 1]), [1, 1])
self.assertEqual(
self.no_alias_load_base_model_helper(List[float], [1.0]), [1.0])
self.assertEqual(
self.no_alias_load_base_model_helper(List[bool], [True]), [True])
self.assertEqual(
self.no_alias_load_base_model_helper(List[type(None)], [None]),
[None])
def test_load_primitive_other_iterables(self):
self.assertEqual(
self.no_alias_load_base_model_helper(Sequence[bool], [True]),
[True])
self.assertEqual(
self.no_alias_load_base_model_helper(MutableSequence[type(None)],
[None]), [None])
self.assertEqual(
self.no_alias_load_base_model_helper(Sequence[str], ['a']), ['a'])
def test_load_dict(self):
self.assertEqual(
self.no_alias_load_base_model_helper(Dict[str, str], {'a': 'a'}),
{'a': 'a'})
self.assertEqual(
self.no_alias_load_base_model_helper(Dict[str, int], {'a': 1}),
{'a': 1})
self.assertEqual(
self.no_alias_load_base_model_helper(Dict[str, float], {'a': 1.0}),
{'a': 1.0})
self.assertEqual(
self.no_alias_load_base_model_helper(Dict[str, bool], {'a': True}),
{'a': True})
self.assertEqual(
self.no_alias_load_base_model_helper(Dict[str, type(None)],
{'a': None}), {'a': None})
def test_load_mapping(self):
self.assertEqual(
self.no_alias_load_base_model_helper(Mapping[str, float],
{'a': 1.0}), {'a': 1.0})
self.assertEqual(
self.no_alias_load_base_model_helper(MutableMapping[str, bool],
{'a': True}), {'a': True})
self.assertEqual(
self.no_alias_load_base_model_helper(OrderedDict[str,
type(None)],
{'a': None}), {'a': None})
def test_load_union_types(self):
self.assertEqual(
self.no_alias_load_base_model_helper(Union[str, int], 'a'), 'a')
self.assertEqual(
self.no_alias_load_base_model_helper(Union[str, int], 1), '1')
self.assertEqual(
self.no_alias_load_base_model_helper(Union[int, str], 1), 1)
self.assertEqual(
self.no_alias_load_base_model_helper(Union[int, str], '1'), 1)
def test_load_optional_types(self):
self.assertEqual(
self.no_alias_load_base_model_helper(Optional[str], 'a'), 'a')
self.assertEqual(
self.no_alias_load_base_model_helper(Optional[str], None), None)
def test_unsupported_type(self):
with self.assertRaisesRegex(TypeError, r'Unsupported type:'):
self.no_alias_load_base_model_helper(Set[int], {1})
class TestGetOriginPy37(parameterized.TestCase):
def test_is_same_as_typing_version(self):
import sys
if sys.version_info.major == 3 and sys.version_info.minor >= 8:
import typing
for field in dataclasses.fields(TypeClass):
self.assertEqual(
base_model._get_origin_py37(field.type),
typing.get_origin(field.type))
class TestGetArgsPy37(parameterized.TestCase):
def test_is_same_as_typing_version(self):
import sys
if sys.version_info.major == 3 and sys.version_info.minor >= 8:
import typing
for field in dataclasses.fields(TypeClass):
self.assertEqual(
base_model._get_args_py37(field.type),
typing.get_args(field.type))
if __name__ == '__main__':
unittest.main()

View File

@ -13,13 +13,13 @@
# limitations under the License.
"""Pipeline task class and operations."""
import re
import copy
import re
from typing import Any, List, Mapping, Optional, Union
from kfp.components import constants
from kfp.components import placeholders
from kfp.components import pipeline_channel
from kfp.components import placeholders
from kfp.components import structures
from kfp.components.types import type_utils
@ -116,6 +116,7 @@ class PipelineTask:
self.container_spec = None
if component_spec.implementation.container is not None:
self.container_spec = self._resolve_command_line_and_arguments(
component_spec=component_spec,
args=args,
@ -266,7 +267,7 @@ class PipelineTask:
return input_path
else:
input_spec = inputs_dict[input_name]
if input_spec.optional:
if input_spec._optional:
return None
else:
raise ValueError(
@ -315,7 +316,7 @@ class PipelineTask:
return expanded_result
else:
raise TypeError('Unrecognized argument type: {}'.format(arg))
raise TypeError(f'Unrecognized argument type: {arg}')
def expand_argument_list(argument_list) -> Optional[List[str]]:
if argument_list is None:

View File

@ -13,51 +13,52 @@
# limitations under the License.
"""Definitions for component spec."""
import dataclasses
import ast
import functools
import itertools
from typing import Any, Dict, Mapping, Optional, Sequence, Union
from typing import Any, Dict, List, Mapping, Optional, Union
import pydantic
import yaml
from kfp.components import base_model
from kfp.components import utils
from kfp.components import v1_components
from kfp.components import v1_structures
from kfp.utils import ir_utils
class BaseModel(pydantic.BaseModel):
class InputSpec_(base_model.BaseModel):
"""Component input definitions. (Inner class).
class Config:
allow_population_by_field_name = True
arbitrary_types_allowed = True
Attributes:
type: The type of the input.
default (optional): the default value for the input.
description: Optional: the user description of the input.
"""
type: Union[str, dict]
default: Union[Any, None] = None
description: Optional[str] = None
class InputSpec(BaseModel):
# Hack to allow access to __init__ arguments for setting _optional value
class InputSpec(InputSpec_, base_model.BaseModel):
"""Component input definitions.
Attributes:
type: The type of the input.
default: Optional; the default value for the input.
default (optional): the default value for the input.
description: Optional: the user description of the input.
optional: Wether the input is optional. An input is optional when it has
an explicit default value.
_optional: Wether the input is optional. An input is optional when it has an explicit default value.
"""
type: Union[str, dict]
default: Optional[Any] = None
description: Optional[str] = None
_optional: bool = pydantic.PrivateAttr()
def __init__(self, **data):
super().__init__(**data)
# An input is optional if a default value is explicitly specified.
self._optional = 'default' in data
@property
def optional(self) -> bool:
return self._optional
@functools.wraps(InputSpec_.__init__)
def __init__(self, *args, **kwargs):
if args:
raise ValueError('InputSpec does not accept positional arguments.')
super().__init__(*args, **kwargs)
self._optional = 'default' in kwargs
class OutputSpec(BaseModel):
class OutputSpec(base_model.BaseModel):
"""Component output definitions.
Attributes:
@ -68,55 +69,54 @@ class OutputSpec(BaseModel):
description: Optional[str] = None
class BasePlaceholder(BaseModel):
"""Base class for placeholders that could appear in container cmd and
args."""
pass
class InputValuePlaceholder(BasePlaceholder):
class InputValuePlaceholder(base_model.BaseModel):
"""Class that holds input value for conditional cases.
Attributes:
input_name: name of the input.
"""
input_name: str = pydantic.Field(alias='inputValue')
input_name: str
_aliases = {'input_name': 'inputValue'}
class InputPathPlaceholder(BasePlaceholder):
class InputPathPlaceholder(base_model.BaseModel):
"""Class that holds input path for conditional cases.
Attributes:
input_name: name of the input.
"""
input_name: str = pydantic.Field(alias='inputPath')
input_name: str
_aliases = {'input_name': 'inputPath'}
class InputUriPlaceholder(BasePlaceholder):
class InputUriPlaceholder(base_model.BaseModel):
"""Class that holds input uri for conditional cases.
Attributes:
input_name: name of the input.
"""
input_name: str = pydantic.Field(alias='inputUri')
input_name: str
_aliases = {'input_name': 'inputUri'}
class OutputPathPlaceholder(BasePlaceholder):
class OutputPathPlaceholder(base_model.BaseModel):
"""Class that holds output path for conditional cases.
Attributes:
output_name: name of the output.
"""
output_name: str = pydantic.Field(alias='outputPath')
output_name: str
_aliases = {'output_name': 'outputPath'}
class OutputUriPlaceholder(BasePlaceholder):
class OutputUriPlaceholder(base_model.BaseModel):
"""Class that holds output uri for conditional cases.
Attributes:
output_name: name of the output.
"""
output_name: str = pydantic.Field(alias='outputUri')
output_name: str
_aliases = {'output_name': 'outputUri'}
ValidCommandArgs = Union[str, InputValuePlaceholder, InputPathPlaceholder,
@ -125,16 +125,17 @@ ValidCommandArgs = Union[str, InputValuePlaceholder, InputPathPlaceholder,
'ConcatPlaceholder']
class ConcatPlaceholder(BasePlaceholder):
class ConcatPlaceholder(base_model.BaseModel):
"""Class that extends basePlaceholders for concatenation.
Attributes:
items: string or ValidCommandArgs for concatenation.
"""
items: Sequence[ValidCommandArgs] = pydantic.Field(alias='concat')
items: List[ValidCommandArgs]
_aliases = {'items': 'concat'}
class IfPresentPlaceholderStructure(BaseModel):
class IfPresentPlaceholderStructure(base_model.BaseModel):
"""Class that holds structure for conditional cases.
Attributes:
@ -146,43 +147,35 @@ class IfPresentPlaceholderStructure(BaseModel):
the command-line argument will be replaced at run-time by the
expanded value of otherwise.
"""
input_name: str = pydantic.Field(alias='inputName')
then: Sequence[ValidCommandArgs]
otherwise: Optional[Sequence[ValidCommandArgs]] = pydantic.Field(
None, alias='else')
input_name: str
then: List[ValidCommandArgs]
otherwise: Optional[List[ValidCommandArgs]] = None
_aliases = {'input_name': 'inputName', 'otherwise': 'else'}
@pydantic.validator('otherwise', allow_reuse=True)
def empty_otherwise_sequence(cls, v):
if v == []:
return None
return v
def transform_otherwise(self) -> None:
"""Use None instead of empty list for optional."""
self.otherwise = None if self.otherwise == [] else self.otherwise
class IfPresentPlaceholder(BasePlaceholder):
class IfPresentPlaceholder(base_model.BaseModel):
"""Class that extends basePlaceholders for conditional cases.
Attributes:
if_present (ifPresent): holds structure for conditional cases.
"""
if_structure: IfPresentPlaceholderStructure = pydantic.Field(
alias='ifPresent')
if_structure: IfPresentPlaceholderStructure
_aliases = {'if_structure': 'ifPresent'}
IfPresentPlaceholderStructure.update_forward_refs()
IfPresentPlaceholder.update_forward_refs()
ConcatPlaceholder.update_forward_refs()
@dataclasses.dataclass
class ResourceSpec:
class ResourceSpec(base_model.BaseModel):
"""The resource requirements of a container execution.
Attributes:
cpu_limit: Optional; the limit of the number of vCPU cores.
memory_limit: Optional; the memory limit in GB.
accelerator_type: Optional; the type of accelerators attached to the
cpu_limit (optional): the limit of the number of vCPU cores.
memory_limit (optional): the memory limit in GB.
accelerator_type (optional): the type of accelerators attached to the
container.
accelerator_count: Optional; the number of accelerators attached.
accelerator_count (optional): the number of accelerators attached.
"""
cpu_limit: Optional[float] = None
memory_limit: Optional[float] = None
@ -190,36 +183,36 @@ class ResourceSpec:
accelerator_count: Optional[int] = None
class ContainerSpec(BaseModel):
class ContainerSpec(base_model.BaseModel):
"""Container implementation definition.
Attributes:
image: The container image.
command: Optional; the container entrypoint.
args: Optional; the arguments to the container entrypoint.
env: Optional; the environment variables to be passed to the container.
resources: Optional; the specification on the resource requirements.
command (optional): the container entrypoint.
args (optional): the arguments to the container entrypoint.
env (optional): the environment variables to be passed to the container.
resources (optional): the specification on the resource requirements.
"""
image: str
command: Optional[Sequence[ValidCommandArgs]] = None
args: Optional[Sequence[ValidCommandArgs]] = None
command: Optional[List[ValidCommandArgs]] = None
args: Optional[List[ValidCommandArgs]] = None
env: Optional[Mapping[str, ValidCommandArgs]] = None
resources: Optional[ResourceSpec] = None
@pydantic.validator('command', 'args', allow_reuse=True)
def empty_sequence(cls, v):
if v == []:
return None
return v
def transform_command(self) -> None:
"""Use None instead of empty list for command."""
self.command = None if self.command == [] else self.command
@pydantic.validator('env', allow_reuse=True)
def empty_map(cls, v):
if v == {}:
return None
return v
def transform_args(self) -> None:
"""Use None instead of empty list for args."""
self.args = None if self.args == [] else self.args
def transform_env(self) -> None:
"""Use None instead of empty dict for env."""
self.env = None if self.env == {} else self.env
class TaskSpec(BaseModel):
class TaskSpec(base_model.BaseModel):
"""The spec of a pipeline task.
Attributes:
@ -227,23 +220,23 @@ class TaskSpec(BaseModel):
inputs: The sources of task inputs. Constant values or PipelineParams.
dependent_tasks: The list of upstream tasks.
component_ref: The name of a component spec this task is based on.
trigger_condition: Optional; an expression which will be evaluated into
trigger_condition (optional): an expression which will be evaluated into
a boolean value. True to trigger the task to run.
trigger_strategy: Optional; when the task will be ready to be triggered.
trigger_strategy (optional): when the task will be ready to be triggered.
Valid values include: "TRIGGER_STRATEGY_UNSPECIFIED",
"ALL_UPSTREAM_TASKS_SUCCEEDED", and "ALL_UPSTREAM_TASKS_COMPLETED".
iterator_items: Optional; the items to iterate on. A constant value or
iterator_items (optional): the items to iterate on. A constant value or
a PipelineParam.
iterator_item_input: Optional; the name of the input which has the item
iterator_item_input (optional): the name of the input which has the item
from the [items][] collection.
enable_caching: Optional; whether or not to enable caching for the task.
enable_caching (optional): whether or not to enable caching for the task.
Default is True.
display_name: Optional; the display name of the task. If not specified,
display_name (optional): the display name of the task. If not specified,
the task name will be used as the display name.
"""
name: str
inputs: Mapping[str, Any]
dependent_tasks: Sequence[str]
dependent_tasks: List[str]
component_ref: str
trigger_condition: Optional[str] = None
trigger_strategy: Optional[str] = None
@ -253,7 +246,7 @@ class TaskSpec(BaseModel):
display_name: Optional[str] = None
class DagSpec(BaseModel):
class DagSpec(base_model.BaseModel):
"""DAG(graph) implementation definition.
Attributes:
@ -261,11 +254,10 @@ class DagSpec(BaseModel):
outputs: Defines how the outputs of the dag are linked to the sub tasks.
"""
tasks: Mapping[str, TaskSpec]
# TODO(chensun): revisit if we need a DagOutputsSpec class.
outputs: Mapping[str, Any]
class ImporterSpec(BaseModel):
class ImporterSpec(base_model.BaseModel):
"""ImporterSpec definition.
Attributes:
@ -273,7 +265,7 @@ class ImporterSpec(BaseModel):
type_schema: The type of the artifact.
reimport: Whether or not import an artifact regardless it has been
imported before.
metadata: Optional; the properties of the artifact.
metadata (optional): the properties of the artifact.
"""
artifact_uri: str
type_schema: str
@ -281,7 +273,7 @@ class ImporterSpec(BaseModel):
metadata: Optional[Mapping[str, Any]] = None
class Implementation(BaseModel):
class Implementation(base_model.BaseModel):
"""Implementation definition.
Attributes:
@ -294,97 +286,211 @@ class Implementation(BaseModel):
importer: Optional[ImporterSpec] = None
class ComponentSpec(BaseModel):
def try_to_get_dict_from_string(element: str) -> Union[dict, str]:
try:
res = ast.literal_eval(element)
except (ValueError, SyntaxError):
return element
if not isinstance(res, dict):
return element
return res
def convert_str_or_dict_to_placeholder(
element: Union[str, dict,
ValidCommandArgs]) -> Union[str, ValidCommandArgs]:
"""Converts command and args elements to a placholder type based on value
of the key of the placeholder string, else returns the input.
Args:
element (Union[str, dict, ValidCommandArgs]): A ContainerSpec.command or ContainerSpec.args element.
Raises:
TypeError: If `element` is invalid.
Returns:
Union[str, ValidCommandArgs]: Possibly converted placeholder or original input.
"""
if not isinstance(element, (dict, str)):
return element
elif isinstance(element, str):
res = try_to_get_dict_from_string(element)
if not isinstance(res, dict):
return element
elif isinstance(element, dict):
res = element
else:
raise TypeError(
f'Invalid type for arg: {type(element)}. Expected str or dict.')
has_one_entry = len(res) == 1
if not has_one_entry:
raise ValueError(
f'Got unexpected dictionary {res}. Expected a dictionary with one entry.'
)
first_key = list(res.keys())[0]
first_value = list(res.values())[0]
if first_key == 'inputValue':
return InputValuePlaceholder(
input_name=utils.sanitize_input_name(first_value))
elif first_key == 'inputPath':
return InputPathPlaceholder(
input_name=utils.sanitize_input_name(first_value))
elif first_key == 'inputUri':
return InputUriPlaceholder(
input_name=utils.sanitize_input_name(first_value))
elif first_key == 'outputPath':
return OutputPathPlaceholder(
output_name=utils.sanitize_input_name(first_value))
elif first_key == 'outputUri':
return OutputUriPlaceholder(
output_name=utils.sanitize_input_name(first_value))
elif first_key == 'ifPresent':
structure_kwargs = res['ifPresent']
structure_kwargs['input_name'] = structure_kwargs.pop('inputName')
structure_kwargs['otherwise'] = structure_kwargs.pop('else')
structure_kwargs['then'] = [
convert_str_or_dict_to_placeholder(e)
for e in structure_kwargs['then']
]
structure_kwargs['otherwise'] = [
convert_str_or_dict_to_placeholder(e)
for e in structure_kwargs['otherwise']
]
if_structure = IfPresentPlaceholderStructure(**structure_kwargs)
return IfPresentPlaceholder(if_structure=if_structure)
elif first_key == 'concat':
return ConcatPlaceholder(items=[
convert_str_or_dict_to_placeholder(e) for e in res['concat']
])
else:
raise TypeError(
f'Unexpected command/argument type: "{element}" of type "{type(element)}".'
)
def _check_valid_placeholder_reference(valid_inputs: List[str],
valid_outputs: List[str],
placeholder: ValidCommandArgs) -> None:
"""Validates input/output placeholders refer to an existing input/output.
Args:
valid_inputs: The existing input names.
valid_outputs: The existing output names.
arg: The placeholder argument for checking.
Raises:
ValueError: if any placeholder references a non-existing input or
output.
TypeError: if any argument is neither a str nor a placeholder
instance.
"""
if isinstance(
placeholder,
(InputValuePlaceholder, InputPathPlaceholder, InputUriPlaceholder)):
if placeholder.input_name not in valid_inputs:
raise ValueError(
f'Argument "{placeholder}" references non-existing input.')
elif isinstance(placeholder, (OutputPathPlaceholder, OutputUriPlaceholder)):
if placeholder.output_name not in valid_outputs:
raise ValueError(
f'Argument "{placeholder}" references non-existing output.')
elif isinstance(placeholder, IfPresentPlaceholder):
if placeholder.if_structure.input_name not in valid_inputs:
raise ValueError(
f'Argument "{placeholder}" references non-existing input.')
for placeholder in itertools.chain(
placeholder.if_structure.then or [],
placeholder.if_structure.otherwise or []):
_check_valid_placeholder_reference(valid_inputs, valid_outputs,
placeholder)
elif isinstance(placeholder, ConcatPlaceholder):
for placeholder in placeholder.items:
_check_valid_placeholder_reference(valid_inputs, valid_outputs,
placeholder)
elif not isinstance(placeholder, str):
raise TypeError(
f'Unexpected argument "{placeholder}" of type {type(placeholder)}.')
ValidCommandArgTypes = (str, InputValuePlaceholder, InputPathPlaceholder,
InputUriPlaceholder, OutputPathPlaceholder,
OutputUriPlaceholder, IfPresentPlaceholder,
ConcatPlaceholder)
class ComponentSpec(base_model.BaseModel):
"""The definition of a component.
Attributes:
name: The name of the component.
description: Optional; the description of the component.
inputs: Optional; the input definitions of the component.
outputs: Optional; the output definitions of the component.
description (optional): the description of the component.
inputs (optional): the input definitions of the component.
outputs (optional): the output definitions of the component.
implementation: The implementation of the component. Either an executor
(container, importer) or a DAG consists of other components.
"""
name: str
implementation: Implementation
description: Optional[str] = None
inputs: Optional[Dict[str, InputSpec]] = None
outputs: Optional[Dict[str, OutputSpec]] = None
implementation: Implementation
@pydantic.validator('inputs', 'outputs', allow_reuse=True)
def empty_map(cls, v):
if v == {}:
return None
return v
def transform_inputs(self) -> None:
"""Use None instead of empty list for inputs."""
self.inputs = None if self.inputs == {} else self.inputs
@pydantic.root_validator(allow_reuse=True)
def validate_placeholders(cls, values):
if values.get('implementation').container is None:
return values
containerSpec: ContainerSpec = values.get('implementation').container
def transform_outputs(self) -> None:
"""Use None instead of empty list for outputs."""
self.outputs = None if self.outputs == {} else self.outputs
try:
valid_inputs = values.get('inputs').keys()
except AttributeError:
valid_inputs = []
def transform_command_input_placeholders(self) -> None:
"""Converts command and args elements to a placholder type where
applicable."""
if self.implementation.container is not None:
try:
valid_outputs = values.get('outputs').keys()
except AttributeError:
valid_outputs = []
if self.implementation.container.command is not None:
self.implementation.container.command = [
convert_str_or_dict_to_placeholder(e)
for e in self.implementation.container.command
]
if self.implementation.container.args is not None:
self.implementation.container.args = [
convert_str_or_dict_to_placeholder(e)
for e in self.implementation.container.args
]
def validate_placeholders(self):
"""Validates that input/output placeholders refer to an existing
input/output."""
implementation = self.implementation
if getattr(implementation, 'container', None) is None:
return
containerSpec: ContainerSpec = implementation.container
valid_inputs = [] if self.inputs is None else list(self.inputs.keys())
valid_outputs = [] if self.outputs is None else list(
self.outputs.keys())
for arg in itertools.chain((containerSpec.command or []),
(containerSpec.args or [])):
cls._check_valid_placeholder_reference(valid_inputs, valid_outputs,
arg)
return values
@classmethod
def _check_valid_placeholder_reference(cls, valid_inputs: Sequence[str],
valid_outputs: Sequence[str],
arg: ValidCommandArgs) -> None:
"""Validates placeholder reference existing input/output names.
Args:
valid_inputs: The existing input names.
valid_outputs: The existing output names.
arg: The placeholder argument for checking.
Raises:
ValueError: if any placeholder references a non-existing input or
output.
TypeError: if any argument is neither a str nor a placeholder
instance.
"""
if isinstance(
arg,
(InputValuePlaceholder, InputPathPlaceholder, InputUriPlaceholder)):
if arg.input_name not in valid_inputs:
raise ValueError(
f'Argument "{arg}" references non-existing input.')
elif isinstance(arg, (OutputPathPlaceholder, OutputUriPlaceholder)):
if arg.output_name not in valid_outputs:
raise ValueError(
f'Argument "{arg}" references non-existing output.')
elif isinstance(arg, IfPresentPlaceholder):
if arg.if_structure.input_name not in valid_inputs:
raise ValueError(
f'Argument "{arg}" references non-existing input.')
for placeholder in itertools.chain(arg.if_structure.then or [],
arg.if_structure.otherwise or
[]):
cls._check_valid_placeholder_reference(valid_inputs,
valid_outputs,
placeholder)
elif isinstance(arg, ConcatPlaceholder):
for placeholder in arg.items:
cls._check_valid_placeholder_reference(valid_inputs,
valid_outputs,
placeholder)
elif not isinstance(arg, str):
raise TypeError(f'Unexpected argument "{arg}".')
_check_valid_placeholder_reference(valid_inputs, valid_outputs, arg)
@classmethod
def from_v1_component_spec(
@ -405,27 +511,18 @@ class ComponentSpec(BaseModel):
component_dict = v1_component_spec.to_dict()
if component_dict.get('implementation') is None:
raise ValueError('Implementation field not found')
if 'container' not in component_dict.get('implementation'):
if 'container' not in component_dict.get(
'implementation'): # type: ignore
raise NotImplementedError
def _transform_arg(arg: Union[str, Dict[str, str]]) -> ValidCommandArgs:
def convert_v1_if_present_placholder_to_v2(
arg: Dict[str, Any]) -> Union[Dict[str, Any], ValidCommandArgs]:
if isinstance(arg, str):
arg = try_to_get_dict_from_string(arg)
if not isinstance(arg, dict):
return arg
if 'inputValue' in arg:
return InputValuePlaceholder(
input_name=utils.sanitize_input_name(arg['inputValue']))
if 'inputPath' in arg:
return InputPathPlaceholder(
input_name=utils.sanitize_input_name(arg['inputPath']))
if 'inputUri' in arg:
return InputUriPlaceholder(
input_name=utils.sanitize_input_name(arg['inputUri']))
if 'outputPath' in arg:
return OutputPathPlaceholder(
output_name=utils.sanitize_input_name(arg['outputPath']))
if 'outputUri' in arg:
return OutputUriPlaceholder(
output_name=utils.sanitize_input_name(arg['outputUri']))
if 'if' in arg:
if_placeholder_values = arg['if']
if_placeholder_values_then = list(if_placeholder_values['then'])
@ -434,62 +531,53 @@ class ComponentSpec(BaseModel):
if_placeholder_values['else'])
except KeyError:
if_placeholder_values_else = []
IfPresentPlaceholderStructure.update_forward_refs()
return IfPresentPlaceholder(
if_structure=IfPresentPlaceholderStructure(
input_name=utils.sanitize_input_name(
if_placeholder_values['cond']['isPresent']),
then=list(
_transform_arg(val)
for val in if_placeholder_values_then),
otherwise=list(
_transform_arg(val)
for val in if_placeholder_values_else)))
if 'concat' in arg:
ConcatPlaceholder.update_forward_refs()
then=[
convert_str_or_dict_to_placeholder(val)
for val in if_placeholder_values_then
],
otherwise=[
convert_str_or_dict_to_placeholder(val)
for val in if_placeholder_values_else
]))
return ConcatPlaceholder(
concat=list(_transform_arg(val) for val in arg['concat']))
raise ValueError(
f'Unexpected command/argument type: "{arg}" of type "{type(arg)}".'
)
elif 'concat' in arg:
return ConcatPlaceholder(items=[
convert_str_or_dict_to_placeholder(val)
for val in arg['concat']
])
elif isinstance(arg, (ValidCommandArgTypes, dict)):
return arg
else:
raise TypeError(
f'Unexpected argument {arg} of type {type(arg)}.')
implementation = component_dict['implementation']['container']
implementation['command'] = [
_transform_arg(command)
convert_v1_if_present_placholder_to_v2(command)
for command in implementation.pop('command', [])
]
implementation['args'] = [
_transform_arg(command)
convert_v1_if_present_placholder_to_v2(command)
for command in implementation.pop('args', [])
]
implementation['env'] = {
key: _transform_arg(command)
key: convert_v1_if_present_placholder_to_v2(command)
for key, command in implementation.pop('env', {}).items()
}
container_spec = ContainerSpec(image=implementation['image'])
# Workaround for https://github.com/samuelcolvin/pydantic/issues/2079
def _copy_model(obj):
if isinstance(obj, BaseModel):
return obj.copy(deep=True)
return obj
# Must assign these after the constructor call, otherwise it won't work.
if implementation['command']:
container_spec.command = [
_copy_model(cmd) for cmd in implementation['command']
]
container_spec.command = implementation['command']
if implementation['args']:
container_spec.args = [
_copy_model(arg) for arg in implementation['args']
]
container_spec.args = implementation['args']
if implementation['env']:
container_spec.env = {
k: _copy_model(v) for k, v in implementation['env']
}
container_spec.env = implementation['env']
return ComponentSpec(
name=component_dict.get('name', 'name'),
@ -507,69 +595,6 @@ class ComponentSpec(BaseModel):
for spec in component_dict.get('outputs', [])
})
def to_v1_component_spec(self) -> v1_structures.ComponentSpec:
"""Converts to v1 ComponentSpec.
Returns:
Component spec in the form of V1 ComponentSpec.
Needed until downstream accept new ComponentSpec.
"""
def _transform_arg(arg: ValidCommandArgs) -> Any:
if isinstance(arg, str):
return arg
if isinstance(arg, InputValuePlaceholder):
return v1_structures.InputValuePlaceholder(arg.input_name)
if isinstance(arg, InputPathPlaceholder):
return v1_structures.InputPathPlaceholder(arg.input_name)
if isinstance(arg, InputUriPlaceholder):
return v1_structures.InputUriPlaceholder(arg.input_name)
if isinstance(arg, OutputPathPlaceholder):
return v1_structures.OutputPathPlaceholder(arg.output_name)
if isinstance(arg, OutputUriPlaceholder):
return v1_structures.OutputUriPlaceholder(arg.output_name)
if isinstance(arg, IfPresentPlaceholder):
return v1_structures.IfPlaceholder(arg.if_structure)
if isinstance(arg, ConcatPlaceholder):
return v1_structures.ConcatPlaceholder(arg.concat)
raise ValueError(
f'Unexpected command/argument type: "{arg}" of type "{type(arg)}".'
)
return v1_structures.ComponentSpec(
name=self.name,
inputs=[
v1_structures.InputSpec(
name=name,
type=input_spec.type,
default=input_spec.default,
) for name, input_spec in self.inputs.items()
],
outputs=[
v1_structures.OutputSpec(
name=name,
type=output_spec.type,
) for name, output_spec in self.outputs.items()
],
implementation=v1_structures.ContainerImplementation(
container=v1_structures.ContainerSpec(
image=self.implementation.container.image,
command=[
_transform_arg(cmd)
for cmd in self.implementation.container.command or []
],
args=[
_transform_arg(arg)
for arg in self.implementation.container.args or []
],
env={
name: _transform_arg(value) for name, value in
self.implementation.container.env or {}
},
)),
)
@classmethod
def load_from_component_yaml(cls, component_yaml: str) -> 'ComponentSpec':
"""Loads V1 or V2 component yaml into ComponentSpec.
@ -580,19 +605,18 @@ class ComponentSpec(BaseModel):
Returns:
Component spec in the form of V2 ComponentSpec.
"""
json_component = yaml.safe_load(component_yaml)
try:
return ComponentSpec.parse_obj(json_component)
except (pydantic.ValidationError, AttributeError):
return ComponentSpec.from_dict(json_component, by_alias=True)
except AttributeError:
v1_component = v1_components._load_component_spec_from_component_text(
component_yaml)
return cls.from_v1_component_spec(v1_component)
def save_to_component_yaml(self, output_file: str) -> None:
"""Saves ComponentSpec into yaml file.
"""Saves ComponentSpec into YAML file.
Args:
output_file: File path to store the component yaml.
"""
ir_utils._write_ir_to_file(self.dict(), output_file)
ir_utils._write_ir_to_file(self.to_dict(by_alias=True), output_file)

View File

@ -1,4 +1,4 @@
# Copyright 2021 The Kubeflow Authors
# Copyright 2021-2022 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,11 +13,11 @@
# limitations under the License.
"""Tests for kfp.components.structures."""
import os
import tempfile
import textwrap
import unittest
from unittest import mock
import pydantic
from absl.testing import parameterized
from kfp.components import structures
@ -33,7 +33,7 @@ V1_YAML_IF_PLACEHOLDER = textwrap.dedent("""\
- default
then:
- --arg1
- {inputValue: optional_input_1}
- {inputUri: optional_input_1}
image: alpine
inputs:
- {name: optional_input_1, optional: true, type: String}
@ -49,7 +49,7 @@ V2_YAML_IF_PLACEHOLDER = textwrap.dedent("""\
inputName: optional_input_1
then:
- --arg1
- {inputValue: optional_input_1}
- {inputUri: optional_input_1}
image: alpine
inputs:
optional_input_1: {default: null, type: String}
@ -67,7 +67,7 @@ V2_COMPONENT_SPEC_IF_PLACEHOLDER = structures.ComponentSpec(
input_name='optional_input_1',
then=[
'--arg1',
structures.InputValuePlaceholder(
structures.InputUriPlaceholder(
input_name='optional_input_1'),
],
otherwise=[
@ -110,7 +110,7 @@ V2_COMPONENT_SPEC_CONCAT_PLACEHOLDER = structures.ComponentSpec(
container=structures.ContainerSpec(
image='alpine',
args=[
structures.ConcatPlaceholder(concat=[
structures.ConcatPlaceholder(items=[
'--arg1',
structures.InputValuePlaceholder(input_name='input_prefix'),
])
@ -147,7 +147,7 @@ V2_COMPONENT_SPEC_NESTED_PLACEHOLDER = structures.ComponentSpec(
container=structures.ContainerSpec(
image='alpine',
args=[
structures.ConcatPlaceholder(concat=[
structures.ConcatPlaceholder(items=[
'--arg1',
structures.IfPresentPlaceholder(
if_structure=structures.IfPresentPlaceholderStructure(
@ -160,7 +160,7 @@ V2_COMPONENT_SPEC_NESTED_PLACEHOLDER = structures.ComponentSpec(
otherwise=[
'--arg2',
'default',
structures.ConcatPlaceholder(concat=[
structures.ConcatPlaceholder(items=[
'--arg1',
structures.InputValuePlaceholder(
input_name='input_prefix'),
@ -177,8 +177,9 @@ class StructuresTest(parameterized.TestCase):
def test_component_spec_with_placeholder_referencing_nonexisting_input_output(
self):
with self.assertRaisesRegex(
pydantic.ValidationError, 'Argument "input_name=\'input000\'" '
'references non-existing input.'):
ValueError,
r'^Argument \"InputValuePlaceholder[\s\S]*\'input000\'[\s\S]*references non-existing input.'
):
structures.ComponentSpec(
name='component_1',
implementation=structures.Implementation(
@ -199,9 +200,9 @@ class StructuresTest(parameterized.TestCase):
)
with self.assertRaisesRegex(
pydantic.ValidationError,
'Argument "output_name=\'output000\'" '
'references non-existing output.'):
ValueError,
r'^Argument \"OutputPathPlaceholder[\s\S]*\'output000\'[\s\S]*references non-existing output.'
):
structures.ComponentSpec(
name='component_1',
implementation=structures.Implementation(
@ -222,28 +223,10 @@ class StructuresTest(parameterized.TestCase):
)
def test_simple_component_spec_save_to_component_yaml(self):
open_mock = mock.mock_open()
expected_yaml = textwrap.dedent("""\
implementation:
container:
command:
- sh
- -c
- 'set -ex
echo "$0" > "$1"'
- {inputValue: input1}
- {outputPath: output1}
image: alpine
inputs:
input1: {type: String}
name: component_1
outputs:
output1: {type: String}
""")
with mock.patch("builtins.open", open_mock, create=True):
structures.ComponentSpec(
# tests writing old style (less verbose) and reading in new style (more verbose)
with tempfile.TemporaryDirectory() as tempdir:
output_path = os.path.join(tempdir, 'component.yaml')
original_component_spec = structures.ComponentSpec(
name='component_1',
implementation=structures.Implementation(
container=structures.ContainerSpec(
@ -258,64 +241,17 @@ class StructuresTest(parameterized.TestCase):
output_name='output1'),
],
)),
inputs={
'input1': structures.InputSpec(type='String')
},
outputs={
'output1': structures.OutputSpec(type='String')
},
).save_to_component_yaml('test_save_file.yaml')
inputs={'input1': structures.InputSpec(type='String')},
outputs={'output1': structures.OutputSpec(type='String')},
)
original_component_spec.save_to_component_yaml(output_path)
open_mock.assert_called_once_with('test_save_file.yaml', 'w')
# test that it can be read back correctly
with open(output_path, 'r') as f:
new_component_spec = structures.ComponentSpec.load_from_component_yaml(
f.read())
def test_simple_component_spec_save_to_component_yaml(self):
open_mock = mock.mock_open()
expected_yaml = textwrap.dedent("""\
implementation:
container:
command:
- sh
- -c
- 'set -ex
echo "$0" > "$1"'
- {inputValue: input1}
- {outputPath: output1}
image: alpine
inputs:
input1: {type: String}
name: component_1
outputs:
output1: {type: String}
""")
with mock.patch(
"builtins.open", open_mock, create=True), self.assertWarnsRegex(
DeprecationWarning, r"Compiling to JSON is deprecated"):
structures.ComponentSpec(
name='component_1',
implementation=structures.Implementation(
container=structures.ContainerSpec(
image='alpine',
command=[
'sh',
'-c',
'set -ex\necho "$0" > "$1"',
structures.InputValuePlaceholder(
input_name='input1'),
structures.OutputPathPlaceholder(
output_name='output1'),
],
)),
inputs={
'input1': structures.InputSpec(type='String')
},
outputs={
'output1': structures.OutputSpec(type='String')
},
).save_to_component_yaml('test_save_file.json')
open_mock.assert_called_once_with('test_save_file.json', 'w')
self.assertEqual(original_component_spec, new_component_spec)
@parameterized.parameters(
{
@ -333,12 +269,16 @@ class StructuresTest(parameterized.TestCase):
)
def test_component_spec_placeholder_save_to_component_yaml(
self, expected_yaml, component):
open_mock = mock.mock_open()
with tempfile.TemporaryDirectory() as tempdir:
output_path = os.path.join(tempdir, 'component.yaml')
component.save_to_component_yaml(output_path)
with open(output_path, 'r') as f:
contents = f.read()
with mock.patch("builtins.open", open_mock, create=True):
component.save_to_component_yaml('test_save_file.yaml')
open_mock.assert_called_once_with('test_save_file.yaml', 'w')
# test that what was written can be reloaded correctly
new_component_spec = structures.ComponentSpec.load_from_component_yaml(
contents)
self.assertEqual(new_component_spec, component)
def test_simple_component_spec_load_from_v2_component_yaml(self):
component_yaml_v2 = textwrap.dedent("""\
@ -469,9 +409,56 @@ class StructuresTest(parameterized.TestCase):
'output_1': structures.OutputSpec(type='Artifact'),
'output_2': structures.OutputSpec(type='Artifact'),
})
self.assertEqual(generated_spec, expected_spec)
class TestValidators(unittest.TestCase):
def test_IfPresentPlaceholderStructure_otherwise(self):
obj = structures.IfPresentPlaceholderStructure(
then='then', input_name='input_name', otherwise=['something'])
self.assertEqual(obj.otherwise, ['something'])
obj = structures.IfPresentPlaceholderStructure(
then='then', input_name='input_name', otherwise=[])
self.assertEqual(obj.otherwise, None)
def test_ContainerSpec_command_and_args(self):
obj = structures.ContainerSpec(
image='image', command=['command'], args=['args'])
self.assertEqual(obj.command, ['command'])
self.assertEqual(obj.args, ['args'])
obj = structures.ContainerSpec(image='image', command=[], args=[])
self.assertEqual(obj.command, None)
self.assertEqual(obj.args, None)
def test_ContainerSpec_env(self):
obj = structures.ContainerSpec(
image='image',
command=['command'],
args=['args'],
env={'env': 'env'})
self.assertEqual(obj.env, {'env': 'env'})
obj = structures.ContainerSpec(
image='image', command=[], args=[], env={})
self.assertEqual(obj.env, None)
def test_ComponentSpec_inputs(self):
obj = structures.ComponentSpec(
name='name',
implementation=structures.Implementation(container=None),
inputs={})
self.assertEqual(obj.inputs, None)
def test_ComponentSpec_outputs(self):
obj = structures.ComponentSpec(
name='name',
implementation=structures.Implementation(container=None),
outputs={})
self.assertEqual(obj.outputs, None)
if __name__ == '__main__':
unittest.main()

View File

@ -24,7 +24,6 @@ kfp-pipeline-spec>=0.1.14,<0.2.0
kfp-server-api>=2.0.0a0,<3.0.0
kubernetes>=8.0.0,<19
protobuf>=3.13.0,<4
pydantic>=1.8.2,<2
PyYAML>=5.3,<6
requests-toolbelt>=0.8.0,<1

View File

@ -1,5 +1,5 @@
#
# This file is autogenerated by pip-compile
# This file is autogenerated by pip-compile with python 3.7
# To update, run:
#
# pip-compile --no-emit-index-url requirements.in
@ -53,6 +53,10 @@ googleapis-common-protos==1.53.0
# via google-api-core
idna==3.2
# via requests
importlib-metadata==4.11.3
# via
# click
# jsonschema
jsonschema==3.2.0
# via -r requirements.in
kfp-pipeline-spec==0.1.14
@ -70,14 +74,12 @@ protobuf==3.17.3
# google-cloud-storage
# googleapis-common-protos
# kfp-pipeline-spec
pyasn1-modules==0.2.8
# via google-auth
pyasn1==0.4.8
# via
# pyasn1-modules
# rsa
pydantic==1.8.2
# via -r requirements.in
pyasn1-modules==0.2.8
# via google-auth
pyrsistent==0.18.0
# via jsonschema
python-dateutil==2.8.2
@ -88,10 +90,6 @@ pyyaml==5.4.1
# via
# -r requirements.in
# kubernetes
requests-oauthlib==1.3.0
# via kubernetes
requests-toolbelt==0.9.1
# via -r requirements.in
requests==2.26.0
# via
# google-api-core
@ -99,6 +97,10 @@ requests==2.26.0
# kubernetes
# requests-oauthlib
# requests-toolbelt
requests-oauthlib==1.3.0
# via kubernetes
requests-toolbelt==0.9.1
# via -r requirements.in
rsa==4.7.2
# via google-auth
six==1.16.0
@ -106,7 +108,6 @@ six==1.16.0
# absl-py
# fire
# google-auth
# google-cloud-storage
# jsonschema
# kfp-server-api
# kubernetes
@ -120,8 +121,10 @@ termcolor==1.1.0
# via fire
typer==0.4.0
# via -r requirements.in
typing-extensions==3.10.0.2
# via pydantic
typing-extensions==3.10.0.2 ; python_version < "3.9"
# via
# -r requirements.in
# importlib-metadata
uritemplate==3.0.1
# via -r requirements.in
urllib3==1.26.7
@ -135,6 +138,8 @@ wheel==0.37.0
# via strip-hints
wrapt==1.13.1
# via deprecated
zipp==3.8.0
# via importlib-metadata
# The following packages are considered to be unsafe in a requirements file:
# setuptools