165 lines
5.9 KiB
Python
165 lines
5.9 KiB
Python
# Copyright 2020 The Kubeflow Authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Utility function for building Importer Node spec."""
|
|
|
|
from typing import Union, Type
|
|
|
|
from kfp.dsl import _container_op
|
|
from kfp.dsl import _pipeline_param
|
|
from kfp.dsl import dsl_utils
|
|
from kfp.pipeline_spec import pipeline_spec_pb2
|
|
from kfp.v2.components.types import artifact_types, type_utils
|
|
|
|
INPUT_KEY = 'uri'
|
|
OUTPUT_KEY = 'artifact'
|
|
|
|
|
|
def _build_importer_spec(
|
|
artifact_uri: Union[_pipeline_param.PipelineParam, str],
|
|
artifact_type_schema: pipeline_spec_pb2.ArtifactTypeSchema,
|
|
) -> pipeline_spec_pb2.PipelineDeploymentConfig.ImporterSpec:
|
|
"""Builds an importer executor spec.
|
|
|
|
Args:
|
|
artifact_uri: The artifact uri to import from.
|
|
artifact_type_schema: The user specified artifact type schema of the
|
|
artifact to be imported.
|
|
|
|
Returns:
|
|
An importer spec.
|
|
"""
|
|
importer_spec = pipeline_spec_pb2.PipelineDeploymentConfig.ImporterSpec()
|
|
importer_spec.type_schema.CopyFrom(artifact_type_schema)
|
|
|
|
if isinstance(artifact_uri, _pipeline_param.PipelineParam):
|
|
importer_spec.artifact_uri.runtime_parameter = INPUT_KEY
|
|
elif isinstance(artifact_uri, str):
|
|
importer_spec.artifact_uri.constant_value.string_value = artifact_uri
|
|
|
|
return importer_spec
|
|
|
|
|
|
def _build_importer_task_spec(
|
|
importer_base_name: str,
|
|
artifact_uri: Union[_pipeline_param.PipelineParam, str],
|
|
) -> pipeline_spec_pb2.PipelineTaskSpec:
|
|
"""Builds an importer task spec.
|
|
|
|
Args:
|
|
importer_base_name: The base name of the importer node.
|
|
artifact_uri: The artifact uri to import from.
|
|
|
|
Returns:
|
|
An importer node task spec.
|
|
"""
|
|
result = pipeline_spec_pb2.PipelineTaskSpec()
|
|
result.component_ref.name = dsl_utils.sanitize_component_name(
|
|
importer_base_name)
|
|
|
|
if isinstance(artifact_uri, _pipeline_param.PipelineParam):
|
|
param = artifact_uri
|
|
if param.op_name:
|
|
result.inputs.parameters[
|
|
INPUT_KEY].task_output_parameter.producer_task = (
|
|
dsl_utils.sanitize_task_name(param.op_name))
|
|
result.inputs.parameters[
|
|
INPUT_KEY].task_output_parameter.output_parameter_key = param.name
|
|
else:
|
|
result.inputs.parameters[
|
|
INPUT_KEY].component_input_parameter = param.full_name
|
|
elif isinstance(artifact_uri, str):
|
|
result.inputs.parameters[
|
|
INPUT_KEY].runtime_value.constant_value.string_value = artifact_uri
|
|
|
|
return result
|
|
|
|
|
|
def _build_importer_component_spec(
|
|
importer_base_name: str,
|
|
artifact_type_schema: pipeline_spec_pb2.ArtifactTypeSchema,
|
|
) -> pipeline_spec_pb2.ComponentSpec:
|
|
"""Builds an importer component spec.
|
|
|
|
Args:
|
|
importer_base_name: The base name of the importer node.
|
|
artifact_type_schema: The user specified artifact type schema of the
|
|
artifact to be imported.
|
|
|
|
Returns:
|
|
An importer node component spec.
|
|
"""
|
|
result = pipeline_spec_pb2.ComponentSpec()
|
|
result.executor_label = dsl_utils.sanitize_executor_label(
|
|
importer_base_name)
|
|
result.input_definitions.parameters[
|
|
INPUT_KEY].type = pipeline_spec_pb2.PrimitiveType.STRING
|
|
result.output_definitions.artifacts[OUTPUT_KEY].artifact_type.CopyFrom(
|
|
artifact_type_schema)
|
|
|
|
return result
|
|
|
|
|
|
def importer(artifact_uri: Union[_pipeline_param.PipelineParam, str],
|
|
artifact_class: Type[artifact_types.Artifact],
|
|
reimport: bool = False) -> _container_op.ContainerOp:
|
|
"""dsl.importer for importing an existing artifact. Only for v2 pipeline.
|
|
|
|
Args:
|
|
artifact_uri: The artifact uri to import from.
|
|
artifact_type_schema: The user specified artifact type schema of the
|
|
artifact to be imported.
|
|
reimport: Whether to reimport the artifact. Defaults to False.
|
|
|
|
Returns:
|
|
A ContainerOp instance.
|
|
|
|
Raises:
|
|
ValueError if the passed in artifact_uri is neither a PipelineParam nor a
|
|
constant string value.
|
|
"""
|
|
|
|
if isinstance(artifact_uri, _pipeline_param.PipelineParam):
|
|
input_param = artifact_uri
|
|
elif isinstance(artifact_uri, str):
|
|
input_param = _pipeline_param.PipelineParam(
|
|
name='uri', value=artifact_uri, param_type='String')
|
|
else:
|
|
raise ValueError(
|
|
'Importer got unexpected artifact_uri: {} of type: {}.'.format(
|
|
artifact_uri, type(artifact_uri)))
|
|
|
|
old_warn_value = _container_op.ContainerOp._DISABLE_REUSABLE_COMPONENT_WARNING
|
|
_container_op.ContainerOp._DISABLE_REUSABLE_COMPONENT_WARNING = True
|
|
|
|
task = _container_op.ContainerOp(
|
|
name='importer',
|
|
image='importer_image', # TODO: need a v1 implementation of importer.
|
|
file_outputs={
|
|
OUTPUT_KEY:
|
|
"{{{{$.outputs.artifacts['{}'].uri}}}}".format(OUTPUT_KEY)
|
|
},
|
|
)
|
|
_container_op.ContainerOp._DISABLE_REUSABLE_COMPONENT_WARNING = old_warn_value
|
|
|
|
artifact_type_schema = type_utils.get_artifact_type_schema(artifact_class)
|
|
task.importer_spec = _build_importer_spec(
|
|
artifact_uri=artifact_uri, artifact_type_schema=artifact_type_schema)
|
|
task.task_spec = _build_importer_task_spec(
|
|
importer_base_name=task.name, artifact_uri=artifact_uri)
|
|
task.component_spec = _build_importer_component_spec(
|
|
importer_base_name=task.name, artifact_type_schema=artifact_type_schema)
|
|
task.inputs = [input_param]
|
|
|
|
return task
|