fix(sdk): fix bug when `dsl.importer` argument is provided by loop variable (#10116)

This commit is contained in:
Connor McCarthy 2023-10-18 13:37:56 -07:00 committed by GitHub
parent faba9223ee
commit 73d51c8a23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 80 additions and 43 deletions

View File

@ -9,6 +9,7 @@
## Bug fixes and other changes
* Fix type on `dsl.ParallelFor` sub-DAG output when a `dsl.Collected` is used. Non-functional fix. [\#10069](https://github.com/kubeflow/pipelines/pull/10069)
* Fix bug when `dsl.importer` argument is provided by a `dsl.ParallelFor` loop variable. [\#10116](https://github.com/kubeflow/pipelines/pull/10116)
## Documentation updates

View File

@ -128,11 +128,35 @@ def build_task_spec_for_task(
task._task_spec.retry_policy.to_proto())
for input_name, input_value in task.inputs.items():
# since LoopArgument and LoopArgumentVariable are narrower types than PipelineParameterChannel, start with it
if isinstance(input_value, for_loop.LoopArgument):
if isinstance(input_value,
pipeline_channel.PipelineArtifactChannel) or (
isinstance(input_value, for_loop.Collected) and
input_value.is_artifact_channel):
component_input_parameter = (
compiler_utils.additional_input_name_for_pipeline_channel(
input_value))
assert component_input_parameter in parent_component_inputs.parameters, \
f'component_input_parameter: {component_input_parameter} not found. All inputs: {parent_component_inputs}'
pipeline_task_spec.inputs.parameters[
input_name].component_input_parameter = (
component_input_parameter)
elif isinstance(input_value, for_loop.LoopArgumentVariable):
component_input_parameter = (
compiler_utils.additional_input_name_for_pipeline_channel(
input_value.loop_argument))
assert component_input_parameter in parent_component_inputs.parameters, \
f'component_input_parameter: {component_input_parameter} not found. All inputs: {parent_component_inputs}'
pipeline_task_spec.inputs.parameters[
input_name].component_input_parameter = (
component_input_parameter)
pipeline_task_spec.inputs.parameters[
input_name].parameter_expression_selector = (
f'parseJson(string_value)["{input_value.subvar_name}"]')
elif isinstance(input_value,
pipeline_channel.PipelineArtifactChannel) or (
isinstance(input_value, for_loop.Collected) and
input_value.is_artifact_channel):
if input_value.task_name:
# Value is produced by an upstream task.
@ -200,31 +224,6 @@ def build_task_spec_for_task(
input_name].component_input_parameter = (
component_input_parameter)
elif isinstance(input_value, for_loop.LoopArgument):
component_input_parameter = (
compiler_utils.additional_input_name_for_pipeline_channel(
input_value))
assert component_input_parameter in parent_component_inputs.parameters, \
f'component_input_parameter: {component_input_parameter} not found. All inputs: {parent_component_inputs}'
pipeline_task_spec.inputs.parameters[
input_name].component_input_parameter = (
component_input_parameter)
elif isinstance(input_value, for_loop.LoopArgumentVariable):
component_input_parameter = (
compiler_utils.additional_input_name_for_pipeline_channel(
input_value.loop_argument))
assert component_input_parameter in parent_component_inputs.parameters, \
f'component_input_parameter: {component_input_parameter} not found. All inputs: {parent_component_inputs}'
pipeline_task_spec.inputs.parameters[
input_name].component_input_parameter = (
component_input_parameter)
pipeline_task_spec.inputs.parameters[
input_name].parameter_expression_selector = (
f'parseJson(string_value)["{input_value.subvar_name}"]')
elif isinstance(input_value, str):
# Handle extra input due to string concat
pipeline_channels = (
@ -572,7 +571,7 @@ def build_importer_spec_for_task(
importer_spec.metadata.CopyFrom(metadata_protobuf_struct)
if isinstance(task.importer_spec.artifact_uri,
pipeline_channel.PipelineParameterChannel):
pipeline_channel.PipelineChannel):
importer_spec.artifact_uri.runtime_parameter = 'uri'
elif isinstance(task.importer_spec.artifact_uri, str):
importer_spec.artifact_uri.constant.string_value = task.importer_spec.artifact_uri

View File

@ -77,7 +77,7 @@ class LoopArgument(pipeline_channel.PipelineParameterChannel):
Attributes:
items_or_pipeline_channel: The raw items or the PipelineChannel object
items_or_pipeline_channel: The raw items or the PipelineParameterChannel object
this LoopArgument is associated to.
"""
LOOP_ITEM_NAME_BASE = 'loop-item'
@ -85,7 +85,7 @@ class LoopArgument(pipeline_channel.PipelineParameterChannel):
def __init__(
self,
items: Union[ItemList, pipeline_channel.PipelineChannel],
items: Union[ItemList, pipeline_channel.PipelineParameterChannel],
name_code: Optional[str] = None,
name_override: Optional[str] = None,
**kwargs,
@ -99,8 +99,8 @@ class LoopArgument(pipeline_channel.PipelineParameterChannel):
name_code: A unique code used to identify these loop arguments.
Should match the code for the ParallelFor ops_group which created
these LoopArguments. This prevents parameter name collisions.
name_override: The override name for PipelineChannel.
**kwargs: Any other keyword arguments passed down to PipelineChannel.
name_override: The override name for PipelineParameterChannel.
**kwargs: Any other keyword arguments passed down to PipelineParameterChannel.
"""
if (name_code is None) == (name_override is None):
raise ValueError(
@ -112,17 +112,19 @@ class LoopArgument(pipeline_channel.PipelineParameterChannel):
else:
super().__init__(name=name_override, **kwargs)
if not isinstance(items,
(list, tuple, pipeline_channel.PipelineChannel)):
if not isinstance(
items,
(list, tuple, pipeline_channel.PipelineParameterChannel)):
raise TypeError(
f'Expected list, tuple, or PipelineChannel, got {items}.')
f'Expected list, tuple, or PipelineParameterChannel, got {items}.'
)
if isinstance(items, tuple):
items = list(items)
self.items_or_pipeline_channel = items
self.is_with_items_loop_argument = not isinstance(
items, pipeline_channel.PipelineChannel)
items, pipeline_channel.PipelineParameterChannel)
self._referenced_subvars: Dict[str, LoopArgumentVariable] = {}
if isinstance(items, list) and isinstance(items[0], dict):
@ -154,9 +156,10 @@ class LoopArgument(pipeline_channel.PipelineParameterChannel):
@classmethod
def from_pipeline_channel(
cls,
channel: pipeline_channel.PipelineChannel,
channel: pipeline_channel.PipelineParameterChannel,
) -> 'LoopArgument':
"""Creates a LoopArgument object from a PipelineChannel object."""
"""Creates a LoopArgument object from a PipelineParameterChannel
object."""
return LoopArgument(
items=channel,
name_override=channel.name + '-' + cls.LOOP_ITEM_NAME_BASE,
@ -191,7 +194,7 @@ class LoopArgument(pipeline_channel.PipelineParameterChannel):
or (cls.LOOP_ITEM_PARAM_NAME_BASE + '-') in name
class LoopArgumentVariable(pipeline_channel.PipelineChannel):
class LoopArgumentVariable(pipeline_channel.PipelineParameterChannel):
"""Represents a subvariable for a loop argument.
This is used for cases where we're looping over maps, each of which contains
@ -246,7 +249,7 @@ class LoopArgumentVariable(pipeline_channel.PipelineChannel):
@property
def items_or_pipeline_channel(
self) -> Union[ItemList, pipeline_channel.PipelineChannel]:
self) -> Union[ItemList, pipeline_channel.PipelineParameterChannel]:
"""Returns the loop argument items."""
return self.loop_argument.items_or_pipeline_chanenl

View File

@ -14,8 +14,8 @@
import unittest
from kfp import dsl
from kfp.dsl import Dataset
from kfp.dsl import importer_node
from kfp.dsl.types.artifact_types import Dataset
class TestImporterSupportsDynamicMetadata(unittest.TestCase):
@ -184,3 +184,37 @@ class TestImporterSupportsDynamicMetadata(unittest.TestCase):
"prefix2-{{$.inputs.parameters[\'metadata-2\']}}")
self.assertEqual(metadata.struct_value.fields['key'].string_value,
'value')
def test_uri_from_loop(self):
@dsl.component
def make_args() -> list:
return [{'uri': 'gs://foo', 'key': 'foo'}]
@dsl.pipeline
def my_pipeline():
with dsl.ParallelFor(make_args().output) as data:
dsl.importer(
artifact_uri=data.uri,
artifact_class=Dataset,
metadata={'metadata_key': data.key})
self.assertEqual(
my_pipeline.pipeline_spec.deployment_spec['executors']
['exec-importer']['importer']['artifactUri']['runtimeParameter'],
'uri')
self.assertEqual(
my_pipeline.pipeline_spec.deployment_spec['executors']
['exec-importer']['importer']['metadata']['metadata_key'],
"{{$.inputs.parameters[\'metadata\']}}")
self.assertEqual(
my_pipeline.pipeline_spec.components['comp-for-loop-1'].dag
.tasks['importer'].inputs.parameters['metadata']
.component_input_parameter,
'pipelinechannel--make-args-Output-loop-item')
self.assertEqual(
my_pipeline.pipeline_spec.components['comp-for-loop-1'].dag
.tasks['importer'].inputs.parameters['metadata']
.parameter_expression_selector,
'parseJson(string_value)["key"]',
)