fix(sdk.v2)!: Block task dependency referencing tasks inside a sibling condition or loop group. (#7050)

This commit is contained in:
Chen Sun 2021-12-13 11:08:37 -08:00 committed by GitHub
parent 85d74337d6
commit 6dfaeebd92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 64 additions and 1 deletions

View File

@ -749,7 +749,6 @@ class Compiler:
group2 = task2_groups[common_groups_len:]
return (group1, group2)
# TODO: revisit for dependency that breaks through DAGs.
def _get_dependencies(
self,
pipeline: pipeline_context.Pipeline,
@ -780,6 +779,10 @@ class Compiler:
dependent on G2. Basically dependency only exists in the first
uncommon ancesters in their ancesters chain. Only sibling
groups/tasks can have dependencies.
Raises:
RuntimeError: if a task depends on a task inside a condition or loop
group.
"""
dependencies = collections.defaultdict(set)
for task in pipeline.tasks.values():
@ -806,6 +809,19 @@ class Compiler:
task1=upstream_task,
task2=task,
)
# If a task depends on a condition group or a loop group, it
# must explicitly dependent on a task inside the group. This
# should not be allowed, because it leads to ambiguous
# expectations for runtime behaviors.
dependent_group = group_name_to_group.get(
upstream_groups[0], None)
if isinstance(dependent_group,
(tasks_group.Condition, tasks_group.ParallelFor)):
raise RuntimeError(
f'Task {task.name} cannot dependent on any task inside'
f' the group: {upstream_groups[0]}.')
dependencies[downstream_groups[0]].add(upstream_groups[0])
return dependencies

View File

@ -107,6 +107,7 @@ class CompilerTest(parameterized.TestCase):
args:
- {inputValue: generate_explanation}
""")
@dsl.pipeline(name='test-boolean-pipeline')
def simple_pipeline():
predict_op(generate_explanation=True)
@ -480,6 +481,52 @@ class CompilerTest(parameterized.TestCase):
with self.assertRaisesRegex(ValueError, 'Invalid pipeline name: '):
compiler.Compiler()._validate_pipeline_name('my_pipeline')
def test_invalid_after_dependency(self):
@dsl.component
def producer_op() -> str:
return 'a'
@dsl.component
def dummy_op(msg: str = ''):
pass
@dsl.pipeline(name='test-pipeline')
def my_pipeline(text: str):
with dsl.Condition(text == 'a'):
producer_task = producer_op()
dummy_op().after(producer_task)
with self.assertRaisesRegex(
RuntimeError,
'Task dummy-op cannot dependent on any task inside the group:'):
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path='result.json')
def test_invalid_data_dependency(self):
@dsl.component
def producer_op() -> str:
return 'a'
@dsl.component
def dummy_op(msg: str = ''):
pass
@dsl.pipeline(name='test-pipeline')
def my_pipeline(text: bool):
with dsl.ParallelFor(['a, b']):
producer_task = producer_op()
dummy_op(msg=producer_task.output)
with self.assertRaisesRegex(
RuntimeError,
'Task dummy-op cannot dependent on any task inside the group:'):
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path='result.json')
if __name__ == '__main__':
unittest.main()