188 lines
6.5 KiB
Python
188 lines
6.5 KiB
Python
# Copyright 2021 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from absl import logging
|
|
import importlib
|
|
import sys
|
|
from typing import Callable, Dict, Optional, Union
|
|
from google.protobuf import json_format
|
|
|
|
from kfp.components import _python_op
|
|
from kfp.containers import _gcs_helper
|
|
from kfp.pipeline_spec import pipeline_spec_pb2
|
|
from kfp.dsl import artifact
|
|
|
|
# If path starts with one of those, consider files are in remote filesystem.
|
|
_REMOTE_FS_PREFIX = ['gs://', 'hdfs://', 's3://']
|
|
|
|
# Constant user module name when importing the function from a Python file.
|
|
_USER_MODULE = 'user_module'
|
|
|
|
|
|
def get_parameter_from_output(file_path: str, param_name: str):
|
|
"""Gets a parameter value by its name from output metadata JSON."""
|
|
output = pipeline_spec_pb2.ExecutorOutput()
|
|
json_format.Parse(
|
|
text=_gcs_helper.GCSHelper.read_from_gcs_path(file_path),
|
|
message=output)
|
|
value = output.parameters[param_name]
|
|
return getattr(value, value.WhichOneof('value'))
|
|
|
|
|
|
def get_artifact_from_output(
|
|
file_path: str, output_name: str) -> artifact.Artifact:
|
|
"""Gets an artifact object from output metadata JSON."""
|
|
output = pipeline_spec_pb2.ExecutorOutput()
|
|
json_format.Parse(
|
|
text=_gcs_helper.GCSHelper.read_from_gcs_path(file_path),
|
|
message=output
|
|
)
|
|
# Currently we bear the assumption that each output contains only one artifact
|
|
json_str = json_format.MessageToJson(
|
|
output.artifacts[output_name].artifacts[0], sort_keys=True)
|
|
|
|
# Convert runtime_artifact to Python artifact
|
|
return artifact.Artifact.deserialize(json_str)
|
|
|
|
|
|
def import_func_from_source(source_path: str, fn_name: str) -> Callable:
|
|
"""Imports a function from a Python file.
|
|
|
|
The implementation is borrowed from
|
|
https://github.com/tensorflow/tfx/blob/8f25a4d1cc92dfc8c3a684dfc8b82699513cafb5/tfx/utils/import_utils.py#L50
|
|
|
|
Args:
|
|
source_path: The local path to the Python source file.
|
|
fn_name: The function name, which can be found in the source file.
|
|
|
|
Return: A Python function object.
|
|
|
|
Raises:
|
|
ImportError when failed to load the source file or cannot find the function
|
|
with the given name.
|
|
"""
|
|
if any([source_path.startswith(prefix) for prefix in _REMOTE_FS_PREFIX]):
|
|
raise RuntimeError('Only local source file can be imported. Please make '
|
|
'sure the user code is built into executor container. '
|
|
'Got file path: %s' % source_path)
|
|
try:
|
|
loader = importlib.machinery.SourceFileLoader(
|
|
fullname=_USER_MODULE,
|
|
path=source_path,
|
|
)
|
|
spec = importlib.util.spec_from_loader(
|
|
loader.name, loader, origin=source_path)
|
|
module = importlib.util.module_from_spec(spec)
|
|
sys.modules[loader.name] = module
|
|
loader.exec_module(module)
|
|
except IOError:
|
|
raise ImportError('{} in {} not found in import_func_from_source()'.format(
|
|
fn_name, source_path
|
|
))
|
|
try:
|
|
return getattr(module, fn_name)
|
|
except AttributeError:
|
|
raise ImportError('{} in {} not found in import_func_from_source()'.format(
|
|
fn_name, source_path
|
|
))
|
|
|
|
|
|
def get_output_artifacts(
|
|
fn: Callable, output_uris: Dict[str, str]) -> Dict[str, artifact.Artifact]:
|
|
"""Gets the output artifacts from function signature and provided URIs.
|
|
|
|
Args:
|
|
fn: A user-provided function, whose signature annotates the type of output
|
|
artifacts.
|
|
output_uris: The mapping from output artifact name to its URI.
|
|
|
|
Returns:
|
|
A mapping from output artifact name to Python artifact objects.
|
|
"""
|
|
# Inspect the function signature to determine the set of output artifact.
|
|
spec = _python_op._extract_component_interface(fn)
|
|
|
|
result = {} # Mapping from output name to artifacts.
|
|
for output in spec.outputs:
|
|
if (getattr(output, '_passing_style', None) == _python_op.OutputArtifact):
|
|
# Creates an artifact according to its name
|
|
type_name = getattr(output, 'type', None)
|
|
if not type_name:
|
|
continue
|
|
|
|
try:
|
|
artifact_cls = getattr(
|
|
importlib.import_module(artifact.KFP_ARTIFACT_ONTOLOGY_MODULE),
|
|
type_name)
|
|
|
|
except (AttributeError, ImportError, ValueError):
|
|
logging.warning((
|
|
'Could not load artifact class %s.%s; using fallback deserialization'
|
|
' for the relevant artifact. Please make sure that any artifact '
|
|
'classes can be imported within your container or environment.'),
|
|
artifact.KFP_ARTIFACT_ONTOLOGY_MODULE, type_name)
|
|
artifact_cls = artifact.Artifact
|
|
|
|
if artifact_cls == artifact.Artifact:
|
|
# Provide an empty schema if instantiating an bare-metal artifact.
|
|
art = artifact_cls(instance_schema=artifact.DEFAULT_ARTIFACT_SCHEMA)
|
|
else:
|
|
art = artifact_cls()
|
|
|
|
art.uri = output_uris[output.name]
|
|
result[output.name] = art
|
|
|
|
return result
|
|
|
|
|
|
def _get_pipeline_value(value: Union[int, float, str]) -> Optional[
|
|
pipeline_spec_pb2.Value]:
|
|
"""Converts Python primitive value to pipeline value pb."""
|
|
if value is None:
|
|
return None
|
|
|
|
result = pipeline_spec_pb2.Value()
|
|
if isinstance(value, int):
|
|
result.int_value = value
|
|
elif isinstance(value, float):
|
|
result.double_value = value
|
|
elif isinstance(value, str):
|
|
result.string_value = value
|
|
else:
|
|
raise TypeError('Got unknown type of value: {}'.format(value))
|
|
|
|
return result
|
|
|
|
|
|
def get_python_value(value: pipeline_spec_pb2.Value) -> Union[int, float, str]:
|
|
"""Gets Python value from pipeline value pb message."""
|
|
return getattr(value, value.WhichOneof('value'))
|
|
|
|
|
|
def get_executor_output(
|
|
output_artifacts: Dict[str, artifact.Artifact],
|
|
output_params: Dict[str, Union[int, float, str]]
|
|
) -> pipeline_spec_pb2.ExecutorOutput:
|
|
"""Gets the output metadata message."""
|
|
result = pipeline_spec_pb2.ExecutorOutput()
|
|
|
|
for name, art in output_artifacts.items():
|
|
result.artifacts[name].CopyFrom(pipeline_spec_pb2.ArtifactList(
|
|
artifacts=[art.runtime_artifact]
|
|
))
|
|
|
|
for name, param in output_params.items():
|
|
result.parameters[name].CopyFrom(_get_pipeline_value(param))
|
|
|
|
return result |