pipelines/sdk/python/kfp/containers/entrypoint_utils.py

188 lines
6.5 KiB
Python

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl import logging
import importlib
import sys
from typing import Callable, Dict, Optional, Union
from google.protobuf import json_format
from kfp.components import _python_op
from kfp.containers import _gcs_helper
from kfp.pipeline_spec import pipeline_spec_pb2
from kfp.dsl import artifact
# If path starts with one of those, consider files are in remote filesystem.
_REMOTE_FS_PREFIX = ['gs://', 'hdfs://', 's3://']
# Constant user module name when importing the function from a Python file.
_USER_MODULE = 'user_module'
def get_parameter_from_output(file_path: str, param_name: str):
"""Gets a parameter value by its name from output metadata JSON."""
output = pipeline_spec_pb2.ExecutorOutput()
json_format.Parse(
text=_gcs_helper.GCSHelper.read_from_gcs_path(file_path),
message=output)
value = output.parameters[param_name]
return getattr(value, value.WhichOneof('value'))
def get_artifact_from_output(
file_path: str, output_name: str) -> artifact.Artifact:
"""Gets an artifact object from output metadata JSON."""
output = pipeline_spec_pb2.ExecutorOutput()
json_format.Parse(
text=_gcs_helper.GCSHelper.read_from_gcs_path(file_path),
message=output
)
# Currently we bear the assumption that each output contains only one artifact
json_str = json_format.MessageToJson(
output.artifacts[output_name].artifacts[0], sort_keys=True)
# Convert runtime_artifact to Python artifact
return artifact.Artifact.deserialize(json_str)
def import_func_from_source(source_path: str, fn_name: str) -> Callable:
"""Imports a function from a Python file.
The implementation is borrowed from
https://github.com/tensorflow/tfx/blob/8f25a4d1cc92dfc8c3a684dfc8b82699513cafb5/tfx/utils/import_utils.py#L50
Args:
source_path: The local path to the Python source file.
fn_name: The function name, which can be found in the source file.
Return: A Python function object.
Raises:
ImportError when failed to load the source file or cannot find the function
with the given name.
"""
if any([source_path.startswith(prefix) for prefix in _REMOTE_FS_PREFIX]):
raise RuntimeError('Only local source file can be imported. Please make '
'sure the user code is built into executor container. '
'Got file path: %s' % source_path)
try:
loader = importlib.machinery.SourceFileLoader(
fullname=_USER_MODULE,
path=source_path,
)
spec = importlib.util.spec_from_loader(
loader.name, loader, origin=source_path)
module = importlib.util.module_from_spec(spec)
sys.modules[loader.name] = module
loader.exec_module(module)
except IOError:
raise ImportError('{} in {} not found in import_func_from_source()'.format(
fn_name, source_path
))
try:
return getattr(module, fn_name)
except AttributeError:
raise ImportError('{} in {} not found in import_func_from_source()'.format(
fn_name, source_path
))
def get_output_artifacts(
fn: Callable, output_uris: Dict[str, str]) -> Dict[str, artifact.Artifact]:
"""Gets the output artifacts from function signature and provided URIs.
Args:
fn: A user-provided function, whose signature annotates the type of output
artifacts.
output_uris: The mapping from output artifact name to its URI.
Returns:
A mapping from output artifact name to Python artifact objects.
"""
# Inspect the function signature to determine the set of output artifact.
spec = _python_op._extract_component_interface(fn)
result = {} # Mapping from output name to artifacts.
for output in spec.outputs:
if (getattr(output, '_passing_style', None) == _python_op.OutputArtifact):
# Creates an artifact according to its name
type_name = getattr(output, 'type', None)
if not type_name:
continue
try:
artifact_cls = getattr(
importlib.import_module(artifact.KFP_ARTIFACT_ONTOLOGY_MODULE),
type_name)
except (AttributeError, ImportError, ValueError):
logging.warning((
'Could not load artifact class %s.%s; using fallback deserialization'
' for the relevant artifact. Please make sure that any artifact '
'classes can be imported within your container or environment.'),
artifact.KFP_ARTIFACT_ONTOLOGY_MODULE, type_name)
artifact_cls = artifact.Artifact
if artifact_cls == artifact.Artifact:
# Provide an empty schema if instantiating an bare-metal artifact.
art = artifact_cls(instance_schema=artifact.DEFAULT_ARTIFACT_SCHEMA)
else:
art = artifact_cls()
art.uri = output_uris[output.name]
result[output.name] = art
return result
def _get_pipeline_value(value: Union[int, float, str]) -> Optional[
pipeline_spec_pb2.Value]:
"""Converts Python primitive value to pipeline value pb."""
if value is None:
return None
result = pipeline_spec_pb2.Value()
if isinstance(value, int):
result.int_value = value
elif isinstance(value, float):
result.double_value = value
elif isinstance(value, str):
result.string_value = value
else:
raise TypeError('Got unknown type of value: {}'.format(value))
return result
def get_python_value(value: pipeline_spec_pb2.Value) -> Union[int, float, str]:
"""Gets Python value from pipeline value pb message."""
return getattr(value, value.WhichOneof('value'))
def get_executor_output(
output_artifacts: Dict[str, artifact.Artifact],
output_params: Dict[str, Union[int, float, str]]
) -> pipeline_spec_pb2.ExecutorOutput:
"""Gets the output metadata message."""
result = pipeline_spec_pb2.ExecutorOutput()
for name, art in output_artifacts.items():
result.artifacts[name].CopyFrom(pipeline_spec_pb2.ArtifactList(
artifacts=[art.runtime_artifact]
))
for name, param in output_params.items():
result.parameters[name].CopyFrom(_get_pipeline_value(param))
return result