feat(sdk): support more than one exit handler per pipeline (#8088)

* add compiler test pipeline with multiple exit handlers

* remove blocker of multiple exit handlers

* move exit handler builder logic to pipeline_spec_builder

* build all exit handlers per pipeline

* add compiler test with IR inspection

* prevent usage of cross-pipeline after

* test cross-pipeline after is prevented

* update existing task dependency logic and tests

* add v2 sample test

* remove cross-pipeline .after

* prevent cross-dag data dependency for dsl features

* add compiler test pipeline with nested exit handlers

* add support for nested exit handlers

* clean up pipeline with nested exit handlers

* remove sample with multiple exit handlers

* remove compiler test with nested exit handlers

* add compilation guard against nested exit handlers in subdag

* update release notes
This commit is contained in:
Connor McCarthy 2022-08-10 10:56:21 -06:00 committed by GitHub
parent e728d0871b
commit bdff332ac6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 887 additions and 130 deletions

View File

@ -1,4 +1,5 @@
import os
from kfp import dsl
# In tests, we install a KFP package from the PR under test. Users should not
@ -32,5 +33,4 @@ def my_pipeline(
generate_task = generate_op()
with dsl.ParallelFor(generate_task.output) as item:
concat_task = concat_op(a=item.a, b=item.b)
concat_task.after(print_task)
print_task_2 = print_op(text=concat_task.output)

View File

@ -1,7 +1,8 @@
import os
from kfp import dsl
from typing import List
from kfp import dsl
# In tests, we install a KFP package from the PR under test. Users should not
# normally need to specify `kfp_package_path` in their component definitions.
_KFP_PACKAGE_PATH = os.getenv('KFP_PACKAGE_PATH')
@ -21,12 +22,10 @@ def concat_op(a: str, b: str) -> str:
@dsl.pipeline(name='pipeline-with-loop-static')
def my_pipeline(
greeting: str = 'this is a test for looping through parameters',
):
greeting: str = 'this is a test for looping through parameters',):
print_task = print_op(text=greeting)
static_loop_arguments = [{'a': '1', 'b': '2'}, {'a': '10', 'b': '20'}]
with dsl.ParallelFor(static_loop_arguments) as item:
concat_task = concat_op(a=item.a, b=item.b)
concat_task.after(print_task)
print_task_2 = print_op(text=concat_task.output)
print_task_2 = print_op(text=concat_task.output)

View File

@ -5,7 +5,6 @@
## Breaking Changes
### For Pipeline Authors
* Add support for task-level retry policy [\#7867](https://github.com/kubeflow/pipelines/pull/7867)
### For Component Authors
@ -14,6 +13,8 @@
## Bug Fixes and Other Changes
* Enable overriding caching options at submission time [\#7912](https://github.com/kubeflow/pipelines/pull/7912)
* Allow artifact inputs in pipeline definition. [\#8044](https://github.com/kubeflow/pipelines/pull/8044)
* Support task-level retry policy [\#7867](https://github.com/kubeflow/pipelines/pull/7867)
* Support multiple exit handlers per pipeline [\#8088](https://github.com/kubeflow/pipelines/pull/8088)
## Documentation Updates

View File

@ -44,6 +44,7 @@ CONFIG = {
'component_with_pip_index_urls',
'container_component_with_no_inputs',
'two_step_pipeline_containerized',
'pipeline_with_multiple_exit_handlers',
],
'test_data_dir': 'sdk/python/kfp/compiler/test_data/pipelines',
'config': {

View File

@ -148,8 +148,6 @@ class Compiler:
if not dsl_pipeline.tasks:
raise ValueError('Task is missing from pipeline.')
self._validate_exit_handler(dsl_pipeline)
pipeline_inputs = pipeline_meta.inputs or {}
# Verify that pipeline_parameters_override contains only input names
@ -186,45 +184,6 @@ class Compiler:
return pipeline_spec
def _validate_exit_handler(self,
pipeline: pipeline_context.Pipeline) -> None:
"""Makes sure there is only one global exit handler.
This is temporary to be compatible with KFP v1.
Raises:
ValueError if there are more than one exit handler.
"""
def _validate_exit_handler_helper(
group: tasks_group.TasksGroup,
exiting_task_names: List[str],
handler_exists: bool,
) -> None:
if isinstance(group, dsl.ExitHandler):
if handler_exists or len(exiting_task_names) > 1:
raise ValueError(
'Only one global exit_handler is allowed and all ops need to be included.'
)
handler_exists = True
if group.tasks:
exiting_task_names.extend([x.name for x in group.tasks])
for group in group.groups:
_validate_exit_handler_helper(
group=group,
exiting_task_names=exiting_task_names,
handler_exists=handler_exists,
)
_validate_exit_handler_helper(
group=pipeline.groups[0],
exiting_task_names=[],
handler_exists=False,
)
def _create_pipeline_spec(
self,
pipeline_args: List[pipeline_channel.PipelineChannel],
@ -301,49 +260,11 @@ class Compiler:
name_to_for_loop_group=name_to_for_loop_group,
)
# TODO: refactor to support multiple exit handler per pipeline.
if pipeline.groups[0].groups:
first_group = pipeline.groups[0].groups[0]
if isinstance(first_group, dsl.ExitHandler):
exit_task = first_group.exit_task
exit_task_name = component_utils.sanitize_task_name(
exit_task.name)
exit_handler_group_task_name = component_utils.sanitize_task_name(
first_group.name)
input_parameters_in_current_dag = [
input_name for input_name in
pipeline_spec.root.input_definitions.parameters
]
exit_task_task_spec = builder.build_task_spec_for_exit_task(
task=exit_task,
dependent_task=exit_handler_group_task_name,
pipeline_inputs=pipeline_spec.root.input_definitions,
)
exit_task_component_spec = builder.build_component_spec_for_exit_task(
task=exit_task)
exit_task_container_spec = builder.build_container_spec_for_task(
task=exit_task)
# Add exit task task spec
pipeline_spec.root.dag.tasks[exit_task_name].CopyFrom(
exit_task_task_spec)
# Add exit task component spec if it does not exist.
component_name = exit_task_task_spec.component_ref.name
if component_name not in pipeline_spec.components:
pipeline_spec.components[component_name].CopyFrom(
exit_task_component_spec)
# Add exit task container spec if it does not exist.
executor_label = exit_task_component_spec.executor_label
if executor_label not in deployment_config.executors:
deployment_config.executors[
executor_label].container.CopyFrom(
exit_task_container_spec)
pipeline_spec.deployment_spec.update(
json_format.MessageToDict(deployment_config))
builder.build_exit_handler_groups_recursively(
parent_group=root_group,
pipeline_spec=pipeline_spec,
deployment_config=deployment_config,
)
return pipeline_spec
@ -705,14 +626,12 @@ class Compiler:
task2=task,
)
# If a task depends on a condition group or a loop group, it
# must explicitly dependent on a task inside the group. This
# should not be allowed, because it leads to ambiguous
# expectations for runtime behaviors.
# a task cannot depend on a task created in a for loop group since individual PipelineTask variables are reassigned after each loop iteration
dependent_group = group_name_to_group.get(
upstream_groups[0], None)
if isinstance(dependent_group,
(tasks_group.Condition, tasks_group.ParallelFor)):
(tasks_group.ParallelFor, tasks_group.Condition,
tasks_group.ExitHandler)):
raise RuntimeError(
f'Task {task.name} cannot dependent on any task inside'
f' the group: {upstream_groups[0]}.')

View File

@ -440,7 +440,7 @@ class TestCompilePipeline(parameterized.TestCase):
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path='result.yaml')
def test_invalid_after_dependency(self):
def test_invalid_data_dependency_loop(self):
@dsl.component
def producer_op() -> str:
@ -451,30 +451,7 @@ class TestCompilePipeline(parameterized.TestCase):
pass
@dsl.pipeline(name='test-pipeline')
def my_pipeline(text: str):
with dsl.Condition(text == 'a'):
producer_task = producer_op()
dummy_op().after(producer_task)
with self.assertRaisesRegex(
RuntimeError,
'Task dummy-op cannot dependent on any task inside the group:'):
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path='result.yaml')
def test_invalid_data_dependency(self):
@dsl.component
def producer_op() -> str:
return 'a'
@dsl.component
def dummy_op(msg: str = ''):
pass
@dsl.pipeline(name='test-pipeline')
def my_pipeline(text: bool):
def my_pipeline(val: bool):
with dsl.ParallelFor(['a, b']):
producer_task = producer_op()
@ -483,8 +460,125 @@ class TestCompilePipeline(parameterized.TestCase):
with self.assertRaisesRegex(
RuntimeError,
'Task dummy-op cannot dependent on any task inside the group:'):
with tempfile.TemporaryDirectory() as tmpdir:
package_path = os.path.join(tmpdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=package_path)
def test_valid_data_dependency_loop(self):
@dsl.component
def producer_op() -> str:
return 'a'
@dsl.component
def dummy_op(msg: str = ''):
pass
@dsl.pipeline(name='test-pipeline')
def my_pipeline(val: bool):
with dsl.ParallelFor(['a, b']):
producer_task = producer_op()
dummy_op(msg=producer_task.output)
with tempfile.TemporaryDirectory() as tmpdir:
package_path = os.path.join(tmpdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path='result.yaml')
pipeline_func=my_pipeline, package_path=package_path)
def test_invalid_data_dependency_condition(self):
@dsl.component
def producer_op() -> str:
return 'a'
@dsl.component
def dummy_op(msg: str = ''):
pass
@dsl.pipeline(name='test-pipeline')
def my_pipeline(val: bool):
with dsl.Condition(val == False):
producer_task = producer_op()
dummy_op(msg=producer_task.output)
with self.assertRaisesRegex(
RuntimeError,
'Task dummy-op cannot dependent on any task inside the group:'):
with tempfile.TemporaryDirectory() as tmpdir:
package_path = os.path.join(tmpdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=package_path)
def test_valid_data_dependency_condition(self):
@dsl.component
def producer_op() -> str:
return 'a'
@dsl.component
def dummy_op(msg: str = ''):
pass
@dsl.pipeline(name='test-pipeline')
def my_pipeline(val: bool):
with dsl.Condition(val == False):
producer_task = producer_op()
dummy_op(msg=producer_task.output)
with tempfile.TemporaryDirectory() as tmpdir:
package_path = os.path.join(tmpdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=package_path)
def test_invalid_data_dependency_exit_handler(self):
@dsl.component
def producer_op() -> str:
return 'a'
@dsl.component
def dummy_op(msg: str = ''):
pass
@dsl.pipeline(name='test-pipeline')
def my_pipeline(val: bool):
first_producer = producer_op()
with dsl.ExitHandler(first_producer):
producer_task = producer_op()
dummy_op(msg=producer_task.output)
with self.assertRaisesRegex(
RuntimeError,
'Task dummy-op cannot dependent on any task inside the group:'):
with tempfile.TemporaryDirectory() as tmpdir:
package_path = os.path.join(tmpdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=package_path)
def test_valid_data_dependency_exit_handler(self):
@dsl.component
def producer_op() -> str:
return 'a'
@dsl.component
def dummy_op(msg: str = ''):
pass
@dsl.pipeline(name='test-pipeline')
def my_pipeline(val: bool):
first_producer = producer_op()
with dsl.ExitHandler(first_producer):
producer_task = producer_op()
dummy_op(msg=producer_task.output)
with tempfile.TemporaryDirectory() as tmpdir:
package_path = os.path.join(tmpdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=package_path)
def test_use_task_final_status_in_non_exit_op(self):
@ -527,7 +621,6 @@ implementation:
pipeline_func=my_pipeline, package_path='result.yaml')
# pylint: disable=import-outside-toplevel,unused-import,import-error,redefined-outer-name,reimported
class V2NamespaceAliasTest(unittest.TestCase):
"""Test that imports of both modules and objects are aliased (e.g. all
import path variants work)."""
@ -536,7 +629,7 @@ class V2NamespaceAliasTest(unittest.TestCase):
# the kfp.v2 module is loaded. Due to the way we run tests in CI/CD, we cannot ensure that the kfp.v2 module will first be loaded in these tests,
# so we do not test for the DeprecationWarning here.
def test_import_namespace(self): # pylint: disable=no-self-use
def test_import_namespace(self):
from kfp import v2
@v2.dsl.component
@ -560,7 +653,7 @@ class V2NamespaceAliasTest(unittest.TestCase):
with open(temp_filepath, 'r') as f:
yaml.load(f)
def test_import_modules(self): # pylint: disable=no-self-use
def test_import_modules(self):
from kfp.v2 import compiler
from kfp.v2 import dsl
@ -584,7 +677,7 @@ class V2NamespaceAliasTest(unittest.TestCase):
with open(temp_filepath, 'r') as f:
yaml.load(f)
def test_import_object(self): # pylint: disable=no-self-use
def test_import_object(self):
from kfp.v2.compiler import Compiler
from kfp.v2.dsl import component
from kfp.v2.dsl import pipeline
@ -1125,5 +1218,82 @@ class TestSetRetryCompilation(unittest.TestCase):
self.assertEqual(retry_policy.backoff_max_duration.seconds, 3600)
from google.protobuf import json_format
class TestMultipleExitHandlerCompilation(unittest.TestCase):
def test_basic(self):
@dsl.component
def print_op(message: str):
print(message)
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
def my_pipeline():
first_exit_task = print_op(message='First exit task.')
with dsl.ExitHandler(first_exit_task):
print_op(message='Inside first exit handler.')
second_exit_task = print_op(message='Second exit task.')
with dsl.ExitHandler(second_exit_task):
print_op(message='Inside second exit handler.')
with tempfile.TemporaryDirectory() as tempdir:
package_path = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=package_path)
pipeline_spec = pipeline_spec_from_file(package_path)
# check that the exit handler dags exist
self.assertEqual(
pipeline_spec.components['comp-exit-handler-1'].dag
.tasks['print-op-2'].inputs.parameters['message'].runtime_value
.constant.string_value, 'Inside first exit handler.')
self.assertEqual(
pipeline_spec.components['comp-exit-handler-2'].dag
.tasks['print-op-4'].inputs.parameters['message'].runtime_value
.constant.string_value, 'Inside second exit handler.')
# check that the exit handler dags are in the root dag
self.assertIn('exit-handler-1', pipeline_spec.root.dag.tasks)
self.assertIn('exit-handler-2', pipeline_spec.root.dag.tasks)
# check that the exit tasks are in the root dag
self.assertIn('print-op', pipeline_spec.root.dag.tasks)
self.assertEqual(
pipeline_spec.root.dag.tasks['print-op'].inputs
.parameters['message'].runtime_value.constant.string_value,
'First exit task.')
self.assertIn('print-op-3', pipeline_spec.root.dag.tasks)
self.assertEqual(
pipeline_spec.root.dag.tasks['print-op-3'].inputs
.parameters['message'].runtime_value.constant.string_value,
'Second exit task.')
def test_nested_unsupported(self):
@dsl.component
def print_op(message: str):
print(message)
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
def my_pipeline():
first_exit_task = print_op(message='First exit task.')
with dsl.ExitHandler(first_exit_task):
print_op(message='Inside first exit handler.')
second_exit_task = print_op(message='Second exit task.')
with dsl.ExitHandler(second_exit_task):
print_op(message='Inside second exit handler.')
with self.assertRaisesRegex(
ValueError,
r'ExitHandler can only be used within the outermost scope of a pipeline function definition\.'
):
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path='output.yaml')
if __name__ == '__main__':
unittest.main()

View File

@ -23,6 +23,7 @@ from kfp import dsl
from kfp.compiler import pipeline_spec_builder as builder
from kfp.components import for_loop
from kfp.components import pipeline_channel
from kfp.components import pipeline_context
from kfp.components import pipeline_task
from kfp.components import placeholders
from kfp.components import structures
@ -34,6 +35,13 @@ from kfp.components.types import type_utils
from kfp.pipeline_spec import pipeline_spec_pb2
GroupOrTaskType = Union[tasks_group.TasksGroup, pipeline_task.PipelineTask]
# must be defined here to avoid circular imports
group_type_to_dsl_class = {
tasks_group.TasksGroupType.PIPELINE: pipeline_context.Pipeline,
tasks_group.TasksGroupType.CONDITION: tasks_group.Condition,
tasks_group.TasksGroupType.FOR_LOOP: tasks_group.ParallelFor,
tasks_group.TasksGroupType.EXIT_HANDLER: tasks_group.ExitHandler,
}
def _additional_input_name_for_pipeline_channel(
@ -772,7 +780,7 @@ def build_task_spec_for_exit_task(
pipeline_spec_pb2.PipelineTaskSpec.TriggerPolicy.TriggerStrategy
.ALL_UPSTREAM_TASKS_COMPLETED)
for input_name, input_spec in task.component_spec.inputs.items():
for input_name, input_spec in (task.component_spec.inputs or {}).items():
if type_utils.is_task_final_status_type(input_spec.type):
pipeline_task_spec.inputs.parameters[
input_name].task_final_status.producer_task = dependent_task
@ -1184,6 +1192,61 @@ def build_spec_by_group(
)
def build_exit_handler_groups_recursively(
parent_group: tasks_group.TasksGroup,
pipeline_spec: pipeline_spec_pb2.PipelineSpec,
deployment_config: pipeline_spec_pb2.PipelineDeploymentConfig,
):
if not parent_group.groups:
return
for group in parent_group.groups:
if isinstance(group, dsl.ExitHandler):
exit_task = group.exit_task
exit_task_name = utils.sanitize_task_name(exit_task.name)
exit_handler_group_task_name = utils.sanitize_task_name(group.name)
exit_task_task_spec = builder.build_task_spec_for_exit_task(
task=exit_task,
dependent_task=exit_handler_group_task_name,
pipeline_inputs=pipeline_spec.root.input_definitions,
)
exit_task_component_spec = builder.build_component_spec_for_exit_task(
task=exit_task)
exit_task_container_spec = builder.build_container_spec_for_task(
task=exit_task)
# remove this if block to support nested exit handlers
if not parent_group.is_root:
raise ValueError(
f'{dsl.ExitHandler.__name__} can only be used within the outermost scope of a pipeline function definition. Using an {dsl.ExitHandler.__name__} within {group_type_to_dsl_class[parent_group.group_type].__name__} {parent_group.name} is not allowed.'
)
parent_dag = pipeline_spec.root.dag if parent_group.is_root else pipeline_spec.components[
utils.sanitize_component_name(parent_group.name)].dag
parent_dag.tasks[exit_task_name].CopyFrom(exit_task_task_spec)
# Add exit task component spec if it does not exist.
component_name = exit_task_task_spec.component_ref.name
if component_name not in pipeline_spec.components:
pipeline_spec.components[component_name].CopyFrom(
exit_task_component_spec)
# Add exit task container spec if it does not exist.
executor_label = exit_task_component_spec.executor_label
if executor_label not in deployment_config.executors:
deployment_config.executors[executor_label].container.CopyFrom(
exit_task_container_spec)
pipeline_spec.deployment_spec.update(
json_format.MessageToDict(deployment_config))
build_exit_handler_groups_recursively(
parent_group=group,
pipeline_spec=pipeline_spec,
deployment_config=deployment_config)
def get_parent_groups(
root_group: tasks_group.TasksGroup,
) -> Tuple[Mapping[str, List[GroupOrTaskType]], Mapping[str,

View File

@ -0,0 +1,59 @@
# Copyright 2022 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pipeline using multiple ExitHandlers."""
from kfp import compiler
from kfp import dsl
from kfp.dsl import component
@component
def print_op(message: str):
"""Prints a message."""
print(message)
@component
def fail_op(message: str):
"""Fails."""
import sys
print(message)
sys.exit(1)
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
def my_pipeline(message: str = 'Hello World!'):
first_exit_task = print_op(message='First exit handler has worked!')
with dsl.ExitHandler(first_exit_task):
first_exit_print_task = print_op(message=message)
print(first_exit_print_task.outputs)
fail_op(message='Task failed.')
second_exit_task = print_op(message='Second exit handler has worked!')
with dsl.ExitHandler(second_exit_task):
print_op(message=message)
third_exit_task = print_op(message='Third exit handler has worked!')
with dsl.ExitHandler(third_exit_task):
print_op(message=message)
if __name__ == '__main__':
compiler.Compiler().compile(
pipeline_func=my_pipeline,
package_path=__file__.replace('.py', '.yaml'))

View File

@ -0,0 +1,387 @@
components:
comp-exit-handler-1:
dag:
tasks:
fail-op:
cachingOptions:
enableCache: true
componentRef:
name: comp-fail-op
inputs:
parameters:
message:
runtimeValue:
constant: Task failed.
taskInfo:
name: fail-op
print-op-2:
cachingOptions:
enableCache: true
componentRef:
name: comp-print-op-2
inputs:
parameters:
message:
componentInputParameter: pipelinechannel--message
taskInfo:
name: print-op-2
inputDefinitions:
parameters:
pipelinechannel--message:
parameterType: STRING
comp-exit-handler-2:
dag:
tasks:
print-op-4:
cachingOptions:
enableCache: true
componentRef:
name: comp-print-op-4
inputs:
parameters:
message:
componentInputParameter: pipelinechannel--message
taskInfo:
name: print-op-4
inputDefinitions:
parameters:
pipelinechannel--message:
parameterType: STRING
comp-exit-handler-3:
dag:
tasks:
print-op-6:
cachingOptions:
enableCache: true
componentRef:
name: comp-print-op-6
inputs:
parameters:
message:
componentInputParameter: pipelinechannel--message
taskInfo:
name: print-op-6
inputDefinitions:
parameters:
pipelinechannel--message:
parameterType: STRING
comp-fail-op:
executorLabel: exec-fail-op
inputDefinitions:
parameters:
message:
parameterType: STRING
comp-print-op:
executorLabel: exec-print-op
inputDefinitions:
parameters:
message:
parameterType: STRING
comp-print-op-2:
executorLabel: exec-print-op-2
inputDefinitions:
parameters:
message:
parameterType: STRING
comp-print-op-3:
executorLabel: exec-print-op-3
inputDefinitions:
parameters:
message:
parameterType: STRING
comp-print-op-4:
executorLabel: exec-print-op-4
inputDefinitions:
parameters:
message:
parameterType: STRING
comp-print-op-5:
executorLabel: exec-print-op-5
inputDefinitions:
parameters:
message:
parameterType: STRING
comp-print-op-6:
executorLabel: exec-print-op-6
inputDefinitions:
parameters:
message:
parameterType: STRING
deploymentSpec:
executors:
exec-fail-op:
container:
args:
- --executor_input
- '{{$}}'
- --function_to_execute
- fail_op
command:
- sh
- -c
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
\ && \"$0\" \"$@\"\n"
- sh
- -ec
- 'program_path=$(mktemp -d)
printf "%s" "$0" > "$program_path/ephemeral_component.py"
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
'
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
\ *\n\ndef fail_op(message: str):\n \"\"\"Fails.\"\"\"\n import sys\n\
\ print(message)\n sys.exit(1)\n\n"
image: python:3.7
exec-print-op:
container:
args:
- --executor_input
- '{{$}}'
- --function_to_execute
- print_op
command:
- sh
- -c
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
\ && \"$0\" \"$@\"\n"
- sh
- -ec
- 'program_path=$(mktemp -d)
printf "%s" "$0" > "$program_path/ephemeral_component.py"
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
'
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
\ *\n\ndef print_op(message: str):\n \"\"\"Prints a message.\"\"\"\n\
\ print(message)\n\n"
image: python:3.7
exec-print-op-2:
container:
args:
- --executor_input
- '{{$}}'
- --function_to_execute
- print_op
command:
- sh
- -c
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
\ && \"$0\" \"$@\"\n"
- sh
- -ec
- 'program_path=$(mktemp -d)
printf "%s" "$0" > "$program_path/ephemeral_component.py"
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
'
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
\ *\n\ndef print_op(message: str):\n \"\"\"Prints a message.\"\"\"\n\
\ print(message)\n\n"
image: python:3.7
exec-print-op-3:
container:
args:
- --executor_input
- '{{$}}'
- --function_to_execute
- print_op
command:
- sh
- -c
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
\ && \"$0\" \"$@\"\n"
- sh
- -ec
- 'program_path=$(mktemp -d)
printf "%s" "$0" > "$program_path/ephemeral_component.py"
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
'
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
\ *\n\ndef print_op(message: str):\n \"\"\"Prints a message.\"\"\"\n\
\ print(message)\n\n"
image: python:3.7
exec-print-op-4:
container:
args:
- --executor_input
- '{{$}}'
- --function_to_execute
- print_op
command:
- sh
- -c
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
\ && \"$0\" \"$@\"\n"
- sh
- -ec
- 'program_path=$(mktemp -d)
printf "%s" "$0" > "$program_path/ephemeral_component.py"
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
'
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
\ *\n\ndef print_op(message: str):\n \"\"\"Prints a message.\"\"\"\n\
\ print(message)\n\n"
image: python:3.7
exec-print-op-5:
container:
args:
- --executor_input
- '{{$}}'
- --function_to_execute
- print_op
command:
- sh
- -c
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
\ && \"$0\" \"$@\"\n"
- sh
- -ec
- 'program_path=$(mktemp -d)
printf "%s" "$0" > "$program_path/ephemeral_component.py"
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
'
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
\ *\n\ndef print_op(message: str):\n \"\"\"Prints a message.\"\"\"\n\
\ print(message)\n\n"
image: python:3.7
exec-print-op-6:
container:
args:
- --executor_input
- '{{$}}'
- --function_to_execute
- print_op
command:
- sh
- -c
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
\ && \"$0\" \"$@\"\n"
- sh
- -ec
- 'program_path=$(mktemp -d)
printf "%s" "$0" > "$program_path/ephemeral_component.py"
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
'
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
\ *\n\ndef print_op(message: str):\n \"\"\"Prints a message.\"\"\"\n\
\ print(message)\n\n"
image: python:3.7
pipelineInfo:
name: pipeline-with-multiple-exit-handlers
root:
dag:
tasks:
exit-handler-1:
componentRef:
name: comp-exit-handler-1
inputs:
parameters:
pipelinechannel--message:
componentInputParameter: message
taskInfo:
name: exit-handler-1
exit-handler-2:
componentRef:
name: comp-exit-handler-2
inputs:
parameters:
pipelinechannel--message:
componentInputParameter: message
taskInfo:
name: exit-handler-2
exit-handler-3:
componentRef:
name: comp-exit-handler-3
inputs:
parameters:
pipelinechannel--message:
componentInputParameter: message
taskInfo:
name: exit-handler-3
print-op:
cachingOptions:
enableCache: true
componentRef:
name: comp-print-op
dependentTasks:
- exit-handler-1
inputs:
parameters:
message:
runtimeValue:
constant: First exit handler has worked!
taskInfo:
name: print-op
triggerPolicy:
strategy: ALL_UPSTREAM_TASKS_COMPLETED
print-op-3:
cachingOptions:
enableCache: true
componentRef:
name: comp-print-op-3
dependentTasks:
- exit-handler-2
inputs:
parameters:
message:
runtimeValue:
constant: Second exit handler has worked!
taskInfo:
name: print-op-3
triggerPolicy:
strategy: ALL_UPSTREAM_TASKS_COMPLETED
print-op-5:
cachingOptions:
enableCache: true
componentRef:
name: comp-print-op-5
dependentTasks:
- exit-handler-3
inputs:
parameters:
message:
runtimeValue:
constant: Third exit handler has worked!
taskInfo:
name: print-op-5
triggerPolicy:
strategy: ALL_UPSTREAM_TASKS_COMPLETED
inputDefinitions:
parameters:
message:
defaultValue: Hello World!
parameterType: STRING
schemaVersion: 2.1.0
sdkVersion: kfp-2.0.0-beta.1

View File

@ -118,7 +118,9 @@ class Pipeline:
# Add the root group.
self.groups = [
tasks_group.TasksGroup(
group_type=tasks_group.TasksGroupType.PIPELINE, name=name)
group_type=tasks_group.TasksGroupType.PIPELINE,
name=name,
is_root=True)
]
self._group_id = 0
@ -174,6 +176,7 @@ class Pipeline:
self.tasks[task_name] = task
if add_to_group:
task.parent_task_group = self.groups[-1]
self.groups[-1].tasks.append(task)
return task_name

View File

@ -69,6 +69,10 @@ class PipelineTask:
args: Mapping[str, Any],
):
"""Initilizes a PipelineTask instance."""
# import within __init__ to avoid circular import
from kfp.components.tasks_group import TasksGroup
self.parent_task_group: Union[None, TasksGroup] = None
args = args or {}
for input_name, argument_value in args.items():
@ -558,5 +562,9 @@ class PipelineTask:
task2 = my_component(text='2nd task').after(task1)
"""
for task in tasks:
if task.parent_task_group is not self.parent_task_group:
raise ValueError(
f'Cannot use .after() across inner pipelines or DSL control flow features. Tried to set {self.name} after {task.name}, but these tasks do not belong to the same pipeline or are not enclosed in the same control flow content manager.'
)
self._task_spec.dependent_tasks.append(task.name)
return self

View File

@ -13,10 +13,14 @@
# limitations under the License.
"""Tests for kfp.components.pipeline_task."""
import os
import tempfile
import textwrap
import unittest
from absl.testing import parameterized
from kfp import compiler
from kfp import dsl
from kfp.components import pipeline_task
from kfp.components import placeholders
from kfp.components import structures
@ -301,5 +305,136 @@ class PipelineTaskTest(parameterized.TestCase):
self.assertEqual('test_name', task._task_spec.display_name)
class TestCannotUseAfterCrossDAG(unittest.TestCase):
def test_inner_task_prevented(self):
with self.assertRaisesRegex(ValueError,
r'Cannot use \.after\(\) across'):
@dsl.component
def print_op(message: str):
print(message)
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
def my_pipeline():
first_exit_task = print_op(message='First exit task.')
with dsl.ExitHandler(first_exit_task):
first_print_op = print_op(
message='Inside first exit handler.')
second_exit_task = print_op(message='Second exit task.')
with dsl.ExitHandler(second_exit_task):
print_op(message='Inside second exit handler.').after(
first_print_op)
with tempfile.TemporaryDirectory() as tempdir:
package_path = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=package_path)
def test_exit_handler_task_prevented(self):
with self.assertRaisesRegex(ValueError,
r'Cannot use \.after\(\) across'):
@dsl.component
def print_op(message: str):
print(message)
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
def my_pipeline():
first_exit_task = print_op(message='First exit task.')
with dsl.ExitHandler(first_exit_task):
first_print_op = print_op(
message='Inside first exit handler.')
second_exit_task = print_op(message='Second exit task.')
with dsl.ExitHandler(second_exit_task):
x = print_op(message='Inside second exit handler.')
x.after(first_exit_task)
with tempfile.TemporaryDirectory() as tempdir:
package_path = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=package_path)
def test_within_same_exit_handler_permitted(self):
@dsl.component
def print_op(message: str):
print(message)
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
def my_pipeline():
first_exit_task = print_op(message='First exit task.')
with dsl.ExitHandler(first_exit_task):
first_print_op = print_op(
message='First task inside first exit handler.')
second_print_op = print_op(
message='Second task inside first exit handler.').after(
first_print_op)
second_exit_task = print_op(message='Second exit task.')
with dsl.ExitHandler(second_exit_task):
print_op(message='Inside second exit handler.')
with tempfile.TemporaryDirectory() as tempdir:
package_path = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=package_path)
def test_outside_of_condition_blocked(self):
with self.assertRaisesRegex(ValueError,
r'Cannot use \.after\(\) across'):
@dsl.component
def print_op(message: str):
print(message)
@dsl.component
def return_1() -> int:
return 1
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
def my_pipeline():
return_1_task = return_1()
with dsl.Condition(return_1_task.output == 1):
one = print_op(message='1')
two = print_op(message='2')
three = print_op(message='3').after(one)
with tempfile.TemporaryDirectory() as tempdir:
package_path = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=package_path)
def test_inside_of_condition_permitted(self):
@dsl.component
def print_op(message: str):
print(message)
@dsl.component
def return_1() -> int:
return 1
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
def my_pipeline():
return_1_task = return_1()
with dsl.Condition(return_1_task.output == '1'):
one = print_op(message='1')
two = print_op(message='2').after(one)
three = print_op(message='3')
with tempfile.TemporaryDirectory() as tempdir:
package_path = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=package_path)
if __name__ == '__main__':
unittest.main()

View File

@ -44,12 +44,14 @@ class TasksGroup:
groups: A list of TasksGroups in this group.
display_name: The optional user given name of the group.
dependencies: A list of tasks or groups this group depends on.
is_root: If TasksGroup is root group.
"""
def __init__(
self,
group_type: TasksGroupType,
name: Optional[str] = None,
is_root: bool = False,
):
"""Create a new instance of TasksGroup.
@ -62,6 +64,7 @@ class TasksGroup:
self.groups = list()
self.display_name = name
self.dependencies = []
self.is_root = is_root
def __enter__(self):
if not pipeline_context.Pipeline.get_default_pipeline():
@ -116,7 +119,11 @@ class ExitHandler(TasksGroup):
name: Optional[str] = None,
):
"""Initializes a Condition task group."""
super().__init__(group_type=TasksGroupType.EXIT_HANDLER, name=name)
super().__init__(
group_type=TasksGroupType.EXIT_HANDLER,
name=name,
is_root=False,
)
if exit_task.dependent_tasks:
raise ValueError('exit_task cannot depend on any other tasks.')
@ -151,6 +158,7 @@ class Condition(TasksGroup):
self,
condition: pipeline_channel.ConditionOperator,
name: Optional[str] = None,
is_root=False,
):
"""Initializes a conditional task group."""
super().__init__(group_type=TasksGroupType.CONDITION, name=name)
@ -182,7 +190,11 @@ class ParallelFor(TasksGroup):
name: Optional[str] = None,
):
"""Initializes a for loop task group."""
super().__init__(group_type=TasksGroupType.FOR_LOOP, name=name)
super().__init__(
group_type=TasksGroupType.FOR_LOOP,
name=name,
is_root=False,
)
if isinstance(items, pipeline_channel.PipelineChannel):
self.loop_argument = for_loop.LoopArgument.from_pipeline_channel(