fix(tfx): fix missing mlmd data when sdk label is overridden. Fixes #5303 (#6035)

This commit is contained in:
Yuan (Bob) Gong 2021-07-14 13:24:58 +08:00 committed by GitHub
parent ad75187f94
commit a30e093f67
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 3 deletions

View File

@ -114,8 +114,10 @@ def argo_artifact_to_uri(artifact: dict) -> str:
def is_tfx_pod(pod) -> bool:
# Later versions of TFX pods do not match the pattern in command line, but now they have this label.
if pod.metadata.labels.get(KFP_SDK_TYPE_LABEL_KEY) == TFX_SDK_TYPE_VALUE:
# The label defaults to 'tfx', but is overridable.
# Official tfx templates override the value to 'tfx-template', so
# we loosely match the word 'tfx'.
if TFX_SDK_TYPE_VALUE in pod.metadata.labels.get(KFP_SDK_TYPE_LABEL_KEY, ''):
return True
main_containers = [container for container in pod.spec.containers if container.name == 'main']
if len(main_containers) != 1:

View File

@ -282,7 +282,10 @@ func isKFPCacheEnabled(pod *corev1.Pod) bool {
}
func isTFXPod(pod *corev1.Pod) bool {
if pod.Labels[SdkTypeLabel] == TfxSdkTypeLabel {
// The label defaults to 'tfx', but is overridable.
// Official tfx templates override the value to 'tfx-template', so
// we loosely match the word 'tfx'.
if strings.Contains(pod.Labels[SdkTypeLabel], TfxSdkTypeLabel) {
return true
}
containers := pod.Spec.Containers

View File

@ -139,6 +139,12 @@ func TestMutatePodIfCachedWithTFXPod2(t *testing.T) {
patchOperation, err := MutatePodIfCached(GetFakeRequestFromPod(&tfxPod), fakeClientManager)
assert.Nil(t, patchOperation)
assert.Nil(t, err)
// test variation 2
tfxPod = *fakePod.DeepCopy()
tfxPod.Labels["pipelines.kubeflow.org/pipeline-sdk-type"] = "tfx-template"
patchOperation, err = MutatePodIfCached(GetFakeRequestFromPod(&tfxPod), fakeClientManager)
assert.Nil(t, patchOperation)
assert.Nil(t, err)
}
func TestMutatePodIfCachedWithKfpV2Pod(t *testing.T) {