fix(sdk): fix bug where `dsl.OneOf` with multiple consumers cannot be compiled (#10452)

This commit is contained in:
Connor McCarthy 2024-02-08 15:51:37 -08:00 committed by GitHub
parent d4c3f35797
commit 21c5ffebb0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 28 additions and 7 deletions

View File

@ -4,7 +4,8 @@
* Support local execution of sequential pipelines [\#10423](https://github.com/kubeflow/pipelines/pull/10423)
* Support local execution of `dsl.importer` components [\#10431](https://github.com/kubeflow/pipelines/pull/10431)
* Support local execution of pipelines in pipelines [\#10440](https://github.com/kubeflow/pipelines/pull/10440)
* Support dsl.ParallelFor over list of Artifacts [\#10441](https://github.com/kubeflow/pipelines/pull/10441)
* Support `dsl.ParallelFor` over list of Artifacts [\#10441](https://github.com/kubeflow/pipelines/pull/10441)
* Fix bug where `dsl.OneOf` with multiple consumers cannot be compiled [\#10452](https://github.com/kubeflow/pipelines/pull/10452)
## Breaking changes

View File

@ -4821,6 +4821,12 @@ class TestDslOneOf(unittest.TestCase):
x = dsl.OneOf(print_task_1.outputs['a'],
print_task_2.outputs['a'])
print_artifact(a=x)
# test can be consumed multiple times from same oneof object
print_artifact(a=x)
y = dsl.OneOf(print_task_1.outputs['a'],
print_task_2.outputs['a'])
# test can be consumed multiple times from different equivalent oneof objects
print_artifact(a=y)
# hole punched through if
self.assertEqual(

View File

@ -522,6 +522,15 @@ def get_outputs_for_all_groups(
break
elif isinstance(channel, pipeline_channel.OneOfMixin):
if channel in processed_oneofs:
continue
# we want to mutate the oneof's inner channels ONLY where they
# are used in the oneof, not if they are used separately
# for example: we should only modify the copy of
# foo.output in dsl.OneOf(foo.output), not if foo.output is
# passed to another downstream task
channel.channels = [copy.copy(c) for c in channel.channels]
for inner_channel in channel.channels:
producer_task = pipeline.tasks[inner_channel.task_name]
consumer_task = task
@ -548,9 +557,8 @@ def get_outputs_for_all_groups(
outputs[upstream_name][channel.name] = channel
break
# copy so we can update the inner channel for the next iteration
# use copy not deepcopy, since deepcopy will needlessly copy the entire pipeline
# this uses more memory than needed and some objects are uncopiable
# copy as a mechanism for "freezing" the inner channel
# before we make updates for the next iteration
outputs[upstream_name][
surfaced_output_name] = copy.copy(inner_channel)
@ -596,6 +604,13 @@ def get_outputs_for_all_groups(
# if the output has already been consumed by a task before it is returned, we don't need to reprocess it
if channel in processed_oneofs:
continue
# we want to mutate the oneof's inner channels ONLY where they
# are used in the oneof, not if they are used separately
# for example: we should only modify the copy of
# foo.output in dsl.OneOf(foo.output), not if foo.output is passed
# to another downstream task
channel.channels = [copy.copy(c) for c in channel.channels]
for inner_channel in channel.channels:
producer_task = pipeline.tasks[inner_channel.task_name]
upstream_groups = task_name_to_parent_groups[
@ -615,9 +630,8 @@ def get_outputs_for_all_groups(
outputs[upstream_name][channel.name] = channel
break
# copy so we can update the inner channel for the next iteration
# use copy not deepcopy, since deepcopy will needlessly copy the entire pipeline
# this uses more memory than needed and some objects are uncopiable
# copy as a mechanism for "freezing" the inner channel
# before we make updates for the next iteration
outputs[upstream_name][surfaced_output_name] = copy.copy(
inner_channel)