feat(sdk): support more than one exit handler per pipeline (#8088)
* add compiler test pipeline with multiple exit handlers * remove blocker of multiple exit handlers * move exit handler builder logic to pipeline_spec_builder * build all exit handlers per pipeline * add compiler test with IR inspection * prevent usage of cross-pipeline after * test cross-pipeline after is prevented * update existing task dependency logic and tests * add v2 sample test * remove cross-pipeline .after * prevent cross-dag data dependency for dsl features * add compiler test pipeline with nested exit handlers * add support for nested exit handlers * clean up pipeline with nested exit handlers * remove sample with multiple exit handlers * remove compiler test with nested exit handlers * add compilation guard against nested exit handlers in subdag * update release notes
This commit is contained in:
parent
e728d0871b
commit
bdff332ac6
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
|
||||
from kfp import dsl
|
||||
|
||||
# In tests, we install a KFP package from the PR under test. Users should not
|
||||
|
@ -32,5 +33,4 @@ def my_pipeline(
|
|||
generate_task = generate_op()
|
||||
with dsl.ParallelFor(generate_task.output) as item:
|
||||
concat_task = concat_op(a=item.a, b=item.b)
|
||||
concat_task.after(print_task)
|
||||
print_task_2 = print_op(text=concat_task.output)
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import os
|
||||
from kfp import dsl
|
||||
from typing import List
|
||||
|
||||
from kfp import dsl
|
||||
|
||||
# In tests, we install a KFP package from the PR under test. Users should not
|
||||
# normally need to specify `kfp_package_path` in their component definitions.
|
||||
_KFP_PACKAGE_PATH = os.getenv('KFP_PACKAGE_PATH')
|
||||
|
@ -21,12 +22,10 @@ def concat_op(a: str, b: str) -> str:
|
|||
|
||||
@dsl.pipeline(name='pipeline-with-loop-static')
|
||||
def my_pipeline(
|
||||
greeting: str = 'this is a test for looping through parameters',
|
||||
):
|
||||
greeting: str = 'this is a test for looping through parameters',):
|
||||
print_task = print_op(text=greeting)
|
||||
static_loop_arguments = [{'a': '1', 'b': '2'}, {'a': '10', 'b': '20'}]
|
||||
|
||||
with dsl.ParallelFor(static_loop_arguments) as item:
|
||||
concat_task = concat_op(a=item.a, b=item.b)
|
||||
concat_task.after(print_task)
|
||||
print_task_2 = print_op(text=concat_task.output)
|
||||
print_task_2 = print_op(text=concat_task.output)
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
## Breaking Changes
|
||||
|
||||
### For Pipeline Authors
|
||||
* Add support for task-level retry policy [\#7867](https://github.com/kubeflow/pipelines/pull/7867)
|
||||
|
||||
### For Component Authors
|
||||
|
||||
|
@ -14,6 +13,8 @@
|
|||
## Bug Fixes and Other Changes
|
||||
* Enable overriding caching options at submission time [\#7912](https://github.com/kubeflow/pipelines/pull/7912)
|
||||
* Allow artifact inputs in pipeline definition. [\#8044](https://github.com/kubeflow/pipelines/pull/8044)
|
||||
* Support task-level retry policy [\#7867](https://github.com/kubeflow/pipelines/pull/7867)
|
||||
* Support multiple exit handlers per pipeline [\#8088](https://github.com/kubeflow/pipelines/pull/8088)
|
||||
|
||||
## Documentation Updates
|
||||
|
||||
|
|
|
@ -44,6 +44,7 @@ CONFIG = {
|
|||
'component_with_pip_index_urls',
|
||||
'container_component_with_no_inputs',
|
||||
'two_step_pipeline_containerized',
|
||||
'pipeline_with_multiple_exit_handlers',
|
||||
],
|
||||
'test_data_dir': 'sdk/python/kfp/compiler/test_data/pipelines',
|
||||
'config': {
|
||||
|
|
|
@ -148,8 +148,6 @@ class Compiler:
|
|||
if not dsl_pipeline.tasks:
|
||||
raise ValueError('Task is missing from pipeline.')
|
||||
|
||||
self._validate_exit_handler(dsl_pipeline)
|
||||
|
||||
pipeline_inputs = pipeline_meta.inputs or {}
|
||||
|
||||
# Verify that pipeline_parameters_override contains only input names
|
||||
|
@ -186,45 +184,6 @@ class Compiler:
|
|||
|
||||
return pipeline_spec
|
||||
|
||||
def _validate_exit_handler(self,
|
||||
pipeline: pipeline_context.Pipeline) -> None:
|
||||
"""Makes sure there is only one global exit handler.
|
||||
|
||||
This is temporary to be compatible with KFP v1.
|
||||
|
||||
Raises:
|
||||
ValueError if there are more than one exit handler.
|
||||
"""
|
||||
|
||||
def _validate_exit_handler_helper(
|
||||
group: tasks_group.TasksGroup,
|
||||
exiting_task_names: List[str],
|
||||
handler_exists: bool,
|
||||
) -> None:
|
||||
|
||||
if isinstance(group, dsl.ExitHandler):
|
||||
if handler_exists or len(exiting_task_names) > 1:
|
||||
raise ValueError(
|
||||
'Only one global exit_handler is allowed and all ops need to be included.'
|
||||
)
|
||||
handler_exists = True
|
||||
|
||||
if group.tasks:
|
||||
exiting_task_names.extend([x.name for x in group.tasks])
|
||||
|
||||
for group in group.groups:
|
||||
_validate_exit_handler_helper(
|
||||
group=group,
|
||||
exiting_task_names=exiting_task_names,
|
||||
handler_exists=handler_exists,
|
||||
)
|
||||
|
||||
_validate_exit_handler_helper(
|
||||
group=pipeline.groups[0],
|
||||
exiting_task_names=[],
|
||||
handler_exists=False,
|
||||
)
|
||||
|
||||
def _create_pipeline_spec(
|
||||
self,
|
||||
pipeline_args: List[pipeline_channel.PipelineChannel],
|
||||
|
@ -301,49 +260,11 @@ class Compiler:
|
|||
name_to_for_loop_group=name_to_for_loop_group,
|
||||
)
|
||||
|
||||
# TODO: refactor to support multiple exit handler per pipeline.
|
||||
if pipeline.groups[0].groups:
|
||||
first_group = pipeline.groups[0].groups[0]
|
||||
if isinstance(first_group, dsl.ExitHandler):
|
||||
exit_task = first_group.exit_task
|
||||
exit_task_name = component_utils.sanitize_task_name(
|
||||
exit_task.name)
|
||||
exit_handler_group_task_name = component_utils.sanitize_task_name(
|
||||
first_group.name)
|
||||
input_parameters_in_current_dag = [
|
||||
input_name for input_name in
|
||||
pipeline_spec.root.input_definitions.parameters
|
||||
]
|
||||
exit_task_task_spec = builder.build_task_spec_for_exit_task(
|
||||
task=exit_task,
|
||||
dependent_task=exit_handler_group_task_name,
|
||||
pipeline_inputs=pipeline_spec.root.input_definitions,
|
||||
)
|
||||
|
||||
exit_task_component_spec = builder.build_component_spec_for_exit_task(
|
||||
task=exit_task)
|
||||
|
||||
exit_task_container_spec = builder.build_container_spec_for_task(
|
||||
task=exit_task)
|
||||
|
||||
# Add exit task task spec
|
||||
pipeline_spec.root.dag.tasks[exit_task_name].CopyFrom(
|
||||
exit_task_task_spec)
|
||||
|
||||
# Add exit task component spec if it does not exist.
|
||||
component_name = exit_task_task_spec.component_ref.name
|
||||
if component_name not in pipeline_spec.components:
|
||||
pipeline_spec.components[component_name].CopyFrom(
|
||||
exit_task_component_spec)
|
||||
|
||||
# Add exit task container spec if it does not exist.
|
||||
executor_label = exit_task_component_spec.executor_label
|
||||
if executor_label not in deployment_config.executors:
|
||||
deployment_config.executors[
|
||||
executor_label].container.CopyFrom(
|
||||
exit_task_container_spec)
|
||||
pipeline_spec.deployment_spec.update(
|
||||
json_format.MessageToDict(deployment_config))
|
||||
builder.build_exit_handler_groups_recursively(
|
||||
parent_group=root_group,
|
||||
pipeline_spec=pipeline_spec,
|
||||
deployment_config=deployment_config,
|
||||
)
|
||||
|
||||
return pipeline_spec
|
||||
|
||||
|
@ -705,14 +626,12 @@ class Compiler:
|
|||
task2=task,
|
||||
)
|
||||
|
||||
# If a task depends on a condition group or a loop group, it
|
||||
# must explicitly dependent on a task inside the group. This
|
||||
# should not be allowed, because it leads to ambiguous
|
||||
# expectations for runtime behaviors.
|
||||
# a task cannot depend on a task created in a for loop group since individual PipelineTask variables are reassigned after each loop iteration
|
||||
dependent_group = group_name_to_group.get(
|
||||
upstream_groups[0], None)
|
||||
if isinstance(dependent_group,
|
||||
(tasks_group.Condition, tasks_group.ParallelFor)):
|
||||
(tasks_group.ParallelFor, tasks_group.Condition,
|
||||
tasks_group.ExitHandler)):
|
||||
raise RuntimeError(
|
||||
f'Task {task.name} cannot dependent on any task inside'
|
||||
f' the group: {upstream_groups[0]}.')
|
||||
|
|
|
@ -440,7 +440,7 @@ class TestCompilePipeline(parameterized.TestCase):
|
|||
compiler.Compiler().compile(
|
||||
pipeline_func=my_pipeline, package_path='result.yaml')
|
||||
|
||||
def test_invalid_after_dependency(self):
|
||||
def test_invalid_data_dependency_loop(self):
|
||||
|
||||
@dsl.component
|
||||
def producer_op() -> str:
|
||||
|
@ -451,30 +451,7 @@ class TestCompilePipeline(parameterized.TestCase):
|
|||
pass
|
||||
|
||||
@dsl.pipeline(name='test-pipeline')
|
||||
def my_pipeline(text: str):
|
||||
with dsl.Condition(text == 'a'):
|
||||
producer_task = producer_op()
|
||||
|
||||
dummy_op().after(producer_task)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
'Task dummy-op cannot dependent on any task inside the group:'):
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=my_pipeline, package_path='result.yaml')
|
||||
|
||||
def test_invalid_data_dependency(self):
|
||||
|
||||
@dsl.component
|
||||
def producer_op() -> str:
|
||||
return 'a'
|
||||
|
||||
@dsl.component
|
||||
def dummy_op(msg: str = ''):
|
||||
pass
|
||||
|
||||
@dsl.pipeline(name='test-pipeline')
|
||||
def my_pipeline(text: bool):
|
||||
def my_pipeline(val: bool):
|
||||
with dsl.ParallelFor(['a, b']):
|
||||
producer_task = producer_op()
|
||||
|
||||
|
@ -483,8 +460,125 @@ class TestCompilePipeline(parameterized.TestCase):
|
|||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
'Task dummy-op cannot dependent on any task inside the group:'):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
package_path = os.path.join(tmpdir, 'pipeline.yaml')
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=my_pipeline, package_path=package_path)
|
||||
|
||||
def test_valid_data_dependency_loop(self):
|
||||
|
||||
@dsl.component
|
||||
def producer_op() -> str:
|
||||
return 'a'
|
||||
|
||||
@dsl.component
|
||||
def dummy_op(msg: str = ''):
|
||||
pass
|
||||
|
||||
@dsl.pipeline(name='test-pipeline')
|
||||
def my_pipeline(val: bool):
|
||||
with dsl.ParallelFor(['a, b']):
|
||||
producer_task = producer_op()
|
||||
dummy_op(msg=producer_task.output)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
package_path = os.path.join(tmpdir, 'pipeline.yaml')
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=my_pipeline, package_path='result.yaml')
|
||||
pipeline_func=my_pipeline, package_path=package_path)
|
||||
|
||||
def test_invalid_data_dependency_condition(self):
|
||||
|
||||
@dsl.component
|
||||
def producer_op() -> str:
|
||||
return 'a'
|
||||
|
||||
@dsl.component
|
||||
def dummy_op(msg: str = ''):
|
||||
pass
|
||||
|
||||
@dsl.pipeline(name='test-pipeline')
|
||||
def my_pipeline(val: bool):
|
||||
with dsl.Condition(val == False):
|
||||
producer_task = producer_op()
|
||||
|
||||
dummy_op(msg=producer_task.output)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
'Task dummy-op cannot dependent on any task inside the group:'):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
package_path = os.path.join(tmpdir, 'pipeline.yaml')
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=my_pipeline, package_path=package_path)
|
||||
|
||||
def test_valid_data_dependency_condition(self):
|
||||
|
||||
@dsl.component
|
||||
def producer_op() -> str:
|
||||
return 'a'
|
||||
|
||||
@dsl.component
|
||||
def dummy_op(msg: str = ''):
|
||||
pass
|
||||
|
||||
@dsl.pipeline(name='test-pipeline')
|
||||
def my_pipeline(val: bool):
|
||||
with dsl.Condition(val == False):
|
||||
producer_task = producer_op()
|
||||
dummy_op(msg=producer_task.output)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
package_path = os.path.join(tmpdir, 'pipeline.yaml')
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=my_pipeline, package_path=package_path)
|
||||
|
||||
def test_invalid_data_dependency_exit_handler(self):
|
||||
|
||||
@dsl.component
|
||||
def producer_op() -> str:
|
||||
return 'a'
|
||||
|
||||
@dsl.component
|
||||
def dummy_op(msg: str = ''):
|
||||
pass
|
||||
|
||||
@dsl.pipeline(name='test-pipeline')
|
||||
def my_pipeline(val: bool):
|
||||
first_producer = producer_op()
|
||||
with dsl.ExitHandler(first_producer):
|
||||
producer_task = producer_op()
|
||||
|
||||
dummy_op(msg=producer_task.output)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
'Task dummy-op cannot dependent on any task inside the group:'):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
package_path = os.path.join(tmpdir, 'pipeline.yaml')
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=my_pipeline, package_path=package_path)
|
||||
|
||||
def test_valid_data_dependency_exit_handler(self):
|
||||
|
||||
@dsl.component
|
||||
def producer_op() -> str:
|
||||
return 'a'
|
||||
|
||||
@dsl.component
|
||||
def dummy_op(msg: str = ''):
|
||||
pass
|
||||
|
||||
@dsl.pipeline(name='test-pipeline')
|
||||
def my_pipeline(val: bool):
|
||||
first_producer = producer_op()
|
||||
with dsl.ExitHandler(first_producer):
|
||||
producer_task = producer_op()
|
||||
dummy_op(msg=producer_task.output)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
package_path = os.path.join(tmpdir, 'pipeline.yaml')
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=my_pipeline, package_path=package_path)
|
||||
|
||||
def test_use_task_final_status_in_non_exit_op(self):
|
||||
|
||||
|
@ -527,7 +621,6 @@ implementation:
|
|||
pipeline_func=my_pipeline, package_path='result.yaml')
|
||||
|
||||
|
||||
# pylint: disable=import-outside-toplevel,unused-import,import-error,redefined-outer-name,reimported
|
||||
class V2NamespaceAliasTest(unittest.TestCase):
|
||||
"""Test that imports of both modules and objects are aliased (e.g. all
|
||||
import path variants work)."""
|
||||
|
@ -536,7 +629,7 @@ class V2NamespaceAliasTest(unittest.TestCase):
|
|||
# the kfp.v2 module is loaded. Due to the way we run tests in CI/CD, we cannot ensure that the kfp.v2 module will first be loaded in these tests,
|
||||
# so we do not test for the DeprecationWarning here.
|
||||
|
||||
def test_import_namespace(self): # pylint: disable=no-self-use
|
||||
def test_import_namespace(self):
|
||||
from kfp import v2
|
||||
|
||||
@v2.dsl.component
|
||||
|
@ -560,7 +653,7 @@ class V2NamespaceAliasTest(unittest.TestCase):
|
|||
with open(temp_filepath, 'r') as f:
|
||||
yaml.load(f)
|
||||
|
||||
def test_import_modules(self): # pylint: disable=no-self-use
|
||||
def test_import_modules(self):
|
||||
from kfp.v2 import compiler
|
||||
from kfp.v2 import dsl
|
||||
|
||||
|
@ -584,7 +677,7 @@ class V2NamespaceAliasTest(unittest.TestCase):
|
|||
with open(temp_filepath, 'r') as f:
|
||||
yaml.load(f)
|
||||
|
||||
def test_import_object(self): # pylint: disable=no-self-use
|
||||
def test_import_object(self):
|
||||
from kfp.v2.compiler import Compiler
|
||||
from kfp.v2.dsl import component
|
||||
from kfp.v2.dsl import pipeline
|
||||
|
@ -1125,5 +1218,82 @@ class TestSetRetryCompilation(unittest.TestCase):
|
|||
self.assertEqual(retry_policy.backoff_max_duration.seconds, 3600)
|
||||
|
||||
|
||||
from google.protobuf import json_format
|
||||
|
||||
|
||||
class TestMultipleExitHandlerCompilation(unittest.TestCase):
|
||||
|
||||
def test_basic(self):
|
||||
|
||||
@dsl.component
|
||||
def print_op(message: str):
|
||||
print(message)
|
||||
|
||||
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
|
||||
def my_pipeline():
|
||||
first_exit_task = print_op(message='First exit task.')
|
||||
|
||||
with dsl.ExitHandler(first_exit_task):
|
||||
print_op(message='Inside first exit handler.')
|
||||
|
||||
second_exit_task = print_op(message='Second exit task.')
|
||||
with dsl.ExitHandler(second_exit_task):
|
||||
print_op(message='Inside second exit handler.')
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
package_path = os.path.join(tempdir, 'pipeline.yaml')
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=my_pipeline, package_path=package_path)
|
||||
pipeline_spec = pipeline_spec_from_file(package_path)
|
||||
# check that the exit handler dags exist
|
||||
self.assertEqual(
|
||||
pipeline_spec.components['comp-exit-handler-1'].dag
|
||||
.tasks['print-op-2'].inputs.parameters['message'].runtime_value
|
||||
.constant.string_value, 'Inside first exit handler.')
|
||||
self.assertEqual(
|
||||
pipeline_spec.components['comp-exit-handler-2'].dag
|
||||
.tasks['print-op-4'].inputs.parameters['message'].runtime_value
|
||||
.constant.string_value, 'Inside second exit handler.')
|
||||
# check that the exit handler dags are in the root dag
|
||||
self.assertIn('exit-handler-1', pipeline_spec.root.dag.tasks)
|
||||
self.assertIn('exit-handler-2', pipeline_spec.root.dag.tasks)
|
||||
# check that the exit tasks are in the root dag
|
||||
self.assertIn('print-op', pipeline_spec.root.dag.tasks)
|
||||
self.assertEqual(
|
||||
pipeline_spec.root.dag.tasks['print-op'].inputs
|
||||
.parameters['message'].runtime_value.constant.string_value,
|
||||
'First exit task.')
|
||||
self.assertIn('print-op-3', pipeline_spec.root.dag.tasks)
|
||||
self.assertEqual(
|
||||
pipeline_spec.root.dag.tasks['print-op-3'].inputs
|
||||
.parameters['message'].runtime_value.constant.string_value,
|
||||
'Second exit task.')
|
||||
|
||||
def test_nested_unsupported(self):
|
||||
|
||||
@dsl.component
|
||||
def print_op(message: str):
|
||||
print(message)
|
||||
|
||||
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
|
||||
def my_pipeline():
|
||||
first_exit_task = print_op(message='First exit task.')
|
||||
|
||||
with dsl.ExitHandler(first_exit_task):
|
||||
print_op(message='Inside first exit handler.')
|
||||
|
||||
second_exit_task = print_op(message='Second exit task.')
|
||||
with dsl.ExitHandler(second_exit_task):
|
||||
print_op(message='Inside second exit handler.')
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r'ExitHandler can only be used within the outermost scope of a pipeline function definition\.'
|
||||
):
|
||||
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=my_pipeline, package_path='output.yaml')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -23,6 +23,7 @@ from kfp import dsl
|
|||
from kfp.compiler import pipeline_spec_builder as builder
|
||||
from kfp.components import for_loop
|
||||
from kfp.components import pipeline_channel
|
||||
from kfp.components import pipeline_context
|
||||
from kfp.components import pipeline_task
|
||||
from kfp.components import placeholders
|
||||
from kfp.components import structures
|
||||
|
@ -34,6 +35,13 @@ from kfp.components.types import type_utils
|
|||
from kfp.pipeline_spec import pipeline_spec_pb2
|
||||
|
||||
GroupOrTaskType = Union[tasks_group.TasksGroup, pipeline_task.PipelineTask]
|
||||
# must be defined here to avoid circular imports
|
||||
group_type_to_dsl_class = {
|
||||
tasks_group.TasksGroupType.PIPELINE: pipeline_context.Pipeline,
|
||||
tasks_group.TasksGroupType.CONDITION: tasks_group.Condition,
|
||||
tasks_group.TasksGroupType.FOR_LOOP: tasks_group.ParallelFor,
|
||||
tasks_group.TasksGroupType.EXIT_HANDLER: tasks_group.ExitHandler,
|
||||
}
|
||||
|
||||
|
||||
def _additional_input_name_for_pipeline_channel(
|
||||
|
@ -772,7 +780,7 @@ def build_task_spec_for_exit_task(
|
|||
pipeline_spec_pb2.PipelineTaskSpec.TriggerPolicy.TriggerStrategy
|
||||
.ALL_UPSTREAM_TASKS_COMPLETED)
|
||||
|
||||
for input_name, input_spec in task.component_spec.inputs.items():
|
||||
for input_name, input_spec in (task.component_spec.inputs or {}).items():
|
||||
if type_utils.is_task_final_status_type(input_spec.type):
|
||||
pipeline_task_spec.inputs.parameters[
|
||||
input_name].task_final_status.producer_task = dependent_task
|
||||
|
@ -1184,6 +1192,61 @@ def build_spec_by_group(
|
|||
)
|
||||
|
||||
|
||||
def build_exit_handler_groups_recursively(
|
||||
parent_group: tasks_group.TasksGroup,
|
||||
pipeline_spec: pipeline_spec_pb2.PipelineSpec,
|
||||
deployment_config: pipeline_spec_pb2.PipelineDeploymentConfig,
|
||||
):
|
||||
if not parent_group.groups:
|
||||
return
|
||||
for group in parent_group.groups:
|
||||
if isinstance(group, dsl.ExitHandler):
|
||||
exit_task = group.exit_task
|
||||
exit_task_name = utils.sanitize_task_name(exit_task.name)
|
||||
exit_handler_group_task_name = utils.sanitize_task_name(group.name)
|
||||
|
||||
exit_task_task_spec = builder.build_task_spec_for_exit_task(
|
||||
task=exit_task,
|
||||
dependent_task=exit_handler_group_task_name,
|
||||
pipeline_inputs=pipeline_spec.root.input_definitions,
|
||||
)
|
||||
|
||||
exit_task_component_spec = builder.build_component_spec_for_exit_task(
|
||||
task=exit_task)
|
||||
|
||||
exit_task_container_spec = builder.build_container_spec_for_task(
|
||||
task=exit_task)
|
||||
|
||||
# remove this if block to support nested exit handlers
|
||||
if not parent_group.is_root:
|
||||
raise ValueError(
|
||||
f'{dsl.ExitHandler.__name__} can only be used within the outermost scope of a pipeline function definition. Using an {dsl.ExitHandler.__name__} within {group_type_to_dsl_class[parent_group.group_type].__name__} {parent_group.name} is not allowed.'
|
||||
)
|
||||
|
||||
parent_dag = pipeline_spec.root.dag if parent_group.is_root else pipeline_spec.components[
|
||||
utils.sanitize_component_name(parent_group.name)].dag
|
||||
|
||||
parent_dag.tasks[exit_task_name].CopyFrom(exit_task_task_spec)
|
||||
|
||||
# Add exit task component spec if it does not exist.
|
||||
component_name = exit_task_task_spec.component_ref.name
|
||||
if component_name not in pipeline_spec.components:
|
||||
pipeline_spec.components[component_name].CopyFrom(
|
||||
exit_task_component_spec)
|
||||
|
||||
# Add exit task container spec if it does not exist.
|
||||
executor_label = exit_task_component_spec.executor_label
|
||||
if executor_label not in deployment_config.executors:
|
||||
deployment_config.executors[executor_label].container.CopyFrom(
|
||||
exit_task_container_spec)
|
||||
pipeline_spec.deployment_spec.update(
|
||||
json_format.MessageToDict(deployment_config))
|
||||
build_exit_handler_groups_recursively(
|
||||
parent_group=group,
|
||||
pipeline_spec=pipeline_spec,
|
||||
deployment_config=deployment_config)
|
||||
|
||||
|
||||
def get_parent_groups(
|
||||
root_group: tasks_group.TasksGroup,
|
||||
) -> Tuple[Mapping[str, List[GroupOrTaskType]], Mapping[str,
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright 2022 The Kubeflow Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Pipeline using multiple ExitHandlers."""
|
||||
|
||||
from kfp import compiler
|
||||
from kfp import dsl
|
||||
from kfp.dsl import component
|
||||
|
||||
|
||||
@component
|
||||
def print_op(message: str):
|
||||
"""Prints a message."""
|
||||
print(message)
|
||||
|
||||
|
||||
@component
|
||||
def fail_op(message: str):
|
||||
"""Fails."""
|
||||
import sys
|
||||
print(message)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
|
||||
def my_pipeline(message: str = 'Hello World!'):
|
||||
|
||||
first_exit_task = print_op(message='First exit handler has worked!')
|
||||
|
||||
with dsl.ExitHandler(first_exit_task):
|
||||
first_exit_print_task = print_op(message=message)
|
||||
print(first_exit_print_task.outputs)
|
||||
fail_op(message='Task failed.')
|
||||
|
||||
second_exit_task = print_op(message='Second exit handler has worked!')
|
||||
|
||||
with dsl.ExitHandler(second_exit_task):
|
||||
print_op(message=message)
|
||||
|
||||
third_exit_task = print_op(message='Third exit handler has worked!')
|
||||
|
||||
with dsl.ExitHandler(third_exit_task):
|
||||
print_op(message=message)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=my_pipeline,
|
||||
package_path=__file__.replace('.py', '.yaml'))
|
|
@ -0,0 +1,387 @@
|
|||
components:
|
||||
comp-exit-handler-1:
|
||||
dag:
|
||||
tasks:
|
||||
fail-op:
|
||||
cachingOptions:
|
||||
enableCache: true
|
||||
componentRef:
|
||||
name: comp-fail-op
|
||||
inputs:
|
||||
parameters:
|
||||
message:
|
||||
runtimeValue:
|
||||
constant: Task failed.
|
||||
taskInfo:
|
||||
name: fail-op
|
||||
print-op-2:
|
||||
cachingOptions:
|
||||
enableCache: true
|
||||
componentRef:
|
||||
name: comp-print-op-2
|
||||
inputs:
|
||||
parameters:
|
||||
message:
|
||||
componentInputParameter: pipelinechannel--message
|
||||
taskInfo:
|
||||
name: print-op-2
|
||||
inputDefinitions:
|
||||
parameters:
|
||||
pipelinechannel--message:
|
||||
parameterType: STRING
|
||||
comp-exit-handler-2:
|
||||
dag:
|
||||
tasks:
|
||||
print-op-4:
|
||||
cachingOptions:
|
||||
enableCache: true
|
||||
componentRef:
|
||||
name: comp-print-op-4
|
||||
inputs:
|
||||
parameters:
|
||||
message:
|
||||
componentInputParameter: pipelinechannel--message
|
||||
taskInfo:
|
||||
name: print-op-4
|
||||
inputDefinitions:
|
||||
parameters:
|
||||
pipelinechannel--message:
|
||||
parameterType: STRING
|
||||
comp-exit-handler-3:
|
||||
dag:
|
||||
tasks:
|
||||
print-op-6:
|
||||
cachingOptions:
|
||||
enableCache: true
|
||||
componentRef:
|
||||
name: comp-print-op-6
|
||||
inputs:
|
||||
parameters:
|
||||
message:
|
||||
componentInputParameter: pipelinechannel--message
|
||||
taskInfo:
|
||||
name: print-op-6
|
||||
inputDefinitions:
|
||||
parameters:
|
||||
pipelinechannel--message:
|
||||
parameterType: STRING
|
||||
comp-fail-op:
|
||||
executorLabel: exec-fail-op
|
||||
inputDefinitions:
|
||||
parameters:
|
||||
message:
|
||||
parameterType: STRING
|
||||
comp-print-op:
|
||||
executorLabel: exec-print-op
|
||||
inputDefinitions:
|
||||
parameters:
|
||||
message:
|
||||
parameterType: STRING
|
||||
comp-print-op-2:
|
||||
executorLabel: exec-print-op-2
|
||||
inputDefinitions:
|
||||
parameters:
|
||||
message:
|
||||
parameterType: STRING
|
||||
comp-print-op-3:
|
||||
executorLabel: exec-print-op-3
|
||||
inputDefinitions:
|
||||
parameters:
|
||||
message:
|
||||
parameterType: STRING
|
||||
comp-print-op-4:
|
||||
executorLabel: exec-print-op-4
|
||||
inputDefinitions:
|
||||
parameters:
|
||||
message:
|
||||
parameterType: STRING
|
||||
comp-print-op-5:
|
||||
executorLabel: exec-print-op-5
|
||||
inputDefinitions:
|
||||
parameters:
|
||||
message:
|
||||
parameterType: STRING
|
||||
comp-print-op-6:
|
||||
executorLabel: exec-print-op-6
|
||||
inputDefinitions:
|
||||
parameters:
|
||||
message:
|
||||
parameterType: STRING
|
||||
deploymentSpec:
|
||||
executors:
|
||||
exec-fail-op:
|
||||
container:
|
||||
args:
|
||||
- --executor_input
|
||||
- '{{$}}'
|
||||
- --function_to_execute
|
||||
- fail_op
|
||||
command:
|
||||
- sh
|
||||
- -c
|
||||
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
|
||||
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
|
||||
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
|
||||
\ && \"$0\" \"$@\"\n"
|
||||
- sh
|
||||
- -ec
|
||||
- 'program_path=$(mktemp -d)
|
||||
|
||||
printf "%s" "$0" > "$program_path/ephemeral_component.py"
|
||||
|
||||
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
|
||||
|
||||
'
|
||||
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
|
||||
\ *\n\ndef fail_op(message: str):\n \"\"\"Fails.\"\"\"\n import sys\n\
|
||||
\ print(message)\n sys.exit(1)\n\n"
|
||||
image: python:3.7
|
||||
exec-print-op:
|
||||
container:
|
||||
args:
|
||||
- --executor_input
|
||||
- '{{$}}'
|
||||
- --function_to_execute
|
||||
- print_op
|
||||
command:
|
||||
- sh
|
||||
- -c
|
||||
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
|
||||
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
|
||||
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
|
||||
\ && \"$0\" \"$@\"\n"
|
||||
- sh
|
||||
- -ec
|
||||
- 'program_path=$(mktemp -d)
|
||||
|
||||
printf "%s" "$0" > "$program_path/ephemeral_component.py"
|
||||
|
||||
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
|
||||
|
||||
'
|
||||
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
|
||||
\ *\n\ndef print_op(message: str):\n \"\"\"Prints a message.\"\"\"\n\
|
||||
\ print(message)\n\n"
|
||||
image: python:3.7
|
||||
exec-print-op-2:
|
||||
container:
|
||||
args:
|
||||
- --executor_input
|
||||
- '{{$}}'
|
||||
- --function_to_execute
|
||||
- print_op
|
||||
command:
|
||||
- sh
|
||||
- -c
|
||||
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
|
||||
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
|
||||
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
|
||||
\ && \"$0\" \"$@\"\n"
|
||||
- sh
|
||||
- -ec
|
||||
- 'program_path=$(mktemp -d)
|
||||
|
||||
printf "%s" "$0" > "$program_path/ephemeral_component.py"
|
||||
|
||||
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
|
||||
|
||||
'
|
||||
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
|
||||
\ *\n\ndef print_op(message: str):\n \"\"\"Prints a message.\"\"\"\n\
|
||||
\ print(message)\n\n"
|
||||
image: python:3.7
|
||||
exec-print-op-3:
|
||||
container:
|
||||
args:
|
||||
- --executor_input
|
||||
- '{{$}}'
|
||||
- --function_to_execute
|
||||
- print_op
|
||||
command:
|
||||
- sh
|
||||
- -c
|
||||
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
|
||||
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
|
||||
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
|
||||
\ && \"$0\" \"$@\"\n"
|
||||
- sh
|
||||
- -ec
|
||||
- 'program_path=$(mktemp -d)
|
||||
|
||||
printf "%s" "$0" > "$program_path/ephemeral_component.py"
|
||||
|
||||
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
|
||||
|
||||
'
|
||||
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
|
||||
\ *\n\ndef print_op(message: str):\n \"\"\"Prints a message.\"\"\"\n\
|
||||
\ print(message)\n\n"
|
||||
image: python:3.7
|
||||
exec-print-op-4:
|
||||
container:
|
||||
args:
|
||||
- --executor_input
|
||||
- '{{$}}'
|
||||
- --function_to_execute
|
||||
- print_op
|
||||
command:
|
||||
- sh
|
||||
- -c
|
||||
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
|
||||
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
|
||||
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
|
||||
\ && \"$0\" \"$@\"\n"
|
||||
- sh
|
||||
- -ec
|
||||
- 'program_path=$(mktemp -d)
|
||||
|
||||
printf "%s" "$0" > "$program_path/ephemeral_component.py"
|
||||
|
||||
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
|
||||
|
||||
'
|
||||
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
|
||||
\ *\n\ndef print_op(message: str):\n \"\"\"Prints a message.\"\"\"\n\
|
||||
\ print(message)\n\n"
|
||||
image: python:3.7
|
||||
exec-print-op-5:
|
||||
container:
|
||||
args:
|
||||
- --executor_input
|
||||
- '{{$}}'
|
||||
- --function_to_execute
|
||||
- print_op
|
||||
command:
|
||||
- sh
|
||||
- -c
|
||||
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
|
||||
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
|
||||
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
|
||||
\ && \"$0\" \"$@\"\n"
|
||||
- sh
|
||||
- -ec
|
||||
- 'program_path=$(mktemp -d)
|
||||
|
||||
printf "%s" "$0" > "$program_path/ephemeral_component.py"
|
||||
|
||||
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
|
||||
|
||||
'
|
||||
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
|
||||
\ *\n\ndef print_op(message: str):\n \"\"\"Prints a message.\"\"\"\n\
|
||||
\ print(message)\n\n"
|
||||
image: python:3.7
|
||||
exec-print-op-6:
|
||||
container:
|
||||
args:
|
||||
- --executor_input
|
||||
- '{{$}}'
|
||||
- --function_to_execute
|
||||
- print_op
|
||||
command:
|
||||
- sh
|
||||
- -c
|
||||
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
|
||||
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
|
||||
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
|
||||
\ && \"$0\" \"$@\"\n"
|
||||
- sh
|
||||
- -ec
|
||||
- 'program_path=$(mktemp -d)
|
||||
|
||||
printf "%s" "$0" > "$program_path/ephemeral_component.py"
|
||||
|
||||
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
|
||||
|
||||
'
|
||||
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
|
||||
\ *\n\ndef print_op(message: str):\n \"\"\"Prints a message.\"\"\"\n\
|
||||
\ print(message)\n\n"
|
||||
image: python:3.7
|
||||
pipelineInfo:
|
||||
name: pipeline-with-multiple-exit-handlers
|
||||
root:
|
||||
dag:
|
||||
tasks:
|
||||
exit-handler-1:
|
||||
componentRef:
|
||||
name: comp-exit-handler-1
|
||||
inputs:
|
||||
parameters:
|
||||
pipelinechannel--message:
|
||||
componentInputParameter: message
|
||||
taskInfo:
|
||||
name: exit-handler-1
|
||||
exit-handler-2:
|
||||
componentRef:
|
||||
name: comp-exit-handler-2
|
||||
inputs:
|
||||
parameters:
|
||||
pipelinechannel--message:
|
||||
componentInputParameter: message
|
||||
taskInfo:
|
||||
name: exit-handler-2
|
||||
exit-handler-3:
|
||||
componentRef:
|
||||
name: comp-exit-handler-3
|
||||
inputs:
|
||||
parameters:
|
||||
pipelinechannel--message:
|
||||
componentInputParameter: message
|
||||
taskInfo:
|
||||
name: exit-handler-3
|
||||
print-op:
|
||||
cachingOptions:
|
||||
enableCache: true
|
||||
componentRef:
|
||||
name: comp-print-op
|
||||
dependentTasks:
|
||||
- exit-handler-1
|
||||
inputs:
|
||||
parameters:
|
||||
message:
|
||||
runtimeValue:
|
||||
constant: First exit handler has worked!
|
||||
taskInfo:
|
||||
name: print-op
|
||||
triggerPolicy:
|
||||
strategy: ALL_UPSTREAM_TASKS_COMPLETED
|
||||
print-op-3:
|
||||
cachingOptions:
|
||||
enableCache: true
|
||||
componentRef:
|
||||
name: comp-print-op-3
|
||||
dependentTasks:
|
||||
- exit-handler-2
|
||||
inputs:
|
||||
parameters:
|
||||
message:
|
||||
runtimeValue:
|
||||
constant: Second exit handler has worked!
|
||||
taskInfo:
|
||||
name: print-op-3
|
||||
triggerPolicy:
|
||||
strategy: ALL_UPSTREAM_TASKS_COMPLETED
|
||||
print-op-5:
|
||||
cachingOptions:
|
||||
enableCache: true
|
||||
componentRef:
|
||||
name: comp-print-op-5
|
||||
dependentTasks:
|
||||
- exit-handler-3
|
||||
inputs:
|
||||
parameters:
|
||||
message:
|
||||
runtimeValue:
|
||||
constant: Third exit handler has worked!
|
||||
taskInfo:
|
||||
name: print-op-5
|
||||
triggerPolicy:
|
||||
strategy: ALL_UPSTREAM_TASKS_COMPLETED
|
||||
inputDefinitions:
|
||||
parameters:
|
||||
message:
|
||||
defaultValue: Hello World!
|
||||
parameterType: STRING
|
||||
schemaVersion: 2.1.0
|
||||
sdkVersion: kfp-2.0.0-beta.1
|
|
@ -118,7 +118,9 @@ class Pipeline:
|
|||
# Add the root group.
|
||||
self.groups = [
|
||||
tasks_group.TasksGroup(
|
||||
group_type=tasks_group.TasksGroupType.PIPELINE, name=name)
|
||||
group_type=tasks_group.TasksGroupType.PIPELINE,
|
||||
name=name,
|
||||
is_root=True)
|
||||
]
|
||||
self._group_id = 0
|
||||
|
||||
|
@ -174,6 +176,7 @@ class Pipeline:
|
|||
|
||||
self.tasks[task_name] = task
|
||||
if add_to_group:
|
||||
task.parent_task_group = self.groups[-1]
|
||||
self.groups[-1].tasks.append(task)
|
||||
|
||||
return task_name
|
||||
|
|
|
@ -69,6 +69,10 @@ class PipelineTask:
|
|||
args: Mapping[str, Any],
|
||||
):
|
||||
"""Initilizes a PipelineTask instance."""
|
||||
# import within __init__ to avoid circular import
|
||||
from kfp.components.tasks_group import TasksGroup
|
||||
|
||||
self.parent_task_group: Union[None, TasksGroup] = None
|
||||
args = args or {}
|
||||
|
||||
for input_name, argument_value in args.items():
|
||||
|
@ -558,5 +562,9 @@ class PipelineTask:
|
|||
task2 = my_component(text='2nd task').after(task1)
|
||||
"""
|
||||
for task in tasks:
|
||||
if task.parent_task_group is not self.parent_task_group:
|
||||
raise ValueError(
|
||||
f'Cannot use .after() across inner pipelines or DSL control flow features. Tried to set {self.name} after {task.name}, but these tasks do not belong to the same pipeline or are not enclosed in the same control flow content manager.'
|
||||
)
|
||||
self._task_spec.dependent_tasks.append(task.name)
|
||||
return self
|
||||
|
|
|
@ -13,10 +13,14 @@
|
|||
# limitations under the License.
|
||||
"""Tests for kfp.components.pipeline_task."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import textwrap
|
||||
import unittest
|
||||
|
||||
from absl.testing import parameterized
|
||||
from kfp import compiler
|
||||
from kfp import dsl
|
||||
from kfp.components import pipeline_task
|
||||
from kfp.components import placeholders
|
||||
from kfp.components import structures
|
||||
|
@ -301,5 +305,136 @@ class PipelineTaskTest(parameterized.TestCase):
|
|||
self.assertEqual('test_name', task._task_spec.display_name)
|
||||
|
||||
|
||||
class TestCannotUseAfterCrossDAG(unittest.TestCase):
|
||||
|
||||
def test_inner_task_prevented(self):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'Cannot use \.after\(\) across'):
|
||||
|
||||
@dsl.component
|
||||
def print_op(message: str):
|
||||
print(message)
|
||||
|
||||
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
|
||||
def my_pipeline():
|
||||
first_exit_task = print_op(message='First exit task.')
|
||||
|
||||
with dsl.ExitHandler(first_exit_task):
|
||||
first_print_op = print_op(
|
||||
message='Inside first exit handler.')
|
||||
|
||||
second_exit_task = print_op(message='Second exit task.')
|
||||
with dsl.ExitHandler(second_exit_task):
|
||||
print_op(message='Inside second exit handler.').after(
|
||||
first_print_op)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
package_path = os.path.join(tempdir, 'pipeline.yaml')
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=my_pipeline, package_path=package_path)
|
||||
|
||||
def test_exit_handler_task_prevented(self):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'Cannot use \.after\(\) across'):
|
||||
|
||||
@dsl.component
|
||||
def print_op(message: str):
|
||||
print(message)
|
||||
|
||||
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
|
||||
def my_pipeline():
|
||||
first_exit_task = print_op(message='First exit task.')
|
||||
|
||||
with dsl.ExitHandler(first_exit_task):
|
||||
first_print_op = print_op(
|
||||
message='Inside first exit handler.')
|
||||
|
||||
second_exit_task = print_op(message='Second exit task.')
|
||||
with dsl.ExitHandler(second_exit_task):
|
||||
x = print_op(message='Inside second exit handler.')
|
||||
x.after(first_exit_task)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
package_path = os.path.join(tempdir, 'pipeline.yaml')
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=my_pipeline, package_path=package_path)
|
||||
|
||||
def test_within_same_exit_handler_permitted(self):
|
||||
|
||||
@dsl.component
|
||||
def print_op(message: str):
|
||||
print(message)
|
||||
|
||||
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
|
||||
def my_pipeline():
|
||||
first_exit_task = print_op(message='First exit task.')
|
||||
|
||||
with dsl.ExitHandler(first_exit_task):
|
||||
first_print_op = print_op(
|
||||
message='First task inside first exit handler.')
|
||||
second_print_op = print_op(
|
||||
message='Second task inside first exit handler.').after(
|
||||
first_print_op)
|
||||
|
||||
second_exit_task = print_op(message='Second exit task.')
|
||||
with dsl.ExitHandler(second_exit_task):
|
||||
print_op(message='Inside second exit handler.')
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
package_path = os.path.join(tempdir, 'pipeline.yaml')
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=my_pipeline, package_path=package_path)
|
||||
|
||||
def test_outside_of_condition_blocked(self):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'Cannot use \.after\(\) across'):
|
||||
|
||||
@dsl.component
|
||||
def print_op(message: str):
|
||||
print(message)
|
||||
|
||||
@dsl.component
|
||||
def return_1() -> int:
|
||||
return 1
|
||||
|
||||
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
|
||||
def my_pipeline():
|
||||
return_1_task = return_1()
|
||||
|
||||
with dsl.Condition(return_1_task.output == 1):
|
||||
one = print_op(message='1')
|
||||
two = print_op(message='2')
|
||||
three = print_op(message='3').after(one)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
package_path = os.path.join(tempdir, 'pipeline.yaml')
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=my_pipeline, package_path=package_path)
|
||||
|
||||
def test_inside_of_condition_permitted(self):
|
||||
|
||||
@dsl.component
|
||||
def print_op(message: str):
|
||||
print(message)
|
||||
|
||||
@dsl.component
|
||||
def return_1() -> int:
|
||||
return 1
|
||||
|
||||
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
|
||||
def my_pipeline():
|
||||
return_1_task = return_1()
|
||||
|
||||
with dsl.Condition(return_1_task.output == '1'):
|
||||
one = print_op(message='1')
|
||||
two = print_op(message='2').after(one)
|
||||
three = print_op(message='3')
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
package_path = os.path.join(tempdir, 'pipeline.yaml')
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=my_pipeline, package_path=package_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -44,12 +44,14 @@ class TasksGroup:
|
|||
groups: A list of TasksGroups in this group.
|
||||
display_name: The optional user given name of the group.
|
||||
dependencies: A list of tasks or groups this group depends on.
|
||||
is_root: If TasksGroup is root group.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group_type: TasksGroupType,
|
||||
name: Optional[str] = None,
|
||||
is_root: bool = False,
|
||||
):
|
||||
"""Create a new instance of TasksGroup.
|
||||
|
||||
|
@ -62,6 +64,7 @@ class TasksGroup:
|
|||
self.groups = list()
|
||||
self.display_name = name
|
||||
self.dependencies = []
|
||||
self.is_root = is_root
|
||||
|
||||
def __enter__(self):
|
||||
if not pipeline_context.Pipeline.get_default_pipeline():
|
||||
|
@ -116,7 +119,11 @@ class ExitHandler(TasksGroup):
|
|||
name: Optional[str] = None,
|
||||
):
|
||||
"""Initializes a Condition task group."""
|
||||
super().__init__(group_type=TasksGroupType.EXIT_HANDLER, name=name)
|
||||
super().__init__(
|
||||
group_type=TasksGroupType.EXIT_HANDLER,
|
||||
name=name,
|
||||
is_root=False,
|
||||
)
|
||||
|
||||
if exit_task.dependent_tasks:
|
||||
raise ValueError('exit_task cannot depend on any other tasks.')
|
||||
|
@ -151,6 +158,7 @@ class Condition(TasksGroup):
|
|||
self,
|
||||
condition: pipeline_channel.ConditionOperator,
|
||||
name: Optional[str] = None,
|
||||
is_root=False,
|
||||
):
|
||||
"""Initializes a conditional task group."""
|
||||
super().__init__(group_type=TasksGroupType.CONDITION, name=name)
|
||||
|
@ -182,7 +190,11 @@ class ParallelFor(TasksGroup):
|
|||
name: Optional[str] = None,
|
||||
):
|
||||
"""Initializes a for loop task group."""
|
||||
super().__init__(group_type=TasksGroupType.FOR_LOOP, name=name)
|
||||
super().__init__(
|
||||
group_type=TasksGroupType.FOR_LOOP,
|
||||
name=name,
|
||||
is_root=False,
|
||||
)
|
||||
|
||||
if isinstance(items, pipeline_channel.PipelineChannel):
|
||||
self.loop_argument = for_loop.LoopArgument.from_pipeline_channel(
|
||||
|
|
Loading…
Reference in New Issue