fix(sdk): Resolves issue when using ParallelFor with param and depending tasks (#11903)

Signed-off-by: Mai Nakagawa <nakagawa.mai@gmail.com>
This commit is contained in:
Mai Nakagawa 2025-05-10 00:49:41 +09:00 committed by GitHub
parent 18bed6c70d
commit ef94ccd734
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 1 deletions

View File

@ -838,6 +838,23 @@ implementation:
with dsl.ParallelFor(items=single_param_task.output) as item:
print_and_return(text=item)
def test_compile_parallel_for_with_param_and_depending_task(self):
@dsl.component
def print_op(s: str):
print(s)
@dsl.pipeline
def my_pipeline(param: str):
with dsl.ParallelFor(items=['a', 'b']) as item:
parallel_tasks = print_op(s=item)
print_op(s=param).after(parallel_tasks)
with tempfile.TemporaryDirectory() as tempdir:
output_yaml = os.path.join(tempdir, 'result.yaml')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=output_yaml)
def test_cannot_compile_parallel_for_with_single_artifact(self):
with self.assertRaisesRegex(

View File

@ -761,7 +761,9 @@ def get_dependencies(
# then make this validation dsl.Collected-aware
elif isinstance(upstream_parent_group, tasks_group.ParallelFor):
upstream_tasks_that_downstream_consumers_from = [
channel.task.name for channel in task._channel_inputs
channel.task.name
for channel in task._channel_inputs
if channel.task is not None
]
has_data_exchange = upstream_task.name in upstream_tasks_that_downstream_consumers_from
# don't raise for .after