diff --git a/sdk/python/kfp/compiler/compiler_test.py b/sdk/python/kfp/compiler/compiler_test.py index acc41c06c1..bdd0033925 100644 --- a/sdk/python/kfp/compiler/compiler_test.py +++ b/sdk/python/kfp/compiler/compiler_test.py @@ -838,6 +838,23 @@ implementation: with dsl.ParallelFor(items=single_param_task.output) as item: print_and_return(text=item) + def test_compile_parallel_for_with_param_and_depending_task(self): + + @dsl.component + def print_op(s: str): + print(s) + + @dsl.pipeline + def my_pipeline(param: str): + with dsl.ParallelFor(items=['a', 'b']) as item: + parallel_tasks = print_op(s=item) + print_op(s=param).after(parallel_tasks) + + with tempfile.TemporaryDirectory() as tempdir: + output_yaml = os.path.join(tempdir, 'result.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=output_yaml) + def test_cannot_compile_parallel_for_with_single_artifact(self): with self.assertRaisesRegex( diff --git a/sdk/python/kfp/compiler/compiler_utils.py b/sdk/python/kfp/compiler/compiler_utils.py index 9be45ac867..1924716f15 100644 --- a/sdk/python/kfp/compiler/compiler_utils.py +++ b/sdk/python/kfp/compiler/compiler_utils.py @@ -761,7 +761,9 @@ def get_dependencies( # then make this validation dsl.Collected-aware elif isinstance(upstream_parent_group, tasks_group.ParallelFor): upstream_tasks_that_downstream_consumers_from = [ - channel.task.name for channel in task._channel_inputs + channel.task.name + for channel in task._channel_inputs + if channel.task is not None ] has_data_exchange = upstream_task.name in upstream_tasks_that_downstream_consumers_from # don't raise for .after