fix(sdk): Resolves issue when using ParallelFor with param and depending tasks (#11903)
Signed-off-by: Mai Nakagawa <nakagawa.mai@gmail.com>
This commit is contained in:
parent
18bed6c70d
commit
ef94ccd734
|
@ -838,6 +838,23 @@ implementation:
|
|||
with dsl.ParallelFor(items=single_param_task.output) as item:
|
||||
print_and_return(text=item)
|
||||
|
||||
def test_compile_parallel_for_with_param_and_depending_task(self):
|
||||
|
||||
@dsl.component
|
||||
def print_op(s: str):
|
||||
print(s)
|
||||
|
||||
@dsl.pipeline
|
||||
def my_pipeline(param: str):
|
||||
with dsl.ParallelFor(items=['a', 'b']) as item:
|
||||
parallel_tasks = print_op(s=item)
|
||||
print_op(s=param).after(parallel_tasks)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
output_yaml = os.path.join(tempdir, 'result.yaml')
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=my_pipeline, package_path=output_yaml)
|
||||
|
||||
def test_cannot_compile_parallel_for_with_single_artifact(self):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
|
|
|
@ -761,7 +761,9 @@ def get_dependencies(
|
|||
# then make this validation dsl.Collected-aware
|
||||
elif isinstance(upstream_parent_group, tasks_group.ParallelFor):
|
||||
upstream_tasks_that_downstream_consumers_from = [
|
||||
channel.task.name for channel in task._channel_inputs
|
||||
channel.task.name
|
||||
for channel in task._channel_inputs
|
||||
if channel.task is not None
|
||||
]
|
||||
has_data_exchange = upstream_task.name in upstream_tasks_that_downstream_consumers_from
|
||||
# don't raise for .after
|
||||
|
|
Loading…
Reference in New Issue