feat(sdk): support more than one exit handler per pipeline (#8088)

* add compiler test pipeline with multiple exit handlers

* remove blocker of multiple exit handlers

* move exit handler builder logic to pipeline_spec_builder

* build all exit handlers per pipeline

* add compiler test with IR inspection

* prevent usage of cross-pipeline after

* test cross-pipeline after is prevented

* update existing task dependency logic and tests

* add v2 sample test

* remove cross-pipeline .after

* prevent cross-dag data dependency for dsl features

* add compiler test pipeline with nested exit handlers

* add support for nested exit handlers

* clean up pipeline with nested exit handlers

* remove sample with multiple exit handlers

* remove compiler test with nested exit handlers

* add compilation guard against nested exit handlers in subdag

* update release notes
This commit is contained in:
Connor McCarthy 2022-08-10 10:56:21 -06:00 committed by GitHub
parent e728d0871b
commit bdff332ac6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 887 additions and 130 deletions

View File

@ -1,4 +1,5 @@
import os import os
from kfp import dsl from kfp import dsl
# In tests, we install a KFP package from the PR under test. Users should not # In tests, we install a KFP package from the PR under test. Users should not
@ -32,5 +33,4 @@ def my_pipeline(
generate_task = generate_op() generate_task = generate_op()
with dsl.ParallelFor(generate_task.output) as item: with dsl.ParallelFor(generate_task.output) as item:
concat_task = concat_op(a=item.a, b=item.b) concat_task = concat_op(a=item.a, b=item.b)
concat_task.after(print_task)
print_task_2 = print_op(text=concat_task.output) print_task_2 = print_op(text=concat_task.output)

View File

@ -1,7 +1,8 @@
import os import os
from kfp import dsl
from typing import List from typing import List
from kfp import dsl
# In tests, we install a KFP package from the PR under test. Users should not # In tests, we install a KFP package from the PR under test. Users should not
# normally need to specify `kfp_package_path` in their component definitions. # normally need to specify `kfp_package_path` in their component definitions.
_KFP_PACKAGE_PATH = os.getenv('KFP_PACKAGE_PATH') _KFP_PACKAGE_PATH = os.getenv('KFP_PACKAGE_PATH')
@ -21,12 +22,10 @@ def concat_op(a: str, b: str) -> str:
@dsl.pipeline(name='pipeline-with-loop-static') @dsl.pipeline(name='pipeline-with-loop-static')
def my_pipeline( def my_pipeline(
greeting: str = 'this is a test for looping through parameters', greeting: str = 'this is a test for looping through parameters',):
):
print_task = print_op(text=greeting) print_task = print_op(text=greeting)
static_loop_arguments = [{'a': '1', 'b': '2'}, {'a': '10', 'b': '20'}] static_loop_arguments = [{'a': '1', 'b': '2'}, {'a': '10', 'b': '20'}]
with dsl.ParallelFor(static_loop_arguments) as item: with dsl.ParallelFor(static_loop_arguments) as item:
concat_task = concat_op(a=item.a, b=item.b) concat_task = concat_op(a=item.a, b=item.b)
concat_task.after(print_task) print_task_2 = print_op(text=concat_task.output)
print_task_2 = print_op(text=concat_task.output)

View File

@ -5,7 +5,6 @@
## Breaking Changes ## Breaking Changes
### For Pipeline Authors ### For Pipeline Authors
* Add support for task-level retry policy [\#7867](https://github.com/kubeflow/pipelines/pull/7867)
### For Component Authors ### For Component Authors
@ -14,6 +13,8 @@
## Bug Fixes and Other Changes ## Bug Fixes and Other Changes
* Enable overriding caching options at submission time [\#7912](https://github.com/kubeflow/pipelines/pull/7912) * Enable overriding caching options at submission time [\#7912](https://github.com/kubeflow/pipelines/pull/7912)
* Allow artifact inputs in pipeline definition. [\#8044](https://github.com/kubeflow/pipelines/pull/8044) * Allow artifact inputs in pipeline definition. [\#8044](https://github.com/kubeflow/pipelines/pull/8044)
* Support task-level retry policy [\#7867](https://github.com/kubeflow/pipelines/pull/7867)
* Support multiple exit handlers per pipeline [\#8088](https://github.com/kubeflow/pipelines/pull/8088)
## Documentation Updates ## Documentation Updates

View File

@ -44,6 +44,7 @@ CONFIG = {
'component_with_pip_index_urls', 'component_with_pip_index_urls',
'container_component_with_no_inputs', 'container_component_with_no_inputs',
'two_step_pipeline_containerized', 'two_step_pipeline_containerized',
'pipeline_with_multiple_exit_handlers',
], ],
'test_data_dir': 'sdk/python/kfp/compiler/test_data/pipelines', 'test_data_dir': 'sdk/python/kfp/compiler/test_data/pipelines',
'config': { 'config': {

View File

@ -148,8 +148,6 @@ class Compiler:
if not dsl_pipeline.tasks: if not dsl_pipeline.tasks:
raise ValueError('Task is missing from pipeline.') raise ValueError('Task is missing from pipeline.')
self._validate_exit_handler(dsl_pipeline)
pipeline_inputs = pipeline_meta.inputs or {} pipeline_inputs = pipeline_meta.inputs or {}
# Verify that pipeline_parameters_override contains only input names # Verify that pipeline_parameters_override contains only input names
@ -186,45 +184,6 @@ class Compiler:
return pipeline_spec return pipeline_spec
def _validate_exit_handler(self,
pipeline: pipeline_context.Pipeline) -> None:
"""Makes sure there is only one global exit handler.
This is temporary to be compatible with KFP v1.
Raises:
ValueError if there are more than one exit handler.
"""
def _validate_exit_handler_helper(
group: tasks_group.TasksGroup,
exiting_task_names: List[str],
handler_exists: bool,
) -> None:
if isinstance(group, dsl.ExitHandler):
if handler_exists or len(exiting_task_names) > 1:
raise ValueError(
'Only one global exit_handler is allowed and all ops need to be included.'
)
handler_exists = True
if group.tasks:
exiting_task_names.extend([x.name for x in group.tasks])
for group in group.groups:
_validate_exit_handler_helper(
group=group,
exiting_task_names=exiting_task_names,
handler_exists=handler_exists,
)
_validate_exit_handler_helper(
group=pipeline.groups[0],
exiting_task_names=[],
handler_exists=False,
)
def _create_pipeline_spec( def _create_pipeline_spec(
self, self,
pipeline_args: List[pipeline_channel.PipelineChannel], pipeline_args: List[pipeline_channel.PipelineChannel],
@ -301,49 +260,11 @@ class Compiler:
name_to_for_loop_group=name_to_for_loop_group, name_to_for_loop_group=name_to_for_loop_group,
) )
# TODO: refactor to support multiple exit handler per pipeline. builder.build_exit_handler_groups_recursively(
if pipeline.groups[0].groups: parent_group=root_group,
first_group = pipeline.groups[0].groups[0] pipeline_spec=pipeline_spec,
if isinstance(first_group, dsl.ExitHandler): deployment_config=deployment_config,
exit_task = first_group.exit_task )
exit_task_name = component_utils.sanitize_task_name(
exit_task.name)
exit_handler_group_task_name = component_utils.sanitize_task_name(
first_group.name)
input_parameters_in_current_dag = [
input_name for input_name in
pipeline_spec.root.input_definitions.parameters
]
exit_task_task_spec = builder.build_task_spec_for_exit_task(
task=exit_task,
dependent_task=exit_handler_group_task_name,
pipeline_inputs=pipeline_spec.root.input_definitions,
)
exit_task_component_spec = builder.build_component_spec_for_exit_task(
task=exit_task)
exit_task_container_spec = builder.build_container_spec_for_task(
task=exit_task)
# Add exit task task spec
pipeline_spec.root.dag.tasks[exit_task_name].CopyFrom(
exit_task_task_spec)
# Add exit task component spec if it does not exist.
component_name = exit_task_task_spec.component_ref.name
if component_name not in pipeline_spec.components:
pipeline_spec.components[component_name].CopyFrom(
exit_task_component_spec)
# Add exit task container spec if it does not exist.
executor_label = exit_task_component_spec.executor_label
if executor_label not in deployment_config.executors:
deployment_config.executors[
executor_label].container.CopyFrom(
exit_task_container_spec)
pipeline_spec.deployment_spec.update(
json_format.MessageToDict(deployment_config))
return pipeline_spec return pipeline_spec
@ -705,14 +626,12 @@ class Compiler:
task2=task, task2=task,
) )
# If a task depends on a condition group or a loop group, it # a task cannot depend on a task created in a for loop group since individual PipelineTask variables are reassigned after each loop iteration
# must explicitly dependent on a task inside the group. This
# should not be allowed, because it leads to ambiguous
# expectations for runtime behaviors.
dependent_group = group_name_to_group.get( dependent_group = group_name_to_group.get(
upstream_groups[0], None) upstream_groups[0], None)
if isinstance(dependent_group, if isinstance(dependent_group,
(tasks_group.Condition, tasks_group.ParallelFor)): (tasks_group.ParallelFor, tasks_group.Condition,
tasks_group.ExitHandler)):
raise RuntimeError( raise RuntimeError(
f'Task {task.name} cannot dependent on any task inside' f'Task {task.name} cannot dependent on any task inside'
f' the group: {upstream_groups[0]}.') f' the group: {upstream_groups[0]}.')

View File

@ -440,7 +440,7 @@ class TestCompilePipeline(parameterized.TestCase):
compiler.Compiler().compile( compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path='result.yaml') pipeline_func=my_pipeline, package_path='result.yaml')
def test_invalid_after_dependency(self): def test_invalid_data_dependency_loop(self):
@dsl.component @dsl.component
def producer_op() -> str: def producer_op() -> str:
@ -451,30 +451,7 @@ class TestCompilePipeline(parameterized.TestCase):
pass pass
@dsl.pipeline(name='test-pipeline') @dsl.pipeline(name='test-pipeline')
def my_pipeline(text: str): def my_pipeline(val: bool):
with dsl.Condition(text == 'a'):
producer_task = producer_op()
dummy_op().after(producer_task)
with self.assertRaisesRegex(
RuntimeError,
'Task dummy-op cannot dependent on any task inside the group:'):
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path='result.yaml')
def test_invalid_data_dependency(self):
@dsl.component
def producer_op() -> str:
return 'a'
@dsl.component
def dummy_op(msg: str = ''):
pass
@dsl.pipeline(name='test-pipeline')
def my_pipeline(text: bool):
with dsl.ParallelFor(['a, b']): with dsl.ParallelFor(['a, b']):
producer_task = producer_op() producer_task = producer_op()
@ -483,8 +460,125 @@ class TestCompilePipeline(parameterized.TestCase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
'Task dummy-op cannot dependent on any task inside the group:'): 'Task dummy-op cannot dependent on any task inside the group:'):
with tempfile.TemporaryDirectory() as tmpdir:
package_path = os.path.join(tmpdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=package_path)
def test_valid_data_dependency_loop(self):
@dsl.component
def producer_op() -> str:
return 'a'
@dsl.component
def dummy_op(msg: str = ''):
pass
@dsl.pipeline(name='test-pipeline')
def my_pipeline(val: bool):
with dsl.ParallelFor(['a, b']):
producer_task = producer_op()
dummy_op(msg=producer_task.output)
with tempfile.TemporaryDirectory() as tmpdir:
package_path = os.path.join(tmpdir, 'pipeline.yaml')
compiler.Compiler().compile( compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path='result.yaml') pipeline_func=my_pipeline, package_path=package_path)
def test_invalid_data_dependency_condition(self):
@dsl.component
def producer_op() -> str:
return 'a'
@dsl.component
def dummy_op(msg: str = ''):
pass
@dsl.pipeline(name='test-pipeline')
def my_pipeline(val: bool):
with dsl.Condition(val == False):
producer_task = producer_op()
dummy_op(msg=producer_task.output)
with self.assertRaisesRegex(
RuntimeError,
'Task dummy-op cannot dependent on any task inside the group:'):
with tempfile.TemporaryDirectory() as tmpdir:
package_path = os.path.join(tmpdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=package_path)
def test_valid_data_dependency_condition(self):
@dsl.component
def producer_op() -> str:
return 'a'
@dsl.component
def dummy_op(msg: str = ''):
pass
@dsl.pipeline(name='test-pipeline')
def my_pipeline(val: bool):
with dsl.Condition(val == False):
producer_task = producer_op()
dummy_op(msg=producer_task.output)
with tempfile.TemporaryDirectory() as tmpdir:
package_path = os.path.join(tmpdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=package_path)
def test_invalid_data_dependency_exit_handler(self):
@dsl.component
def producer_op() -> str:
return 'a'
@dsl.component
def dummy_op(msg: str = ''):
pass
@dsl.pipeline(name='test-pipeline')
def my_pipeline(val: bool):
first_producer = producer_op()
with dsl.ExitHandler(first_producer):
producer_task = producer_op()
dummy_op(msg=producer_task.output)
with self.assertRaisesRegex(
RuntimeError,
'Task dummy-op cannot dependent on any task inside the group:'):
with tempfile.TemporaryDirectory() as tmpdir:
package_path = os.path.join(tmpdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=package_path)
def test_valid_data_dependency_exit_handler(self):
@dsl.component
def producer_op() -> str:
return 'a'
@dsl.component
def dummy_op(msg: str = ''):
pass
@dsl.pipeline(name='test-pipeline')
def my_pipeline(val: bool):
first_producer = producer_op()
with dsl.ExitHandler(first_producer):
producer_task = producer_op()
dummy_op(msg=producer_task.output)
with tempfile.TemporaryDirectory() as tmpdir:
package_path = os.path.join(tmpdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=package_path)
def test_use_task_final_status_in_non_exit_op(self): def test_use_task_final_status_in_non_exit_op(self):
@ -527,7 +621,6 @@ implementation:
pipeline_func=my_pipeline, package_path='result.yaml') pipeline_func=my_pipeline, package_path='result.yaml')
# pylint: disable=import-outside-toplevel,unused-import,import-error,redefined-outer-name,reimported
class V2NamespaceAliasTest(unittest.TestCase): class V2NamespaceAliasTest(unittest.TestCase):
"""Test that imports of both modules and objects are aliased (e.g. all """Test that imports of both modules and objects are aliased (e.g. all
import path variants work).""" import path variants work)."""
@ -536,7 +629,7 @@ class V2NamespaceAliasTest(unittest.TestCase):
# the kfp.v2 module is loaded. Due to the way we run tests in CI/CD, we cannot ensure that the kfp.v2 module will first be loaded in these tests, # the kfp.v2 module is loaded. Due to the way we run tests in CI/CD, we cannot ensure that the kfp.v2 module will first be loaded in these tests,
# so we do not test for the DeprecationWarning here. # so we do not test for the DeprecationWarning here.
def test_import_namespace(self): # pylint: disable=no-self-use def test_import_namespace(self):
from kfp import v2 from kfp import v2
@v2.dsl.component @v2.dsl.component
@ -560,7 +653,7 @@ class V2NamespaceAliasTest(unittest.TestCase):
with open(temp_filepath, 'r') as f: with open(temp_filepath, 'r') as f:
yaml.load(f) yaml.load(f)
def test_import_modules(self): # pylint: disable=no-self-use def test_import_modules(self):
from kfp.v2 import compiler from kfp.v2 import compiler
from kfp.v2 import dsl from kfp.v2 import dsl
@ -584,7 +677,7 @@ class V2NamespaceAliasTest(unittest.TestCase):
with open(temp_filepath, 'r') as f: with open(temp_filepath, 'r') as f:
yaml.load(f) yaml.load(f)
def test_import_object(self): # pylint: disable=no-self-use def test_import_object(self):
from kfp.v2.compiler import Compiler from kfp.v2.compiler import Compiler
from kfp.v2.dsl import component from kfp.v2.dsl import component
from kfp.v2.dsl import pipeline from kfp.v2.dsl import pipeline
@ -1125,5 +1218,82 @@ class TestSetRetryCompilation(unittest.TestCase):
self.assertEqual(retry_policy.backoff_max_duration.seconds, 3600) self.assertEqual(retry_policy.backoff_max_duration.seconds, 3600)
from google.protobuf import json_format
class TestMultipleExitHandlerCompilation(unittest.TestCase):
def test_basic(self):
@dsl.component
def print_op(message: str):
print(message)
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
def my_pipeline():
first_exit_task = print_op(message='First exit task.')
with dsl.ExitHandler(first_exit_task):
print_op(message='Inside first exit handler.')
second_exit_task = print_op(message='Second exit task.')
with dsl.ExitHandler(second_exit_task):
print_op(message='Inside second exit handler.')
with tempfile.TemporaryDirectory() as tempdir:
package_path = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=package_path)
pipeline_spec = pipeline_spec_from_file(package_path)
# check that the exit handler dags exist
self.assertEqual(
pipeline_spec.components['comp-exit-handler-1'].dag
.tasks['print-op-2'].inputs.parameters['message'].runtime_value
.constant.string_value, 'Inside first exit handler.')
self.assertEqual(
pipeline_spec.components['comp-exit-handler-2'].dag
.tasks['print-op-4'].inputs.parameters['message'].runtime_value
.constant.string_value, 'Inside second exit handler.')
# check that the exit handler dags are in the root dag
self.assertIn('exit-handler-1', pipeline_spec.root.dag.tasks)
self.assertIn('exit-handler-2', pipeline_spec.root.dag.tasks)
# check that the exit tasks are in the root dag
self.assertIn('print-op', pipeline_spec.root.dag.tasks)
self.assertEqual(
pipeline_spec.root.dag.tasks['print-op'].inputs
.parameters['message'].runtime_value.constant.string_value,
'First exit task.')
self.assertIn('print-op-3', pipeline_spec.root.dag.tasks)
self.assertEqual(
pipeline_spec.root.dag.tasks['print-op-3'].inputs
.parameters['message'].runtime_value.constant.string_value,
'Second exit task.')
def test_nested_unsupported(self):
@dsl.component
def print_op(message: str):
print(message)
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
def my_pipeline():
first_exit_task = print_op(message='First exit task.')
with dsl.ExitHandler(first_exit_task):
print_op(message='Inside first exit handler.')
second_exit_task = print_op(message='Second exit task.')
with dsl.ExitHandler(second_exit_task):
print_op(message='Inside second exit handler.')
with self.assertRaisesRegex(
ValueError,
r'ExitHandler can only be used within the outermost scope of a pipeline function definition\.'
):
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path='output.yaml')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -23,6 +23,7 @@ from kfp import dsl
from kfp.compiler import pipeline_spec_builder as builder from kfp.compiler import pipeline_spec_builder as builder
from kfp.components import for_loop from kfp.components import for_loop
from kfp.components import pipeline_channel from kfp.components import pipeline_channel
from kfp.components import pipeline_context
from kfp.components import pipeline_task from kfp.components import pipeline_task
from kfp.components import placeholders from kfp.components import placeholders
from kfp.components import structures from kfp.components import structures
@ -34,6 +35,13 @@ from kfp.components.types import type_utils
from kfp.pipeline_spec import pipeline_spec_pb2 from kfp.pipeline_spec import pipeline_spec_pb2
GroupOrTaskType = Union[tasks_group.TasksGroup, pipeline_task.PipelineTask] GroupOrTaskType = Union[tasks_group.TasksGroup, pipeline_task.PipelineTask]
# must be defined here to avoid circular imports
group_type_to_dsl_class = {
tasks_group.TasksGroupType.PIPELINE: pipeline_context.Pipeline,
tasks_group.TasksGroupType.CONDITION: tasks_group.Condition,
tasks_group.TasksGroupType.FOR_LOOP: tasks_group.ParallelFor,
tasks_group.TasksGroupType.EXIT_HANDLER: tasks_group.ExitHandler,
}
def _additional_input_name_for_pipeline_channel( def _additional_input_name_for_pipeline_channel(
@ -772,7 +780,7 @@ def build_task_spec_for_exit_task(
pipeline_spec_pb2.PipelineTaskSpec.TriggerPolicy.TriggerStrategy pipeline_spec_pb2.PipelineTaskSpec.TriggerPolicy.TriggerStrategy
.ALL_UPSTREAM_TASKS_COMPLETED) .ALL_UPSTREAM_TASKS_COMPLETED)
for input_name, input_spec in task.component_spec.inputs.items(): for input_name, input_spec in (task.component_spec.inputs or {}).items():
if type_utils.is_task_final_status_type(input_spec.type): if type_utils.is_task_final_status_type(input_spec.type):
pipeline_task_spec.inputs.parameters[ pipeline_task_spec.inputs.parameters[
input_name].task_final_status.producer_task = dependent_task input_name].task_final_status.producer_task = dependent_task
@ -1184,6 +1192,61 @@ def build_spec_by_group(
) )
def build_exit_handler_groups_recursively(
parent_group: tasks_group.TasksGroup,
pipeline_spec: pipeline_spec_pb2.PipelineSpec,
deployment_config: pipeline_spec_pb2.PipelineDeploymentConfig,
):
if not parent_group.groups:
return
for group in parent_group.groups:
if isinstance(group, dsl.ExitHandler):
exit_task = group.exit_task
exit_task_name = utils.sanitize_task_name(exit_task.name)
exit_handler_group_task_name = utils.sanitize_task_name(group.name)
exit_task_task_spec = builder.build_task_spec_for_exit_task(
task=exit_task,
dependent_task=exit_handler_group_task_name,
pipeline_inputs=pipeline_spec.root.input_definitions,
)
exit_task_component_spec = builder.build_component_spec_for_exit_task(
task=exit_task)
exit_task_container_spec = builder.build_container_spec_for_task(
task=exit_task)
# remove this if block to support nested exit handlers
if not parent_group.is_root:
raise ValueError(
f'{dsl.ExitHandler.__name__} can only be used within the outermost scope of a pipeline function definition. Using an {dsl.ExitHandler.__name__} within {group_type_to_dsl_class[parent_group.group_type].__name__} {parent_group.name} is not allowed.'
)
parent_dag = pipeline_spec.root.dag if parent_group.is_root else pipeline_spec.components[
utils.sanitize_component_name(parent_group.name)].dag
parent_dag.tasks[exit_task_name].CopyFrom(exit_task_task_spec)
# Add exit task component spec if it does not exist.
component_name = exit_task_task_spec.component_ref.name
if component_name not in pipeline_spec.components:
pipeline_spec.components[component_name].CopyFrom(
exit_task_component_spec)
# Add exit task container spec if it does not exist.
executor_label = exit_task_component_spec.executor_label
if executor_label not in deployment_config.executors:
deployment_config.executors[executor_label].container.CopyFrom(
exit_task_container_spec)
pipeline_spec.deployment_spec.update(
json_format.MessageToDict(deployment_config))
build_exit_handler_groups_recursively(
parent_group=group,
pipeline_spec=pipeline_spec,
deployment_config=deployment_config)
def get_parent_groups( def get_parent_groups(
root_group: tasks_group.TasksGroup, root_group: tasks_group.TasksGroup,
) -> Tuple[Mapping[str, List[GroupOrTaskType]], Mapping[str, ) -> Tuple[Mapping[str, List[GroupOrTaskType]], Mapping[str,

View File

@ -0,0 +1,59 @@
# Copyright 2022 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pipeline using multiple ExitHandlers."""
from kfp import compiler
from kfp import dsl
from kfp.dsl import component
@component
def print_op(message: str):
"""Prints a message."""
print(message)
@component
def fail_op(message: str):
"""Fails."""
import sys
print(message)
sys.exit(1)
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
def my_pipeline(message: str = 'Hello World!'):
first_exit_task = print_op(message='First exit handler has worked!')
with dsl.ExitHandler(first_exit_task):
first_exit_print_task = print_op(message=message)
print(first_exit_print_task.outputs)
fail_op(message='Task failed.')
second_exit_task = print_op(message='Second exit handler has worked!')
with dsl.ExitHandler(second_exit_task):
print_op(message=message)
third_exit_task = print_op(message='Third exit handler has worked!')
with dsl.ExitHandler(third_exit_task):
print_op(message=message)
if __name__ == '__main__':
compiler.Compiler().compile(
pipeline_func=my_pipeline,
package_path=__file__.replace('.py', '.yaml'))

View File

@ -0,0 +1,387 @@
components:
comp-exit-handler-1:
dag:
tasks:
fail-op:
cachingOptions:
enableCache: true
componentRef:
name: comp-fail-op
inputs:
parameters:
message:
runtimeValue:
constant: Task failed.
taskInfo:
name: fail-op
print-op-2:
cachingOptions:
enableCache: true
componentRef:
name: comp-print-op-2
inputs:
parameters:
message:
componentInputParameter: pipelinechannel--message
taskInfo:
name: print-op-2
inputDefinitions:
parameters:
pipelinechannel--message:
parameterType: STRING
comp-exit-handler-2:
dag:
tasks:
print-op-4:
cachingOptions:
enableCache: true
componentRef:
name: comp-print-op-4
inputs:
parameters:
message:
componentInputParameter: pipelinechannel--message
taskInfo:
name: print-op-4
inputDefinitions:
parameters:
pipelinechannel--message:
parameterType: STRING
comp-exit-handler-3:
dag:
tasks:
print-op-6:
cachingOptions:
enableCache: true
componentRef:
name: comp-print-op-6
inputs:
parameters:
message:
componentInputParameter: pipelinechannel--message
taskInfo:
name: print-op-6
inputDefinitions:
parameters:
pipelinechannel--message:
parameterType: STRING
comp-fail-op:
executorLabel: exec-fail-op
inputDefinitions:
parameters:
message:
parameterType: STRING
comp-print-op:
executorLabel: exec-print-op
inputDefinitions:
parameters:
message:
parameterType: STRING
comp-print-op-2:
executorLabel: exec-print-op-2
inputDefinitions:
parameters:
message:
parameterType: STRING
comp-print-op-3:
executorLabel: exec-print-op-3
inputDefinitions:
parameters:
message:
parameterType: STRING
comp-print-op-4:
executorLabel: exec-print-op-4
inputDefinitions:
parameters:
message:
parameterType: STRING
comp-print-op-5:
executorLabel: exec-print-op-5
inputDefinitions:
parameters:
message:
parameterType: STRING
comp-print-op-6:
executorLabel: exec-print-op-6
inputDefinitions:
parameters:
message:
parameterType: STRING
deploymentSpec:
executors:
exec-fail-op:
container:
args:
- --executor_input
- '{{$}}'
- --function_to_execute
- fail_op
command:
- sh
- -c
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
\ && \"$0\" \"$@\"\n"
- sh
- -ec
- 'program_path=$(mktemp -d)
printf "%s" "$0" > "$program_path/ephemeral_component.py"
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
'
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
\ *\n\ndef fail_op(message: str):\n \"\"\"Fails.\"\"\"\n import sys\n\
\ print(message)\n sys.exit(1)\n\n"
image: python:3.7
exec-print-op:
container:
args:
- --executor_input
- '{{$}}'
- --function_to_execute
- print_op
command:
- sh
- -c
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
\ && \"$0\" \"$@\"\n"
- sh
- -ec
- 'program_path=$(mktemp -d)
printf "%s" "$0" > "$program_path/ephemeral_component.py"
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
'
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
\ *\n\ndef print_op(message: str):\n \"\"\"Prints a message.\"\"\"\n\
\ print(message)\n\n"
image: python:3.7
exec-print-op-2:
container:
args:
- --executor_input
- '{{$}}'
- --function_to_execute
- print_op
command:
- sh
- -c
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
\ && \"$0\" \"$@\"\n"
- sh
- -ec
- 'program_path=$(mktemp -d)
printf "%s" "$0" > "$program_path/ephemeral_component.py"
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
'
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
\ *\n\ndef print_op(message: str):\n \"\"\"Prints a message.\"\"\"\n\
\ print(message)\n\n"
image: python:3.7
exec-print-op-3:
container:
args:
- --executor_input
- '{{$}}'
- --function_to_execute
- print_op
command:
- sh
- -c
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
\ && \"$0\" \"$@\"\n"
- sh
- -ec
- 'program_path=$(mktemp -d)
printf "%s" "$0" > "$program_path/ephemeral_component.py"
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
'
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
\ *\n\ndef print_op(message: str):\n \"\"\"Prints a message.\"\"\"\n\
\ print(message)\n\n"
image: python:3.7
exec-print-op-4:
container:
args:
- --executor_input
- '{{$}}'
- --function_to_execute
- print_op
command:
- sh
- -c
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
\ && \"$0\" \"$@\"\n"
- sh
- -ec
- 'program_path=$(mktemp -d)
printf "%s" "$0" > "$program_path/ephemeral_component.py"
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
'
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
\ *\n\ndef print_op(message: str):\n \"\"\"Prints a message.\"\"\"\n\
\ print(message)\n\n"
image: python:3.7
exec-print-op-5:
container:
args:
- --executor_input
- '{{$}}'
- --function_to_execute
- print_op
command:
- sh
- -c
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
\ && \"$0\" \"$@\"\n"
- sh
- -ec
- 'program_path=$(mktemp -d)
printf "%s" "$0" > "$program_path/ephemeral_component.py"
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
'
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
\ *\n\ndef print_op(message: str):\n \"\"\"Prints a message.\"\"\"\n\
\ print(message)\n\n"
image: python:3.7
exec-print-op-6:
container:
args:
- --executor_input
- '{{$}}'
- --function_to_execute
- print_op
command:
- sh
- -c
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
\ && \"$0\" \"$@\"\n"
- sh
- -ec
- 'program_path=$(mktemp -d)
printf "%s" "$0" > "$program_path/ephemeral_component.py"
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
'
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
\ *\n\ndef print_op(message: str):\n \"\"\"Prints a message.\"\"\"\n\
\ print(message)\n\n"
image: python:3.7
pipelineInfo:
name: pipeline-with-multiple-exit-handlers
root:
dag:
tasks:
exit-handler-1:
componentRef:
name: comp-exit-handler-1
inputs:
parameters:
pipelinechannel--message:
componentInputParameter: message
taskInfo:
name: exit-handler-1
exit-handler-2:
componentRef:
name: comp-exit-handler-2
inputs:
parameters:
pipelinechannel--message:
componentInputParameter: message
taskInfo:
name: exit-handler-2
exit-handler-3:
componentRef:
name: comp-exit-handler-3
inputs:
parameters:
pipelinechannel--message:
componentInputParameter: message
taskInfo:
name: exit-handler-3
print-op:
cachingOptions:
enableCache: true
componentRef:
name: comp-print-op
dependentTasks:
- exit-handler-1
inputs:
parameters:
message:
runtimeValue:
constant: First exit handler has worked!
taskInfo:
name: print-op
triggerPolicy:
strategy: ALL_UPSTREAM_TASKS_COMPLETED
print-op-3:
cachingOptions:
enableCache: true
componentRef:
name: comp-print-op-3
dependentTasks:
- exit-handler-2
inputs:
parameters:
message:
runtimeValue:
constant: Second exit handler has worked!
taskInfo:
name: print-op-3
triggerPolicy:
strategy: ALL_UPSTREAM_TASKS_COMPLETED
print-op-5:
cachingOptions:
enableCache: true
componentRef:
name: comp-print-op-5
dependentTasks:
- exit-handler-3
inputs:
parameters:
message:
runtimeValue:
constant: Third exit handler has worked!
taskInfo:
name: print-op-5
triggerPolicy:
strategy: ALL_UPSTREAM_TASKS_COMPLETED
inputDefinitions:
parameters:
message:
defaultValue: Hello World!
parameterType: STRING
schemaVersion: 2.1.0
sdkVersion: kfp-2.0.0-beta.1

View File

@ -118,7 +118,9 @@ class Pipeline:
# Add the root group. # Add the root group.
self.groups = [ self.groups = [
tasks_group.TasksGroup( tasks_group.TasksGroup(
group_type=tasks_group.TasksGroupType.PIPELINE, name=name) group_type=tasks_group.TasksGroupType.PIPELINE,
name=name,
is_root=True)
] ]
self._group_id = 0 self._group_id = 0
@ -174,6 +176,7 @@ class Pipeline:
self.tasks[task_name] = task self.tasks[task_name] = task
if add_to_group: if add_to_group:
task.parent_task_group = self.groups[-1]
self.groups[-1].tasks.append(task) self.groups[-1].tasks.append(task)
return task_name return task_name

View File

@ -69,6 +69,10 @@ class PipelineTask:
args: Mapping[str, Any], args: Mapping[str, Any],
): ):
"""Initilizes a PipelineTask instance.""" """Initilizes a PipelineTask instance."""
# import within __init__ to avoid circular import
from kfp.components.tasks_group import TasksGroup
self.parent_task_group: Union[None, TasksGroup] = None
args = args or {} args = args or {}
for input_name, argument_value in args.items(): for input_name, argument_value in args.items():
@ -558,5 +562,9 @@ class PipelineTask:
task2 = my_component(text='2nd task').after(task1) task2 = my_component(text='2nd task').after(task1)
""" """
for task in tasks: for task in tasks:
if task.parent_task_group is not self.parent_task_group:
raise ValueError(
f'Cannot use .after() across inner pipelines or DSL control flow features. Tried to set {self.name} after {task.name}, but these tasks do not belong to the same pipeline or are not enclosed in the same control flow content manager.'
)
self._task_spec.dependent_tasks.append(task.name) self._task_spec.dependent_tasks.append(task.name)
return self return self

View File

@ -13,10 +13,14 @@
# limitations under the License. # limitations under the License.
"""Tests for kfp.components.pipeline_task.""" """Tests for kfp.components.pipeline_task."""
import os
import tempfile
import textwrap import textwrap
import unittest import unittest
from absl.testing import parameterized from absl.testing import parameterized
from kfp import compiler
from kfp import dsl
from kfp.components import pipeline_task from kfp.components import pipeline_task
from kfp.components import placeholders from kfp.components import placeholders
from kfp.components import structures from kfp.components import structures
@ -301,5 +305,136 @@ class PipelineTaskTest(parameterized.TestCase):
self.assertEqual('test_name', task._task_spec.display_name) self.assertEqual('test_name', task._task_spec.display_name)
class TestCannotUseAfterCrossDAG(unittest.TestCase):
def test_inner_task_prevented(self):
with self.assertRaisesRegex(ValueError,
r'Cannot use \.after\(\) across'):
@dsl.component
def print_op(message: str):
print(message)
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
def my_pipeline():
first_exit_task = print_op(message='First exit task.')
with dsl.ExitHandler(first_exit_task):
first_print_op = print_op(
message='Inside first exit handler.')
second_exit_task = print_op(message='Second exit task.')
with dsl.ExitHandler(second_exit_task):
print_op(message='Inside second exit handler.').after(
first_print_op)
with tempfile.TemporaryDirectory() as tempdir:
package_path = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=package_path)
def test_exit_handler_task_prevented(self):
with self.assertRaisesRegex(ValueError,
r'Cannot use \.after\(\) across'):
@dsl.component
def print_op(message: str):
print(message)
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
def my_pipeline():
first_exit_task = print_op(message='First exit task.')
with dsl.ExitHandler(first_exit_task):
first_print_op = print_op(
message='Inside first exit handler.')
second_exit_task = print_op(message='Second exit task.')
with dsl.ExitHandler(second_exit_task):
x = print_op(message='Inside second exit handler.')
x.after(first_exit_task)
with tempfile.TemporaryDirectory() as tempdir:
package_path = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=package_path)
def test_within_same_exit_handler_permitted(self):
@dsl.component
def print_op(message: str):
print(message)
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
def my_pipeline():
first_exit_task = print_op(message='First exit task.')
with dsl.ExitHandler(first_exit_task):
first_print_op = print_op(
message='First task inside first exit handler.')
second_print_op = print_op(
message='Second task inside first exit handler.').after(
first_print_op)
second_exit_task = print_op(message='Second exit task.')
with dsl.ExitHandler(second_exit_task):
print_op(message='Inside second exit handler.')
with tempfile.TemporaryDirectory() as tempdir:
package_path = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=package_path)
def test_outside_of_condition_blocked(self):
with self.assertRaisesRegex(ValueError,
r'Cannot use \.after\(\) across'):
@dsl.component
def print_op(message: str):
print(message)
@dsl.component
def return_1() -> int:
return 1
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
def my_pipeline():
return_1_task = return_1()
with dsl.Condition(return_1_task.output == 1):
one = print_op(message='1')
two = print_op(message='2')
three = print_op(message='3').after(one)
with tempfile.TemporaryDirectory() as tempdir:
package_path = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=package_path)
def test_inside_of_condition_permitted(self):
@dsl.component
def print_op(message: str):
print(message)
@dsl.component
def return_1() -> int:
return 1
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
def my_pipeline():
return_1_task = return_1()
with dsl.Condition(return_1_task.output == '1'):
one = print_op(message='1')
two = print_op(message='2').after(one)
three = print_op(message='3')
with tempfile.TemporaryDirectory() as tempdir:
package_path = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=package_path)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -44,12 +44,14 @@ class TasksGroup:
groups: A list of TasksGroups in this group. groups: A list of TasksGroups in this group.
display_name: The optional user given name of the group. display_name: The optional user given name of the group.
dependencies: A list of tasks or groups this group depends on. dependencies: A list of tasks or groups this group depends on.
is_root: If TasksGroup is root group.
""" """
def __init__( def __init__(
self, self,
group_type: TasksGroupType, group_type: TasksGroupType,
name: Optional[str] = None, name: Optional[str] = None,
is_root: bool = False,
): ):
"""Create a new instance of TasksGroup. """Create a new instance of TasksGroup.
@ -62,6 +64,7 @@ class TasksGroup:
self.groups = list() self.groups = list()
self.display_name = name self.display_name = name
self.dependencies = [] self.dependencies = []
self.is_root = is_root
def __enter__(self): def __enter__(self):
if not pipeline_context.Pipeline.get_default_pipeline(): if not pipeline_context.Pipeline.get_default_pipeline():
@ -116,7 +119,11 @@ class ExitHandler(TasksGroup):
name: Optional[str] = None, name: Optional[str] = None,
): ):
"""Initializes a Condition task group.""" """Initializes a Condition task group."""
super().__init__(group_type=TasksGroupType.EXIT_HANDLER, name=name) super().__init__(
group_type=TasksGroupType.EXIT_HANDLER,
name=name,
is_root=False,
)
if exit_task.dependent_tasks: if exit_task.dependent_tasks:
raise ValueError('exit_task cannot depend on any other tasks.') raise ValueError('exit_task cannot depend on any other tasks.')
@ -151,6 +158,7 @@ class Condition(TasksGroup):
self, self,
condition: pipeline_channel.ConditionOperator, condition: pipeline_channel.ConditionOperator,
name: Optional[str] = None, name: Optional[str] = None,
is_root=False,
): ):
"""Initializes a conditional task group.""" """Initializes a conditional task group."""
super().__init__(group_type=TasksGroupType.CONDITION, name=name) super().__init__(group_type=TasksGroupType.CONDITION, name=name)
@ -182,7 +190,11 @@ class ParallelFor(TasksGroup):
name: Optional[str] = None, name: Optional[str] = None,
): ):
"""Initializes a for loop task group.""" """Initializes a for loop task group."""
super().__init__(group_type=TasksGroupType.FOR_LOOP, name=name) super().__init__(
group_type=TasksGroupType.FOR_LOOP,
name=name,
is_root=False,
)
if isinstance(items, pipeline_channel.PipelineChannel): if isinstance(items, pipeline_channel.PipelineChannel):
self.loop_argument = for_loop.LoopArgument.from_pipeline_channel( self.loop_argument = for_loop.LoopArgument.from_pipeline_channel(