From 0bd54b436dfbb91f7ab8853f5e391479e9f0c4ca Mon Sep 17 00:00:00 2001 From: capri-xiyue <52932582+capri-xiyue@users.noreply.github.com> Date: Thu, 29 Apr 2021 21:47:30 -0700 Subject: [PATCH] chore(v2): add name as custom property of MLMD artifacts (#5567) * add name as custom property of MLMD artifacts * fixed test * fixed test * added custom_properties * fixed e2e tests * fixed tests * fixed test --- samples/test/two_step_test.py | 35 +++++------- samples/test/util.py | 89 ++++++++++++++++++------------- v2/component/runtime_info.go | 1 + v2/component/runtime_info_test.go | 8 ++- 4 files changed, 73 insertions(+), 60 deletions(-) diff --git a/samples/test/two_step_test.py b/samples/test/two_step_test.py index 348856dce9..92fbabdc13 100644 --- a/samples/test/two_step_test.py +++ b/samples/test/two_step_test.py @@ -15,11 +15,7 @@ # %% -import sys -import logging import unittest -from dataclasses import dataclass, asdict -from typing import Tuple from pprint import pprint import kfp @@ -28,13 +24,10 @@ import kfp_server_api from .two_step import two_step_pipeline from .util import run_pipeline_func, TestCase, KfpMlmdClient -from ml_metadata import metadata_store -from ml_metadata.proto import metadata_store_pb2 - def verify( - run: kfp_server_api.ApiRun, mlmd_connection_config, argo_workflow_name: str, - **kwargs + run: kfp_server_api.ApiRun, mlmd_connection_config, argo_workflow_name: str, + **kwargs ): t = unittest.TestCase() t.maxDiff = None # we always want to see full diff @@ -68,10 +61,10 @@ def verify( } }, 'outputs': { - 'artifacts': [{ - 'name': '', - 'type': 'system.Dataset' - }], + 'artifacts': [{'custom_properties': {'name': 'output_dataset_one'}, + 'name': 'output_dataset_one', + 'type': 'system.Dataset' + }], 'parameters': { 'output_parameter_one': 1234 } @@ -83,19 +76,19 @@ def verify( train.get_dict(), { 'name': 'train-op', 'inputs': { - 'artifacts': [{ - 'name': '', - 'type': 'system.Dataset', - }], + 'artifacts': [{'custom_properties': {'name': 'output_dataset_one'}, + 'name': 'output_dataset_one', + 'type': 'system.Dataset', + }], 'parameters': { 'num_steps': 1234 } }, 'outputs': { - 'artifacts': [{ - 'name': '', - 'type': 'system.Model', - }], + 'artifacts': [{'custom_properties': {'name': 'model'}, + 'name': 'model', + 'type': 'system.Model', + }], 'parameters': {} }, 'type': 'kfp.ContainerExecution' diff --git a/samples/test/util.py b/samples/test/util.py index 67d6e9c958..3c5aabefe7 100644 --- a/samples/test/util.py +++ b/samples/test/util.py @@ -1,10 +1,10 @@ -import os -import logging import json +import logging +import os import random +from dataclasses import dataclass, asdict from pprint import pprint from typing import Dict, List, Callable, Optional -from dataclasses import dataclass, asdict import kfp import kfp_server_api @@ -17,9 +17,9 @@ MINUTE = 60 # Add **kwargs, so that when new arguments are added, this doesn't fail for # unknown arguments. def _default_verify_func( - run_id: int, run: kfp_server_api.ApiRun, - mlmd_connection_config: metadata_store_pb2.MetadataStoreClientConfig, - **kwargs + run_id: int, run: kfp_server_api.ApiRun, + mlmd_connection_config: metadata_store_pb2.MetadataStoreClientConfig, + **kwargs ): assert run.status == 'Succeeded' @@ -36,9 +36,9 @@ class TestCase: mode: kfp.dsl.PipelineExecutionMode = kfp.dsl.PipelineExecutionMode.V2_COMPATIBLE arguments: Optional[Dict[str, str]] = None verify_func: Callable[[ - int, kfp_server_api.ApiRun, kfp_server_api. - ApiRunDetail, metadata_store_pb2.MetadataStoreClientConfig - ], None] = _default_verify_func + int, kfp_server_api.ApiRun, kfp_server_api. + ApiRunDetail, metadata_store_pb2.MetadataStoreClientConfig + ], None] = _default_verify_func def run_pipeline_func(test_cases: List[TestCase]): @@ -49,9 +49,9 @@ def run_pipeline_func(test_cases: List[TestCase]): """ def test_wrapper( - run_pipeline: Callable[[Callable, kfp.dsl.PipelineExecutionMode, dict], - kfp_server_api.ApiRunDetail], - mlmd_connection_config: metadata_store_pb2.MetadataStoreClientConfig, + run_pipeline: Callable[[Callable, kfp.dsl.PipelineExecutionMode, dict], + kfp_server_api.ApiRunDetail], + mlmd_connection_config: metadata_store_pb2.MetadataStoreClientConfig, ): for case in test_cases: run_detail = run_pipeline( @@ -83,22 +83,21 @@ def _retry_with_backoff(fn: Callable, retries=5, backoff_in_seconds=1): print(f"Failed after {retires} retries:") raise else: - sleep = (backoff_in_seconds * 2**i + random.uniform(0, 1)) + sleep = (backoff_in_seconds * 2 ** i + random.uniform(0, 1)) print(" Sleep :", str(sleep) + "s") time.sleep(sleep) i += 1 def _run_test(callback): - def main( - output_directory: Optional[str] = None, # example - host: Optional[str] = None, - external_host: Optional[str] = None, - launcher_image: Optional['URI'] = None, - experiment: str = 'v2_sample_test_samples', - metadata_service_host: Optional[str] = None, - metadata_service_port: int = 8080, + output_directory: Optional[str] = None, # example + host: Optional[str] = None, + external_host: Optional[str] = None, + launcher_image: Optional['URI'] = None, + experiment: str = 'v2_sample_test_samples', + metadata_service_host: Optional[str] = None, + metadata_service_port: int = 8080, ): """Test file CLI entrypoint used by Fire. @@ -137,10 +136,10 @@ def _run_test(callback): client = kfp.Client(host=host) def run_pipeline( - pipeline_func: Callable, - mode: kfp.dsl.PipelineExecutionMode = kfp.dsl.PipelineExecutionMode. - V2_COMPATIBLE, - arguments: dict = {}, + pipeline_func: Callable, + mode: kfp.dsl.PipelineExecutionMode = kfp.dsl.PipelineExecutionMode. + V2_COMPATIBLE, + arguments: dict = {}, ) -> kfp_server_api.ApiRunDetail: extra_arguments = {} if mode != kfp.dsl.PipelineExecutionMode.V1_LEGACY: @@ -204,16 +203,32 @@ class KfpArtifact: name: str uri: str type: str + custom_properties: dict @classmethod def new( - cls, mlmd_artifact: metadata_store_pb2.Artifact, - mlmd_artifact_type: metadata_store_pb2.ArtifactType + cls, mlmd_artifact: metadata_store_pb2.Artifact, + mlmd_artifact_type: metadata_store_pb2.ArtifactType ): + custom_properties = {} + for k, v in mlmd_artifact.custom_properties.items(): + raw_value = None + if v.string_value: + raw_value = v.string_value + if v.int_value: + raw_value = v.int_value + custom_properties[k] = raw_value + artifact_name = '' + if mlmd_artifact.name != '': + artifact_name = mlmd_artifact.name + else: + if 'name' in custom_properties.keys(): + artifact_name = custom_properties['name'] return cls( - name=mlmd_artifact.name, + name=artifact_name, type=mlmd_artifact_type.name, uri=mlmd_artifact.uri, + custom_properties=custom_properties ) @@ -248,13 +263,13 @@ class KfpTask: @classmethod def new( - cls, - context: metadata_store_pb2.Context, - execution: metadata_store_pb2.Execution, - execution_types_by_id, # dict[int, metadata_store_pb2.ExecutionType] - events_by_execution_id, # dict[int, List[metadata_store_pb2.Event]] - artifacts_by_id, # dict[int, metadata_store_pb2.Artifact] - artifact_types_by_id, # dict[int, metadata_store_pb2.ArtifactType] + cls, + context: metadata_store_pb2.Context, + execution: metadata_store_pb2.Execution, + execution_types_by_id, # dict[int, metadata_store_pb2.ExecutionType] + events_by_execution_id, # dict[int, List[metadata_store_pb2.Event]] + artifacts_by_id, # dict[int, metadata_store_pb2.Artifact] + artifact_types_by_id, # dict[int, metadata_store_pb2.ArtifactType] ): execution_type = execution_types_by_id[execution.type_id] params = _parse_parameters(execution) @@ -302,8 +317,8 @@ class KfpTask: class KfpMlmdClient: def __init__( - self, - mlmd_connection_config: metadata_store_pb2.MetadataStoreClientConfig + self, + mlmd_connection_config: metadata_store_pb2.MetadataStoreClientConfig ): self.mlmd_store = metadata_store.MetadataStore(mlmd_connection_config) diff --git a/v2/component/runtime_info.go b/v2/component/runtime_info.go index 572a717773..5ea6002374 100644 --- a/v2/component/runtime_info.go +++ b/v2/component/runtime_info.go @@ -325,6 +325,7 @@ func (r *runtimeInfo) generateExecutorInput(genOutputURI generateOutputURI, outp s3Region := os.Getenv("AWS_REGION") rta.Metadata.Fields["s3_region"] = &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: s3Region}} } + rta.Metadata.Fields["name"] = &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: name}} if err := setRuntimeArtifactType(rta, oa.InstanceSchema, oa.SchemaTitle); err != nil { return nil, fmt.Errorf("failed to generate output RuntimeArtifact: %w", err) diff --git a/v2/component/runtime_info_test.go b/v2/component/runtime_info_test.go index f1ec7d8727..2a2a9658b7 100644 --- a/v2/component/runtime_info_test.go +++ b/v2/component/runtime_info_test.go @@ -234,7 +234,9 @@ func TestExecutorInputGeneration(t *testing.T) { Kind: &pipeline_spec.ArtifactTypeSchema_InstanceSchema{InstanceSchema: "title: kfp.Model\ntype: object\nproperties:\n framework:\n type: string\n framework_version:\n type: string\n"}, }, Uri: "gs://my-bucket/some-prefix/pipeline/task/model", - }}}, + Metadata: &structpb.Struct{ + Fields: map[string]*structpb.Value{"name": {Kind: &structpb.Value_StringValue{StringValue: "model"}}}, + }}}}, "metrics": { Artifacts: []*pipeline_spec.RuntimeArtifact{ { @@ -243,7 +245,9 @@ func TestExecutorInputGeneration(t *testing.T) { Kind: &pipeline_spec.ArtifactTypeSchema_SchemaTitle{SchemaTitle: "kfp.Metrics"}, }, Uri: "gs://my-bucket/some-prefix/pipeline/task/metrics", - }}}, + Metadata: &structpb.Struct{ + Fields: map[string]*structpb.Value{"name": {Kind: &structpb.Value_StringValue{StringValue: "metrics"}}}, + }}}}, }, OutputFile: outputMetadataFilepath, },