feat(sdk): support more than one exit handler per pipeline (#8088)
* add compiler test pipeline with multiple exit handlers * remove blocker of multiple exit handlers * move exit handler builder logic to pipeline_spec_builder * build all exit handlers per pipeline * add compiler test with IR inspection * prevent usage of cross-pipeline after * test cross-pipeline after is prevented * update existing task dependency logic and tests * add v2 sample test * remove cross-pipeline .after * prevent cross-dag data dependency for dsl features * add compiler test pipeline with nested exit handlers * add support for nested exit handlers * clean up pipeline with nested exit handlers * remove sample with multiple exit handlers * remove compiler test with nested exit handlers * add compilation guard against nested exit handlers in subdag * update release notes
This commit is contained in:
parent
e728d0871b
commit
bdff332ac6
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from kfp import dsl
|
from kfp import dsl
|
||||||
|
|
||||||
# In tests, we install a KFP package from the PR under test. Users should not
|
# In tests, we install a KFP package from the PR under test. Users should not
|
||||||
|
@ -32,5 +33,4 @@ def my_pipeline(
|
||||||
generate_task = generate_op()
|
generate_task = generate_op()
|
||||||
with dsl.ParallelFor(generate_task.output) as item:
|
with dsl.ParallelFor(generate_task.output) as item:
|
||||||
concat_task = concat_op(a=item.a, b=item.b)
|
concat_task = concat_op(a=item.a, b=item.b)
|
||||||
concat_task.after(print_task)
|
|
||||||
print_task_2 = print_op(text=concat_task.output)
|
print_task_2 = print_op(text=concat_task.output)
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
import os
|
import os
|
||||||
from kfp import dsl
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from kfp import dsl
|
||||||
|
|
||||||
# In tests, we install a KFP package from the PR under test. Users should not
|
# In tests, we install a KFP package from the PR under test. Users should not
|
||||||
# normally need to specify `kfp_package_path` in their component definitions.
|
# normally need to specify `kfp_package_path` in their component definitions.
|
||||||
_KFP_PACKAGE_PATH = os.getenv('KFP_PACKAGE_PATH')
|
_KFP_PACKAGE_PATH = os.getenv('KFP_PACKAGE_PATH')
|
||||||
|
@ -21,12 +22,10 @@ def concat_op(a: str, b: str) -> str:
|
||||||
|
|
||||||
@dsl.pipeline(name='pipeline-with-loop-static')
|
@dsl.pipeline(name='pipeline-with-loop-static')
|
||||||
def my_pipeline(
|
def my_pipeline(
|
||||||
greeting: str = 'this is a test for looping through parameters',
|
greeting: str = 'this is a test for looping through parameters',):
|
||||||
):
|
|
||||||
print_task = print_op(text=greeting)
|
print_task = print_op(text=greeting)
|
||||||
static_loop_arguments = [{'a': '1', 'b': '2'}, {'a': '10', 'b': '20'}]
|
static_loop_arguments = [{'a': '1', 'b': '2'}, {'a': '10', 'b': '20'}]
|
||||||
|
|
||||||
with dsl.ParallelFor(static_loop_arguments) as item:
|
with dsl.ParallelFor(static_loop_arguments) as item:
|
||||||
concat_task = concat_op(a=item.a, b=item.b)
|
concat_task = concat_op(a=item.a, b=item.b)
|
||||||
concat_task.after(print_task)
|
print_task_2 = print_op(text=concat_task.output)
|
||||||
print_task_2 = print_op(text=concat_task.output)
|
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
## Breaking Changes
|
## Breaking Changes
|
||||||
|
|
||||||
### For Pipeline Authors
|
### For Pipeline Authors
|
||||||
* Add support for task-level retry policy [\#7867](https://github.com/kubeflow/pipelines/pull/7867)
|
|
||||||
|
|
||||||
### For Component Authors
|
### For Component Authors
|
||||||
|
|
||||||
|
@ -14,6 +13,8 @@
|
||||||
## Bug Fixes and Other Changes
|
## Bug Fixes and Other Changes
|
||||||
* Enable overriding caching options at submission time [\#7912](https://github.com/kubeflow/pipelines/pull/7912)
|
* Enable overriding caching options at submission time [\#7912](https://github.com/kubeflow/pipelines/pull/7912)
|
||||||
* Allow artifact inputs in pipeline definition. [\#8044](https://github.com/kubeflow/pipelines/pull/8044)
|
* Allow artifact inputs in pipeline definition. [\#8044](https://github.com/kubeflow/pipelines/pull/8044)
|
||||||
|
* Support task-level retry policy [\#7867](https://github.com/kubeflow/pipelines/pull/7867)
|
||||||
|
* Support multiple exit handlers per pipeline [\#8088](https://github.com/kubeflow/pipelines/pull/8088)
|
||||||
|
|
||||||
## Documentation Updates
|
## Documentation Updates
|
||||||
|
|
||||||
|
|
|
@ -44,6 +44,7 @@ CONFIG = {
|
||||||
'component_with_pip_index_urls',
|
'component_with_pip_index_urls',
|
||||||
'container_component_with_no_inputs',
|
'container_component_with_no_inputs',
|
||||||
'two_step_pipeline_containerized',
|
'two_step_pipeline_containerized',
|
||||||
|
'pipeline_with_multiple_exit_handlers',
|
||||||
],
|
],
|
||||||
'test_data_dir': 'sdk/python/kfp/compiler/test_data/pipelines',
|
'test_data_dir': 'sdk/python/kfp/compiler/test_data/pipelines',
|
||||||
'config': {
|
'config': {
|
||||||
|
|
|
@ -148,8 +148,6 @@ class Compiler:
|
||||||
if not dsl_pipeline.tasks:
|
if not dsl_pipeline.tasks:
|
||||||
raise ValueError('Task is missing from pipeline.')
|
raise ValueError('Task is missing from pipeline.')
|
||||||
|
|
||||||
self._validate_exit_handler(dsl_pipeline)
|
|
||||||
|
|
||||||
pipeline_inputs = pipeline_meta.inputs or {}
|
pipeline_inputs = pipeline_meta.inputs or {}
|
||||||
|
|
||||||
# Verify that pipeline_parameters_override contains only input names
|
# Verify that pipeline_parameters_override contains only input names
|
||||||
|
@ -186,45 +184,6 @@ class Compiler:
|
||||||
|
|
||||||
return pipeline_spec
|
return pipeline_spec
|
||||||
|
|
||||||
def _validate_exit_handler(self,
|
|
||||||
pipeline: pipeline_context.Pipeline) -> None:
|
|
||||||
"""Makes sure there is only one global exit handler.
|
|
||||||
|
|
||||||
This is temporary to be compatible with KFP v1.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError if there are more than one exit handler.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _validate_exit_handler_helper(
|
|
||||||
group: tasks_group.TasksGroup,
|
|
||||||
exiting_task_names: List[str],
|
|
||||||
handler_exists: bool,
|
|
||||||
) -> None:
|
|
||||||
|
|
||||||
if isinstance(group, dsl.ExitHandler):
|
|
||||||
if handler_exists or len(exiting_task_names) > 1:
|
|
||||||
raise ValueError(
|
|
||||||
'Only one global exit_handler is allowed and all ops need to be included.'
|
|
||||||
)
|
|
||||||
handler_exists = True
|
|
||||||
|
|
||||||
if group.tasks:
|
|
||||||
exiting_task_names.extend([x.name for x in group.tasks])
|
|
||||||
|
|
||||||
for group in group.groups:
|
|
||||||
_validate_exit_handler_helper(
|
|
||||||
group=group,
|
|
||||||
exiting_task_names=exiting_task_names,
|
|
||||||
handler_exists=handler_exists,
|
|
||||||
)
|
|
||||||
|
|
||||||
_validate_exit_handler_helper(
|
|
||||||
group=pipeline.groups[0],
|
|
||||||
exiting_task_names=[],
|
|
||||||
handler_exists=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_pipeline_spec(
|
def _create_pipeline_spec(
|
||||||
self,
|
self,
|
||||||
pipeline_args: List[pipeline_channel.PipelineChannel],
|
pipeline_args: List[pipeline_channel.PipelineChannel],
|
||||||
|
@ -301,49 +260,11 @@ class Compiler:
|
||||||
name_to_for_loop_group=name_to_for_loop_group,
|
name_to_for_loop_group=name_to_for_loop_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: refactor to support multiple exit handler per pipeline.
|
builder.build_exit_handler_groups_recursively(
|
||||||
if pipeline.groups[0].groups:
|
parent_group=root_group,
|
||||||
first_group = pipeline.groups[0].groups[0]
|
pipeline_spec=pipeline_spec,
|
||||||
if isinstance(first_group, dsl.ExitHandler):
|
deployment_config=deployment_config,
|
||||||
exit_task = first_group.exit_task
|
)
|
||||||
exit_task_name = component_utils.sanitize_task_name(
|
|
||||||
exit_task.name)
|
|
||||||
exit_handler_group_task_name = component_utils.sanitize_task_name(
|
|
||||||
first_group.name)
|
|
||||||
input_parameters_in_current_dag = [
|
|
||||||
input_name for input_name in
|
|
||||||
pipeline_spec.root.input_definitions.parameters
|
|
||||||
]
|
|
||||||
exit_task_task_spec = builder.build_task_spec_for_exit_task(
|
|
||||||
task=exit_task,
|
|
||||||
dependent_task=exit_handler_group_task_name,
|
|
||||||
pipeline_inputs=pipeline_spec.root.input_definitions,
|
|
||||||
)
|
|
||||||
|
|
||||||
exit_task_component_spec = builder.build_component_spec_for_exit_task(
|
|
||||||
task=exit_task)
|
|
||||||
|
|
||||||
exit_task_container_spec = builder.build_container_spec_for_task(
|
|
||||||
task=exit_task)
|
|
||||||
|
|
||||||
# Add exit task task spec
|
|
||||||
pipeline_spec.root.dag.tasks[exit_task_name].CopyFrom(
|
|
||||||
exit_task_task_spec)
|
|
||||||
|
|
||||||
# Add exit task component spec if it does not exist.
|
|
||||||
component_name = exit_task_task_spec.component_ref.name
|
|
||||||
if component_name not in pipeline_spec.components:
|
|
||||||
pipeline_spec.components[component_name].CopyFrom(
|
|
||||||
exit_task_component_spec)
|
|
||||||
|
|
||||||
# Add exit task container spec if it does not exist.
|
|
||||||
executor_label = exit_task_component_spec.executor_label
|
|
||||||
if executor_label not in deployment_config.executors:
|
|
||||||
deployment_config.executors[
|
|
||||||
executor_label].container.CopyFrom(
|
|
||||||
exit_task_container_spec)
|
|
||||||
pipeline_spec.deployment_spec.update(
|
|
||||||
json_format.MessageToDict(deployment_config))
|
|
||||||
|
|
||||||
return pipeline_spec
|
return pipeline_spec
|
||||||
|
|
||||||
|
@ -705,14 +626,12 @@ class Compiler:
|
||||||
task2=task,
|
task2=task,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If a task depends on a condition group or a loop group, it
|
# a task cannot depend on a task created in a for loop group since individual PipelineTask variables are reassigned after each loop iteration
|
||||||
# must explicitly dependent on a task inside the group. This
|
|
||||||
# should not be allowed, because it leads to ambiguous
|
|
||||||
# expectations for runtime behaviors.
|
|
||||||
dependent_group = group_name_to_group.get(
|
dependent_group = group_name_to_group.get(
|
||||||
upstream_groups[0], None)
|
upstream_groups[0], None)
|
||||||
if isinstance(dependent_group,
|
if isinstance(dependent_group,
|
||||||
(tasks_group.Condition, tasks_group.ParallelFor)):
|
(tasks_group.ParallelFor, tasks_group.Condition,
|
||||||
|
tasks_group.ExitHandler)):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f'Task {task.name} cannot dependent on any task inside'
|
f'Task {task.name} cannot dependent on any task inside'
|
||||||
f' the group: {upstream_groups[0]}.')
|
f' the group: {upstream_groups[0]}.')
|
||||||
|
|
|
@ -440,7 +440,7 @@ class TestCompilePipeline(parameterized.TestCase):
|
||||||
compiler.Compiler().compile(
|
compiler.Compiler().compile(
|
||||||
pipeline_func=my_pipeline, package_path='result.yaml')
|
pipeline_func=my_pipeline, package_path='result.yaml')
|
||||||
|
|
||||||
def test_invalid_after_dependency(self):
|
def test_invalid_data_dependency_loop(self):
|
||||||
|
|
||||||
@dsl.component
|
@dsl.component
|
||||||
def producer_op() -> str:
|
def producer_op() -> str:
|
||||||
|
@ -451,30 +451,7 @@ class TestCompilePipeline(parameterized.TestCase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@dsl.pipeline(name='test-pipeline')
|
@dsl.pipeline(name='test-pipeline')
|
||||||
def my_pipeline(text: str):
|
def my_pipeline(val: bool):
|
||||||
with dsl.Condition(text == 'a'):
|
|
||||||
producer_task = producer_op()
|
|
||||||
|
|
||||||
dummy_op().after(producer_task)
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
RuntimeError,
|
|
||||||
'Task dummy-op cannot dependent on any task inside the group:'):
|
|
||||||
compiler.Compiler().compile(
|
|
||||||
pipeline_func=my_pipeline, package_path='result.yaml')
|
|
||||||
|
|
||||||
def test_invalid_data_dependency(self):
|
|
||||||
|
|
||||||
@dsl.component
|
|
||||||
def producer_op() -> str:
|
|
||||||
return 'a'
|
|
||||||
|
|
||||||
@dsl.component
|
|
||||||
def dummy_op(msg: str = ''):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@dsl.pipeline(name='test-pipeline')
|
|
||||||
def my_pipeline(text: bool):
|
|
||||||
with dsl.ParallelFor(['a, b']):
|
with dsl.ParallelFor(['a, b']):
|
||||||
producer_task = producer_op()
|
producer_task = producer_op()
|
||||||
|
|
||||||
|
@ -483,8 +460,125 @@ class TestCompilePipeline(parameterized.TestCase):
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError,
|
RuntimeError,
|
||||||
'Task dummy-op cannot dependent on any task inside the group:'):
|
'Task dummy-op cannot dependent on any task inside the group:'):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
package_path = os.path.join(tmpdir, 'pipeline.yaml')
|
||||||
|
compiler.Compiler().compile(
|
||||||
|
pipeline_func=my_pipeline, package_path=package_path)
|
||||||
|
|
||||||
|
def test_valid_data_dependency_loop(self):
|
||||||
|
|
||||||
|
@dsl.component
|
||||||
|
def producer_op() -> str:
|
||||||
|
return 'a'
|
||||||
|
|
||||||
|
@dsl.component
|
||||||
|
def dummy_op(msg: str = ''):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@dsl.pipeline(name='test-pipeline')
|
||||||
|
def my_pipeline(val: bool):
|
||||||
|
with dsl.ParallelFor(['a, b']):
|
||||||
|
producer_task = producer_op()
|
||||||
|
dummy_op(msg=producer_task.output)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
package_path = os.path.join(tmpdir, 'pipeline.yaml')
|
||||||
compiler.Compiler().compile(
|
compiler.Compiler().compile(
|
||||||
pipeline_func=my_pipeline, package_path='result.yaml')
|
pipeline_func=my_pipeline, package_path=package_path)
|
||||||
|
|
||||||
|
def test_invalid_data_dependency_condition(self):
|
||||||
|
|
||||||
|
@dsl.component
|
||||||
|
def producer_op() -> str:
|
||||||
|
return 'a'
|
||||||
|
|
||||||
|
@dsl.component
|
||||||
|
def dummy_op(msg: str = ''):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@dsl.pipeline(name='test-pipeline')
|
||||||
|
def my_pipeline(val: bool):
|
||||||
|
with dsl.Condition(val == False):
|
||||||
|
producer_task = producer_op()
|
||||||
|
|
||||||
|
dummy_op(msg=producer_task.output)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
'Task dummy-op cannot dependent on any task inside the group:'):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
package_path = os.path.join(tmpdir, 'pipeline.yaml')
|
||||||
|
compiler.Compiler().compile(
|
||||||
|
pipeline_func=my_pipeline, package_path=package_path)
|
||||||
|
|
||||||
|
def test_valid_data_dependency_condition(self):
|
||||||
|
|
||||||
|
@dsl.component
|
||||||
|
def producer_op() -> str:
|
||||||
|
return 'a'
|
||||||
|
|
||||||
|
@dsl.component
|
||||||
|
def dummy_op(msg: str = ''):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@dsl.pipeline(name='test-pipeline')
|
||||||
|
def my_pipeline(val: bool):
|
||||||
|
with dsl.Condition(val == False):
|
||||||
|
producer_task = producer_op()
|
||||||
|
dummy_op(msg=producer_task.output)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
package_path = os.path.join(tmpdir, 'pipeline.yaml')
|
||||||
|
compiler.Compiler().compile(
|
||||||
|
pipeline_func=my_pipeline, package_path=package_path)
|
||||||
|
|
||||||
|
def test_invalid_data_dependency_exit_handler(self):
|
||||||
|
|
||||||
|
@dsl.component
|
||||||
|
def producer_op() -> str:
|
||||||
|
return 'a'
|
||||||
|
|
||||||
|
@dsl.component
|
||||||
|
def dummy_op(msg: str = ''):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@dsl.pipeline(name='test-pipeline')
|
||||||
|
def my_pipeline(val: bool):
|
||||||
|
first_producer = producer_op()
|
||||||
|
with dsl.ExitHandler(first_producer):
|
||||||
|
producer_task = producer_op()
|
||||||
|
|
||||||
|
dummy_op(msg=producer_task.output)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
'Task dummy-op cannot dependent on any task inside the group:'):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
package_path = os.path.join(tmpdir, 'pipeline.yaml')
|
||||||
|
compiler.Compiler().compile(
|
||||||
|
pipeline_func=my_pipeline, package_path=package_path)
|
||||||
|
|
||||||
|
def test_valid_data_dependency_exit_handler(self):
|
||||||
|
|
||||||
|
@dsl.component
|
||||||
|
def producer_op() -> str:
|
||||||
|
return 'a'
|
||||||
|
|
||||||
|
@dsl.component
|
||||||
|
def dummy_op(msg: str = ''):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@dsl.pipeline(name='test-pipeline')
|
||||||
|
def my_pipeline(val: bool):
|
||||||
|
first_producer = producer_op()
|
||||||
|
with dsl.ExitHandler(first_producer):
|
||||||
|
producer_task = producer_op()
|
||||||
|
dummy_op(msg=producer_task.output)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
package_path = os.path.join(tmpdir, 'pipeline.yaml')
|
||||||
|
compiler.Compiler().compile(
|
||||||
|
pipeline_func=my_pipeline, package_path=package_path)
|
||||||
|
|
||||||
def test_use_task_final_status_in_non_exit_op(self):
|
def test_use_task_final_status_in_non_exit_op(self):
|
||||||
|
|
||||||
|
@ -527,7 +621,6 @@ implementation:
|
||||||
pipeline_func=my_pipeline, package_path='result.yaml')
|
pipeline_func=my_pipeline, package_path='result.yaml')
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=import-outside-toplevel,unused-import,import-error,redefined-outer-name,reimported
|
|
||||||
class V2NamespaceAliasTest(unittest.TestCase):
|
class V2NamespaceAliasTest(unittest.TestCase):
|
||||||
"""Test that imports of both modules and objects are aliased (e.g. all
|
"""Test that imports of both modules and objects are aliased (e.g. all
|
||||||
import path variants work)."""
|
import path variants work)."""
|
||||||
|
@ -536,7 +629,7 @@ class V2NamespaceAliasTest(unittest.TestCase):
|
||||||
# the kfp.v2 module is loaded. Due to the way we run tests in CI/CD, we cannot ensure that the kfp.v2 module will first be loaded in these tests,
|
# the kfp.v2 module is loaded. Due to the way we run tests in CI/CD, we cannot ensure that the kfp.v2 module will first be loaded in these tests,
|
||||||
# so we do not test for the DeprecationWarning here.
|
# so we do not test for the DeprecationWarning here.
|
||||||
|
|
||||||
def test_import_namespace(self): # pylint: disable=no-self-use
|
def test_import_namespace(self):
|
||||||
from kfp import v2
|
from kfp import v2
|
||||||
|
|
||||||
@v2.dsl.component
|
@v2.dsl.component
|
||||||
|
@ -560,7 +653,7 @@ class V2NamespaceAliasTest(unittest.TestCase):
|
||||||
with open(temp_filepath, 'r') as f:
|
with open(temp_filepath, 'r') as f:
|
||||||
yaml.load(f)
|
yaml.load(f)
|
||||||
|
|
||||||
def test_import_modules(self): # pylint: disable=no-self-use
|
def test_import_modules(self):
|
||||||
from kfp.v2 import compiler
|
from kfp.v2 import compiler
|
||||||
from kfp.v2 import dsl
|
from kfp.v2 import dsl
|
||||||
|
|
||||||
|
@ -584,7 +677,7 @@ class V2NamespaceAliasTest(unittest.TestCase):
|
||||||
with open(temp_filepath, 'r') as f:
|
with open(temp_filepath, 'r') as f:
|
||||||
yaml.load(f)
|
yaml.load(f)
|
||||||
|
|
||||||
def test_import_object(self): # pylint: disable=no-self-use
|
def test_import_object(self):
|
||||||
from kfp.v2.compiler import Compiler
|
from kfp.v2.compiler import Compiler
|
||||||
from kfp.v2.dsl import component
|
from kfp.v2.dsl import component
|
||||||
from kfp.v2.dsl import pipeline
|
from kfp.v2.dsl import pipeline
|
||||||
|
@ -1125,5 +1218,82 @@ class TestSetRetryCompilation(unittest.TestCase):
|
||||||
self.assertEqual(retry_policy.backoff_max_duration.seconds, 3600)
|
self.assertEqual(retry_policy.backoff_max_duration.seconds, 3600)
|
||||||
|
|
||||||
|
|
||||||
|
from google.protobuf import json_format
|
||||||
|
|
||||||
|
|
||||||
|
class TestMultipleExitHandlerCompilation(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_basic(self):
|
||||||
|
|
||||||
|
@dsl.component
|
||||||
|
def print_op(message: str):
|
||||||
|
print(message)
|
||||||
|
|
||||||
|
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
|
||||||
|
def my_pipeline():
|
||||||
|
first_exit_task = print_op(message='First exit task.')
|
||||||
|
|
||||||
|
with dsl.ExitHandler(first_exit_task):
|
||||||
|
print_op(message='Inside first exit handler.')
|
||||||
|
|
||||||
|
second_exit_task = print_op(message='Second exit task.')
|
||||||
|
with dsl.ExitHandler(second_exit_task):
|
||||||
|
print_op(message='Inside second exit handler.')
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
package_path = os.path.join(tempdir, 'pipeline.yaml')
|
||||||
|
compiler.Compiler().compile(
|
||||||
|
pipeline_func=my_pipeline, package_path=package_path)
|
||||||
|
pipeline_spec = pipeline_spec_from_file(package_path)
|
||||||
|
# check that the exit handler dags exist
|
||||||
|
self.assertEqual(
|
||||||
|
pipeline_spec.components['comp-exit-handler-1'].dag
|
||||||
|
.tasks['print-op-2'].inputs.parameters['message'].runtime_value
|
||||||
|
.constant.string_value, 'Inside first exit handler.')
|
||||||
|
self.assertEqual(
|
||||||
|
pipeline_spec.components['comp-exit-handler-2'].dag
|
||||||
|
.tasks['print-op-4'].inputs.parameters['message'].runtime_value
|
||||||
|
.constant.string_value, 'Inside second exit handler.')
|
||||||
|
# check that the exit handler dags are in the root dag
|
||||||
|
self.assertIn('exit-handler-1', pipeline_spec.root.dag.tasks)
|
||||||
|
self.assertIn('exit-handler-2', pipeline_spec.root.dag.tasks)
|
||||||
|
# check that the exit tasks are in the root dag
|
||||||
|
self.assertIn('print-op', pipeline_spec.root.dag.tasks)
|
||||||
|
self.assertEqual(
|
||||||
|
pipeline_spec.root.dag.tasks['print-op'].inputs
|
||||||
|
.parameters['message'].runtime_value.constant.string_value,
|
||||||
|
'First exit task.')
|
||||||
|
self.assertIn('print-op-3', pipeline_spec.root.dag.tasks)
|
||||||
|
self.assertEqual(
|
||||||
|
pipeline_spec.root.dag.tasks['print-op-3'].inputs
|
||||||
|
.parameters['message'].runtime_value.constant.string_value,
|
||||||
|
'Second exit task.')
|
||||||
|
|
||||||
|
def test_nested_unsupported(self):
|
||||||
|
|
||||||
|
@dsl.component
|
||||||
|
def print_op(message: str):
|
||||||
|
print(message)
|
||||||
|
|
||||||
|
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
|
||||||
|
def my_pipeline():
|
||||||
|
first_exit_task = print_op(message='First exit task.')
|
||||||
|
|
||||||
|
with dsl.ExitHandler(first_exit_task):
|
||||||
|
print_op(message='Inside first exit handler.')
|
||||||
|
|
||||||
|
second_exit_task = print_op(message='Second exit task.')
|
||||||
|
with dsl.ExitHandler(second_exit_task):
|
||||||
|
print_op(message='Inside second exit handler.')
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
r'ExitHandler can only be used within the outermost scope of a pipeline function definition\.'
|
||||||
|
):
|
||||||
|
|
||||||
|
compiler.Compiler().compile(
|
||||||
|
pipeline_func=my_pipeline, package_path='output.yaml')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -23,6 +23,7 @@ from kfp import dsl
|
||||||
from kfp.compiler import pipeline_spec_builder as builder
|
from kfp.compiler import pipeline_spec_builder as builder
|
||||||
from kfp.components import for_loop
|
from kfp.components import for_loop
|
||||||
from kfp.components import pipeline_channel
|
from kfp.components import pipeline_channel
|
||||||
|
from kfp.components import pipeline_context
|
||||||
from kfp.components import pipeline_task
|
from kfp.components import pipeline_task
|
||||||
from kfp.components import placeholders
|
from kfp.components import placeholders
|
||||||
from kfp.components import structures
|
from kfp.components import structures
|
||||||
|
@ -34,6 +35,13 @@ from kfp.components.types import type_utils
|
||||||
from kfp.pipeline_spec import pipeline_spec_pb2
|
from kfp.pipeline_spec import pipeline_spec_pb2
|
||||||
|
|
||||||
GroupOrTaskType = Union[tasks_group.TasksGroup, pipeline_task.PipelineTask]
|
GroupOrTaskType = Union[tasks_group.TasksGroup, pipeline_task.PipelineTask]
|
||||||
|
# must be defined here to avoid circular imports
|
||||||
|
group_type_to_dsl_class = {
|
||||||
|
tasks_group.TasksGroupType.PIPELINE: pipeline_context.Pipeline,
|
||||||
|
tasks_group.TasksGroupType.CONDITION: tasks_group.Condition,
|
||||||
|
tasks_group.TasksGroupType.FOR_LOOP: tasks_group.ParallelFor,
|
||||||
|
tasks_group.TasksGroupType.EXIT_HANDLER: tasks_group.ExitHandler,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _additional_input_name_for_pipeline_channel(
|
def _additional_input_name_for_pipeline_channel(
|
||||||
|
@ -772,7 +780,7 @@ def build_task_spec_for_exit_task(
|
||||||
pipeline_spec_pb2.PipelineTaskSpec.TriggerPolicy.TriggerStrategy
|
pipeline_spec_pb2.PipelineTaskSpec.TriggerPolicy.TriggerStrategy
|
||||||
.ALL_UPSTREAM_TASKS_COMPLETED)
|
.ALL_UPSTREAM_TASKS_COMPLETED)
|
||||||
|
|
||||||
for input_name, input_spec in task.component_spec.inputs.items():
|
for input_name, input_spec in (task.component_spec.inputs or {}).items():
|
||||||
if type_utils.is_task_final_status_type(input_spec.type):
|
if type_utils.is_task_final_status_type(input_spec.type):
|
||||||
pipeline_task_spec.inputs.parameters[
|
pipeline_task_spec.inputs.parameters[
|
||||||
input_name].task_final_status.producer_task = dependent_task
|
input_name].task_final_status.producer_task = dependent_task
|
||||||
|
@ -1184,6 +1192,61 @@ def build_spec_by_group(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_exit_handler_groups_recursively(
|
||||||
|
parent_group: tasks_group.TasksGroup,
|
||||||
|
pipeline_spec: pipeline_spec_pb2.PipelineSpec,
|
||||||
|
deployment_config: pipeline_spec_pb2.PipelineDeploymentConfig,
|
||||||
|
):
|
||||||
|
if not parent_group.groups:
|
||||||
|
return
|
||||||
|
for group in parent_group.groups:
|
||||||
|
if isinstance(group, dsl.ExitHandler):
|
||||||
|
exit_task = group.exit_task
|
||||||
|
exit_task_name = utils.sanitize_task_name(exit_task.name)
|
||||||
|
exit_handler_group_task_name = utils.sanitize_task_name(group.name)
|
||||||
|
|
||||||
|
exit_task_task_spec = builder.build_task_spec_for_exit_task(
|
||||||
|
task=exit_task,
|
||||||
|
dependent_task=exit_handler_group_task_name,
|
||||||
|
pipeline_inputs=pipeline_spec.root.input_definitions,
|
||||||
|
)
|
||||||
|
|
||||||
|
exit_task_component_spec = builder.build_component_spec_for_exit_task(
|
||||||
|
task=exit_task)
|
||||||
|
|
||||||
|
exit_task_container_spec = builder.build_container_spec_for_task(
|
||||||
|
task=exit_task)
|
||||||
|
|
||||||
|
# remove this if block to support nested exit handlers
|
||||||
|
if not parent_group.is_root:
|
||||||
|
raise ValueError(
|
||||||
|
f'{dsl.ExitHandler.__name__} can only be used within the outermost scope of a pipeline function definition. Using an {dsl.ExitHandler.__name__} within {group_type_to_dsl_class[parent_group.group_type].__name__} {parent_group.name} is not allowed.'
|
||||||
|
)
|
||||||
|
|
||||||
|
parent_dag = pipeline_spec.root.dag if parent_group.is_root else pipeline_spec.components[
|
||||||
|
utils.sanitize_component_name(parent_group.name)].dag
|
||||||
|
|
||||||
|
parent_dag.tasks[exit_task_name].CopyFrom(exit_task_task_spec)
|
||||||
|
|
||||||
|
# Add exit task component spec if it does not exist.
|
||||||
|
component_name = exit_task_task_spec.component_ref.name
|
||||||
|
if component_name not in pipeline_spec.components:
|
||||||
|
pipeline_spec.components[component_name].CopyFrom(
|
||||||
|
exit_task_component_spec)
|
||||||
|
|
||||||
|
# Add exit task container spec if it does not exist.
|
||||||
|
executor_label = exit_task_component_spec.executor_label
|
||||||
|
if executor_label not in deployment_config.executors:
|
||||||
|
deployment_config.executors[executor_label].container.CopyFrom(
|
||||||
|
exit_task_container_spec)
|
||||||
|
pipeline_spec.deployment_spec.update(
|
||||||
|
json_format.MessageToDict(deployment_config))
|
||||||
|
build_exit_handler_groups_recursively(
|
||||||
|
parent_group=group,
|
||||||
|
pipeline_spec=pipeline_spec,
|
||||||
|
deployment_config=deployment_config)
|
||||||
|
|
||||||
|
|
||||||
def get_parent_groups(
|
def get_parent_groups(
|
||||||
root_group: tasks_group.TasksGroup,
|
root_group: tasks_group.TasksGroup,
|
||||||
) -> Tuple[Mapping[str, List[GroupOrTaskType]], Mapping[str,
|
) -> Tuple[Mapping[str, List[GroupOrTaskType]], Mapping[str,
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
# Copyright 2022 The Kubeflow Authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Pipeline using multiple ExitHandlers."""
|
||||||
|
|
||||||
|
from kfp import compiler
|
||||||
|
from kfp import dsl
|
||||||
|
from kfp.dsl import component
|
||||||
|
|
||||||
|
|
||||||
|
@component
|
||||||
|
def print_op(message: str):
|
||||||
|
"""Prints a message."""
|
||||||
|
print(message)
|
||||||
|
|
||||||
|
|
||||||
|
@component
|
||||||
|
def fail_op(message: str):
|
||||||
|
"""Fails."""
|
||||||
|
import sys
|
||||||
|
print(message)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
|
||||||
|
def my_pipeline(message: str = 'Hello World!'):
|
||||||
|
|
||||||
|
first_exit_task = print_op(message='First exit handler has worked!')
|
||||||
|
|
||||||
|
with dsl.ExitHandler(first_exit_task):
|
||||||
|
first_exit_print_task = print_op(message=message)
|
||||||
|
print(first_exit_print_task.outputs)
|
||||||
|
fail_op(message='Task failed.')
|
||||||
|
|
||||||
|
second_exit_task = print_op(message='Second exit handler has worked!')
|
||||||
|
|
||||||
|
with dsl.ExitHandler(second_exit_task):
|
||||||
|
print_op(message=message)
|
||||||
|
|
||||||
|
third_exit_task = print_op(message='Third exit handler has worked!')
|
||||||
|
|
||||||
|
with dsl.ExitHandler(third_exit_task):
|
||||||
|
print_op(message=message)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
compiler.Compiler().compile(
|
||||||
|
pipeline_func=my_pipeline,
|
||||||
|
package_path=__file__.replace('.py', '.yaml'))
|
|
@ -0,0 +1,387 @@
|
||||||
|
components:
|
||||||
|
comp-exit-handler-1:
|
||||||
|
dag:
|
||||||
|
tasks:
|
||||||
|
fail-op:
|
||||||
|
cachingOptions:
|
||||||
|
enableCache: true
|
||||||
|
componentRef:
|
||||||
|
name: comp-fail-op
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
message:
|
||||||
|
runtimeValue:
|
||||||
|
constant: Task failed.
|
||||||
|
taskInfo:
|
||||||
|
name: fail-op
|
||||||
|
print-op-2:
|
||||||
|
cachingOptions:
|
||||||
|
enableCache: true
|
||||||
|
componentRef:
|
||||||
|
name: comp-print-op-2
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
message:
|
||||||
|
componentInputParameter: pipelinechannel--message
|
||||||
|
taskInfo:
|
||||||
|
name: print-op-2
|
||||||
|
inputDefinitions:
|
||||||
|
parameters:
|
||||||
|
pipelinechannel--message:
|
||||||
|
parameterType: STRING
|
||||||
|
comp-exit-handler-2:
|
||||||
|
dag:
|
||||||
|
tasks:
|
||||||
|
print-op-4:
|
||||||
|
cachingOptions:
|
||||||
|
enableCache: true
|
||||||
|
componentRef:
|
||||||
|
name: comp-print-op-4
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
message:
|
||||||
|
componentInputParameter: pipelinechannel--message
|
||||||
|
taskInfo:
|
||||||
|
name: print-op-4
|
||||||
|
inputDefinitions:
|
||||||
|
parameters:
|
||||||
|
pipelinechannel--message:
|
||||||
|
parameterType: STRING
|
||||||
|
comp-exit-handler-3:
|
||||||
|
dag:
|
||||||
|
tasks:
|
||||||
|
print-op-6:
|
||||||
|
cachingOptions:
|
||||||
|
enableCache: true
|
||||||
|
componentRef:
|
||||||
|
name: comp-print-op-6
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
message:
|
||||||
|
componentInputParameter: pipelinechannel--message
|
||||||
|
taskInfo:
|
||||||
|
name: print-op-6
|
||||||
|
inputDefinitions:
|
||||||
|
parameters:
|
||||||
|
pipelinechannel--message:
|
||||||
|
parameterType: STRING
|
||||||
|
comp-fail-op:
|
||||||
|
executorLabel: exec-fail-op
|
||||||
|
inputDefinitions:
|
||||||
|
parameters:
|
||||||
|
message:
|
||||||
|
parameterType: STRING
|
||||||
|
comp-print-op:
|
||||||
|
executorLabel: exec-print-op
|
||||||
|
inputDefinitions:
|
||||||
|
parameters:
|
||||||
|
message:
|
||||||
|
parameterType: STRING
|
||||||
|
comp-print-op-2:
|
||||||
|
executorLabel: exec-print-op-2
|
||||||
|
inputDefinitions:
|
||||||
|
parameters:
|
||||||
|
message:
|
||||||
|
parameterType: STRING
|
||||||
|
comp-print-op-3:
|
||||||
|
executorLabel: exec-print-op-3
|
||||||
|
inputDefinitions:
|
||||||
|
parameters:
|
||||||
|
message:
|
||||||
|
parameterType: STRING
|
||||||
|
comp-print-op-4:
|
||||||
|
executorLabel: exec-print-op-4
|
||||||
|
inputDefinitions:
|
||||||
|
parameters:
|
||||||
|
message:
|
||||||
|
parameterType: STRING
|
||||||
|
comp-print-op-5:
|
||||||
|
executorLabel: exec-print-op-5
|
||||||
|
inputDefinitions:
|
||||||
|
parameters:
|
||||||
|
message:
|
||||||
|
parameterType: STRING
|
||||||
|
comp-print-op-6:
|
||||||
|
executorLabel: exec-print-op-6
|
||||||
|
inputDefinitions:
|
||||||
|
parameters:
|
||||||
|
message:
|
||||||
|
parameterType: STRING
|
||||||
|
deploymentSpec:
|
||||||
|
executors:
|
||||||
|
exec-fail-op:
|
||||||
|
container:
|
||||||
|
args:
|
||||||
|
- --executor_input
|
||||||
|
- '{{$}}'
|
||||||
|
- --function_to_execute
|
||||||
|
- fail_op
|
||||||
|
command:
|
||||||
|
- sh
|
||||||
|
- -c
|
||||||
|
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
|
||||||
|
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
|
||||||
|
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
|
||||||
|
\ && \"$0\" \"$@\"\n"
|
||||||
|
- sh
|
||||||
|
- -ec
|
||||||
|
- 'program_path=$(mktemp -d)
|
||||||
|
|
||||||
|
printf "%s" "$0" > "$program_path/ephemeral_component.py"
|
||||||
|
|
||||||
|
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
|
||||||
|
|
||||||
|
'
|
||||||
|
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
|
||||||
|
\ *\n\ndef fail_op(message: str):\n \"\"\"Fails.\"\"\"\n import sys\n\
|
||||||
|
\ print(message)\n sys.exit(1)\n\n"
|
||||||
|
image: python:3.7
|
||||||
|
exec-print-op:
|
||||||
|
container:
|
||||||
|
args:
|
||||||
|
- --executor_input
|
||||||
|
- '{{$}}'
|
||||||
|
- --function_to_execute
|
||||||
|
- print_op
|
||||||
|
command:
|
||||||
|
- sh
|
||||||
|
- -c
|
||||||
|
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
|
||||||
|
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
|
||||||
|
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
|
||||||
|
\ && \"$0\" \"$@\"\n"
|
||||||
|
- sh
|
||||||
|
- -ec
|
||||||
|
- 'program_path=$(mktemp -d)
|
||||||
|
|
||||||
|
printf "%s" "$0" > "$program_path/ephemeral_component.py"
|
||||||
|
|
||||||
|
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
|
||||||
|
|
||||||
|
'
|
||||||
|
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
|
||||||
|
\ *\n\ndef print_op(message: str):\n \"\"\"Prints a message.\"\"\"\n\
|
||||||
|
\ print(message)\n\n"
|
||||||
|
image: python:3.7
|
||||||
|
exec-print-op-2:
|
||||||
|
container:
|
||||||
|
args:
|
||||||
|
- --executor_input
|
||||||
|
- '{{$}}'
|
||||||
|
- --function_to_execute
|
||||||
|
- print_op
|
||||||
|
command:
|
||||||
|
- sh
|
||||||
|
- -c
|
||||||
|
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
|
||||||
|
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
|
||||||
|
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
|
||||||
|
\ && \"$0\" \"$@\"\n"
|
||||||
|
- sh
|
||||||
|
- -ec
|
||||||
|
- 'program_path=$(mktemp -d)
|
||||||
|
|
||||||
|
printf "%s" "$0" > "$program_path/ephemeral_component.py"
|
||||||
|
|
||||||
|
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
|
||||||
|
|
||||||
|
'
|
||||||
|
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
|
||||||
|
\ *\n\ndef print_op(message: str):\n \"\"\"Prints a message.\"\"\"\n\
|
||||||
|
\ print(message)\n\n"
|
||||||
|
image: python:3.7
|
||||||
|
exec-print-op-3:
|
||||||
|
container:
|
||||||
|
args:
|
||||||
|
- --executor_input
|
||||||
|
- '{{$}}'
|
||||||
|
- --function_to_execute
|
||||||
|
- print_op
|
||||||
|
command:
|
||||||
|
- sh
|
||||||
|
- -c
|
||||||
|
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
|
||||||
|
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
|
||||||
|
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
|
||||||
|
\ && \"$0\" \"$@\"\n"
|
||||||
|
- sh
|
||||||
|
- -ec
|
||||||
|
- 'program_path=$(mktemp -d)
|
||||||
|
|
||||||
|
printf "%s" "$0" > "$program_path/ephemeral_component.py"
|
||||||
|
|
||||||
|
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
|
||||||
|
|
||||||
|
'
|
||||||
|
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
|
||||||
|
\ *\n\ndef print_op(message: str):\n \"\"\"Prints a message.\"\"\"\n\
|
||||||
|
\ print(message)\n\n"
|
||||||
|
image: python:3.7
|
||||||
|
exec-print-op-4:
|
||||||
|
container:
|
||||||
|
args:
|
||||||
|
- --executor_input
|
||||||
|
- '{{$}}'
|
||||||
|
- --function_to_execute
|
||||||
|
- print_op
|
||||||
|
command:
|
||||||
|
- sh
|
||||||
|
- -c
|
||||||
|
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
|
||||||
|
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
|
||||||
|
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
|
||||||
|
\ && \"$0\" \"$@\"\n"
|
||||||
|
- sh
|
||||||
|
- -ec
|
||||||
|
- 'program_path=$(mktemp -d)
|
||||||
|
|
||||||
|
printf "%s" "$0" > "$program_path/ephemeral_component.py"
|
||||||
|
|
||||||
|
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
|
||||||
|
|
||||||
|
'
|
||||||
|
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
|
||||||
|
\ *\n\ndef print_op(message: str):\n \"\"\"Prints a message.\"\"\"\n\
|
||||||
|
\ print(message)\n\n"
|
||||||
|
image: python:3.7
|
||||||
|
exec-print-op-5:
|
||||||
|
container:
|
||||||
|
args:
|
||||||
|
- --executor_input
|
||||||
|
- '{{$}}'
|
||||||
|
- --function_to_execute
|
||||||
|
- print_op
|
||||||
|
command:
|
||||||
|
- sh
|
||||||
|
- -c
|
||||||
|
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
|
||||||
|
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
|
||||||
|
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
|
||||||
|
\ && \"$0\" \"$@\"\n"
|
||||||
|
- sh
|
||||||
|
- -ec
|
||||||
|
- 'program_path=$(mktemp -d)
|
||||||
|
|
||||||
|
printf "%s" "$0" > "$program_path/ephemeral_component.py"
|
||||||
|
|
||||||
|
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
|
||||||
|
|
||||||
|
'
|
||||||
|
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
|
||||||
|
\ *\n\ndef print_op(message: str):\n \"\"\"Prints a message.\"\"\"\n\
|
||||||
|
\ print(message)\n\n"
|
||||||
|
image: python:3.7
|
||||||
|
exec-print-op-6:
|
||||||
|
container:
|
||||||
|
args:
|
||||||
|
- --executor_input
|
||||||
|
- '{{$}}'
|
||||||
|
- --function_to_execute
|
||||||
|
- print_op
|
||||||
|
command:
|
||||||
|
- sh
|
||||||
|
- -c
|
||||||
|
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\
|
||||||
|
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\
|
||||||
|
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.0.0-beta.1'\
|
||||||
|
\ && \"$0\" \"$@\"\n"
|
||||||
|
- sh
|
||||||
|
- -ec
|
||||||
|
- 'program_path=$(mktemp -d)
|
||||||
|
|
||||||
|
printf "%s" "$0" > "$program_path/ephemeral_component.py"
|
||||||
|
|
||||||
|
python3 -m kfp.components.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@"
|
||||||
|
|
||||||
|
'
|
||||||
|
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\
|
||||||
|
\ *\n\ndef print_op(message: str):\n \"\"\"Prints a message.\"\"\"\n\
|
||||||
|
\ print(message)\n\n"
|
||||||
|
image: python:3.7
|
||||||
|
pipelineInfo:
|
||||||
|
name: pipeline-with-multiple-exit-handlers
|
||||||
|
root:
|
||||||
|
dag:
|
||||||
|
tasks:
|
||||||
|
exit-handler-1:
|
||||||
|
componentRef:
|
||||||
|
name: comp-exit-handler-1
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
pipelinechannel--message:
|
||||||
|
componentInputParameter: message
|
||||||
|
taskInfo:
|
||||||
|
name: exit-handler-1
|
||||||
|
exit-handler-2:
|
||||||
|
componentRef:
|
||||||
|
name: comp-exit-handler-2
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
pipelinechannel--message:
|
||||||
|
componentInputParameter: message
|
||||||
|
taskInfo:
|
||||||
|
name: exit-handler-2
|
||||||
|
exit-handler-3:
|
||||||
|
componentRef:
|
||||||
|
name: comp-exit-handler-3
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
pipelinechannel--message:
|
||||||
|
componentInputParameter: message
|
||||||
|
taskInfo:
|
||||||
|
name: exit-handler-3
|
||||||
|
print-op:
|
||||||
|
cachingOptions:
|
||||||
|
enableCache: true
|
||||||
|
componentRef:
|
||||||
|
name: comp-print-op
|
||||||
|
dependentTasks:
|
||||||
|
- exit-handler-1
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
message:
|
||||||
|
runtimeValue:
|
||||||
|
constant: First exit handler has worked!
|
||||||
|
taskInfo:
|
||||||
|
name: print-op
|
||||||
|
triggerPolicy:
|
||||||
|
strategy: ALL_UPSTREAM_TASKS_COMPLETED
|
||||||
|
print-op-3:
|
||||||
|
cachingOptions:
|
||||||
|
enableCache: true
|
||||||
|
componentRef:
|
||||||
|
name: comp-print-op-3
|
||||||
|
dependentTasks:
|
||||||
|
- exit-handler-2
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
message:
|
||||||
|
runtimeValue:
|
||||||
|
constant: Second exit handler has worked!
|
||||||
|
taskInfo:
|
||||||
|
name: print-op-3
|
||||||
|
triggerPolicy:
|
||||||
|
strategy: ALL_UPSTREAM_TASKS_COMPLETED
|
||||||
|
print-op-5:
|
||||||
|
cachingOptions:
|
||||||
|
enableCache: true
|
||||||
|
componentRef:
|
||||||
|
name: comp-print-op-5
|
||||||
|
dependentTasks:
|
||||||
|
- exit-handler-3
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
message:
|
||||||
|
runtimeValue:
|
||||||
|
constant: Third exit handler has worked!
|
||||||
|
taskInfo:
|
||||||
|
name: print-op-5
|
||||||
|
triggerPolicy:
|
||||||
|
strategy: ALL_UPSTREAM_TASKS_COMPLETED
|
||||||
|
inputDefinitions:
|
||||||
|
parameters:
|
||||||
|
message:
|
||||||
|
defaultValue: Hello World!
|
||||||
|
parameterType: STRING
|
||||||
|
schemaVersion: 2.1.0
|
||||||
|
sdkVersion: kfp-2.0.0-beta.1
|
|
@ -118,7 +118,9 @@ class Pipeline:
|
||||||
# Add the root group.
|
# Add the root group.
|
||||||
self.groups = [
|
self.groups = [
|
||||||
tasks_group.TasksGroup(
|
tasks_group.TasksGroup(
|
||||||
group_type=tasks_group.TasksGroupType.PIPELINE, name=name)
|
group_type=tasks_group.TasksGroupType.PIPELINE,
|
||||||
|
name=name,
|
||||||
|
is_root=True)
|
||||||
]
|
]
|
||||||
self._group_id = 0
|
self._group_id = 0
|
||||||
|
|
||||||
|
@ -174,6 +176,7 @@ class Pipeline:
|
||||||
|
|
||||||
self.tasks[task_name] = task
|
self.tasks[task_name] = task
|
||||||
if add_to_group:
|
if add_to_group:
|
||||||
|
task.parent_task_group = self.groups[-1]
|
||||||
self.groups[-1].tasks.append(task)
|
self.groups[-1].tasks.append(task)
|
||||||
|
|
||||||
return task_name
|
return task_name
|
||||||
|
|
|
@ -69,6 +69,10 @@ class PipelineTask:
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
):
|
):
|
||||||
"""Initilizes a PipelineTask instance."""
|
"""Initilizes a PipelineTask instance."""
|
||||||
|
# import within __init__ to avoid circular import
|
||||||
|
from kfp.components.tasks_group import TasksGroup
|
||||||
|
|
||||||
|
self.parent_task_group: Union[None, TasksGroup] = None
|
||||||
args = args or {}
|
args = args or {}
|
||||||
|
|
||||||
for input_name, argument_value in args.items():
|
for input_name, argument_value in args.items():
|
||||||
|
@ -558,5 +562,9 @@ class PipelineTask:
|
||||||
task2 = my_component(text='2nd task').after(task1)
|
task2 = my_component(text='2nd task').after(task1)
|
||||||
"""
|
"""
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
|
if task.parent_task_group is not self.parent_task_group:
|
||||||
|
raise ValueError(
|
||||||
|
f'Cannot use .after() across inner pipelines or DSL control flow features. Tried to set {self.name} after {task.name}, but these tasks do not belong to the same pipeline or are not enclosed in the same control flow content manager.'
|
||||||
|
)
|
||||||
self._task_spec.dependent_tasks.append(task.name)
|
self._task_spec.dependent_tasks.append(task.name)
|
||||||
return self
|
return self
|
||||||
|
|
|
@ -13,10 +13,14 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tests for kfp.components.pipeline_task."""
|
"""Tests for kfp.components.pipeline_task."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
import textwrap
|
import textwrap
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
from kfp import compiler
|
||||||
|
from kfp import dsl
|
||||||
from kfp.components import pipeline_task
|
from kfp.components import pipeline_task
|
||||||
from kfp.components import placeholders
|
from kfp.components import placeholders
|
||||||
from kfp.components import structures
|
from kfp.components import structures
|
||||||
|
@ -301,5 +305,136 @@ class PipelineTaskTest(parameterized.TestCase):
|
||||||
self.assertEqual('test_name', task._task_spec.display_name)
|
self.assertEqual('test_name', task._task_spec.display_name)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCannotUseAfterCrossDAG(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_inner_task_prevented(self):
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'Cannot use \.after\(\) across'):
|
||||||
|
|
||||||
|
@dsl.component
|
||||||
|
def print_op(message: str):
|
||||||
|
print(message)
|
||||||
|
|
||||||
|
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
|
||||||
|
def my_pipeline():
|
||||||
|
first_exit_task = print_op(message='First exit task.')
|
||||||
|
|
||||||
|
with dsl.ExitHandler(first_exit_task):
|
||||||
|
first_print_op = print_op(
|
||||||
|
message='Inside first exit handler.')
|
||||||
|
|
||||||
|
second_exit_task = print_op(message='Second exit task.')
|
||||||
|
with dsl.ExitHandler(second_exit_task):
|
||||||
|
print_op(message='Inside second exit handler.').after(
|
||||||
|
first_print_op)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
package_path = os.path.join(tempdir, 'pipeline.yaml')
|
||||||
|
compiler.Compiler().compile(
|
||||||
|
pipeline_func=my_pipeline, package_path=package_path)
|
||||||
|
|
||||||
|
def test_exit_handler_task_prevented(self):
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'Cannot use \.after\(\) across'):
|
||||||
|
|
||||||
|
@dsl.component
|
||||||
|
def print_op(message: str):
|
||||||
|
print(message)
|
||||||
|
|
||||||
|
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
|
||||||
|
def my_pipeline():
|
||||||
|
first_exit_task = print_op(message='First exit task.')
|
||||||
|
|
||||||
|
with dsl.ExitHandler(first_exit_task):
|
||||||
|
first_print_op = print_op(
|
||||||
|
message='Inside first exit handler.')
|
||||||
|
|
||||||
|
second_exit_task = print_op(message='Second exit task.')
|
||||||
|
with dsl.ExitHandler(second_exit_task):
|
||||||
|
x = print_op(message='Inside second exit handler.')
|
||||||
|
x.after(first_exit_task)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
package_path = os.path.join(tempdir, 'pipeline.yaml')
|
||||||
|
compiler.Compiler().compile(
|
||||||
|
pipeline_func=my_pipeline, package_path=package_path)
|
||||||
|
|
||||||
|
def test_within_same_exit_handler_permitted(self):
|
||||||
|
|
||||||
|
@dsl.component
|
||||||
|
def print_op(message: str):
|
||||||
|
print(message)
|
||||||
|
|
||||||
|
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
|
||||||
|
def my_pipeline():
|
||||||
|
first_exit_task = print_op(message='First exit task.')
|
||||||
|
|
||||||
|
with dsl.ExitHandler(first_exit_task):
|
||||||
|
first_print_op = print_op(
|
||||||
|
message='First task inside first exit handler.')
|
||||||
|
second_print_op = print_op(
|
||||||
|
message='Second task inside first exit handler.').after(
|
||||||
|
first_print_op)
|
||||||
|
|
||||||
|
second_exit_task = print_op(message='Second exit task.')
|
||||||
|
with dsl.ExitHandler(second_exit_task):
|
||||||
|
print_op(message='Inside second exit handler.')
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
package_path = os.path.join(tempdir, 'pipeline.yaml')
|
||||||
|
compiler.Compiler().compile(
|
||||||
|
pipeline_func=my_pipeline, package_path=package_path)
|
||||||
|
|
||||||
|
def test_outside_of_condition_blocked(self):
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'Cannot use \.after\(\) across'):
|
||||||
|
|
||||||
|
@dsl.component
|
||||||
|
def print_op(message: str):
|
||||||
|
print(message)
|
||||||
|
|
||||||
|
@dsl.component
|
||||||
|
def return_1() -> int:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
|
||||||
|
def my_pipeline():
|
||||||
|
return_1_task = return_1()
|
||||||
|
|
||||||
|
with dsl.Condition(return_1_task.output == 1):
|
||||||
|
one = print_op(message='1')
|
||||||
|
two = print_op(message='2')
|
||||||
|
three = print_op(message='3').after(one)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
package_path = os.path.join(tempdir, 'pipeline.yaml')
|
||||||
|
compiler.Compiler().compile(
|
||||||
|
pipeline_func=my_pipeline, package_path=package_path)
|
||||||
|
|
||||||
|
def test_inside_of_condition_permitted(self):
|
||||||
|
|
||||||
|
@dsl.component
|
||||||
|
def print_op(message: str):
|
||||||
|
print(message)
|
||||||
|
|
||||||
|
@dsl.component
|
||||||
|
def return_1() -> int:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
@dsl.pipeline(name='pipeline-with-multiple-exit-handlers')
|
||||||
|
def my_pipeline():
|
||||||
|
return_1_task = return_1()
|
||||||
|
|
||||||
|
with dsl.Condition(return_1_task.output == '1'):
|
||||||
|
one = print_op(message='1')
|
||||||
|
two = print_op(message='2').after(one)
|
||||||
|
three = print_op(message='3')
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
|
package_path = os.path.join(tempdir, 'pipeline.yaml')
|
||||||
|
compiler.Compiler().compile(
|
||||||
|
pipeline_func=my_pipeline, package_path=package_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -44,12 +44,14 @@ class TasksGroup:
|
||||||
groups: A list of TasksGroups in this group.
|
groups: A list of TasksGroups in this group.
|
||||||
display_name: The optional user given name of the group.
|
display_name: The optional user given name of the group.
|
||||||
dependencies: A list of tasks or groups this group depends on.
|
dependencies: A list of tasks or groups this group depends on.
|
||||||
|
is_root: If TasksGroup is root group.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
group_type: TasksGroupType,
|
group_type: TasksGroupType,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
|
is_root: bool = False,
|
||||||
):
|
):
|
||||||
"""Create a new instance of TasksGroup.
|
"""Create a new instance of TasksGroup.
|
||||||
|
|
||||||
|
@ -62,6 +64,7 @@ class TasksGroup:
|
||||||
self.groups = list()
|
self.groups = list()
|
||||||
self.display_name = name
|
self.display_name = name
|
||||||
self.dependencies = []
|
self.dependencies = []
|
||||||
|
self.is_root = is_root
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
if not pipeline_context.Pipeline.get_default_pipeline():
|
if not pipeline_context.Pipeline.get_default_pipeline():
|
||||||
|
@ -116,7 +119,11 @@ class ExitHandler(TasksGroup):
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Initializes a Condition task group."""
|
"""Initializes a Condition task group."""
|
||||||
super().__init__(group_type=TasksGroupType.EXIT_HANDLER, name=name)
|
super().__init__(
|
||||||
|
group_type=TasksGroupType.EXIT_HANDLER,
|
||||||
|
name=name,
|
||||||
|
is_root=False,
|
||||||
|
)
|
||||||
|
|
||||||
if exit_task.dependent_tasks:
|
if exit_task.dependent_tasks:
|
||||||
raise ValueError('exit_task cannot depend on any other tasks.')
|
raise ValueError('exit_task cannot depend on any other tasks.')
|
||||||
|
@ -151,6 +158,7 @@ class Condition(TasksGroup):
|
||||||
self,
|
self,
|
||||||
condition: pipeline_channel.ConditionOperator,
|
condition: pipeline_channel.ConditionOperator,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
|
is_root=False,
|
||||||
):
|
):
|
||||||
"""Initializes a conditional task group."""
|
"""Initializes a conditional task group."""
|
||||||
super().__init__(group_type=TasksGroupType.CONDITION, name=name)
|
super().__init__(group_type=TasksGroupType.CONDITION, name=name)
|
||||||
|
@ -182,7 +190,11 @@ class ParallelFor(TasksGroup):
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Initializes a for loop task group."""
|
"""Initializes a for loop task group."""
|
||||||
super().__init__(group_type=TasksGroupType.FOR_LOOP, name=name)
|
super().__init__(
|
||||||
|
group_type=TasksGroupType.FOR_LOOP,
|
||||||
|
name=name,
|
||||||
|
is_root=False,
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(items, pipeline_channel.PipelineChannel):
|
if isinstance(items, pipeline_channel.PipelineChannel):
|
||||||
self.loop_argument = for_loop.LoopArgument.from_pipeline_channel(
|
self.loop_argument = for_loop.LoopArgument.from_pipeline_channel(
|
||||||
|
|
Loading…
Reference in New Issue