feat(sdk): Support Parallelism in ParallelFor in KFP SDK V2 (#8146)
* feat(sdk): add support for ParallelFor parallelism setting * add parallelism value check * Add unit tests * Adding a compiler test * used None as default parallelism input signature, other minor fixes on format and tests * fix import statements * add release.md message * update ParallelFor docstring * fixed docstring comments * removed 'optional' in docstrings
This commit is contained in:
parent
fe66d20b98
commit
51bea09833
|
|
@ -1,6 +1,7 @@
|
|||
# Current Version (Still in Development)
|
||||
|
||||
## Major Features and Improvements
|
||||
* Support parallelism setting in ParallelFor [\#8146](https://github.com/kubeflow/pipelines/pull/8146)
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ CONFIG = {
|
|||
'container_component_with_no_inputs',
|
||||
'two_step_pipeline_containerized',
|
||||
'pipeline_with_multiple_exit_handlers',
|
||||
'pipeline_with_parallelfor_parallelism',
|
||||
],
|
||||
'test_data_dir': 'sdk/python/kfp/compiler/test_data/pipelines',
|
||||
'config': {
|
||||
|
|
|
|||
|
|
@ -620,6 +620,70 @@ implementation:
|
|||
compiler.Compiler().compile(
|
||||
pipeline_func=my_pipeline, package_path='result.yaml')
|
||||
|
||||
def test_compile_parallel_for_with_valid_parallelism(self):
|
||||
|
||||
@dsl.component
|
||||
def producer_op(item: str) -> str:
|
||||
return item
|
||||
|
||||
@dsl.pipeline(name='test-parallel-for-with-parallelism')
|
||||
def my_pipeline(text: bool):
|
||||
with dsl.ParallelFor(items=['a', 'b'], parallelism=2) as item:
|
||||
producer_task = producer_op(item=item)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
output_yaml = os.path.join(tempdir, 'result.yaml')
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=my_pipeline, package_path=output_yaml)
|
||||
with open(output_yaml, 'r') as f:
|
||||
pipeline_spec = yaml.safe_load(f)
|
||||
self.assertEqual(
|
||||
pipeline_spec['root']['dag']['tasks']['for-loop-2']
|
||||
['iteratorPolicy']['parallelismLimit'], 2)
|
||||
|
||||
def test_compile_parallel_for_with_invalid_parallelism(self):
|
||||
|
||||
@dsl.component
|
||||
def producer_op(item: str) -> str:
|
||||
return item
|
||||
|
||||
@dsl.pipeline(name='test-parallel-for-with-parallelism')
|
||||
def my_pipeline(text: bool):
|
||||
with dsl.ParallelFor(items=['a', 'b'], parallelism=-2) as item:
|
||||
producer_task = producer_op(item=item)
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'ParallelFor parallelism must be >= 0.'):
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=my_pipeline, package_path='result.yaml')
|
||||
|
||||
def test_compile_parallel_for_with_zero_parallelism(self):
|
||||
|
||||
@dsl.component
|
||||
def producer_op(item: str) -> str:
|
||||
return item
|
||||
|
||||
@dsl.pipeline(name='test-parallel-for-with-parallelism')
|
||||
def my_pipeline(text: bool):
|
||||
with dsl.ParallelFor(items=['a', 'b'], parallelism=0) as item:
|
||||
producer_task = producer_op(item=item)
|
||||
|
||||
with dsl.ParallelFor(items=['a', 'b']) as item:
|
||||
producer_task = producer_op(item=item)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
output_yaml = os.path.join(tempdir, 'result.yaml')
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=my_pipeline, package_path=output_yaml)
|
||||
with open(output_yaml, 'r') as f:
|
||||
pipeline_spec = yaml.safe_load(f)
|
||||
for_loop_2 = pipeline_spec['root']['dag']['tasks']['for-loop-2']
|
||||
for_loop_4 = pipeline_spec['root']['dag']['tasks']['for-loop-4']
|
||||
with self.assertRaises(KeyError):
|
||||
for_loop_2['iteratorPolicy']
|
||||
with self.assertRaises(KeyError):
|
||||
for_loop_4['iteratorPolicy']
|
||||
|
||||
|
||||
class V2NamespaceAliasTest(unittest.TestCase):
|
||||
"""Test that imports of both modules and objects are aliased (e.g. all
|
||||
|
|
|
|||
|
|
@ -606,6 +606,10 @@ def _update_task_spec_for_loop_group(
|
|||
pipeline_task_spec.parameter_iterator.item_input = (
|
||||
input_parameter_name)
|
||||
|
||||
if (group.parallelism_limit > 0):
|
||||
pipeline_task_spec.iterator_policy.parallelism_limit = (
|
||||
group.parallelism_limit)
|
||||
|
||||
_pop_input_from_task_spec(
|
||||
task_spec=pipeline_task_spec,
|
||||
input_name=pipeline_task_spec.parameter_iterator.item_input)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright 2022 The Kubeflow Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
from kfp import compiler
|
||||
from kfp import dsl
|
||||
from kfp.dsl import component
|
||||
|
||||
|
||||
@component
|
||||
def print_text(msg: str):
|
||||
print(msg)
|
||||
|
||||
|
||||
@dsl.pipeline(name='pipeline-with-loops')
|
||||
def my_pipeline(loop_parameter: List[str]):
|
||||
|
||||
# Loop argument is from a pipeline input
|
||||
with dsl.ParallelFor(items=loop_parameter, parallelism=2) as item:
|
||||
print_text(msg=item)
|
||||
|
||||
with dsl.ParallelFor(items=loop_parameter) as nested_item:
|
||||
print_text(msg=nested_item)
|
||||
|
||||
# Loop argument is a static value known at compile time
|
||||
loop_args = [{'A_a': '1', 'B_b': '2'}, {'A_a': '10', 'B_b': '20'}]
|
||||
with dsl.ParallelFor(items=loop_args, parallelism=0) as item:
|
||||
print_text(msg=item.A_a)
|
||||
print_text(msg=item.B_b)
|
||||
|
||||
nested_loop_args = [{
|
||||
'A_a': '10',
|
||||
'B_b': '20'
|
||||
}, {
|
||||
'A_a': '100',
|
||||
'B_b': '200'
|
||||
}]
|
||||
with dsl.ParallelFor(
|
||||
items=nested_loop_args, parallelism=1) as nested_item:
|
||||
print_text(msg=nested_item.A_a)
|
||||
print_text(msg=nested_item.B_b)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=my_pipeline,
|
||||
package_path=__file__.replace('.py', '.yaml'))
|
||||
|
|
@ -0,0 +1,356 @@
|
|||
components:
|
||||
comp-for-loop-1:
|
||||
dag:
|
||||
tasks:
|
||||
for-loop-2:
|
||||
componentRef:
|
||||
name: comp-for-loop-2
|
||||
inputs:
|
||||
parameters:
|
||||
pipelinechannel--loop_parameter:
|
||||
componentInputParameter: pipelinechannel--loop_parameter
|
||||
parameterIterator:
|
||||
itemInput: pipelinechannel--loop_parameter-loop-item
|
||||
items:
|
||||
inputParameter: pipelinechannel--loop_parameter
|
||||
taskInfo:
|
||||
name: for-loop-2
|
||||
print-text:
|
||||
cachingOptions:
|
||||
enableCache: true
|
||||
componentRef:
|
||||
name: comp-print-text
|
||||
inputs:
|
||||
parameters:
|
||||
msg:
|
||||
componentInputParameter: pipelinechannel--loop_parameter-loop-item
|
||||
taskInfo:
|
||||
name: print-text
|
||||
inputDefinitions:
|
||||
parameters:
|
||||
pipelinechannel--loop_parameter:
|
||||
parameterType: LIST
|
||||
pipelinechannel--loop_parameter-loop-item:
|
||||
parameterType: STRING
|
||||
comp-for-loop-2:
|
||||
dag:
|
||||
tasks:
|
||||
print-text-2:
|
||||
cachingOptions:
|
||||
enableCache: true
|
||||
componentRef:
|
||||
name: comp-print-text-2
|
||||
inputs:
|
||||
parameters:
|
||||
msg:
|
||||
componentInputParameter: pipelinechannel--loop_parameter-loop-item
|
||||
taskInfo:
|
||||
name: print-text-2
|
||||
inputDefinitions:
|
||||
parameters:
|
||||
pipelinechannel--loop_parameter:
|
||||
parameterType: LIST
|
||||
pipelinechannel--loop_parameter-loop-item:
|
||||
parameterType: STRING
|
||||
comp-for-loop-4:
|
||||
dag:
|
||||
tasks:
|
||||
for-loop-6:
|
||||
componentRef:
|
||||
name: comp-for-loop-6
|
||||
iteratorPolicy:
|
||||
parallelismLimit: 1
|
||||
parameterIterator:
|
||||
itemInput: pipelinechannel--loop-item-param-5
|
||||
items:
|
||||
raw: '[{"A_a": "10", "B_b": "20"}, {"A_a": "100", "B_b": "200"}]'
|
||||
taskInfo:
|
||||
name: for-loop-6
|
||||
print-text-3:
|
||||
cachingOptions:
|
||||
enableCache: true
|
||||
componentRef:
|
||||
name: comp-print-text-3
|
||||
inputs:
|
||||
parameters:
|
||||
msg:
|
||||
componentInputParameter: pipelinechannel--loop-item-param-3
|
||||
parameterExpressionSelector: parseJson(string_value)["A_a"]
|
||||
taskInfo:
|
||||
name: print-text-3
|
||||
print-text-4:
|
||||
cachingOptions:
|
||||
enableCache: true
|
||||
componentRef:
|
||||
name: comp-print-text-4
|
||||
inputs:
|
||||
parameters:
|
||||
msg:
|
||||
componentInputParameter: pipelinechannel--loop-item-param-3
|
||||
parameterExpressionSelector: parseJson(string_value)["B_b"]
|
||||
taskInfo:
|
||||
name: print-text-4
|
||||
inputDefinitions:
|
||||
parameters:
|
||||
pipelinechannel--loop-item-param-3:
|
||||
parameterType: STRUCT
|
||||
comp-for-loop-6:
|
||||
dag:
|
||||
tasks:
|
||||
print-text-5:
|
||||
cachingOptions:
|
||||
enableCache: true
|
||||
componentRef:
|
||||
name: comp-print-text-5
|
||||
inputs:
|
||||
parameters:
|
||||
msg:
|
||||
componentInputParameter: pipelinechannel--loop-item-param-5
|
||||
parameterExpressionSelector: parseJson(string_value)["A_a"]
|
||||
taskInfo:
|
||||
name: print-text-5
|
||||
print-text-6:
|
||||
cachingOptions:
|
||||
enableCache: true
|
||||
componentRef:
|
||||
name: comp-print-text-6
|
||||
inputs:
|
||||
parameters:
|
||||
msg:
|
||||
componentInputParameter: pipelinechannel--loop-item-param-5
|
||||
parameterExpressionSelector: parseJson(string_value)["B_b"]
|
||||
taskInfo:
|
||||
name: print-text-6
|
||||
inputDefinitions:
|
||||
parameters:
|
||||
pipelinechannel--loop-item-param-5:
|
||||
parameterType: STRUCT
|
||||
comp-print-text:
|
||||
executorLabel: exec-print-text
|
||||
inputDefinitions:
|
||||
parameters:
|
||||
msg:
|
||||
parameterType: STRING
|
||||
comp-print-text-2:
|
||||
executorLabel: exec-print-text-2
|
||||
inputDefinitions:
|
||||
parameters:
|
||||
msg:
|
||||
parameterType: STRING
|
||||
comp-print-text-3:
|
||||
executorLabel: exec-print-text-3
|
||||
inputDefinitions:
|
||||
parameters:
|
||||
msg:
|
||||
parameterType: STRING
|
||||
comp-print-text-4:
|
||||
executorLabel: exec-print-text-4
|
||||
inputDefinitions:
|
||||
parameters:
|
||||
msg:
|
||||
parameterType: STRING
|
||||
comp-print-text-5:
|
||||
executorLabel: exec-print-text-5
|
||||
inputDefinitions:
|
||||
parameters:
|
||||
msg:
|
||||
parameterType: STRING
|
||||
comp-print-text-6:
|
||||
executorLabel: exec-print-text-6
|
||||
inputDefinitions:
|
||||
parameters:
|
||||
msg:
|
||||
parameterType: STRING
|
||||
deploymentSpec:
|
||||
executors:
|
||||
exec-print-text:
|
||||
container:
|
||||
args:
|
||||
- --executor_input
|
||||
- '{{$}}'
|
||||
- --function_to_execute
|
||||
- print_text
|
||||
command:
|
||||
- sh
|
||||
- -c
|
||||
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
|
||||
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
|
||||
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.2'\
|
||||
\ && \"$0\" \"$@\"\n"
|
||||
- sh
|
||||
- -ec
|
||||
- 'program_path=$(mktemp -d)
|
||||
|
||||
printf "%s" "$0" > "$program_path/ephemeral_component.py"
|
||||
|
||||
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
|
||||
|
||||
'
|
||||
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
|
||||
\ *\n\ndef print_text(msg: str):\n print(msg)\n\n"
|
||||
image: python:3.7
|
||||
exec-print-text-2:
|
||||
container:
|
||||
args:
|
||||
- --executor_input
|
||||
- '{{$}}'
|
||||
- --function_to_execute
|
||||
- print_text
|
||||
command:
|
||||
- sh
|
||||
- -c
|
||||
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
|
||||
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
|
||||
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.2'\
|
||||
\ && \"$0\" \"$@\"\n"
|
||||
- sh
|
||||
- -ec
|
||||
- 'program_path=$(mktemp -d)
|
||||
|
||||
printf "%s" "$0" > "$program_path/ephemeral_component.py"
|
||||
|
||||
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
|
||||
|
||||
'
|
||||
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
|
||||
\ *\n\ndef print_text(msg: str):\n print(msg)\n\n"
|
||||
image: python:3.7
|
||||
exec-print-text-3:
|
||||
container:
|
||||
args:
|
||||
- --executor_input
|
||||
- '{{$}}'
|
||||
- --function_to_execute
|
||||
- print_text
|
||||
command:
|
||||
- sh
|
||||
- -c
|
||||
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
|
||||
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
|
||||
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.2'\
|
||||
\ && \"$0\" \"$@\"\n"
|
||||
- sh
|
||||
- -ec
|
||||
- 'program_path=$(mktemp -d)
|
||||
|
||||
printf "%s" "$0" > "$program_path/ephemeral_component.py"
|
||||
|
||||
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
|
||||
|
||||
'
|
||||
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
|
||||
\ *\n\ndef print_text(msg: str):\n print(msg)\n\n"
|
||||
image: python:3.7
|
||||
exec-print-text-4:
|
||||
container:
|
||||
args:
|
||||
- --executor_input
|
||||
- '{{$}}'
|
||||
- --function_to_execute
|
||||
- print_text
|
||||
command:
|
||||
- sh
|
||||
- -c
|
||||
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
|
||||
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
|
||||
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.2'\
|
||||
\ && \"$0\" \"$@\"\n"
|
||||
- sh
|
||||
- -ec
|
||||
- 'program_path=$(mktemp -d)
|
||||
|
||||
printf "%s" "$0" > "$program_path/ephemeral_component.py"
|
||||
|
||||
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
|
||||
|
||||
'
|
||||
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
|
||||
\ *\n\ndef print_text(msg: str):\n print(msg)\n\n"
|
||||
image: python:3.7
|
||||
exec-print-text-5:
|
||||
container:
|
||||
args:
|
||||
- --executor_input
|
||||
- '{{$}}'
|
||||
- --function_to_execute
|
||||
- print_text
|
||||
command:
|
||||
- sh
|
||||
- -c
|
||||
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
|
||||
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
|
||||
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.2'\
|
||||
\ && \"$0\" \"$@\"\n"
|
||||
- sh
|
||||
- -ec
|
||||
- 'program_path=$(mktemp -d)
|
||||
|
||||
printf "%s" "$0" > "$program_path/ephemeral_component.py"
|
||||
|
||||
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
|
||||
|
||||
'
|
||||
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
|
||||
\ *\n\ndef print_text(msg: str):\n print(msg)\n\n"
|
||||
image: python:3.7
|
||||
exec-print-text-6:
|
||||
container:
|
||||
args:
|
||||
- --executor_input
|
||||
- '{{$}}'
|
||||
- --function_to_execute
|
||||
- print_text
|
||||
command:
|
||||
- sh
|
||||
- -c
|
||||
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
|
||||
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
|
||||
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.2'\
|
||||
\ && \"$0\" \"$@\"\n"
|
||||
- sh
|
||||
- -ec
|
||||
- 'program_path=$(mktemp -d)
|
||||
|
||||
printf "%s" "$0" > "$program_path/ephemeral_component.py"
|
||||
|
||||
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
|
||||
|
||||
'
|
||||
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
|
||||
\ *\n\ndef print_text(msg: str):\n print(msg)\n\n"
|
||||
image: python:3.7
|
||||
pipelineInfo:
|
||||
name: pipeline-with-loops
|
||||
root:
|
||||
dag:
|
||||
tasks:
|
||||
for-loop-1:
|
||||
componentRef:
|
||||
name: comp-for-loop-1
|
||||
inputs:
|
||||
parameters:
|
||||
pipelinechannel--loop_parameter:
|
||||
componentInputParameter: loop_parameter
|
||||
iteratorPolicy:
|
||||
parallelismLimit: 2
|
||||
parameterIterator:
|
||||
itemInput: pipelinechannel--loop_parameter-loop-item
|
||||
items:
|
||||
inputParameter: pipelinechannel--loop_parameter
|
||||
taskInfo:
|
||||
name: for-loop-1
|
||||
for-loop-4:
|
||||
componentRef:
|
||||
name: comp-for-loop-4
|
||||
parameterIterator:
|
||||
itemInput: pipelinechannel--loop-item-param-3
|
||||
items:
|
||||
raw: '[{"A_a": "1", "B_b": "2"}, {"A_a": "10", "B_b": "20"}]'
|
||||
taskInfo:
|
||||
name: for-loop-4
|
||||
inputDefinitions:
|
||||
parameters:
|
||||
loop_parameter:
|
||||
parameterType: LIST
|
||||
schemaVersion: 2.1.0
|
||||
sdkVersion: kfp-2.0.0-beta.2
|
||||
|
|
@ -57,7 +57,7 @@ class TasksGroup:
|
|||
|
||||
Args:
|
||||
group_type: The type of the group.
|
||||
name: Optional; the name of the group. Used as display name in UI.
|
||||
name: The name of the group. Used as display name in UI.
|
||||
"""
|
||||
self.group_type = group_type
|
||||
self.tasks = list()
|
||||
|
|
@ -172,24 +172,36 @@ class ParallelFor(TasksGroup):
|
|||
Args:
|
||||
items: The items to loop over. It can be either a constant Python list or a list output from an upstream task.
|
||||
name: The name of the for loop group.
|
||||
parallelism: The maximum number of concurrent iterations that can be scheduled for execution. A value of 0 represents unconstrained parallelism (default is unconstrained).
|
||||
|
||||
Example:
|
||||
::
|
||||
|
||||
with dsl.ParallelFor([{'a': 1, 'b': 10}, {'a': 2, 'b': 20}]) as item:
|
||||
with dsl.ParallelFor(
|
||||
items=[{'a': 1, 'b': 10}, {'a': 2, 'b': 20}],
|
||||
parallelism=1
|
||||
) as item:
|
||||
task1 = MyComponent(..., item.a)
|
||||
task2 = MyComponent(..., item.b)
|
||||
|
||||
In the example, ``task1`` would be executed twice, once with case
|
||||
``args=['echo 1']`` and once with case ``args=['echo 2']``.
|
||||
In the example, the group of tasks containing ``task1`` and ``task2`` would
|
||||
be executed twice, once with case ``args=[{'a': 1, 'b': 10}]`` and once with
|
||||
case ``args=[{'a': 2, 'b': 20}]``. The ``parallelism=1`` setting causes only
|
||||
1 execution to be scheduled at a time.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
items: Union[for_loop.ItemList, pipeline_channel.PipelineChannel],
|
||||
name: Optional[str] = None,
|
||||
parallelism: Optional[int] = None,
|
||||
):
|
||||
"""Initializes a for loop task group."""
|
||||
parallelism = parallelism or 0
|
||||
if parallelism < 0:
|
||||
raise ValueError(
|
||||
f'ParallelFor parallelism must be >= 0. Got: {parallelism}.')
|
||||
|
||||
super().__init__(
|
||||
group_type=TasksGroupType.FOR_LOOP,
|
||||
name=name,
|
||||
|
|
@ -208,6 +220,8 @@ class ParallelFor(TasksGroup):
|
|||
)
|
||||
self.items_is_pipeline_channel = False
|
||||
|
||||
self.parallelism_limit = parallelism
|
||||
|
||||
def __enter__(self) -> for_loop.LoopArgument:
|
||||
super().__enter__()
|
||||
return self.loop_argument
|
||||
|
|
|
|||
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2022 The Kubeflow Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import parameterized
|
||||
from kfp.components import for_loop
|
||||
from kfp.components import pipeline_context
|
||||
from kfp.components import tasks_group
|
||||
|
||||
|
||||
class ParallelForTest(parameterized.TestCase):
|
||||
|
||||
def test_basic(self):
|
||||
loop_items = ['pizza', 'hotdog', 'pasta']
|
||||
with pipeline_context.Pipeline('pipeline') as p:
|
||||
with tasks_group.ParallelFor(items=loop_items) as parallel_for:
|
||||
loop_argument = for_loop.LoopArgument.from_raw_items(
|
||||
loop_items, '1')
|
||||
self.assertEqual(parallel_for.group_type, 'for-loop')
|
||||
self.assertEqual(parallel_for.parallelism, 0)
|
||||
self.assertEqual(parallel_for.loop_argument, loop_argument)
|
||||
|
||||
def test_parallelfor_valid_parallelism(self):
|
||||
loop_items = ['pizza', 'hotdog', 'pasta']
|
||||
with pipeline_context.Pipeline('pipeline') as p:
|
||||
with tasks_group.ParallelFor(
|
||||
items=loop_items, parallelism=3) as parallel_for:
|
||||
loop_argument = for_loop.LoopArgument.from_raw_items(
|
||||
loop_items, '1')
|
||||
self.assertEqual(parallel_for.group_type, 'for-loop')
|
||||
self.assertEqual(parallel_for.parallelism, 3)
|
||||
self.assertEqual(parallel_for.loop_argument, loop_argument)
|
||||
|
||||
def test_parallelfor_zero_parallelism(self):
|
||||
loop_items = ['pizza', 'hotdog', 'pasta']
|
||||
with pipeline_context.Pipeline('pipeline') as p:
|
||||
with tasks_group.ParallelFor(
|
||||
items=loop_items, parallelism=0) as parallel_for:
|
||||
loop_argument = for_loop.LoopArgument.from_raw_items(
|
||||
loop_items, '1')
|
||||
self.assertEqual(parallel_for.group_type, 'for-loop')
|
||||
self.assertEqual(parallel_for.parallelism, 0)
|
||||
self.assertEqual(parallel_for.loop_argument, loop_argument)
|
||||
|
||||
def test_parallelfor_invalid_parallelism(self):
|
||||
loop_items = ['pizza', 'hotdog', 'pasta']
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'ParallelFor parallelism must be >= 0.'):
|
||||
with pipeline_context.Pipeline('pipeline') as p:
|
||||
tasks_group.ParallelFor(items=loop_items, parallelism=-1)
|
||||
Loading…
Reference in New Issue