chore(sdk.v2): Implement LoopArgument and LoopArgumentVariable v2 (#6755)
* for_loop v2 * release note * address review comments
This commit is contained in:
parent
484c61931a
commit
7f1e34f17d
|
|
@ -16,6 +16,7 @@
|
|||
[\#6731](https://github.com/kubeflow/pipelines/pull/6731)
|
||||
* Try to use `apt-get python3-pip` when pip does not exist in containers used by
|
||||
v2 lightweight components [\#6737](https://github.com/kubeflow/pipelines/pull/6737)
|
||||
* Implement LoopArgument and LoopArgumentVariable v2. [\#6755](https://github.com/kubeflow/pipelines/pull/6755)
|
||||
|
||||
## Documentation Updates
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,16 @@
|
|||
# Copyright 2021 The Kubeflow Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
|
|
@ -18,7 +31,7 @@ class LoopArguments(dsl.PipelineParam):
|
|||
LOOP_ITEM_PARAM_NAME_BASE = 'loop-item-param'
|
||||
# number of characters in the code which is passed to the constructor
|
||||
NUM_CODE_CHARS = 8
|
||||
LEGAL_SUBVAR_NAME_REGEX = re.compile(r'[a-zA-Z_][0-9a-zA-Z_]*')
|
||||
LEGAL_SUBVAR_NAME_REGEX = re.compile(r'^[a-zA-Z_][0-9a-zA-Z_]*$')
|
||||
|
||||
@classmethod
|
||||
def _subvar_name_is_legal(cls, proposed_variable_name: str):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2021 The Kubeflow Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
|
@ -0,0 +1,266 @@
|
|||
# Copyright 2021 The Kubeflow Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Classes and methods that supports argument for ParallelFor."""
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, get_type_hints
|
||||
|
||||
from kfp.v2.components.experimental import pipeline_channel
|
||||
|
||||
ItemList = List[Union[int, float, str, Dict[str, Any]]]
|
||||
|
||||
|
||||
def _get_loop_item_type(type_name: str) -> Optional[str]:
|
||||
"""Extracts the loop item type.
|
||||
|
||||
This method is used for extract the item type from a collection type.
|
||||
For example:
|
||||
|
||||
List[str] -> str
|
||||
typing.List[int] -> int
|
||||
typing.Sequence[str] -> str
|
||||
List -> None
|
||||
str -> None
|
||||
|
||||
Args:
|
||||
type_name: The collection type name, like `List`, Sequence`, etc.
|
||||
|
||||
Returns:
|
||||
The collection item type or None if no match found.
|
||||
"""
|
||||
match = re.match('(typing\.)?(?:\w+)(?:\[(?P<item_type>.+)\])', type_name)
|
||||
if match:
|
||||
return match.group('item_type').lstrip().rstrip()
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _get_subvar_type(type_name: str) -> Optional[str]:
|
||||
"""Extracts the subvar type.
|
||||
|
||||
This method is used for extract the value type from a dictionary type.
|
||||
For example:
|
||||
|
||||
Dict[str, int] -> int
|
||||
typing.Mapping[str, float] -> float
|
||||
|
||||
Args:
|
||||
type_name: The dictionary type.
|
||||
|
||||
Returns:
|
||||
The dictionary value type or None if no match found.
|
||||
"""
|
||||
match = re.match(
|
||||
'(typing\.)?(?:\w+)(?:\[\s*(?:\w+)\s*,\s*(?P<value_type>.+)\])',
|
||||
type_name)
|
||||
if match:
|
||||
return match.group('value_type').lstrip().rstrip()
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class LoopArgument(pipeline_channel.PipelineChannel):
|
||||
"""Represents the argument that are looped over in a ParallelFor loop.
|
||||
|
||||
The class shouldn't be instantiated by the end user, rather it is
|
||||
created automatically by a ParallelFor ops group.
|
||||
|
||||
To create a LoopArgument instance, use one of its factory methods::
|
||||
|
||||
LoopArgument.from_pipeline_channel(...)
|
||||
LoopArgument.from_raw_items(...)
|
||||
|
||||
|
||||
Attributes:
|
||||
items_or_pipeline_channel: The raw items or the PipelineChannel object
|
||||
this LoopArgument is associated to.
|
||||
"""
|
||||
LOOP_ITEM_NAME_BASE = 'loop-item'
|
||||
LOOP_ITEM_PARAM_NAME_BASE = 'loop-item-param'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
items: Union[ItemList, pipeline_channel.PipelineChannel],
|
||||
name_code: Optional[str] = None,
|
||||
name_override: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initializes a LoopArguments object.
|
||||
|
||||
Args:
|
||||
items: List of items to loop over. If a list of dicts then, all
|
||||
dicts must have the same keys and every key must be a legal
|
||||
Python variable name.
|
||||
name_code: A unique code used to identify these loop arguments.
|
||||
Should match the code for the ParallelFor ops_group which created
|
||||
these LoopArguments. This prevents parameter name collisions.
|
||||
name_override: The override name for PipelineChannel.
|
||||
**kwargs: Any other keyword arguments passed down to PipelineChannel.
|
||||
"""
|
||||
if (name_code is None) == (name_override is None):
|
||||
raise ValueError(
|
||||
'Expect one and only one of `name_code` and `name_override` to '
|
||||
'be specified.')
|
||||
|
||||
if name_override is None:
|
||||
super().__init__(name=self._make_name(name_code), **kwargs)
|
||||
else:
|
||||
super().__init__(name=name_override, **kwargs)
|
||||
|
||||
if not isinstance(items,
|
||||
(list, tuple, pipeline_channel.PipelineChannel)):
|
||||
raise TypeError(
|
||||
f'Expected list, tuple, or PipelineChannel, got {items}.')
|
||||
|
||||
if isinstance(items, tuple):
|
||||
items = list(items)
|
||||
|
||||
self.items_or_pipeline_channel = items
|
||||
self._referenced_subvars: Dict[str, LoopArgumentVariable] = {}
|
||||
|
||||
if isinstance(items, list) and isinstance(items[0], dict):
|
||||
subvar_names = set(items[0].keys())
|
||||
# then this block creates loop_arg.variable_a and loop_arg.variable_b
|
||||
for subvar_name in subvar_names:
|
||||
loop_arg_var = LoopArgumentVariable(
|
||||
loop_argument=self,
|
||||
subvar_name=subvar_name,
|
||||
)
|
||||
self._referenced_subvars[subvar_name] = loop_arg_var
|
||||
setattr(self, subvar_name, loop_arg_var)
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
# this is being overridden so that we can access subvariables of the
|
||||
# LoopArgument (i.e.: item.a) without knowing the subvariable names ahead
|
||||
# of time.
|
||||
|
||||
return self._referenced_subvars.setdefault(
|
||||
name, LoopArgumentVariable(
|
||||
loop_argument=self,
|
||||
subvar_name=name,
|
||||
))
|
||||
|
||||
def _make_name(self, code: str):
|
||||
"""Makes a name for this loop argument from a unique code."""
|
||||
return '{}-{}'.format(self.LOOP_ITEM_PARAM_NAME_BASE, code)
|
||||
|
||||
@classmethod
|
||||
def from_pipeline_channel(
|
||||
cls,
|
||||
channel: pipeline_channel.PipelineChannel,
|
||||
) -> 'LoopArgument':
|
||||
"""Creates a LoopArgument object from a PipelineChannel object."""
|
||||
return LoopArgument(
|
||||
items=channel,
|
||||
name_override=channel.name + '-' + cls.LOOP_ITEM_NAME_BASE,
|
||||
task_name=channel.task_name,
|
||||
channel_type=_get_loop_item_type(channel.channel_type),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_raw_items(
|
||||
cls,
|
||||
raw_items: ItemList,
|
||||
name_code: str,
|
||||
) -> 'LoopArgument':
|
||||
"""Creates a LoopArgument object from raw item list."""
|
||||
if len(raw_items) == 0:
|
||||
raise ValueError('Got an empty item list for loop argument.')
|
||||
|
||||
return LoopArgument(
|
||||
items=raw_items,
|
||||
name_code=name_code,
|
||||
channel_type=type(raw_items[0]).__name__,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def name_is_loop_argument(cls, name: str) -> bool:
|
||||
"""Returns True if the given channel name looks like a loop argument.
|
||||
|
||||
Either it came from a withItems loop item or withParams loop
|
||||
item.
|
||||
"""
|
||||
return ('-' + cls.LOOP_ITEM_NAME_BASE) in name \
|
||||
or (cls.LOOP_ITEM_PARAM_NAME_BASE + '-') in name
|
||||
|
||||
|
||||
class LoopArgumentVariable(pipeline_channel.PipelineChannel):
|
||||
"""Represents a subvariable for a loop argument.
|
||||
|
||||
This is used for cases where we're looping over maps, each of which contains
|
||||
several variables. If the user ran:
|
||||
|
||||
with dsl.ParallelFor([{'a': 1, 'b': 2}, {'a': 3, 'b': 4}]) as item:
|
||||
...
|
||||
|
||||
Then there's one LoopArgumentVariable for 'a' and another for 'b'.
|
||||
|
||||
Attributes:
|
||||
loop_argument: The original LoopArgument object this subvariable is
|
||||
attached to.
|
||||
subvar_name: The subvariable name.
|
||||
"""
|
||||
SUBVAR_NAME_DELIMITER = '-subvar-'
|
||||
LEGAL_SUBVAR_NAME_REGEX = re.compile(r'^[a-zA-Z_][0-9a-zA-Z_]*$')
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loop_argument: LoopArgument,
|
||||
subvar_name: str,
|
||||
):
|
||||
"""Initializes a LoopArgumentVariable instance.
|
||||
|
||||
Args:
|
||||
loop_argument: The LoopArgument object this subvariable is based on
|
||||
a subvariable to.
|
||||
subvar_name: The name of this subvariable, which is the name of the
|
||||
dict key that spawned this subvariable.
|
||||
|
||||
Raises:
|
||||
ValueError is subvar name is illegal.
|
||||
"""
|
||||
if not self._subvar_name_is_legal(subvar_name):
|
||||
raise ValueError(
|
||||
f'Tried to create subvariable named {subvar_name}, but that is '
|
||||
'not a legal Python variable name.')
|
||||
|
||||
self.subvar_name = subvar_name
|
||||
self.loop_argument = loop_argument
|
||||
|
||||
super().__init__(
|
||||
name=self._get_name_override(
|
||||
loop_arg_name=loop_argument.name,
|
||||
subvar_name=subvar_name,
|
||||
),
|
||||
task_name=loop_argument.task_name,
|
||||
channel_type=_get_subvar_type(loop_argument.channel_type),
|
||||
)
|
||||
|
||||
def _subvar_name_is_legal(self, proposed_variable_name: str) -> bool:
|
||||
"""Returns True if the subvar name is legal."""
|
||||
return re.match(self.LEGAL_SUBVAR_NAME_REGEX,
|
||||
proposed_variable_name) is not None
|
||||
|
||||
def _get_name_override(self, loop_arg_name: str, subvar_name: str) -> str:
|
||||
"""Gets the name.
|
||||
|
||||
Args:
|
||||
loop_arg_name: the name of the loop argument parameter that this
|
||||
LoopArgumentVariable is attached to.
|
||||
subvar_name: The name of this subvariable.
|
||||
|
||||
Returns:
|
||||
The name of this loop arg variable.
|
||||
"""
|
||||
return f'{loop_arg_name}{self.SUBVAR_NAME_DELIMITER}{subvar_name}'
|
||||
|
|
@ -0,0 +1,203 @@
|
|||
# Copyright 2021 The Kubeflow Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for kfp.v2.dsl.experimental.for_loop."""
|
||||
import unittest
|
||||
|
||||
from absl.testing import parameterized
|
||||
from kfp.v2.components.experimental import pipeline_channel
|
||||
from kfp.v2.dsl.experimental import for_loop
|
||||
|
||||
|
||||
class ForLoopTest(parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(
|
||||
{
|
||||
'collection_type': 'List[int]',
|
||||
'item_type': 'int',
|
||||
},
|
||||
{
|
||||
'collection_type': 'typing.List[str]',
|
||||
'item_type': 'str',
|
||||
},
|
||||
{
|
||||
'collection_type': 'typing.Tuple[ float ]',
|
||||
'item_type': 'float',
|
||||
},
|
||||
{
|
||||
'collection_type': 'typing.Sequence[Dict[str, str]]',
|
||||
'item_type': 'Dict[str, str]',
|
||||
},
|
||||
{
|
||||
'collection_type': 'List',
|
||||
'item_type': None,
|
||||
},
|
||||
)
|
||||
def test_get_loop_item_type(self, collection_type, item_type):
|
||||
self.assertEqual(
|
||||
for_loop._get_loop_item_type(collection_type), item_type)
|
||||
|
||||
@parameterized.parameters(
|
||||
{
|
||||
'dict_type': 'Dict[str, int]',
|
||||
'value_type': 'int',
|
||||
},
|
||||
{
|
||||
'dict_type': 'typing.Mapping[str,float]',
|
||||
'value_type': 'float',
|
||||
},
|
||||
{
|
||||
'dict_type': 'typing.Mapping[str, Dict[str, str] ]',
|
||||
'value_type': 'Dict[str, str]',
|
||||
},
|
||||
{
|
||||
'dict_type': 'dict',
|
||||
'value_type': None,
|
||||
},
|
||||
)
|
||||
def test_get_subvar_type(self, dict_type, value_type):
|
||||
self.assertEqual(for_loop._get_subvar_type(dict_type), value_type)
|
||||
|
||||
@parameterized.parameters(
|
||||
{
|
||||
'channel':
|
||||
pipeline_channel.PipelineParameterChannel(
|
||||
name='param1',
|
||||
channel_type='List[str]',
|
||||
),
|
||||
'expected_serialization_value':
|
||||
'{{channel:task=;name=param1-loop-item;type=str;}}',
|
||||
},
|
||||
{
|
||||
'channel':
|
||||
pipeline_channel.PipelineParameterChannel(
|
||||
name='output1',
|
||||
channel_type='List[Dict[str, str]]',
|
||||
task_name='task1',
|
||||
),
|
||||
'expected_serialization_value':
|
||||
'{{channel:task=task1;name=output1-loop-item;type=Dict[str, str];}}',
|
||||
},
|
||||
)
|
||||
def test_loop_argument_from_pipeline_channel(self, channel,
|
||||
expected_serialization_value):
|
||||
loop_argument = for_loop.LoopArgument.from_pipeline_channel(channel)
|
||||
self.assertEqual(loop_argument.items_or_pipeline_channel, channel)
|
||||
self.assertEqual(str(loop_argument), expected_serialization_value)
|
||||
|
||||
@parameterized.parameters(
|
||||
{
|
||||
'raw_items': ['a', 'b', 'c'],
|
||||
'name_code':
|
||||
'1',
|
||||
'expected_serialization_value':
|
||||
'{{channel:task=;name=loop-item-param-1;type=str;}}',
|
||||
},
|
||||
{
|
||||
'raw_items': [
|
||||
{
|
||||
'A_a': 1
|
||||
},
|
||||
{
|
||||
'A_a': 2
|
||||
},
|
||||
],
|
||||
'name_code':
|
||||
'2',
|
||||
'expected_serialization_value':
|
||||
'{{channel:task=;name=loop-item-param-2;type=dict;}}',
|
||||
},
|
||||
)
|
||||
def test_loop_argument_from_raw_items(self, raw_items, name_code,
|
||||
expected_serialization_value):
|
||||
loop_argument = for_loop.LoopArgument.from_raw_items(
|
||||
raw_items, name_code)
|
||||
self.assertEqual(loop_argument.items_or_pipeline_channel, raw_items)
|
||||
self.assertEqual(str(loop_argument), expected_serialization_value)
|
||||
|
||||
@parameterized.parameters(
|
||||
{
|
||||
'name': 'abc-loop-item',
|
||||
'expected_result': True
|
||||
},
|
||||
{
|
||||
'name': 'abc-loop-item-subvar-a',
|
||||
'expected_result': True
|
||||
},
|
||||
{
|
||||
'name': 'loop-item-param-1',
|
||||
'expected_result': True
|
||||
},
|
||||
{
|
||||
'name': 'loop-item-param-1-subvar-a',
|
||||
'expected_result': True
|
||||
},
|
||||
{
|
||||
'name': 'param1',
|
||||
'expected_result': False
|
||||
},
|
||||
)
|
||||
def test_name_is_loop_argument(self, name, expected_result):
|
||||
self.assertEqual(
|
||||
for_loop.LoopArgument.name_is_loop_argument(name), expected_result)
|
||||
|
||||
@parameterized.parameters(
|
||||
{
|
||||
'subvar_name': 'a',
|
||||
'valid': True
|
||||
},
|
||||
{
|
||||
'subvar_name': 'A_a',
|
||||
'valid': True
|
||||
},
|
||||
{
|
||||
'subvar_name': 'a0',
|
||||
'valid': True
|
||||
},
|
||||
{
|
||||
'subvar_name': 'a-b',
|
||||
'valid': False
|
||||
},
|
||||
{
|
||||
'subvar_name': '0',
|
||||
'valid': False
|
||||
},
|
||||
{
|
||||
'subvar_name': 'a#',
|
||||
'valid': False
|
||||
},
|
||||
)
|
||||
def test_create_loop_argument_varaible(self, subvar_name, valid):
|
||||
loop_argument = for_loop.LoopArgument.from_pipeline_channel(
|
||||
pipeline_channel.PipelineParameterChannel(
|
||||
name='param1',
|
||||
channel_type='List[Dict[str, str]]',
|
||||
))
|
||||
if valid:
|
||||
loop_arg_var = for_loop.LoopArgumentVariable(
|
||||
loop_argument=loop_argument,
|
||||
subvar_name=subvar_name,
|
||||
)
|
||||
self.assertEqual(loop_arg_var.loop_argument, loop_argument)
|
||||
self.assertEqual(loop_arg_var.subvar_name, subvar_name)
|
||||
else:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'Tried to create subvariable'):
|
||||
for_loop.LoopArgumentVariable(
|
||||
loop_argument=loop_argument,
|
||||
subvar_name=subvar_name,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue