diff --git a/sdk/python/kfp/v2/compiler/compiler.py b/sdk/python/kfp/v2/compiler/compiler.py index ecfbdc3dbc..8c6ba84728 100644 --- a/sdk/python/kfp/v2/compiler/compiler.py +++ b/sdk/python/kfp/v2/compiler/compiler.py @@ -749,7 +749,6 @@ class Compiler: group2 = task2_groups[common_groups_len:] return (group1, group2) - # TODO: revisit for dependency that breaks through DAGs. def _get_dependencies( self, pipeline: pipeline_context.Pipeline, @@ -780,6 +779,10 @@ class Compiler: dependent on G2. Basically dependency only exists in the first uncommon ancesters in their ancesters chain. Only sibling groups/tasks can have dependencies. + + Raises: + RuntimeError: if a task depends on a task inside a condition or loop + group. """ dependencies = collections.defaultdict(set) for task in pipeline.tasks.values(): @@ -806,6 +809,19 @@ class Compiler: task1=upstream_task, task2=task, ) + + # If a task depends on a condition group or a loop group, it + # must explicitly dependent on a task inside the group. This + # should not be allowed, because it leads to ambiguous + # expectations for runtime behaviors. + dependent_group = group_name_to_group.get( + upstream_groups[0], None) + if isinstance(dependent_group, + (tasks_group.Condition, tasks_group.ParallelFor)): + raise RuntimeError( + f'Task {task.name} cannot dependent on any task inside' + f' the group: {upstream_groups[0]}.') + dependencies[downstream_groups[0]].add(upstream_groups[0]) return dependencies diff --git a/sdk/python/kfp/v2/compiler/compiler_test.py b/sdk/python/kfp/v2/compiler/compiler_test.py index c5f8161d9a..ce7bc61562 100644 --- a/sdk/python/kfp/v2/compiler/compiler_test.py +++ b/sdk/python/kfp/v2/compiler/compiler_test.py @@ -107,6 +107,7 @@ class CompilerTest(parameterized.TestCase): args: - {inputValue: generate_explanation} """) + @dsl.pipeline(name='test-boolean-pipeline') def simple_pipeline(): predict_op(generate_explanation=True) @@ -480,6 +481,52 @@ class CompilerTest(parameterized.TestCase): with self.assertRaisesRegex(ValueError, 'Invalid pipeline name: '): compiler.Compiler()._validate_pipeline_name('my_pipeline') + def test_invalid_after_dependency(self): + + @dsl.component + def producer_op() -> str: + return 'a' + + @dsl.component + def dummy_op(msg: str = ''): + pass + + @dsl.pipeline(name='test-pipeline') + def my_pipeline(text: str): + with dsl.Condition(text == 'a'): + producer_task = producer_op() + + dummy_op().after(producer_task) + + with self.assertRaisesRegex( + RuntimeError, + 'Task dummy-op cannot dependent on any task inside the group:'): + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path='result.json') + + def test_invalid_data_dependency(self): + + @dsl.component + def producer_op() -> str: + return 'a' + + @dsl.component + def dummy_op(msg: str = ''): + pass + + @dsl.pipeline(name='test-pipeline') + def my_pipeline(text: bool): + with dsl.ParallelFor(['a, b']): + producer_task = producer_op() + + dummy_op(msg=producer_task.output) + + with self.assertRaisesRegex( + RuntimeError, + 'Task dummy-op cannot dependent on any task inside the group:'): + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path='result.json') + if __name__ == '__main__': unittest.main()