diff --git a/kubernetes_platform/python/kfp/kubernetes/volume.py b/kubernetes_platform/python/kfp/kubernetes/volume.py index 49e46532d8..3af850600a 100644 --- a/kubernetes_platform/python/kfp/kubernetes/volume.py +++ b/kubernetes_platform/python/kfp/kubernetes/volume.py @@ -64,7 +64,7 @@ def CreatePVC( def mount_pvc( task: PipelineTask, - pvc_name: str, + pvc_name: Union[str, 'PipelineChannel'], mount_path: str, ) -> PipelineTask: """Mount a PersistentVolumeClaim to the task's container. @@ -81,7 +81,9 @@ def mount_pvc( msg = common.get_existing_kubernetes_config_as_message(task) pvc_mount = pb.PvcMount(mount_path=mount_path) - _assign_pvc_name_to_msg(pvc_mount, pvc_name) + pvc_name_from_task = _assign_pvc_name_to_msg(pvc_mount, pvc_name) + if pvc_name_from_task: + task.after(pvc_name.task) msg.pvc_mount.append(pvc_mount) task.platform_config['kubernetes'] = json_format.MessageToDict(msg) @@ -99,16 +101,22 @@ def DeletePVC(pvc_name: str): return dsl.ContainerSpec(image='argostub/deletepvc') -def _assign_pvc_name_to_msg(msg: message.Message, - pvc_name: Union[str, 'PipelineChannel']) -> None: +def _assign_pvc_name_to_msg( + msg: message.Message, + pvc_name: Union[str, 'PipelineChannel'], +) -> bool: + """Assigns pvc_name to the msg's pvc_reference oneof. Returns True if pvc_name is an upstream task output. Else, returns False.""" if isinstance(pvc_name, str): msg.constant = pvc_name + return False elif hasattr(pvc_name, 'task_name'): if pvc_name.task_name is None: msg.component_input_parameter = pvc_name.name + return False else: msg.task_output_parameter.producer_task = pvc_name.task_name msg.task_output_parameter.output_parameter_key = pvc_name.name + return True else: raise ValueError( f'Argument for {"pvc_name"!r} must be an instance of str or PipelineChannel. Got unknown input type: {type(pvc_name)!r}. ' diff --git a/kubernetes_platform/python/test/snapshot/data/create_mount_delete_dynamic_pvc.yaml b/kubernetes_platform/python/test/snapshot/data/create_mount_delete_dynamic_pvc.yaml index f0065b4f40..1e6142a253 100644 --- a/kubernetes_platform/python/test/snapshot/data/create_mount_delete_dynamic_pvc.yaml +++ b/kubernetes_platform/python/test/snapshot/data/create_mount_delete_dynamic_pvc.yaml @@ -109,6 +109,8 @@ root: enableCache: true componentRef: name: comp-comp + dependentTasks: + - createpvc taskInfo: name: comp comp-2: @@ -118,6 +120,7 @@ root: name: comp-comp-2 dependentTasks: - comp + - createpvc taskInfo: name: comp-2 createpvc: diff --git a/kubernetes_platform/python/test/snapshot/data/create_mount_delete_existing_pvc.yaml b/kubernetes_platform/python/test/snapshot/data/create_mount_delete_existing_pvc.yaml index eae5ed2d3f..d06fbdc422 100644 --- a/kubernetes_platform/python/test/snapshot/data/create_mount_delete_existing_pvc.yaml +++ b/kubernetes_platform/python/test/snapshot/data/create_mount_delete_existing_pvc.yaml @@ -81,6 +81,8 @@ root: enableCache: true componentRef: name: comp-comp + dependentTasks: + - createpvc taskInfo: name: comp createpvc: diff --git a/kubernetes_platform/python/test/snapshot/data/create_mount_delete_existing_pvc_from_task_output.yaml b/kubernetes_platform/python/test/snapshot/data/create_mount_delete_existing_pvc_from_task_output.yaml index 4eae38e4df..0876ab3427 100644 --- a/kubernetes_platform/python/test/snapshot/data/create_mount_delete_existing_pvc_from_task_output.yaml +++ b/kubernetes_platform/python/test/snapshot/data/create_mount_delete_existing_pvc_from_task_output.yaml @@ -113,6 +113,8 @@ root: enableCache: true componentRef: name: comp-comp + dependentTasks: + - createpvc taskInfo: name: comp createpvc: diff --git a/sdk/python/kfp/components/pipeline_channel.py b/sdk/python/kfp/components/pipeline_channel.py index 95344d412c..26ad27eae3 100644 --- a/sdk/python/kfp/components/pipeline_channel.py +++ b/sdk/python/kfp/components/pipeline_channel.py @@ -97,6 +97,14 @@ class PipelineChannel(abc.ABC): # so that serialization and unserialization remain consistent # (i.e. None => '' => None) self.task_name = task_name or None + from kfp.components import pipeline_context + + default_pipeline = pipeline_context.Pipeline.get_default_pipeline() + if self.task_name is not None and default_pipeline is not None and default_pipeline.tasks: + self.task = pipeline_context.Pipeline.get_default_pipeline().tasks[ + self.task_name] + else: + self.task = None @property def full_name(self) -> str: diff --git a/sdk/python/kfp/components/pipeline_channel_test.py b/sdk/python/kfp/components/pipeline_channel_test.py index 334349f97d..060fe4ad23 100644 --- a/sdk/python/kfp/components/pipeline_channel_test.py +++ b/sdk/python/kfp/components/pipeline_channel_test.py @@ -16,6 +16,7 @@ import unittest from absl.testing import parameterized +from kfp import dsl from kfp.components import pipeline_channel @@ -155,5 +156,19 @@ class PipelineChannelTest(parameterized.TestCase): self.assertListEqual([p1, p2, p3], params) +class TestCanAccessTask(unittest.TestCase): + + def test(self): + + @dsl.component + def comp() -> str: + return 'text' + + @dsl.pipeline + def my_pipeline(): + op1 = comp() + self.assertEqual(op1.output.task, op1) + + if __name__ == '__main__': unittest.main()