# Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging import importlib import sys from typing import Callable, Dict, Optional, Union from google.protobuf import json_format from kfp.components import _python_op from kfp.containers import _gcs_helper from kfp.pipeline_spec import pipeline_spec_pb2 from kfp.dsl import artifact # If path starts with one of those, consider files are in remote filesystem. _REMOTE_FS_PREFIX = ['gs://', 'hdfs://', 's3://'] # Constant user module name when importing the function from a Python file. _USER_MODULE = 'user_module' def get_parameter_from_output(file_path: str, param_name: str): """Gets a parameter value by its name from output metadata JSON.""" output = pipeline_spec_pb2.ExecutorOutput() json_format.Parse( text=_gcs_helper.GCSHelper.read_from_gcs_path(file_path), message=output) value = output.parameters[param_name] return getattr(value, value.WhichOneof('value')) def get_artifact_from_output( file_path: str, output_name: str) -> artifact.Artifact: """Gets an artifact object from output metadata JSON.""" output = pipeline_spec_pb2.ExecutorOutput() json_format.Parse( text=_gcs_helper.GCSHelper.read_from_gcs_path(file_path), message=output ) # Currently we bear the assumption that each output contains only one artifact json_str = json_format.MessageToJson( output.artifacts[output_name].artifacts[0], sort_keys=True) # Convert runtime_artifact to Python artifact return artifact.Artifact.deserialize(json_str) def import_func_from_source(source_path: str, fn_name: str) -> Callable: """Imports a function from a Python file. The implementation is borrowed from https://github.com/tensorflow/tfx/blob/8f25a4d1cc92dfc8c3a684dfc8b82699513cafb5/tfx/utils/import_utils.py#L50 Args: source_path: The local path to the Python source file. fn_name: The function name, which can be found in the source file. Return: A Python function object. Raises: ImportError when failed to load the source file or cannot find the function with the given name. """ if any([source_path.startswith(prefix) for prefix in _REMOTE_FS_PREFIX]): raise RuntimeError('Only local source file can be imported. Please make ' 'sure the user code is built into executor container. ' 'Got file path: %s' % source_path) try: loader = importlib.machinery.SourceFileLoader( fullname=_USER_MODULE, path=source_path, ) spec = importlib.util.spec_from_loader( loader.name, loader, origin=source_path) module = importlib.util.module_from_spec(spec) sys.modules[loader.name] = module loader.exec_module(module) except IOError: raise ImportError('{} in {} not found in import_func_from_source()'.format( fn_name, source_path )) try: return getattr(module, fn_name) except AttributeError: raise ImportError('{} in {} not found in import_func_from_source()'.format( fn_name, source_path )) def get_output_artifacts( fn: Callable, output_uris: Dict[str, str]) -> Dict[str, artifact.Artifact]: """Gets the output artifacts from function signature and provided URIs. Args: fn: A user-provided function, whose signature annotates the type of output artifacts. output_uris: The mapping from output artifact name to its URI. Returns: A mapping from output artifact name to Python artifact objects. """ # Inspect the function signature to determine the set of output artifact. spec = _python_op._extract_component_interface(fn) result = {} # Mapping from output name to artifacts. for output in spec.outputs: if (getattr(output, '_passing_style', None) == _python_op.OutputArtifact): # Creates an artifact according to its name type_name = getattr(output, 'type', None) if not type_name: continue try: artifact_cls = getattr( importlib.import_module(artifact.KFP_ARTIFACT_ONTOLOGY_MODULE), type_name) except (AttributeError, ImportError, ValueError): logging.warning(( 'Could not load artifact class %s.%s; using fallback deserialization' ' for the relevant artifact. Please make sure that any artifact ' 'classes can be imported within your container or environment.'), artifact.KFP_ARTIFACT_ONTOLOGY_MODULE, type_name) artifact_cls = artifact.Artifact if artifact_cls == artifact.Artifact: # Provide an empty schema if instantiating an bare-metal artifact. art = artifact_cls(instance_schema=artifact.DEFAULT_ARTIFACT_SCHEMA) else: art = artifact_cls() art.uri = output_uris[output.name] result[output.name] = art return result def _get_pipeline_value(value: Union[int, float, str]) -> Optional[ pipeline_spec_pb2.Value]: """Converts Python primitive value to pipeline value pb.""" if value is None: return None result = pipeline_spec_pb2.Value() if isinstance(value, int): result.int_value = value elif isinstance(value, float): result.double_value = value elif isinstance(value, str): result.string_value = value else: raise TypeError('Got unknown type of value: {}'.format(value)) return result def get_python_value(value: pipeline_spec_pb2.Value) -> Union[int, float, str]: """Gets Python value from pipeline value pb message.""" return getattr(value, value.WhichOneof('value')) def get_executor_output( output_artifacts: Dict[str, artifact.Artifact], output_params: Dict[str, Union[int, float, str]] ) -> pipeline_spec_pb2.ExecutorOutput: """Gets the output metadata message.""" result = pipeline_spec_pb2.ExecutorOutput() for name, art in output_artifacts.items(): result.artifacts[name].CopyFrom(pipeline_spec_pb2.ArtifactList( artifacts=[art.runtime_artifact] )) for name, param in output_params.items(): result.parameters[name].CopyFrom(_get_pipeline_value(param)) return result