fix(sdk.v2)!: Block task dependency referencing tasks inside a sibling condition or loop group. (#7050)
This commit is contained in:
parent
85d74337d6
commit
6dfaeebd92
|
@ -749,7 +749,6 @@ class Compiler:
|
|||
group2 = task2_groups[common_groups_len:]
|
||||
return (group1, group2)
|
||||
|
||||
# TODO: revisit for dependency that breaks through DAGs.
|
||||
def _get_dependencies(
|
||||
self,
|
||||
pipeline: pipeline_context.Pipeline,
|
||||
|
@ -780,6 +779,10 @@ class Compiler:
|
|||
dependent on G2. Basically dependency only exists in the first
|
||||
uncommon ancesters in their ancesters chain. Only sibling
|
||||
groups/tasks can have dependencies.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if a task depends on a task inside a condition or loop
|
||||
group.
|
||||
"""
|
||||
dependencies = collections.defaultdict(set)
|
||||
for task in pipeline.tasks.values():
|
||||
|
@ -806,6 +809,19 @@ class Compiler:
|
|||
task1=upstream_task,
|
||||
task2=task,
|
||||
)
|
||||
|
||||
# If a task depends on a condition group or a loop group, it
|
||||
# must explicitly dependent on a task inside the group. This
|
||||
# should not be allowed, because it leads to ambiguous
|
||||
# expectations for runtime behaviors.
|
||||
dependent_group = group_name_to_group.get(
|
||||
upstream_groups[0], None)
|
||||
if isinstance(dependent_group,
|
||||
(tasks_group.Condition, tasks_group.ParallelFor)):
|
||||
raise RuntimeError(
|
||||
f'Task {task.name} cannot dependent on any task inside'
|
||||
f' the group: {upstream_groups[0]}.')
|
||||
|
||||
dependencies[downstream_groups[0]].add(upstream_groups[0])
|
||||
|
||||
return dependencies
|
||||
|
|
|
@ -107,6 +107,7 @@ class CompilerTest(parameterized.TestCase):
|
|||
args:
|
||||
- {inputValue: generate_explanation}
|
||||
""")
|
||||
|
||||
@dsl.pipeline(name='test-boolean-pipeline')
|
||||
def simple_pipeline():
|
||||
predict_op(generate_explanation=True)
|
||||
|
@ -480,6 +481,52 @@ class CompilerTest(parameterized.TestCase):
|
|||
with self.assertRaisesRegex(ValueError, 'Invalid pipeline name: '):
|
||||
compiler.Compiler()._validate_pipeline_name('my_pipeline')
|
||||
|
||||
def test_invalid_after_dependency(self):
|
||||
|
||||
@dsl.component
|
||||
def producer_op() -> str:
|
||||
return 'a'
|
||||
|
||||
@dsl.component
|
||||
def dummy_op(msg: str = ''):
|
||||
pass
|
||||
|
||||
@dsl.pipeline(name='test-pipeline')
|
||||
def my_pipeline(text: str):
|
||||
with dsl.Condition(text == 'a'):
|
||||
producer_task = producer_op()
|
||||
|
||||
dummy_op().after(producer_task)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
'Task dummy-op cannot dependent on any task inside the group:'):
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=my_pipeline, package_path='result.json')
|
||||
|
||||
def test_invalid_data_dependency(self):
|
||||
|
||||
@dsl.component
|
||||
def producer_op() -> str:
|
||||
return 'a'
|
||||
|
||||
@dsl.component
|
||||
def dummy_op(msg: str = ''):
|
||||
pass
|
||||
|
||||
@dsl.pipeline(name='test-pipeline')
|
||||
def my_pipeline(text: bool):
|
||||
with dsl.ParallelFor(['a, b']):
|
||||
producer_task = producer_op()
|
||||
|
||||
dummy_op(msg=producer_task.output)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
'Task dummy-op cannot dependent on any task inside the group:'):
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=my_pipeline, package_path='result.json')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue