This commit is contained in:
parent
ad75187f94
commit
a30e093f67
|
|
@ -114,8 +114,10 @@ def argo_artifact_to_uri(artifact: dict) -> str:
|
||||||
|
|
||||||
|
|
||||||
def is_tfx_pod(pod) -> bool:
|
def is_tfx_pod(pod) -> bool:
|
||||||
# Later versions of TFX pods do not match the pattern in command line, but now they have this label.
|
# The label defaults to 'tfx', but is overridable.
|
||||||
if pod.metadata.labels.get(KFP_SDK_TYPE_LABEL_KEY) == TFX_SDK_TYPE_VALUE:
|
# Official tfx templates override the value to 'tfx-template', so
|
||||||
|
# we loosely match the word 'tfx'.
|
||||||
|
if TFX_SDK_TYPE_VALUE in pod.metadata.labels.get(KFP_SDK_TYPE_LABEL_KEY, ''):
|
||||||
return True
|
return True
|
||||||
main_containers = [container for container in pod.spec.containers if container.name == 'main']
|
main_containers = [container for container in pod.spec.containers if container.name == 'main']
|
||||||
if len(main_containers) != 1:
|
if len(main_containers) != 1:
|
||||||
|
|
|
||||||
|
|
@ -282,7 +282,10 @@ func isKFPCacheEnabled(pod *corev1.Pod) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func isTFXPod(pod *corev1.Pod) bool {
|
func isTFXPod(pod *corev1.Pod) bool {
|
||||||
if pod.Labels[SdkTypeLabel] == TfxSdkTypeLabel {
|
// The label defaults to 'tfx', but is overridable.
|
||||||
|
// Official tfx templates override the value to 'tfx-template', so
|
||||||
|
// we loosely match the word 'tfx'.
|
||||||
|
if strings.Contains(pod.Labels[SdkTypeLabel], TfxSdkTypeLabel) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
containers := pod.Spec.Containers
|
containers := pod.Spec.Containers
|
||||||
|
|
|
||||||
|
|
@ -139,6 +139,12 @@ func TestMutatePodIfCachedWithTFXPod2(t *testing.T) {
|
||||||
patchOperation, err := MutatePodIfCached(GetFakeRequestFromPod(&tfxPod), fakeClientManager)
|
patchOperation, err := MutatePodIfCached(GetFakeRequestFromPod(&tfxPod), fakeClientManager)
|
||||||
assert.Nil(t, patchOperation)
|
assert.Nil(t, patchOperation)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
// test variation 2
|
||||||
|
tfxPod = *fakePod.DeepCopy()
|
||||||
|
tfxPod.Labels["pipelines.kubeflow.org/pipeline-sdk-type"] = "tfx-template"
|
||||||
|
patchOperation, err = MutatePodIfCached(GetFakeRequestFromPod(&tfxPod), fakeClientManager)
|
||||||
|
assert.Nil(t, patchOperation)
|
||||||
|
assert.Nil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMutatePodIfCachedWithKfpV2Pod(t *testing.T) {
|
func TestMutatePodIfCachedWithKfpV2Pod(t *testing.T) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue