feat(sdk): support `PipelineTaskFinalStatus` in tasks that use `.ignore_upstream_failure()` (#10010)

* support taskfinalstatus in tasks that ignore upstream failure

* address review feedback
This commit is contained in:
Connor McCarthy 2023-09-20 18:46:37 -07:00 committed by GitHub
parent adb86777a0
commit e99f2704fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 182 additions and 6 deletions

View File

@ -2,6 +2,7 @@
## Features
* Support `PipelineTaskFinalStatus` in tasks that use `.ignore_upstream_failure()` [\#10010](https://github.com/kubeflow/pipelines/pull/10010)
## Breaking changes

View File

@ -3127,6 +3127,141 @@ class TestValidIgnoreUpstreamTaskSyntax(unittest.TestCase):
my_pipeline.pipeline_spec.root.dag.tasks['fail-op'].trigger_policy
.strategy, 0)
def test_can_use_task_final_status(self):
@dsl.component
def worker_component() -> str:
return 'hello'
@dsl.component
def cancel_handler(
status: PipelineTaskFinalStatus,
text: str = '',
):
print(text)
print(status)
@dsl.pipeline
def my_pipeline():
worker_task = worker_component()
exit_task = cancel_handler(
text=worker_task.output).ignore_upstream_failure()
self.assertEqual(
my_pipeline.pipeline_spec.root.dag.tasks['cancel-handler']
.trigger_policy.strategy, 2)
self.assertEqual(
my_pipeline.pipeline_spec.root.dag.tasks['cancel-handler'].inputs
.parameters['status'].task_final_status.producer_task,
'worker-component')
status_param = my_pipeline.pipeline_spec.components[
'comp-cancel-handler'].input_definitions.parameters['status']
self.assertTrue(status_param.is_optional)
self.assertEqual(status_param.parameter_type,
type_utils.TASK_FINAL_STATUS)
self.assertEqual(
my_pipeline.pipeline_spec.root.dag.tasks['worker-component']
.trigger_policy.strategy, 0)
def test_cannot_use_task_final_status_under_task_group(self):
@dsl.component
def worker_component() -> str:
return 'hello'
@dsl.component
def cancel_handler(
status: PipelineTaskFinalStatus,
text: str = '',
):
print(text)
print(status)
with self.assertRaisesRegex(
compiler_utils.InvalidTopologyException,
r"Tasks that use '\.ignore_upstream_failure\(\)' and 'PipelineTaskFinalStatus' must have exactly one dependent upstream task within the same control flow scope\. Got task 'cancel-handler' beneath a 'dsl\.Condition' that does not also contain the upstream dependent task\.",
):
@dsl.pipeline
def my_pipeline():
worker_task = worker_component()
with dsl.Condition(worker_task.output == 'foo'):
exit_task = cancel_handler(
text=worker_task.output).ignore_upstream_failure()
def test_cannot_use_final_task_status_if_zero_dependencies(self):
@dsl.component
def worker_component() -> str:
return 'hello'
@dsl.component
def cancel_handler(status: PipelineTaskFinalStatus,):
print(status)
with self.assertRaisesRegex(
compiler_utils.InvalidTopologyException,
r"Tasks that use '\.ignore_upstream_failure\(\)' and 'PipelineTaskFinalStatus' must have exactly one dependent upstream task\. Got task 'cancel-handler with no upstream dependencies\.",
):
@dsl.pipeline
def my_pipeline():
worker_task = worker_component()
exit_task = cancel_handler().ignore_upstream_failure()
def test_cannot_use_task_final_status_if_more_than_one_dependency_implicit(
self):
@dsl.component
def worker_component() -> str:
return 'hello'
@dsl.component
def cancel_handler(
status: PipelineTaskFinalStatus,
a: str = '',
b: str = '',
):
print(status)
with self.assertRaisesRegex(
compiler_utils.InvalidTopologyException,
r"Tasks that use '\.ignore_upstream_failure\(\)' and 'PipelineTaskFinalStatus' must have exactly one dependent upstream task\. Got 2 dependent tasks: \['worker-component', 'worker-component-2']\.",
):
@dsl.pipeline
def my_pipeline():
worker_task1 = worker_component()
worker_task2 = worker_component()
exit_task = cancel_handler(
a=worker_task1.output,
b=worker_task2.output).ignore_upstream_failure()
def test_cannot_use_task_final_status_if_more_than_one_dependency_explicit(
self):
@dsl.component
def worker_component() -> str:
return 'hello'
@dsl.component
def cancel_handler(status: PipelineTaskFinalStatus,):
print(status)
with self.assertRaisesRegex(
compiler_utils.InvalidTopologyException,
r"Tasks that use '\.ignore_upstream_failure\(\)' and 'PipelineTaskFinalStatus' must have exactly one dependent upstream task\. Got 2 dependent tasks: \['worker-component', 'worker-component-2']\.",
):
@dsl.pipeline
def my_pipeline():
worker_task1 = worker_component()
worker_task2 = worker_component()
exit_task = cancel_handler().after(
worker_task1, worker_task2).ignore_upstream_failure()
def test_component_with_no_input_permitted(self):
@dsl.component

View File

@ -301,11 +301,6 @@ def build_task_spec_for_task(
'str, int, float, bool, dict, and list.'
f'Got {input_value} of type {type(input_value)}.')
if task._ignore_upstream_failure_tag:
pipeline_task_spec.trigger_policy.strategy = (
pipeline_spec_pb2.PipelineTaskSpec.TriggerPolicy.TriggerStrategy
.ALL_UPSTREAM_TASKS_COMPLETED)
return pipeline_task_spec
@ -339,7 +334,8 @@ def build_component_spec_for_task(
"""
for input_name, input_spec in (task.component_spec.inputs or {}).items():
if not is_exit_task and type_utils.is_task_final_status_type(
input_spec.type) and not is_compiled_component:
input_spec.type
) and not is_compiled_component and not task._ignore_upstream_failure_tag:
raise ValueError(
f'PipelineTaskFinalStatus can only be used in an exit task. Parameter {input_name} of a non exit task has type PipelineTaskFinalStatus.'
)
@ -1302,6 +1298,11 @@ def build_spec_by_group(
subgroup_task_spec.dependent_tasks.extend(
[utils.sanitize_task_name(dep) for dep in group_dependencies])
# Modify the task inputs for PipelineTaskFinalStatus if ignore_upstream_failure is used
# Must be done after dependencies are added
if isinstance(subgroup, pipeline_task.PipelineTask):
modify_task_for_ignore_upstream_failure(
task=subgroup, pipeline_task_spec=subgroup_task_spec)
# Add component spec
subgroup_component_name = utils.make_name_unique_by_adding_index(
name=subgroup_component_name,
@ -1328,6 +1329,42 @@ def build_spec_by_group(
)
def modify_task_for_ignore_upstream_failure(
task: pipeline_task.PipelineTask,
pipeline_task_spec: pipeline_spec_pb2.PipelineTaskSpec,
):
if task._ignore_upstream_failure_tag:
pipeline_task_spec.trigger_policy.strategy = (
pipeline_spec_pb2.PipelineTaskSpec.TriggerPolicy.TriggerStrategy
.ALL_UPSTREAM_TASKS_COMPLETED)
for input_name, input_spec in (task.component_spec.inputs or
{}).items():
if not type_utils.is_task_final_status_type(input_spec.type):
continue
if len(pipeline_task_spec.dependent_tasks) == 0:
if task.parent_task_group.group_type == tasks_group.TasksGroupType.PIPELINE:
raise compiler_utils.InvalidTopologyException(
f"Tasks that use '.ignore_upstream_failure()' and 'PipelineTaskFinalStatus' must have exactly one dependent upstream task. Got task '{pipeline_task_spec.task_info.name} with no upstream dependencies."
)
else:
# TODO: permit additional PipelineTaskFinalStatus flexibility by "punching the hole" through Condition and ParallelFor groups
raise compiler_utils.InvalidTopologyException(
f"Tasks that use '.ignore_upstream_failure()' and 'PipelineTaskFinalStatus' must have exactly one dependent upstream task within the same control flow scope. Got task '{pipeline_task_spec.task_info.name}' beneath a 'dsl.{group_type_to_dsl_class[task.parent_task_group.group_type].__name__}' that does not also contain the upstream dependent task."
)
# if >1 dependent task, ambiguous to which upstream task the PipelineTaskFinalStatus should correspond, since there is no ExitHandler that bundles these together
if len(pipeline_task_spec.dependent_tasks) > 1:
raise compiler_utils.InvalidTopologyException(
f"Tasks that use '.ignore_upstream_failure()' and 'PipelineTaskFinalStatus' must have exactly one dependent upstream task. Got {len(pipeline_task_spec.dependent_tasks)} dependent tasks: {pipeline_task_spec.dependent_tasks}."
)
pipeline_task_spec.inputs.parameters[
input_name].task_final_status.producer_task = pipeline_task_spec.dependent_tasks[
0]
def platform_config_to_platform_spec(
platform_config: dict,
executor_label: str,

View File

@ -604,6 +604,8 @@ class PipelineTask:
for input_spec_name, input_spec in (self.component_spec.inputs or
{}).items():
if type_utils.is_task_final_status_type(input_spec.type):
continue
argument_value = self._inputs[input_spec_name]
if (isinstance(argument_value, pipeline_channel.PipelineChannel)
) and (not input_spec.optional) and (argument_value.task_name

View File

@ -53,6 +53,7 @@ STRING = 3
BOOLEAN = 4
LIST = 5
STRUCT = 6
TASK_FINAL_STATUS = 7
PARAMETER_TYPES_MAPPING = {
'integer': NUMBER_INTEGER,
'int': NUMBER_INTEGER,