fix(sdk): fix bug where `dsl.OneOf` with multiple consumers cannot be compiled (#10452)
This commit is contained in:
parent
d4c3f35797
commit
21c5ffebb0
|
|
@ -4,7 +4,8 @@
|
|||
* Support local execution of sequential pipelines [\#10423](https://github.com/kubeflow/pipelines/pull/10423)
|
||||
* Support local execution of `dsl.importer` components [\#10431](https://github.com/kubeflow/pipelines/pull/10431)
|
||||
* Support local execution of pipelines in pipelines [\#10440](https://github.com/kubeflow/pipelines/pull/10440)
|
||||
* Support dsl.ParallelFor over list of Artifacts [\#10441](https://github.com/kubeflow/pipelines/pull/10441)
|
||||
* Support `dsl.ParallelFor` over list of Artifacts [\#10441](https://github.com/kubeflow/pipelines/pull/10441)
|
||||
* Fix bug where `dsl.OneOf` with multiple consumers cannot be compiled [\#10452](https://github.com/kubeflow/pipelines/pull/10452)
|
||||
|
||||
## Breaking changes
|
||||
|
||||
|
|
|
|||
|
|
@ -4821,6 +4821,12 @@ class TestDslOneOf(unittest.TestCase):
|
|||
x = dsl.OneOf(print_task_1.outputs['a'],
|
||||
print_task_2.outputs['a'])
|
||||
print_artifact(a=x)
|
||||
# test can be consumed multiple times from same oneof object
|
||||
print_artifact(a=x)
|
||||
y = dsl.OneOf(print_task_1.outputs['a'],
|
||||
print_task_2.outputs['a'])
|
||||
# test can be consumed multiple times from different equivalent oneof objects
|
||||
print_artifact(a=y)
|
||||
|
||||
# hole punched through if
|
||||
self.assertEqual(
|
||||
|
|
|
|||
|
|
@ -522,6 +522,15 @@ def get_outputs_for_all_groups(
|
|||
break
|
||||
|
||||
elif isinstance(channel, pipeline_channel.OneOfMixin):
|
||||
if channel in processed_oneofs:
|
||||
continue
|
||||
|
||||
# we want to mutate the oneof's inner channels ONLY where they
|
||||
# are used in the oneof, not if they are used separately
|
||||
# for example: we should only modify the copy of
|
||||
# foo.output in dsl.OneOf(foo.output), not if foo.output is
|
||||
# passed to another downstream task
|
||||
channel.channels = [copy.copy(c) for c in channel.channels]
|
||||
for inner_channel in channel.channels:
|
||||
producer_task = pipeline.tasks[inner_channel.task_name]
|
||||
consumer_task = task
|
||||
|
|
@ -548,9 +557,8 @@ def get_outputs_for_all_groups(
|
|||
outputs[upstream_name][channel.name] = channel
|
||||
break
|
||||
|
||||
# copy so we can update the inner channel for the next iteration
|
||||
# use copy not deepcopy, since deepcopy will needlessly copy the entire pipeline
|
||||
# this uses more memory than needed and some objects are uncopiable
|
||||
# copy as a mechanism for "freezing" the inner channel
|
||||
# before we make updates for the next iteration
|
||||
outputs[upstream_name][
|
||||
surfaced_output_name] = copy.copy(inner_channel)
|
||||
|
||||
|
|
@ -596,6 +604,13 @@ def get_outputs_for_all_groups(
|
|||
# if the output has already been consumed by a task before it is returned, we don't need to reprocess it
|
||||
if channel in processed_oneofs:
|
||||
continue
|
||||
|
||||
# we want to mutate the oneof's inner channels ONLY where they
|
||||
# are used in the oneof, not if they are used separately
|
||||
# for example: we should only modify the copy of
|
||||
# foo.output in dsl.OneOf(foo.output), not if foo.output is passed
|
||||
# to another downstream task
|
||||
channel.channels = [copy.copy(c) for c in channel.channels]
|
||||
for inner_channel in channel.channels:
|
||||
producer_task = pipeline.tasks[inner_channel.task_name]
|
||||
upstream_groups = task_name_to_parent_groups[
|
||||
|
|
@ -615,9 +630,8 @@ def get_outputs_for_all_groups(
|
|||
outputs[upstream_name][channel.name] = channel
|
||||
break
|
||||
|
||||
# copy so we can update the inner channel for the next iteration
|
||||
# use copy not deepcopy, since deepcopy will needlessly copy the entire pipeline
|
||||
# this uses more memory than needed and some objects are uncopiable
|
||||
# copy as a mechanism for "freezing" the inner channel
|
||||
# before we make updates for the next iteration
|
||||
outputs[upstream_name][surfaced_output_name] = copy.copy(
|
||||
inner_channel)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue