chore(sdk): Fix a bug where we missed injecting importer node (#4712)
* Fix bug where we missed injecting importer node * moved files * address review comments
This commit is contained in:
parent
92a932e9d9
commit
935a9b5ba5
|
|
@ -25,7 +25,7 @@ import kfp
|
|||
from kfp.compiler._k8s_helper import sanitize_k8s_name
|
||||
from kfp.components import _python_op
|
||||
from kfp.v2 import dsl
|
||||
from kfp.v2.compiler import importer_node
|
||||
from kfp.v2.dsl import importer_node
|
||||
from kfp.v2.dsl import type_utils
|
||||
from kfp.v2.proto import pipeline_spec_pb2
|
||||
|
||||
|
|
@ -114,11 +114,13 @@ class Compiler(object):
|
|||
# Check if need to insert importer node
|
||||
for input_name in task.inputs.artifacts:
|
||||
if not task.inputs.artifacts[input_name].producer_task:
|
||||
artifact_type = type_utils.get_input_artifact_type_schema(
|
||||
type_schema = type_utils.get_input_artifact_type_schema(
|
||||
input_name, component_spec.inputs)
|
||||
|
||||
importer_task, importer_spec = importer_node.build_importer_spec(
|
||||
task, input_name, artifact_type)
|
||||
importer_task = importer_node.build_importer_task_spec(
|
||||
dependent_task=task,
|
||||
input_name=input_name,
|
||||
input_type_schema=type_schema)
|
||||
importer_tasks.append(importer_task)
|
||||
|
||||
task.inputs.artifacts[
|
||||
|
|
@ -126,6 +128,8 @@ class Compiler(object):
|
|||
task.inputs.artifacts[
|
||||
input_name].output_artifact_key = importer_node.OUTPUT_KEY
|
||||
|
||||
# Retrieve the pre-built importer spec
|
||||
importer_spec = op.importer_spec[input_name]
|
||||
deployment_config.executors[
|
||||
importer_task.executor_label].importer.CopyFrom(importer_spec)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,81 +0,0 @@
|
|||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import unittest
|
||||
from kfp.v2.compiler import importer_node
|
||||
from kfp.v2.proto import pipeline_spec_pb2 as pb
|
||||
from google.protobuf import json_format
|
||||
|
||||
|
||||
class ImporterNodeTest(unittest.TestCase):
|
||||
|
||||
def test_build_importer_spec(self):
|
||||
|
||||
dependent_task = {
|
||||
'taskInfo': {
|
||||
'name': 'task1'
|
||||
},
|
||||
'inputs': {
|
||||
'artifacts': {
|
||||
'input1': {
|
||||
'producerTask': '',
|
||||
'outputArtifactKey': 'output1'
|
||||
}
|
||||
}
|
||||
},
|
||||
'executorLabel': 'task1_input1_importer'
|
||||
}
|
||||
dependent_task_spec = pb.PipelineTaskSpec()
|
||||
json_format.Parse(json.dumps(dependent_task), dependent_task_spec)
|
||||
|
||||
expected_task = {
|
||||
'taskInfo': {
|
||||
'name': 'task1_input1_importer'
|
||||
},
|
||||
'outputs': {
|
||||
'artifacts': {
|
||||
'result': {
|
||||
'artifactType': {
|
||||
'instanceSchema': 'title: Artifact'
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
'executorLabel': 'task1_input1_importer'
|
||||
}
|
||||
expected_task_spec = pb.PipelineTaskSpec()
|
||||
json_format.Parse(json.dumps(expected_task), expected_task_spec)
|
||||
|
||||
expected_importer = {
|
||||
'artifactUri': {
|
||||
'runtimeParameter': 'output1'
|
||||
},
|
||||
'typeSchema': {
|
||||
'instanceSchema': 'title: Artifact'
|
||||
}
|
||||
}
|
||||
expected_importer_spec = pb.PipelineDeploymentConfig.ImporterSpec()
|
||||
json_format.Parse(json.dumps(expected_importer), expected_importer_spec)
|
||||
|
||||
task_spec, importer_spec = importer_node.build_importer_spec(
|
||||
dependent_task_spec, 'input1', 'title: Artifact')
|
||||
|
||||
self.maxDiff = None
|
||||
self.assertEqual(expected_task_spec, task_spec)
|
||||
self.assertEqual(expected_importer_spec, importer_spec)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -23,10 +23,6 @@
|
|||
}
|
||||
},
|
||||
"artifacts": {
|
||||
"input_6": {
|
||||
"producerTask": "upstream_input_6_importer",
|
||||
"outputArtifactKey": "result"
|
||||
},
|
||||
"input_7": {
|
||||
"producerTask": "upstream_input_7_importer",
|
||||
"outputArtifactKey": "result"
|
||||
|
|
@ -35,16 +31,20 @@
|
|||
"producerTask": "upstream_input_5_importer",
|
||||
"outputArtifactKey": "result"
|
||||
},
|
||||
"input_3": {
|
||||
"producerTask": "upstream_input_3_importer",
|
||||
"outputArtifactKey": "result"
|
||||
},
|
||||
"input_4": {
|
||||
"producerTask": "upstream_input_4_importer",
|
||||
"outputArtifactKey": "result"
|
||||
},
|
||||
"input_8": {
|
||||
"producerTask": "upstream_input_8_importer",
|
||||
"input_6": {
|
||||
"producerTask": "upstream_input_6_importer",
|
||||
"outputArtifactKey": "result"
|
||||
},
|
||||
"input_3": {
|
||||
"producerTask": "upstream_input_3_importer",
|
||||
"input_8": {
|
||||
"producerTask": "upstream_input_8_importer",
|
||||
"outputArtifactKey": "result"
|
||||
}
|
||||
}
|
||||
|
|
@ -84,33 +84,18 @@
|
|||
}
|
||||
},
|
||||
"artifacts": {
|
||||
"input_c": {
|
||||
"producerTask": "upstream",
|
||||
"outputArtifactKey": "output_3"
|
||||
},
|
||||
"input_b": {
|
||||
"producerTask": "upstream",
|
||||
"outputArtifactKey": "output_2"
|
||||
},
|
||||
"input_c": {
|
||||
"producerTask": "upstream",
|
||||
"outputArtifactKey": "output_3"
|
||||
}
|
||||
}
|
||||
},
|
||||
"executorLabel": "downstream"
|
||||
},
|
||||
{
|
||||
"taskInfo": {
|
||||
"name": "upstream_input_6_importer"
|
||||
},
|
||||
"outputs": {
|
||||
"artifacts": {
|
||||
"result": {
|
||||
"artifactType": {
|
||||
"instanceSchema": "title: kfp.Artifact\ntype: object\nproperties:\n"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"executorLabel": "upstream_input_6_importer"
|
||||
},
|
||||
{
|
||||
"taskInfo": {
|
||||
"name": "upstream_input_7_importer"
|
||||
|
|
@ -141,6 +126,21 @@
|
|||
},
|
||||
"executorLabel": "upstream_input_5_importer"
|
||||
},
|
||||
{
|
||||
"taskInfo": {
|
||||
"name": "upstream_input_3_importer"
|
||||
},
|
||||
"outputs": {
|
||||
"artifacts": {
|
||||
"result": {
|
||||
"artifactType": {
|
||||
"instanceSchema": "title: kfp.Artifact\ntype: object\nproperties:\n"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"executorLabel": "upstream_input_3_importer"
|
||||
},
|
||||
{
|
||||
"taskInfo": {
|
||||
"name": "upstream_input_4_importer"
|
||||
|
|
@ -156,6 +156,21 @@
|
|||
},
|
||||
"executorLabel": "upstream_input_4_importer"
|
||||
},
|
||||
{
|
||||
"taskInfo": {
|
||||
"name": "upstream_input_6_importer"
|
||||
},
|
||||
"outputs": {
|
||||
"artifacts": {
|
||||
"result": {
|
||||
"artifactType": {
|
||||
"instanceSchema": "title: kfp.Artifact\ntype: object\nproperties:\n"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"executorLabel": "upstream_input_6_importer"
|
||||
},
|
||||
{
|
||||
"taskInfo": {
|
||||
"name": "upstream_input_8_importer"
|
||||
|
|
@ -170,50 +185,15 @@
|
|||
}
|
||||
},
|
||||
"executorLabel": "upstream_input_8_importer"
|
||||
},
|
||||
{
|
||||
"taskInfo": {
|
||||
"name": "upstream_input_3_importer"
|
||||
},
|
||||
"outputs": {
|
||||
"artifacts": {
|
||||
"result": {
|
||||
"artifactType": {
|
||||
"instanceSchema": "title: kfp.Artifact\ntype: object\nproperties:\n"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"executorLabel": "upstream_input_3_importer"
|
||||
}
|
||||
],
|
||||
"deploymentConfig": {
|
||||
"@type": "type.googleapis.com/ml_pipelines.PipelineDeploymentConfig",
|
||||
"executors": {
|
||||
"upstream_input_7_importer": {
|
||||
"upstream_input_6_importer": {
|
||||
"importer": {
|
||||
"artifactUri": {
|
||||
"runtimeParameter": "input7"
|
||||
},
|
||||
"typeSchema": {
|
||||
"instanceSchema": "title: kfp.Artifact\ntype: object\nproperties:\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
"upstream_input_5_importer": {
|
||||
"importer": {
|
||||
"artifactUri": {
|
||||
"runtimeParameter": "input5"
|
||||
},
|
||||
"typeSchema": {
|
||||
"instanceSchema": "title: kfp.Metrics\ntype: object\nproperties:\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
"upstream_input_8_importer": {
|
||||
"importer": {
|
||||
"artifactUri": {
|
||||
"runtimeParameter": "input8"
|
||||
"runtimeParameter": "input6"
|
||||
},
|
||||
"typeSchema": {
|
||||
"instanceSchema": "title: kfp.Artifact\ntype: object\nproperties:\n"
|
||||
|
|
@ -238,16 +218,6 @@
|
|||
]
|
||||
}
|
||||
},
|
||||
"upstream_input_4_importer": {
|
||||
"importer": {
|
||||
"artifactUri": {
|
||||
"runtimeParameter": "input4"
|
||||
},
|
||||
"typeSchema": {
|
||||
"instanceSchema": "title: kfp.Artifact\ntype: object\nproperties:\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
"upstream_input_3_importer": {
|
||||
"importer": {
|
||||
"artifactUri": {
|
||||
|
|
@ -268,10 +238,42 @@
|
|||
]
|
||||
}
|
||||
},
|
||||
"upstream_input_6_importer": {
|
||||
"upstream_input_4_importer": {
|
||||
"importer": {
|
||||
"artifactUri": {
|
||||
"runtimeParameter": "input6"
|
||||
"runtimeParameter": "input4"
|
||||
},
|
||||
"typeSchema": {
|
||||
"instanceSchema": "title: kfp.Artifact\ntype: object\nproperties:\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
"upstream_input_8_importer": {
|
||||
"importer": {
|
||||
"artifactUri": {
|
||||
"runtimeParameter": "input8"
|
||||
},
|
||||
"typeSchema": {
|
||||
"instanceSchema": "title: kfp.Artifact\ntype: object\nproperties:\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
"upstream_input_5_importer": {
|
||||
"importer": {
|
||||
"artifactUri": {
|
||||
"constantValue": {
|
||||
"stringValue": "gs://bucket/metrics"
|
||||
}
|
||||
},
|
||||
"typeSchema": {
|
||||
"instanceSchema": "title: kfp.Metrics\ntype: object\nproperties:\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
"upstream_input_7_importer": {
|
||||
"importer": {
|
||||
"artifactUri": {
|
||||
"runtimeParameter": "input7"
|
||||
},
|
||||
"typeSchema": {
|
||||
"instanceSchema": "title: kfp.Artifact\ntype: object\nproperties:\n"
|
||||
|
|
@ -280,13 +282,13 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"sdkVersion": "kfp-1.0.1-dev20201029",
|
||||
"sdkVersion": "kfp-1.1.0-alpha.1",
|
||||
"schemaVersion": "v2alpha1",
|
||||
"runtimeParameters": {
|
||||
"input6": {
|
||||
"input8": {
|
||||
"type": "STRING",
|
||||
"defaultValue": {
|
||||
"stringValue": "gs://bucket/dataset"
|
||||
"stringValue": "gs://path2"
|
||||
}
|
||||
},
|
||||
"input7": {
|
||||
|
|
@ -295,16 +297,10 @@
|
|||
"stringValue": "arbitrary value"
|
||||
}
|
||||
},
|
||||
"input5": {
|
||||
"input6": {
|
||||
"type": "STRING",
|
||||
"defaultValue": {
|
||||
"stringValue": "gs://bucket/metrics"
|
||||
}
|
||||
},
|
||||
"input8": {
|
||||
"type": "STRING",
|
||||
"defaultValue": {
|
||||
"stringValue": "gs://path2"
|
||||
"stringValue": "gs://bucket/dataset"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -70,7 +70,6 @@ implementation:
|
|||
def my_pipeline(input1,
|
||||
input3,
|
||||
input4,
|
||||
input5='gs://bucket/metrics',
|
||||
input6='gs://bucket/dataset',
|
||||
input7='arbitrary value',
|
||||
input8='gs://path2'):
|
||||
|
|
@ -79,7 +78,7 @@ def my_pipeline(input1,
|
|||
input_2=3.1415926,
|
||||
input_3=input3,
|
||||
input_4=input4,
|
||||
input_5=input5,
|
||||
input_5='gs://bucket/metrics',
|
||||
input_6=input6,
|
||||
input_7=input7,
|
||||
input_8=input8)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from kfp.components._naming import _sanitize_python_function_name
|
|||
from kfp.components._naming import generate_unique_name_conversion_table
|
||||
from kfp.dsl import types
|
||||
from kfp.v2.dsl import container_op
|
||||
from kfp.v2.dsl import importer_node
|
||||
from kfp.v2.dsl import type_utils
|
||||
from kfp.v2.proto import pipeline_spec_pb2
|
||||
|
||||
|
|
@ -50,6 +51,9 @@ def create_container_op_from_component_and_arguments(
|
|||
# might need to append suffix to exuector_label to ensure its uniqueness?
|
||||
pipeline_task_spec.executor_label = component_spec.name
|
||||
|
||||
# Keep track of auto-injected importer spec.
|
||||
importer_spec = {}
|
||||
|
||||
# Check types of the reference arguments and serialize PipelineParams
|
||||
arguments = arguments.copy()
|
||||
for input_name, argument_value in arguments.items():
|
||||
|
|
@ -75,16 +79,34 @@ def create_container_op_from_component_and_arguments(
|
|||
pipeline_task_spec.inputs.parameters[
|
||||
input_name].runtime_value.runtime_parameter = argument_value.name
|
||||
else:
|
||||
# argument_value.op_name could be none, in which case an importer node
|
||||
# will be inserted later. Use output_artifact_key to preserve the name
|
||||
# of pipeline parameter which is needed by importer.
|
||||
pipeline_task_spec.inputs.artifacts[input_name].producer_task = (
|
||||
argument_value.op_name or '')
|
||||
pipeline_task_spec.inputs.artifacts[input_name].output_artifact_key = (
|
||||
argument_value.name)
|
||||
if argument_value.op_name:
|
||||
pipeline_task_spec.inputs.artifacts[input_name].producer_task = (
|
||||
argument_value.op_name)
|
||||
pipeline_task_spec.inputs.artifacts[
|
||||
input_name].output_artifact_key = (
|
||||
argument_value.name)
|
||||
else:
|
||||
# argument_value.op_name could be none, in which case an importer node
|
||||
# will be inserted later.
|
||||
pipeline_task_spec.inputs.artifacts[input_name].producer_task = ''
|
||||
type_schema = type_utils.get_input_artifact_type_schema(
|
||||
input_name, component_spec.inputs)
|
||||
importer_spec[input_name] = importer_node.build_importer_spec(
|
||||
input_type_schema=type_schema,
|
||||
pipeline_param_name=argument_value.name)
|
||||
elif isinstance(argument_value, str):
|
||||
pipeline_task_spec.inputs.parameters[
|
||||
input_name].runtime_value.constant_value.string_value = argument_value
|
||||
input_type = component_spec._inputs_dict[input_name].type
|
||||
if type_utils.is_parameter_type(input_type):
|
||||
pipeline_task_spec.inputs.parameters[
|
||||
input_name].runtime_value.constant_value.string_value = (
|
||||
argument_value)
|
||||
else:
|
||||
# An importer node with constant value artifact_uri will be inserted.
|
||||
pipeline_task_spec.inputs.artifacts[input_name].producer_task = ''
|
||||
type_schema = type_utils.get_input_artifact_type_schema(
|
||||
input_name, component_spec.inputs)
|
||||
importer_spec[input_name] = importer_node.build_importer_spec(
|
||||
input_type_schema=type_schema, constant_value=argument_value)
|
||||
elif isinstance(argument_value, int):
|
||||
pipeline_task_spec.inputs.parameters[
|
||||
input_name].runtime_value.constant_value.int_value = argument_value
|
||||
|
|
@ -186,6 +208,7 @@ def create_container_op_from_component_and_arguments(
|
|||
)
|
||||
|
||||
task.task_spec = pipeline_task_spec
|
||||
task.importer_spec = importer_spec
|
||||
task.container_spec = pipeline_container_spec
|
||||
dsl.ContainerOp._DISABLE_REUSABLE_COMPONENT_WARNING = old_warn_value
|
||||
|
||||
|
|
|
|||
|
|
@ -13,31 +13,60 @@
|
|||
# limitations under the License.
|
||||
"""Utility function for building Importer Node spec."""
|
||||
|
||||
from typing import Tuple
|
||||
from kfp.v2.proto import pipeline_spec_pb2
|
||||
|
||||
OUTPUT_KEY = 'result'
|
||||
|
||||
|
||||
def build_importer_spec(
|
||||
input_type_schema: str,
|
||||
pipeline_param_name: str = None,
|
||||
constant_value: str = None
|
||||
) -> pipeline_spec_pb2.PipelineDeploymentConfig.ImporterSpec:
|
||||
"""Builds an importer executor spec.
|
||||
|
||||
Args:
|
||||
input_type_schema: The type of the input artifact.
|
||||
pipeline_param_name: The name of the pipeline parameter if the importer gets
|
||||
its artifacts_uri via a pipeline parameter. This argument is mutually
|
||||
exclusive with constant_value.
|
||||
constant_value: The value of artifact_uri in case a contant value is passed
|
||||
directly into the compoent op. This argument is mutually exclusive with
|
||||
pipeline_param_name.
|
||||
|
||||
Returns:
|
||||
An importer spec.
|
||||
"""
|
||||
assert (
|
||||
bool(pipeline_param_name) != bool(constant_value),
|
||||
'importer spec should be built using either pipeline_param_name or'
|
||||
'constant_value.'
|
||||
)
|
||||
importer_spec = pipeline_spec_pb2.PipelineDeploymentConfig.ImporterSpec()
|
||||
importer_spec.type_schema.instance_schema = input_type_schema
|
||||
if pipeline_param_name:
|
||||
importer_spec.artifact_uri.runtime_parameter = pipeline_param_name
|
||||
elif constant_value:
|
||||
importer_spec.artifact_uri.constant_value.string_value = constant_value
|
||||
return importer_spec
|
||||
|
||||
|
||||
def build_importer_task_spec(
|
||||
dependent_task: pipeline_spec_pb2.PipelineTaskSpec,
|
||||
input_name: str,
|
||||
input_type_schema: str,
|
||||
) -> Tuple[pipeline_spec_pb2.PipelineTaskSpec,
|
||||
pipeline_spec_pb2.PipelineDeploymentConfig.ImporterSpec]:
|
||||
"""Build importer task spec and importer executor spec.
|
||||
) -> pipeline_spec_pb2.PipelineTaskSpec:
|
||||
"""Builds an importer task spec.
|
||||
|
||||
Args:
|
||||
dependent_task: the task requires importer node.
|
||||
input_name: the name of the input artifact needs to be imported.
|
||||
input_type_schema: the type of the input artifact.
|
||||
dependent_task: The task requires importer node.
|
||||
input_name: The name of the input artifact needs to be imported.
|
||||
input_type_schema: The type of the input artifact.
|
||||
|
||||
Returns:
|
||||
a tuple of task_spec and importer_spec
|
||||
An importer node task spec.
|
||||
"""
|
||||
dependent_task_name = dependent_task.task_info.name
|
||||
pipeline_parameter_name = (
|
||||
dependent_task.inputs.artifacts[input_name].output_artifact_key)
|
||||
|
||||
task_spec = pipeline_spec_pb2.PipelineTaskSpec()
|
||||
task_spec.task_info.name = '{}_{}_importer'.format(dependent_task_name,
|
||||
|
|
@ -46,8 +75,4 @@ def build_importer_spec(
|
|||
input_type_schema)
|
||||
task_spec.executor_label = task_spec.task_info.name
|
||||
|
||||
importer_spec = pipeline_spec_pb2.PipelineDeploymentConfig.ImporterSpec()
|
||||
importer_spec.artifact_uri.runtime_parameter = pipeline_parameter_name
|
||||
importer_spec.type_schema.instance_schema = input_type_schema
|
||||
|
||||
return task_spec, importer_spec
|
||||
return task_spec
|
||||
|
|
@ -0,0 +1,122 @@
|
|||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from kfp.v2.dsl import importer_node
|
||||
from kfp.v2.proto import pipeline_spec_pb2 as pb
|
||||
from google.protobuf import json_format
|
||||
|
||||
|
||||
class ImporterNodeTest(unittest.TestCase):
|
||||
|
||||
def test_build_importer_task(self):
|
||||
dependent_task = {
|
||||
'taskInfo': {
|
||||
'name': 'task1'
|
||||
},
|
||||
'inputs': {
|
||||
'artifacts': {
|
||||
'input1': {
|
||||
'producerTask': '',
|
||||
}
|
||||
}
|
||||
},
|
||||
'executorLabel': 'task1_input1_importer'
|
||||
}
|
||||
dependent_task_spec = pb.PipelineTaskSpec()
|
||||
json_format.ParseDict(dependent_task, dependent_task_spec)
|
||||
|
||||
expected_task = {
|
||||
'taskInfo': {
|
||||
'name': 'task1_input1_importer'
|
||||
},
|
||||
'outputs': {
|
||||
'artifacts': {
|
||||
'result': {
|
||||
'artifactType': {
|
||||
'instanceSchema': 'title: kfp.Artifact'
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
'executorLabel': 'task1_input1_importer'
|
||||
}
|
||||
expected_task_spec = pb.PipelineTaskSpec()
|
||||
json_format.ParseDict(expected_task, expected_task_spec)
|
||||
|
||||
task_spec = importer_node.build_importer_task_spec(
|
||||
dependent_task=dependent_task_spec,
|
||||
input_name='input1',
|
||||
input_type_schema='title: kfp.Artifact')
|
||||
|
||||
self.maxDiff = None
|
||||
self.assertEqual(expected_task_spec, task_spec)
|
||||
|
||||
def test_build_importer_spec_from_pipeline_param(self):
|
||||
expected_importer = {
|
||||
'artifactUri': {
|
||||
'runtimeParameter': 'param1'
|
||||
},
|
||||
'typeSchema': {
|
||||
'instanceSchema': 'title: kfp.Artifact'
|
||||
}
|
||||
}
|
||||
expected_importer_spec = pb.PipelineDeploymentConfig.ImporterSpec()
|
||||
json_format.ParseDict(expected_importer, expected_importer_spec)
|
||||
importer_spec = importer_node.build_importer_spec(
|
||||
input_type_schema='title: kfp.Artifact', pipeline_param_name='param1')
|
||||
|
||||
self.maxDiff = None
|
||||
self.assertEqual(expected_importer_spec, importer_spec)
|
||||
|
||||
def test_build_importer_spec_from_constant_value(self):
|
||||
expected_importer = {
|
||||
'artifactUri': {
|
||||
'constantValue': {
|
||||
'stringValue': 'some_uri'
|
||||
}
|
||||
},
|
||||
'typeSchema': {
|
||||
'instanceSchema': 'title: kfp.Artifact'
|
||||
}
|
||||
}
|
||||
expected_importer_spec = pb.PipelineDeploymentConfig.ImporterSpec()
|
||||
json_format.ParseDict(expected_importer, expected_importer_spec)
|
||||
importer_spec = importer_node.build_importer_spec(
|
||||
input_type_schema='title: kfp.Artifact', constant_value='some_uri')
|
||||
|
||||
self.maxDiff = None
|
||||
self.assertEqual(expected_importer_spec, importer_spec)
|
||||
|
||||
def test_build_importer_spec_with_invalid_inputs_should_fail(self):
|
||||
with self.assertRaises(AssertionError) as cm:
|
||||
importer_node.build_importer_spec(
|
||||
input_type_schema='title: kfp.Artifact',
|
||||
pipeline_param_name='param1',
|
||||
constant_value='some_uri')
|
||||
self.assertEqual(
|
||||
'importer spec should be built using either pipeline_param_name or'
|
||||
'constant_value.',
|
||||
str(cm))
|
||||
|
||||
with self.assertRaises(AssertionError) as cm:
|
||||
importer_node.build_importer_spec(input_type_schema='title: kfp.Artifact')
|
||||
self.assertEqual(
|
||||
'importer spec should be built using either pipeline_param_name or'
|
||||
'constant_value.',
|
||||
str(cm))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue