fix(sdk): fix bug when `dsl.importer` argument is provided by loop variable (#10116)
This commit is contained in:
parent
faba9223ee
commit
73d51c8a23
|
|
@ -9,6 +9,7 @@
|
|||
|
||||
## Bug fixes and other changes
|
||||
* Fix type on `dsl.ParallelFor` sub-DAG output when a `dsl.Collected` is used. Non-functional fix. [\#10069](https://github.com/kubeflow/pipelines/pull/10069)
|
||||
* Fix bug when `dsl.importer` argument is provided by a `dsl.ParallelFor` loop variable. [\#10116](https://github.com/kubeflow/pipelines/pull/10116)
|
||||
|
||||
## Documentation updates
|
||||
|
||||
|
|
|
|||
|
|
@ -128,11 +128,35 @@ def build_task_spec_for_task(
|
|||
task._task_spec.retry_policy.to_proto())
|
||||
|
||||
for input_name, input_value in task.inputs.items():
|
||||
# since LoopArgument and LoopArgumentVariable are narrower types than PipelineParameterChannel, start with it
|
||||
if isinstance(input_value, for_loop.LoopArgument):
|
||||
|
||||
if isinstance(input_value,
|
||||
pipeline_channel.PipelineArtifactChannel) or (
|
||||
isinstance(input_value, for_loop.Collected) and
|
||||
input_value.is_artifact_channel):
|
||||
component_input_parameter = (
|
||||
compiler_utils.additional_input_name_for_pipeline_channel(
|
||||
input_value))
|
||||
assert component_input_parameter in parent_component_inputs.parameters, \
|
||||
f'component_input_parameter: {component_input_parameter} not found. All inputs: {parent_component_inputs}'
|
||||
pipeline_task_spec.inputs.parameters[
|
||||
input_name].component_input_parameter = (
|
||||
component_input_parameter)
|
||||
|
||||
elif isinstance(input_value, for_loop.LoopArgumentVariable):
|
||||
|
||||
component_input_parameter = (
|
||||
compiler_utils.additional_input_name_for_pipeline_channel(
|
||||
input_value.loop_argument))
|
||||
assert component_input_parameter in parent_component_inputs.parameters, \
|
||||
f'component_input_parameter: {component_input_parameter} not found. All inputs: {parent_component_inputs}'
|
||||
pipeline_task_spec.inputs.parameters[
|
||||
input_name].component_input_parameter = (
|
||||
component_input_parameter)
|
||||
pipeline_task_spec.inputs.parameters[
|
||||
input_name].parameter_expression_selector = (
|
||||
f'parseJson(string_value)["{input_value.subvar_name}"]')
|
||||
elif isinstance(input_value,
|
||||
pipeline_channel.PipelineArtifactChannel) or (
|
||||
isinstance(input_value, for_loop.Collected) and
|
||||
input_value.is_artifact_channel):
|
||||
|
||||
if input_value.task_name:
|
||||
# Value is produced by an upstream task.
|
||||
|
|
@ -200,31 +224,6 @@ def build_task_spec_for_task(
|
|||
input_name].component_input_parameter = (
|
||||
component_input_parameter)
|
||||
|
||||
elif isinstance(input_value, for_loop.LoopArgument):
|
||||
|
||||
component_input_parameter = (
|
||||
compiler_utils.additional_input_name_for_pipeline_channel(
|
||||
input_value))
|
||||
assert component_input_parameter in parent_component_inputs.parameters, \
|
||||
f'component_input_parameter: {component_input_parameter} not found. All inputs: {parent_component_inputs}'
|
||||
pipeline_task_spec.inputs.parameters[
|
||||
input_name].component_input_parameter = (
|
||||
component_input_parameter)
|
||||
|
||||
elif isinstance(input_value, for_loop.LoopArgumentVariable):
|
||||
|
||||
component_input_parameter = (
|
||||
compiler_utils.additional_input_name_for_pipeline_channel(
|
||||
input_value.loop_argument))
|
||||
assert component_input_parameter in parent_component_inputs.parameters, \
|
||||
f'component_input_parameter: {component_input_parameter} not found. All inputs: {parent_component_inputs}'
|
||||
pipeline_task_spec.inputs.parameters[
|
||||
input_name].component_input_parameter = (
|
||||
component_input_parameter)
|
||||
pipeline_task_spec.inputs.parameters[
|
||||
input_name].parameter_expression_selector = (
|
||||
f'parseJson(string_value)["{input_value.subvar_name}"]')
|
||||
|
||||
elif isinstance(input_value, str):
|
||||
# Handle extra input due to string concat
|
||||
pipeline_channels = (
|
||||
|
|
@ -572,7 +571,7 @@ def build_importer_spec_for_task(
|
|||
importer_spec.metadata.CopyFrom(metadata_protobuf_struct)
|
||||
|
||||
if isinstance(task.importer_spec.artifact_uri,
|
||||
pipeline_channel.PipelineParameterChannel):
|
||||
pipeline_channel.PipelineChannel):
|
||||
importer_spec.artifact_uri.runtime_parameter = 'uri'
|
||||
elif isinstance(task.importer_spec.artifact_uri, str):
|
||||
importer_spec.artifact_uri.constant.string_value = task.importer_spec.artifact_uri
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ class LoopArgument(pipeline_channel.PipelineParameterChannel):
|
|||
|
||||
|
||||
Attributes:
|
||||
items_or_pipeline_channel: The raw items or the PipelineChannel object
|
||||
items_or_pipeline_channel: The raw items or the PipelineParameterChannel object
|
||||
this LoopArgument is associated to.
|
||||
"""
|
||||
LOOP_ITEM_NAME_BASE = 'loop-item'
|
||||
|
|
@ -85,7 +85,7 @@ class LoopArgument(pipeline_channel.PipelineParameterChannel):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
items: Union[ItemList, pipeline_channel.PipelineChannel],
|
||||
items: Union[ItemList, pipeline_channel.PipelineParameterChannel],
|
||||
name_code: Optional[str] = None,
|
||||
name_override: Optional[str] = None,
|
||||
**kwargs,
|
||||
|
|
@ -99,8 +99,8 @@ class LoopArgument(pipeline_channel.PipelineParameterChannel):
|
|||
name_code: A unique code used to identify these loop arguments.
|
||||
Should match the code for the ParallelFor ops_group which created
|
||||
these LoopArguments. This prevents parameter name collisions.
|
||||
name_override: The override name for PipelineChannel.
|
||||
**kwargs: Any other keyword arguments passed down to PipelineChannel.
|
||||
name_override: The override name for PipelineParameterChannel.
|
||||
**kwargs: Any other keyword arguments passed down to PipelineParameterChannel.
|
||||
"""
|
||||
if (name_code is None) == (name_override is None):
|
||||
raise ValueError(
|
||||
|
|
@ -112,17 +112,19 @@ class LoopArgument(pipeline_channel.PipelineParameterChannel):
|
|||
else:
|
||||
super().__init__(name=name_override, **kwargs)
|
||||
|
||||
if not isinstance(items,
|
||||
(list, tuple, pipeline_channel.PipelineChannel)):
|
||||
if not isinstance(
|
||||
items,
|
||||
(list, tuple, pipeline_channel.PipelineParameterChannel)):
|
||||
raise TypeError(
|
||||
f'Expected list, tuple, or PipelineChannel, got {items}.')
|
||||
f'Expected list, tuple, or PipelineParameterChannel, got {items}.'
|
||||
)
|
||||
|
||||
if isinstance(items, tuple):
|
||||
items = list(items)
|
||||
|
||||
self.items_or_pipeline_channel = items
|
||||
self.is_with_items_loop_argument = not isinstance(
|
||||
items, pipeline_channel.PipelineChannel)
|
||||
items, pipeline_channel.PipelineParameterChannel)
|
||||
self._referenced_subvars: Dict[str, LoopArgumentVariable] = {}
|
||||
|
||||
if isinstance(items, list) and isinstance(items[0], dict):
|
||||
|
|
@ -154,9 +156,10 @@ class LoopArgument(pipeline_channel.PipelineParameterChannel):
|
|||
@classmethod
|
||||
def from_pipeline_channel(
|
||||
cls,
|
||||
channel: pipeline_channel.PipelineChannel,
|
||||
channel: pipeline_channel.PipelineParameterChannel,
|
||||
) -> 'LoopArgument':
|
||||
"""Creates a LoopArgument object from a PipelineChannel object."""
|
||||
"""Creates a LoopArgument object from a PipelineParameterChannel
|
||||
object."""
|
||||
return LoopArgument(
|
||||
items=channel,
|
||||
name_override=channel.name + '-' + cls.LOOP_ITEM_NAME_BASE,
|
||||
|
|
@ -191,7 +194,7 @@ class LoopArgument(pipeline_channel.PipelineParameterChannel):
|
|||
or (cls.LOOP_ITEM_PARAM_NAME_BASE + '-') in name
|
||||
|
||||
|
||||
class LoopArgumentVariable(pipeline_channel.PipelineChannel):
|
||||
class LoopArgumentVariable(pipeline_channel.PipelineParameterChannel):
|
||||
"""Represents a subvariable for a loop argument.
|
||||
|
||||
This is used for cases where we're looping over maps, each of which contains
|
||||
|
|
@ -246,7 +249,7 @@ class LoopArgumentVariable(pipeline_channel.PipelineChannel):
|
|||
|
||||
@property
|
||||
def items_or_pipeline_channel(
|
||||
self) -> Union[ItemList, pipeline_channel.PipelineChannel]:
|
||||
self) -> Union[ItemList, pipeline_channel.PipelineParameterChannel]:
|
||||
"""Returns the loop argument items."""
|
||||
return self.loop_argument.items_or_pipeline_chanenl
|
||||
|
||||
|
|
|
|||
|
|
@ -14,8 +14,8 @@
|
|||
import unittest
|
||||
|
||||
from kfp import dsl
|
||||
from kfp.dsl import Dataset
|
||||
from kfp.dsl import importer_node
|
||||
from kfp.dsl.types.artifact_types import Dataset
|
||||
|
||||
|
||||
class TestImporterSupportsDynamicMetadata(unittest.TestCase):
|
||||
|
|
@ -184,3 +184,37 @@ class TestImporterSupportsDynamicMetadata(unittest.TestCase):
|
|||
"prefix2-{{$.inputs.parameters[\'metadata-2\']}}")
|
||||
self.assertEqual(metadata.struct_value.fields['key'].string_value,
|
||||
'value')
|
||||
|
||||
def test_uri_from_loop(self):
|
||||
|
||||
@dsl.component
|
||||
def make_args() -> list:
|
||||
return [{'uri': 'gs://foo', 'key': 'foo'}]
|
||||
|
||||
@dsl.pipeline
|
||||
def my_pipeline():
|
||||
with dsl.ParallelFor(make_args().output) as data:
|
||||
dsl.importer(
|
||||
artifact_uri=data.uri,
|
||||
artifact_class=Dataset,
|
||||
metadata={'metadata_key': data.key})
|
||||
|
||||
self.assertEqual(
|
||||
my_pipeline.pipeline_spec.deployment_spec['executors']
|
||||
['exec-importer']['importer']['artifactUri']['runtimeParameter'],
|
||||
'uri')
|
||||
self.assertEqual(
|
||||
my_pipeline.pipeline_spec.deployment_spec['executors']
|
||||
['exec-importer']['importer']['metadata']['metadata_key'],
|
||||
"{{$.inputs.parameters[\'metadata\']}}")
|
||||
self.assertEqual(
|
||||
my_pipeline.pipeline_spec.components['comp-for-loop-1'].dag
|
||||
.tasks['importer'].inputs.parameters['metadata']
|
||||
.component_input_parameter,
|
||||
'pipelinechannel--make-args-Output-loop-item')
|
||||
self.assertEqual(
|
||||
my_pipeline.pipeline_spec.components['comp-for-loop-1'].dag
|
||||
.tasks['importer'].inputs.parameters['metadata']
|
||||
.parameter_expression_selector,
|
||||
'parseJson(string_value)["key"]',
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue