diff --git a/backend/metadata_writer/src/metadata_writer.py b/backend/metadata_writer/src/metadata_writer.py index 62a7a502f9..4f31bc9dad 100644 --- a/backend/metadata_writer/src/metadata_writer.py +++ b/backend/metadata_writer/src/metadata_writer.py @@ -114,8 +114,10 @@ def argo_artifact_to_uri(artifact: dict) -> str: def is_tfx_pod(pod) -> bool: - # Later versions of TFX pods do not match the pattern in command line, but now they have this label. - if pod.metadata.labels.get(KFP_SDK_TYPE_LABEL_KEY) == TFX_SDK_TYPE_VALUE: + # The label defaults to 'tfx', but is overridable. + # Official tfx templates override the value to 'tfx-template', so + # we loosely match the word 'tfx'. + if TFX_SDK_TYPE_VALUE in pod.metadata.labels.get(KFP_SDK_TYPE_LABEL_KEY, ''): return True main_containers = [container for container in pod.spec.containers if container.name == 'main'] if len(main_containers) != 1: diff --git a/backend/src/cache/server/mutation.go b/backend/src/cache/server/mutation.go index 4c1b13272d..9b7e7d3876 100644 --- a/backend/src/cache/server/mutation.go +++ b/backend/src/cache/server/mutation.go @@ -282,7 +282,10 @@ func isKFPCacheEnabled(pod *corev1.Pod) bool { } func isTFXPod(pod *corev1.Pod) bool { - if pod.Labels[SdkTypeLabel] == TfxSdkTypeLabel { + // The label defaults to 'tfx', but is overridable. + // Official tfx templates override the value to 'tfx-template', so + // we loosely match the word 'tfx'. + if strings.Contains(pod.Labels[SdkTypeLabel], TfxSdkTypeLabel) { return true } containers := pod.Spec.Containers diff --git a/backend/src/cache/server/mutation_test.go b/backend/src/cache/server/mutation_test.go index cfc1cab10e..23c019df15 100644 --- a/backend/src/cache/server/mutation_test.go +++ b/backend/src/cache/server/mutation_test.go @@ -139,6 +139,12 @@ func TestMutatePodIfCachedWithTFXPod2(t *testing.T) { patchOperation, err := MutatePodIfCached(GetFakeRequestFromPod(&tfxPod), fakeClientManager) assert.Nil(t, patchOperation) assert.Nil(t, err) + // test variation 2 + tfxPod = *fakePod.DeepCopy() + tfxPod.Labels["pipelines.kubeflow.org/pipeline-sdk-type"] = "tfx-template" + patchOperation, err = MutatePodIfCached(GetFakeRequestFromPod(&tfxPod), fakeClientManager) + assert.Nil(t, patchOperation) + assert.Nil(t, err) } func TestMutatePodIfCachedWithKfpV2Pod(t *testing.T) {