[SDK] Add first party component label (#3861)
* add OOB component dict and utility function * add test * add a transformer, which appends the component name label * add transformer function, compiler and test * move telemetry test * fix none uri * applies comments * revert dependency on frozendict * fixes some tests * resolve comments
This commit is contained in:
parent
da4acbbd73
commit
1e2b9d4e7e
|
@ -12,7 +12,9 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import re
|
||||||
from typing import Callable, Dict, Optional, Text
|
from typing import Callable, Dict, Optional, Text
|
||||||
|
|
||||||
from ..dsl._container_op import BaseOp, ContainerOp
|
from ..dsl._container_op import BaseOp, ContainerOp
|
||||||
|
|
||||||
# Pod label indicating the SDK type from which the pipeline is
|
# Pod label indicating the SDK type from which the pipeline is
|
||||||
|
@ -20,6 +22,16 @@ from ..dsl._container_op import BaseOp, ContainerOp
|
||||||
_SDK_ENV_LABEL = 'pipelines.kubeflow.org/pipeline-sdk-type'
|
_SDK_ENV_LABEL = 'pipelines.kubeflow.org/pipeline-sdk-type'
|
||||||
_SDK_ENV_DEFAULT = 'kfp'
|
_SDK_ENV_DEFAULT = 'kfp'
|
||||||
|
|
||||||
|
# Common prefix of KFP OOB components url paths.
|
||||||
|
_OOB_COMPONENT_PATH_PREFIX = 'https://raw.githubusercontent.com/kubeflow/'\
|
||||||
|
'pipelines'
|
||||||
|
|
||||||
|
# Key for component origin path pod label.
|
||||||
|
COMPONENT_PATH_LABEL_KEY = 'pipelines.kubeflow.org/component_origin_path'
|
||||||
|
|
||||||
|
# Key for component spec digest pod label.
|
||||||
|
COMPONENT_DIGEST_LABEL_KEY = 'pipelines.kubeflow.org/component_digest'
|
||||||
|
|
||||||
|
|
||||||
def get_default_telemetry_labels() -> Dict[Text, Text]:
|
def get_default_telemetry_labels() -> Dict[Text, Text]:
|
||||||
"""Returns the default pod labels for telemetry purpose."""
|
"""Returns the default pod labels for telemetry purpose."""
|
||||||
|
@ -68,3 +80,40 @@ def add_pod_labels(labels: Optional[Dict[Text, Text]] = None) -> Callable:
|
||||||
return task
|
return task
|
||||||
|
|
||||||
return _add_pod_labels
|
return _add_pod_labels
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_suffix(string: Text, suffix: Text) -> Text:
|
||||||
|
"""Removes the suffix from a string."""
|
||||||
|
if suffix and string.endswith(suffix):
|
||||||
|
return string[:-len(suffix)]
|
||||||
|
else:
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def add_name_for_oob_components() -> Callable:
|
||||||
|
"""Adds the OOB component name if applicable."""
|
||||||
|
|
||||||
|
def _add_name_for_oob_components(task):
|
||||||
|
# Detect the component origin uri in component_ref if exists, and
|
||||||
|
# attach the OOB component name as a pod label.
|
||||||
|
component_ref = getattr(task, '_component_ref', None)
|
||||||
|
if component_ref:
|
||||||
|
if component_ref.url:
|
||||||
|
origin_path = _remove_suffix(
|
||||||
|
component_ref.url, 'component.yaml').rstrip('/')
|
||||||
|
# Only include KFP OOB components.
|
||||||
|
if origin_path.startswith(_OOB_COMPONENT_PATH_PREFIX):
|
||||||
|
origin_path = origin_path.split('/', 7)[-1]
|
||||||
|
else:
|
||||||
|
return task
|
||||||
|
# Clean the label to comply with the k8s label convention.
|
||||||
|
origin_path = re.sub('[^-a-z0-9A-Z_.]', '.', origin_path)
|
||||||
|
origin_path_label = origin_path[-63:].strip('-_.')
|
||||||
|
task.add_pod_label(COMPONENT_PATH_LABEL_KEY, origin_path_label)
|
||||||
|
if component_ref.digest:
|
||||||
|
task.add_pod_label(
|
||||||
|
COMPONENT_DIGEST_LABEL_KEY, component_ref.digest)
|
||||||
|
|
||||||
|
return task
|
||||||
|
|
||||||
|
return _add_name_for_oob_components
|
|
@ -27,7 +27,7 @@ from kfp.dsl import _for_loop
|
||||||
from .. import dsl
|
from .. import dsl
|
||||||
from ._k8s_helper import convert_k8s_obj_to_json, sanitize_k8s_name
|
from ._k8s_helper import convert_k8s_obj_to_json, sanitize_k8s_name
|
||||||
from ._op_to_template import _op_to_template
|
from ._op_to_template import _op_to_template
|
||||||
from ._default_transformers import add_pod_env, add_pod_labels, get_default_telemetry_labels
|
from ._default_transformers import add_pod_env, add_pod_labels, add_name_for_oob_components, get_default_telemetry_labels
|
||||||
|
|
||||||
from ..components.structures import InputSpec
|
from ..components.structures import InputSpec
|
||||||
from ..components._yaml_utils import dump_yaml
|
from ..components._yaml_utils import dump_yaml
|
||||||
|
@ -836,6 +836,7 @@ class Compiler(object):
|
||||||
if allow_telemetry:
|
if allow_telemetry:
|
||||||
pod_labels = get_default_telemetry_labels()
|
pod_labels = get_default_telemetry_labels()
|
||||||
op_transformers.append(add_pod_labels(pod_labels))
|
op_transformers.append(add_pod_labels(pod_labels))
|
||||||
|
op_transformers.append(add_name_for_oob_components())
|
||||||
|
|
||||||
op_transformers.extend(pipeline_conf.op_transformers)
|
op_transformers.extend(pipeline_conf.op_transformers)
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,8 @@ import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
from kfp import components
|
||||||
|
from kfp.compiler._default_transformers import COMPONENT_DIGEST_LABEL_KEY, COMPONENT_PATH_LABEL_KEY
|
||||||
from kfp.dsl._component import component
|
from kfp.dsl._component import component
|
||||||
from kfp.dsl import ContainerOp, pipeline
|
from kfp.dsl import ContainerOp, pipeline
|
||||||
from kfp.dsl.types import Integer, InconsistentTypeException
|
from kfp.dsl.types import Integer, InconsistentTypeException
|
||||||
|
@ -40,6 +42,11 @@ def some_op():
|
||||||
command=['sleep 1'],
|
command=['sleep 1'],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_TEST_GCS_DOWNLOAD_COMPONENT_URL = 'https://raw.githubusercontent.com/kubeflow/'\
|
||||||
|
'pipelines/2dac60c400ad8767b452649d08f328df'\
|
||||||
|
'af230f96/components/google-cloud/storage/'\
|
||||||
|
'download/component.yaml'
|
||||||
|
|
||||||
|
|
||||||
class TestCompiler(unittest.TestCase):
|
class TestCompiler(unittest.TestCase):
|
||||||
# Define the places of samples covered by unit tests.
|
# Define the places of samples covered by unit tests.
|
||||||
|
@ -711,6 +718,27 @@ implementation:
|
||||||
container = template.get('container', None)
|
container = template.get('container', None)
|
||||||
if container:
|
if container:
|
||||||
self.assertEqual(template['retryStrategy']['limit'], 5)
|
self.assertEqual(template['retryStrategy']['limit'], 5)
|
||||||
|
|
||||||
|
def test_oob_component_label(self):
|
||||||
|
gcs_download_op = components.load_component_from_url(
|
||||||
|
_TEST_GCS_DOWNLOAD_COMPONENT_URL)
|
||||||
|
|
||||||
|
@dsl.pipeline(name='some_pipeline')
|
||||||
|
def some_pipeline():
|
||||||
|
_download_task = gcs_download_op('gs://some_bucket/some_dir/some_file')
|
||||||
|
|
||||||
|
workflow_dict = compiler.Compiler()._compile(some_pipeline)
|
||||||
|
|
||||||
|
found_download_task = False
|
||||||
|
for template in workflow_dict['spec']['templates']:
|
||||||
|
if template.get('container', None):
|
||||||
|
found_download_task = True
|
||||||
|
self.assertEqual(
|
||||||
|
template['metadata']['labels'][COMPONENT_PATH_LABEL_KEY],
|
||||||
|
'google-cloud.storage.download')
|
||||||
|
self.assertIsNotNone(
|
||||||
|
template['metadata']['labels'].get(COMPONENT_DIGEST_LABEL_KEY))
|
||||||
|
self.assertTrue(found_download_task, 'download task not found in workflow.')
|
||||||
|
|
||||||
def test_image_pull_policy(self):
|
def test_image_pull_policy(self):
|
||||||
def some_op():
|
def some_op():
|
||||||
|
|
Loading…
Reference in New Issue