fix: support setting task dependencies via kfp.kubernetes.mount_pvc (#8999)

* enable accessing .task on pipeline channel

* set task dependencies in mount_pvc

* update tests
This commit is contained in:
Connor McCarthy 2023-03-16 21:10:54 -07:00 committed by GitHub
parent 328243b6ea
commit 2bbfd5e89f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 42 additions and 4 deletions

View File

@ -64,7 +64,7 @@ def CreatePVC(
def mount_pvc(
task: PipelineTask,
pvc_name: str,
pvc_name: Union[str, 'PipelineChannel'],
mount_path: str,
) -> PipelineTask:
"""Mount a PersistentVolumeClaim to the task's container.
@ -81,7 +81,9 @@ def mount_pvc(
msg = common.get_existing_kubernetes_config_as_message(task)
pvc_mount = pb.PvcMount(mount_path=mount_path)
_assign_pvc_name_to_msg(pvc_mount, pvc_name)
pvc_name_from_task = _assign_pvc_name_to_msg(pvc_mount, pvc_name)
if pvc_name_from_task:
task.after(pvc_name.task)
msg.pvc_mount.append(pvc_mount)
task.platform_config['kubernetes'] = json_format.MessageToDict(msg)
@ -99,16 +101,22 @@ def DeletePVC(pvc_name: str):
return dsl.ContainerSpec(image='argostub/deletepvc')
def _assign_pvc_name_to_msg(msg: message.Message,
pvc_name: Union[str, 'PipelineChannel']) -> None:
def _assign_pvc_name_to_msg(
msg: message.Message,
pvc_name: Union[str, 'PipelineChannel'],
) -> bool:
"""Assigns pvc_name to the msg's pvc_reference oneof. Returns True if pvc_name is an upstream task output. Else, returns False."""
if isinstance(pvc_name, str):
msg.constant = pvc_name
return False
elif hasattr(pvc_name, 'task_name'):
if pvc_name.task_name is None:
msg.component_input_parameter = pvc_name.name
return False
else:
msg.task_output_parameter.producer_task = pvc_name.task_name
msg.task_output_parameter.output_parameter_key = pvc_name.name
return True
else:
raise ValueError(
f'Argument for {"pvc_name"!r} must be an instance of str or PipelineChannel. Got unknown input type: {type(pvc_name)!r}. '

View File

@ -109,6 +109,8 @@ root:
enableCache: true
componentRef:
name: comp-comp
dependentTasks:
- createpvc
taskInfo:
name: comp
comp-2:
@ -118,6 +120,7 @@ root:
name: comp-comp-2
dependentTasks:
- comp
- createpvc
taskInfo:
name: comp-2
createpvc:

View File

@ -81,6 +81,8 @@ root:
enableCache: true
componentRef:
name: comp-comp
dependentTasks:
- createpvc
taskInfo:
name: comp
createpvc:

View File

@ -113,6 +113,8 @@ root:
enableCache: true
componentRef:
name: comp-comp
dependentTasks:
- createpvc
taskInfo:
name: comp
createpvc:

View File

@ -97,6 +97,14 @@ class PipelineChannel(abc.ABC):
# so that serialization and unserialization remain consistent
# (i.e. None => '' => None)
self.task_name = task_name or None
from kfp.components import pipeline_context
default_pipeline = pipeline_context.Pipeline.get_default_pipeline()
if self.task_name is not None and default_pipeline is not None and default_pipeline.tasks:
self.task = pipeline_context.Pipeline.get_default_pipeline().tasks[
self.task_name]
else:
self.task = None
@property
def full_name(self) -> str:

View File

@ -16,6 +16,7 @@
import unittest
from absl.testing import parameterized
from kfp import dsl
from kfp.components import pipeline_channel
@ -155,5 +156,19 @@ class PipelineChannelTest(parameterized.TestCase):
self.assertListEqual([p1, p2, p3], params)
class TestCanAccessTask(unittest.TestCase):
def test(self):
@dsl.component
def comp() -> str:
return 'text'
@dsl.pipeline
def my_pipeline():
op1 = comp()
self.assertEqual(op1.output.task, op1)
if __name__ == '__main__':
unittest.main()