fix: support setting task dependencies via kfp.kubernetes.mount_pvc (#8999)
* enable accessing .task on pipeline channel * set task dependencies in mount_pvc * update tests
This commit is contained in:
parent
328243b6ea
commit
2bbfd5e89f
|
@ -64,7 +64,7 @@ def CreatePVC(
|
|||
|
||||
def mount_pvc(
|
||||
task: PipelineTask,
|
||||
pvc_name: str,
|
||||
pvc_name: Union[str, 'PipelineChannel'],
|
||||
mount_path: str,
|
||||
) -> PipelineTask:
|
||||
"""Mount a PersistentVolumeClaim to the task's container.
|
||||
|
@ -81,7 +81,9 @@ def mount_pvc(
|
|||
msg = common.get_existing_kubernetes_config_as_message(task)
|
||||
|
||||
pvc_mount = pb.PvcMount(mount_path=mount_path)
|
||||
_assign_pvc_name_to_msg(pvc_mount, pvc_name)
|
||||
pvc_name_from_task = _assign_pvc_name_to_msg(pvc_mount, pvc_name)
|
||||
if pvc_name_from_task:
|
||||
task.after(pvc_name.task)
|
||||
|
||||
msg.pvc_mount.append(pvc_mount)
|
||||
task.platform_config['kubernetes'] = json_format.MessageToDict(msg)
|
||||
|
@ -99,16 +101,22 @@ def DeletePVC(pvc_name: str):
|
|||
return dsl.ContainerSpec(image='argostub/deletepvc')
|
||||
|
||||
|
||||
def _assign_pvc_name_to_msg(msg: message.Message,
|
||||
pvc_name: Union[str, 'PipelineChannel']) -> None:
|
||||
def _assign_pvc_name_to_msg(
|
||||
msg: message.Message,
|
||||
pvc_name: Union[str, 'PipelineChannel'],
|
||||
) -> bool:
|
||||
"""Assigns pvc_name to the msg's pvc_reference oneof. Returns True if pvc_name is an upstream task output. Else, returns False."""
|
||||
if isinstance(pvc_name, str):
|
||||
msg.constant = pvc_name
|
||||
return False
|
||||
elif hasattr(pvc_name, 'task_name'):
|
||||
if pvc_name.task_name is None:
|
||||
msg.component_input_parameter = pvc_name.name
|
||||
return False
|
||||
else:
|
||||
msg.task_output_parameter.producer_task = pvc_name.task_name
|
||||
msg.task_output_parameter.output_parameter_key = pvc_name.name
|
||||
return True
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Argument for {"pvc_name"!r} must be an instance of str or PipelineChannel. Got unknown input type: {type(pvc_name)!r}. '
|
||||
|
|
|
@ -109,6 +109,8 @@ root:
|
|||
enableCache: true
|
||||
componentRef:
|
||||
name: comp-comp
|
||||
dependentTasks:
|
||||
- createpvc
|
||||
taskInfo:
|
||||
name: comp
|
||||
comp-2:
|
||||
|
@ -118,6 +120,7 @@ root:
|
|||
name: comp-comp-2
|
||||
dependentTasks:
|
||||
- comp
|
||||
- createpvc
|
||||
taskInfo:
|
||||
name: comp-2
|
||||
createpvc:
|
||||
|
|
|
@ -81,6 +81,8 @@ root:
|
|||
enableCache: true
|
||||
componentRef:
|
||||
name: comp-comp
|
||||
dependentTasks:
|
||||
- createpvc
|
||||
taskInfo:
|
||||
name: comp
|
||||
createpvc:
|
||||
|
|
|
@ -113,6 +113,8 @@ root:
|
|||
enableCache: true
|
||||
componentRef:
|
||||
name: comp-comp
|
||||
dependentTasks:
|
||||
- createpvc
|
||||
taskInfo:
|
||||
name: comp
|
||||
createpvc:
|
||||
|
|
|
@ -97,6 +97,14 @@ class PipelineChannel(abc.ABC):
|
|||
# so that serialization and unserialization remain consistent
|
||||
# (i.e. None => '' => None)
|
||||
self.task_name = task_name or None
|
||||
from kfp.components import pipeline_context
|
||||
|
||||
default_pipeline = pipeline_context.Pipeline.get_default_pipeline()
|
||||
if self.task_name is not None and default_pipeline is not None and default_pipeline.tasks:
|
||||
self.task = pipeline_context.Pipeline.get_default_pipeline().tasks[
|
||||
self.task_name]
|
||||
else:
|
||||
self.task = None
|
||||
|
||||
@property
|
||||
def full_name(self) -> str:
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import unittest
|
||||
|
||||
from absl.testing import parameterized
|
||||
from kfp import dsl
|
||||
from kfp.components import pipeline_channel
|
||||
|
||||
|
||||
|
@ -155,5 +156,19 @@ class PipelineChannelTest(parameterized.TestCase):
|
|||
self.assertListEqual([p1, p2, p3], params)
|
||||
|
||||
|
||||
class TestCanAccessTask(unittest.TestCase):
|
||||
|
||||
def test(self):
|
||||
|
||||
@dsl.component
|
||||
def comp() -> str:
|
||||
return 'text'
|
||||
|
||||
@dsl.pipeline
|
||||
def my_pipeline():
|
||||
op1 = comp()
|
||||
self.assertEqual(op1.output.task, op1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue