feat(sdk): Container entrypoint used for new styled KFP component authoring (#4978)
* skeleton * add entrypoint utils to parse param * wip: artifact parsing * add input param artifacts passing and clean unused code * wip * add output artifact inspection * add parameter output * finish entrypoint implementation * add entrypoint_utils_test.py * add entrypoint test * add entrypoint test * get rid of tf * fix test * fix file location * fix tests * fix tests * resolving comments * Partially rollback * resolve comments in entrypoint.py * resolve comments
This commit is contained in:
parent
d629397654
commit
279694ec6d
|
|
@ -17,9 +17,11 @@ __all__ = [
|
|||
'func_to_container_op',
|
||||
'func_to_component_text',
|
||||
'default_base_image_or_builder',
|
||||
'InputArtifact',
|
||||
'InputPath',
|
||||
'InputTextFile',
|
||||
'InputBinaryFile',
|
||||
'OutputArtifact',
|
||||
'OutputPath',
|
||||
'OutputTextFile',
|
||||
'OutputBinaryFile',
|
||||
|
|
@ -63,6 +65,16 @@ class InputBinaryFile:
|
|||
self.type = type
|
||||
|
||||
|
||||
class InputArtifact:
|
||||
"""InputArtifact function parameter annotation.
|
||||
|
||||
When creating component from function. InputArtifact indicates that the
|
||||
associated input parameter should be tracked as an MLMD artifact.
|
||||
"""
|
||||
def __init__(self, type: Optional[str] = None):
|
||||
self.type = type
|
||||
|
||||
|
||||
class OutputPath:
|
||||
'''When creating component from function, :class:`.OutputPath` should be used as function parameter annotation to tell the system that the function wants to output data by writing it into a file with the given path instead of returning the data from the function.'''
|
||||
def __init__(self, type=None):
|
||||
|
|
@ -81,6 +93,17 @@ class OutputBinaryFile:
|
|||
self.type = type
|
||||
|
||||
|
||||
class OutputArtifact:
|
||||
"""OutputArtifact function parameter annotation.
|
||||
|
||||
When creating component from function. OutputArtifact indicates that the
|
||||
associated input parameter should be treated as an MLMD artifact, whose
|
||||
underlying content, together with metadata will be updated by this component
|
||||
"""
|
||||
def __init__(self, type: Optional[str] = None):
|
||||
self.type = type
|
||||
|
||||
|
||||
def _make_parent_dirs_and_return_path(file_path: str):
|
||||
import os
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
|
@ -300,7 +323,10 @@ def _extract_component_interface(func: Callable) -> ComponentSpec:
|
|||
parameter_annotation = parameter.annotation
|
||||
passing_style = None
|
||||
io_name = parameter.name
|
||||
if isinstance(parameter_annotation, (InputPath, InputTextFile, InputBinaryFile, OutputPath, OutputTextFile, OutputBinaryFile)):
|
||||
if isinstance(
|
||||
parameter_annotation,
|
||||
(InputArtifact, InputPath, InputTextFile, InputBinaryFile,
|
||||
OutputArtifact, OutputPath, OutputTextFile, OutputBinaryFile)):
|
||||
passing_style = type(parameter_annotation)
|
||||
parameter_annotation = parameter_annotation.type
|
||||
if parameter.default is not inspect.Parameter.empty and not (passing_style == InputPath and parameter.default is None):
|
||||
|
|
@ -317,7 +343,9 @@ def _extract_component_interface(func: Callable) -> ComponentSpec:
|
|||
type_struct = annotation_to_type_struct(parameter_annotation)
|
||||
#TODO: Humanize the input/output names
|
||||
|
||||
if isinstance(parameter.annotation, (OutputPath, OutputTextFile, OutputBinaryFile)):
|
||||
if isinstance(
|
||||
parameter.annotation,
|
||||
(OutputArtifact, OutputPath, OutputTextFile, OutputBinaryFile)):
|
||||
io_name = _make_name_unique_by_adding_index(io_name, output_names, '_')
|
||||
output_names.add(io_name)
|
||||
output_spec = OutputSpec(
|
||||
|
|
|
|||
|
|
@ -12,7 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import PurePath
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
|
||||
class GCSHelper(object):
|
||||
""" GCSHelper manages the connection with the GCS storage """
|
||||
|
|
@ -26,7 +28,7 @@ class GCSHelper(object):
|
|||
gcs_blob: gcs blob object(https://github.com/googleapis/google-cloud-python/blob/5c9bb42cb3c9250131cfeef6e0bafe8f4b7c139f/storage/google/cloud/storage/blob.py#L105)
|
||||
"""
|
||||
from google.cloud import storage
|
||||
pure_path = PurePath(gcs_path)
|
||||
pure_path = pathlib.PurePath(gcs_path)
|
||||
gcs_bucket = pure_path.parts[1]
|
||||
gcs_blob = '/'.join(pure_path.parts[2:])
|
||||
client = storage.Client()
|
||||
|
|
@ -44,6 +46,28 @@ class GCSHelper(object):
|
|||
blob = GCSHelper.get_blob_from_gcs_uri(gcs_path)
|
||||
blob.upload_from_filename(local_path)
|
||||
|
||||
@staticmethod
|
||||
def write_to_gcs_path(path: str, content: str) -> None:
|
||||
"""Writes serialized content to a GCS location.
|
||||
|
||||
Args:
|
||||
path: GCS path to write to.
|
||||
content: The content to be written.
|
||||
"""
|
||||
fd, temp_path = tempfile.mkstemp()
|
||||
try:
|
||||
with os.fdopen(fd, 'w') as tmp:
|
||||
tmp.write(content)
|
||||
|
||||
if not GCSHelper.get_blob_from_gcs_uri(path):
|
||||
pure_path = pathlib.PurePath(path)
|
||||
gcs_bucket = pure_path.parts[1]
|
||||
GCSHelper.create_gcs_bucket_if_not_exist(gcs_bucket)
|
||||
|
||||
GCSHelper.upload_gcs_file(temp_path, path)
|
||||
finally:
|
||||
os.remove(temp_path)
|
||||
|
||||
@staticmethod
|
||||
def remove_gcs_blob(gcs_path):
|
||||
"""
|
||||
|
|
@ -63,6 +87,18 @@ class GCSHelper(object):
|
|||
blob = GCSHelper.get_blob_from_gcs_uri(gcs_path)
|
||||
blob.download_to_filename(local_path)
|
||||
|
||||
@staticmethod
|
||||
def read_from_gcs_path(gcs_path: str) -> str:
|
||||
"""Reads the content of a file hosted on GCS."""
|
||||
fd, temp_path = tempfile.mkstemp()
|
||||
try:
|
||||
GCSHelper.download_gcs_blob(temp_path, gcs_path)
|
||||
with os.fdopen(fd, 'r') as tmp:
|
||||
result = tmp.read()
|
||||
finally:
|
||||
os.remove(temp_path)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def create_gcs_bucket_if_not_exist(gcs_bucket):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,324 @@
|
|||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
from absl import logging
|
||||
import fire
|
||||
from google.protobuf import json_format
|
||||
import os
|
||||
|
||||
from kfp.containers import _gcs_helper
|
||||
from kfp.containers import entrypoint_utils
|
||||
from kfp.dsl import artifact
|
||||
|
||||
_FN_SOURCE = 'ml/main.py'
|
||||
_FN_NAME_ARG = 'function_name'
|
||||
|
||||
_PARAM_METADATA_SUFFIX = '_input_param_metadata_file'
|
||||
_ARTIFACT_METADATA_SUFFIX = '_input_artifact_metadata_file'
|
||||
_FIELD_NAME_SUFFIX = '_input_field_name'
|
||||
_ARGO_PARAM_SUFFIX = '_input_argo_param'
|
||||
_INPUT_PATH_SUFFIX = '_input_path'
|
||||
_OUTPUT_NAME_SUFFIX = '_input_output_name'
|
||||
|
||||
_OUTPUT_PARAM_PATH_SUFFIX = '_parameter_output_path'
|
||||
_OUTPUT_ARTIFACT_PATH_SUFFIX = '_artifact_output_path'
|
||||
|
||||
_METADATA_FILE_ARG = 'executor_metadata_json_file'
|
||||
|
||||
|
||||
class InputParam(object):
|
||||
"""POD that holds an input parameter."""
|
||||
|
||||
def __init__(self,
|
||||
value: Optional[Union[str, float, int]] = None,
|
||||
metadata_file: Optional[str] = None,
|
||||
field_name: Optional[str] = None):
|
||||
"""Instantiates an InputParam object.
|
||||
|
||||
Args:
|
||||
value: The actual value of the parameter.
|
||||
metadata_file: The location of the metadata JSON file output by the
|
||||
producer step.
|
||||
field_name: The output name of the producer.
|
||||
|
||||
Raises:
|
||||
ValueError: when neither of the following is true:
|
||||
1) value is provided, and metadata_file and field_name are not; or
|
||||
2) both metadata_file and field_name are provided, and value is not.
|
||||
"""
|
||||
if not (value is not None and not (metadata_file or field_name) or (
|
||||
metadata_file and field_name and value is None)):
|
||||
raise ValueError('Either value or both metadata_file and field_name '
|
||||
'needs to be provided. Got value={value}, field_name='
|
||||
'{field_name}, metadata_file={metadata_file}'.format(
|
||||
value=value,
|
||||
field_name=field_name,
|
||||
metadata_file=metadata_file
|
||||
))
|
||||
if value is not None:
|
||||
self._value = value
|
||||
else:
|
||||
# Parse the value by inspecting the producer's metadata JSON file.
|
||||
self._value = entrypoint_utils.get_parameter_from_output(
|
||||
metadata_file, field_name)
|
||||
|
||||
self._metadata_file = metadata_file
|
||||
self._field_name = field_name
|
||||
|
||||
# Following properties are read-only
|
||||
@property
|
||||
def value(self) -> Union[float, str, int]:
|
||||
return self._value
|
||||
|
||||
@property
|
||||
def metadata_file(self) -> str:
|
||||
return self._metadata_file
|
||||
|
||||
@property
|
||||
def field_name(self) -> str:
|
||||
return self._field_name
|
||||
|
||||
|
||||
class InputArtifact(object):
|
||||
"""POD that holds an input artifact."""
|
||||
|
||||
def __init__(self,
|
||||
uri: Optional[str] = None,
|
||||
metadata_file: Optional[str] = None,
|
||||
output_name: Optional[str] = None
|
||||
):
|
||||
"""Instantiates an InputParam object.
|
||||
|
||||
Args:
|
||||
uri: The uri holds the input artifact.
|
||||
metadata_file: The location of the metadata JSON file output by the
|
||||
producer step.
|
||||
output_name: The output name of the artifact in producer step.
|
||||
|
||||
Raises:
|
||||
ValueError: when neither of the following is true:
|
||||
1) uri is provided, and metadata_file and output_name are not; or
|
||||
2) both metadata_file and output_name are provided, and uri is not.
|
||||
"""
|
||||
if not ((uri and not (metadata_file or output_name) or (
|
||||
metadata_file and output_name and not uri))):
|
||||
raise ValueError('Either uri or both metadata_file and output_name '
|
||||
'needs to be provided. Got uri={uri}, output_name='
|
||||
'{output_name}, metadata_file={metadata_file}'.format(
|
||||
uri=uri,
|
||||
output_name=output_name,
|
||||
metadata_file=metadata_file
|
||||
))
|
||||
|
||||
self._metadata_file = metadata_file
|
||||
self._output_name = output_name
|
||||
if uri:
|
||||
self._uri = uri
|
||||
else:
|
||||
self._uri = self.get_artifact().uri
|
||||
|
||||
# Following properties are read-only.
|
||||
@property
|
||||
def uri(self) -> str:
|
||||
return self._uri
|
||||
|
||||
@property
|
||||
def metadata_file(self) -> str:
|
||||
return self._metadata_file
|
||||
|
||||
@property
|
||||
def output_name(self) -> str:
|
||||
return self._output_name
|
||||
|
||||
def get_artifact(self) -> artifact.Artifact:
|
||||
"""Gets an artifact object by parsing metadata or creating one from uri."""
|
||||
if self.metadata_file and self.output_name:
|
||||
return entrypoint_utils.get_artifact_from_output(
|
||||
self.metadata_file, self.output_name)
|
||||
else:
|
||||
# Provide an empty schema when returning a raw Artifact.
|
||||
result = artifact.Artifact(
|
||||
instance_schema=artifact.DEFAULT_ARTIFACT_SCHEMA)
|
||||
result.uri = self.uri
|
||||
return result
|
||||
|
||||
|
||||
def main(**kwargs):
|
||||
"""Container entrypoint used by KFP Python function based component.
|
||||
|
||||
This function has a dynamic signature, which will be interpreted according to
|
||||
the I/O and data-passing contract of KFP Python function components. The
|
||||
parameter will be received from command line interface.
|
||||
|
||||
For each declared parameter input of the user function, three command line
|
||||
arguments will be recognized:
|
||||
1. {name of the parameter}_input_param_metadata_file: The metadata JSON file
|
||||
path output by the producer.
|
||||
2. {name of the parameter}_input_field_name: The output name of the parameter,
|
||||
by which the parameter can be found in the producer metadata JSON file.
|
||||
3. {name of the parameter}_input_argo_param: The actual runtime value of the
|
||||
input parameter.
|
||||
When the producer is a new-styled KFP Python component, 1 and 2 will be
|
||||
populated, and when it's a conventional KFP Python component, 3 will be in
|
||||
use.
|
||||
|
||||
For each declared artifact input of the user function, three command line args
|
||||
will be recognized:
|
||||
1. {name of the artifact}_input_path: The actual path, or uri, of the input
|
||||
artifact.
|
||||
2. {name of the artifact}_input_artifact_metadata_file: The metadata JSON file
|
||||
path output by the producer.
|
||||
3. {name of the artifact}_input_output_name: The output name of the artifact,
|
||||
by which the artifact can be found in the producer metadata JSON file.
|
||||
If the producer is a new-styled KFP Python component, 2+3 will be used to give
|
||||
user code access to MLMD (custom) properties associated with this artifact;
|
||||
if the producer is a conventional KFP Python component, 1 will be used to
|
||||
construct an Artifact with only the URI populated.
|
||||
|
||||
For each declared artifact or parameter output of the user function, a command
|
||||
line arg, namely, `{name of the artifact|parameter}_(artifact|parameter)_output_path`,
|
||||
will be passed to specify the location where the output content is written to.
|
||||
|
||||
In addition, `executor_metadata_json_file` specifies the location where the
|
||||
output metadata JSON file will be written.
|
||||
"""
|
||||
if _METADATA_FILE_ARG not in kwargs:
|
||||
raise RuntimeError('Must specify executor_metadata_json_file')
|
||||
|
||||
# Group arguments according to suffixes.
|
||||
input_params_metadata = {}
|
||||
input_params_field_name = {}
|
||||
input_params_value = {}
|
||||
input_artifacts_metadata = {}
|
||||
input_artifacts_uri = {}
|
||||
input_artifacts_output_name = {}
|
||||
output_artifacts_uri = {}
|
||||
output_params_path = {}
|
||||
for k, v in kwargs.items():
|
||||
if k.endswith(_PARAM_METADATA_SUFFIX):
|
||||
param_name = k[:-len(_PARAM_METADATA_SUFFIX)]
|
||||
input_params_metadata[param_name] = v
|
||||
elif k.endswith(_FIELD_NAME_SUFFIX):
|
||||
param_name = k[:-len(_FIELD_NAME_SUFFIX)]
|
||||
input_params_field_name[param_name] = v
|
||||
elif k.endswith(_ARGO_PARAM_SUFFIX):
|
||||
param_name = k[:-len(_ARGO_PARAM_SUFFIX)]
|
||||
input_params_value[param_name] = v
|
||||
elif k.endswith(_ARTIFACT_METADATA_SUFFIX):
|
||||
artifact_name = k[:-len(_ARTIFACT_METADATA_SUFFIX)]
|
||||
input_artifacts_metadata[artifact_name] = v
|
||||
elif k.endswith(_INPUT_PATH_SUFFIX):
|
||||
artifact_name = k[:-len(_INPUT_PATH_SUFFIX)]
|
||||
input_artifacts_uri[artifact_name] = v
|
||||
elif k.endswith(_OUTPUT_NAME_SUFFIX):
|
||||
artifact_name = k[:-len(_OUTPUT_NAME_SUFFIX)]
|
||||
input_artifacts_output_name[artifact_name] = v
|
||||
elif k.endswith(_OUTPUT_PARAM_PATH_SUFFIX):
|
||||
param_name = k[:-len(_OUTPUT_PARAM_PATH_SUFFIX)]
|
||||
output_params_path[param_name] = v
|
||||
elif k.endswith(_OUTPUT_ARTIFACT_PATH_SUFFIX):
|
||||
artifact_name = k[:-len(_OUTPUT_ARTIFACT_PATH_SUFFIX)]
|
||||
output_artifacts_uri[artifact_name] = v
|
||||
elif k not in (_METADATA_FILE_ARG, _FN_NAME_ARG):
|
||||
logging.warning(
|
||||
'Got unexpected command line argument: %s=%s Ignoring', k, v)
|
||||
|
||||
# Instantiate POD objects.
|
||||
input_params = {}
|
||||
for param_name in (
|
||||
input_params_value.keys() |
|
||||
input_params_field_name.keys() | input_params_metadata.keys()):
|
||||
input_param = InputParam(
|
||||
value=input_params_value.get(param_name),
|
||||
metadata_file=input_params_metadata.get(param_name),
|
||||
field_name=input_params_field_name.get(param_name))
|
||||
input_params[param_name] = input_param
|
||||
|
||||
input_artifacts = {}
|
||||
for artifact_name in (
|
||||
input_artifacts_uri.keys() |
|
||||
input_artifacts_metadata.keys() |
|
||||
input_artifacts_output_name.keys()
|
||||
):
|
||||
input_artifact = InputArtifact(
|
||||
uri=input_artifacts_uri.get(artifact_name),
|
||||
metadata_file=input_artifacts_metadata.get(artifact_name),
|
||||
output_name=input_artifacts_output_name.get(artifact_name))
|
||||
input_artifacts[artifact_name] = input_artifact
|
||||
|
||||
# Import and invoke the user-provided function.
|
||||
# Currently the actual user code is built into container as /ml/main.py
|
||||
# which is specified in
|
||||
# kfp.containers._component_builder.build_python_component.
|
||||
|
||||
# Also, determine a way to inspect the function signature to decide the type
|
||||
# of output artifacts.
|
||||
fn_name = kwargs[_FN_NAME_ARG]
|
||||
|
||||
fn = entrypoint_utils.import_func_from_source(_FN_SOURCE, fn_name)
|
||||
# Get the output artifacts and combine them with the provided URIs.
|
||||
output_artifacts = entrypoint_utils.get_output_artifacts(
|
||||
fn, output_artifacts_uri)
|
||||
invoking_kwargs = {}
|
||||
for k, v in output_artifacts.items():
|
||||
invoking_kwargs[k] = v
|
||||
|
||||
for k, v in input_params.items():
|
||||
invoking_kwargs[k] = v.value
|
||||
for k, v in input_artifacts.items():
|
||||
invoking_kwargs[k] = v.get_artifact()
|
||||
|
||||
# Execute the user function. fn_res is expected to contain output parameters
|
||||
# only. It's either an namedtuple or a single primitive value.
|
||||
fn_res = fn(**invoking_kwargs)
|
||||
|
||||
if isinstance(fn_res, (int, float, str)) and len(output_params_path) != 1:
|
||||
raise RuntimeError('For primitive output a single output param path is '
|
||||
'expected. Got %s' % output_params_path)
|
||||
|
||||
if isinstance(fn_res, (int, float, str)):
|
||||
output_name = list(output_params_path.keys())[0]
|
||||
# Write the output to the provided path.
|
||||
_gcs_helper.GCSHelper.write_to_gcs_path(
|
||||
path=output_params_path[output_name],
|
||||
content=str(fn_res))
|
||||
else:
|
||||
# When multiple outputs, we'll need to match each field to the output paths.
|
||||
for idx, output_name in enumerate(fn_res._fields):
|
||||
path = output_params_path[output_name]
|
||||
_gcs_helper.GCSHelper.write_to_gcs_path(
|
||||
path=path,
|
||||
content=str(fn_res[idx]))
|
||||
|
||||
# Write output metadata JSON file.
|
||||
output_parameters = {}
|
||||
if isinstance(fn_res, (int, float, str)):
|
||||
output_parameters['output'] = fn_res
|
||||
else:
|
||||
for idx, output_name in enumerate(fn_res._fields):
|
||||
output_parameters[output_name] = fn_res[idx]
|
||||
|
||||
executor_output = entrypoint_utils.get_executor_output(
|
||||
output_artifacts=output_artifacts,
|
||||
output_params=output_parameters)
|
||||
|
||||
_gcs_helper.GCSHelper.write_to_gcs_path(
|
||||
path=kwargs[_METADATA_FILE_ARG],
|
||||
content=json_format.MessageToJson(executor_output))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
fire.Fire(main)
|
||||
|
|
@ -0,0 +1,183 @@
|
|||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl import logging
|
||||
import importlib
|
||||
import sys
|
||||
from typing import Callable, Dict, Optional, Union
|
||||
from google.protobuf import json_format
|
||||
|
||||
from kfp.components import _python_op
|
||||
from kfp.containers import _gcs_helper
|
||||
from kfp.pipeline_spec import pipeline_spec_pb2
|
||||
from kfp.dsl import artifact
|
||||
|
||||
# If path starts with one of those, consider files are in remote filesystem.
|
||||
_REMOTE_FS_PREFIX = ['gs://', 'hdfs://', 's3://']
|
||||
|
||||
# Constant user module name when importing the function from a Python file.
|
||||
_USER_MODULE = 'user_module'
|
||||
|
||||
|
||||
def get_parameter_from_output(file_path: str, param_name: str):
|
||||
"""Gets a parameter value by its name from output metadata JSON."""
|
||||
output = pipeline_spec_pb2.ExecutorOutput()
|
||||
json_format.Parse(
|
||||
text=_gcs_helper.GCSHelper.read_from_gcs_path(file_path),
|
||||
message=output)
|
||||
value = output.parameters[param_name]
|
||||
return getattr(value, value.WhichOneof('value'))
|
||||
|
||||
|
||||
def get_artifact_from_output(
|
||||
file_path: str, output_name: str) -> artifact.Artifact:
|
||||
"""Gets an artifact object from output metadata JSON."""
|
||||
output = pipeline_spec_pb2.ExecutorOutput()
|
||||
json_format.Parse(
|
||||
text=_gcs_helper.GCSHelper.read_from_gcs_path(file_path),
|
||||
message=output
|
||||
)
|
||||
# Currently we bear the assumption that each output contains only one artifact
|
||||
json_str = json_format.MessageToJson(
|
||||
output.artifacts[output_name].artifacts[0], sort_keys=True)
|
||||
|
||||
# Convert runtime_artifact to Python artifact
|
||||
return artifact.Artifact.deserialize(json_str)
|
||||
|
||||
|
||||
def import_func_from_source(source_path: str, fn_name: str) -> Callable:
|
||||
"""Imports a function from a Python file.
|
||||
|
||||
The implementation is borrowed from
|
||||
https://github.com/tensorflow/tfx/blob/8f25a4d1cc92dfc8c3a684dfc8b82699513cafb5/tfx/utils/import_utils.py#L50
|
||||
|
||||
Args:
|
||||
source_path: The local path to the Python source file.
|
||||
fn_name: The function name, which can be found in the source file.
|
||||
|
||||
Return: A Python function object.
|
||||
|
||||
Raises:
|
||||
ImportError when failed to load the source file or cannot find the function
|
||||
with the given name.
|
||||
"""
|
||||
if any([source_path.startswith(prefix) for prefix in _REMOTE_FS_PREFIX]):
|
||||
raise RuntimeError('Only local source file can be imported. Please make '
|
||||
'sure the user code is built into executor container. '
|
||||
'Got file path: %s' % source_path)
|
||||
try:
|
||||
loader = importlib.machinery.SourceFileLoader(
|
||||
fullname=_USER_MODULE,
|
||||
path=source_path,
|
||||
)
|
||||
spec = importlib.util.spec_from_loader(
|
||||
loader.name, loader, origin=source_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[loader.name] = module
|
||||
loader.exec_module(module)
|
||||
except IOError:
|
||||
raise ImportError('{} in {} not found in import_func_from_source()'.format(
|
||||
fn_name, source_path
|
||||
))
|
||||
try:
|
||||
return getattr(module, fn_name)
|
||||
except AttributeError:
|
||||
raise ImportError('{} in {} not found in import_func_from_source()'.format(
|
||||
fn_name, source_path
|
||||
))
|
||||
|
||||
|
||||
def get_output_artifacts(
|
||||
fn: Callable, output_uris: Dict[str, str]) -> Dict[str, artifact.Artifact]:
|
||||
"""Gets the output artifacts from function signature and provided URIs.
|
||||
|
||||
Args:
|
||||
fn: A user-provided function, whose signature annotates the type of output
|
||||
artifacts.
|
||||
output_uris: The mapping from output artifact name to its URI.
|
||||
|
||||
Returns:
|
||||
A mapping from output artifact name to Python artifact objects.
|
||||
"""
|
||||
# Inspect the function signature to determine the set of output artifact.
|
||||
spec = _python_op._extract_component_interface(fn)
|
||||
|
||||
result = {} # Mapping from output name to artifacts.
|
||||
for output in spec.outputs:
|
||||
if (getattr(output, '_passing_style', None) == _python_op.OutputArtifact):
|
||||
# Creates an artifact according to its name
|
||||
type_name = getattr(output, 'type', None)
|
||||
if not type_name:
|
||||
continue
|
||||
|
||||
try:
|
||||
artifact_cls = getattr(
|
||||
importlib.import_module(artifact.KFP_ARTIFACT_ONTOLOGY_MODULE),
|
||||
type_name)
|
||||
|
||||
except (AttributeError, ImportError, ValueError):
|
||||
logging.warning((
|
||||
'Could not load artifact class %s.%s; using fallback deserialization'
|
||||
' for the relevant artifact. Please make sure that any artifact '
|
||||
'classes can be imported within your container or environment.'),
|
||||
artifact.KFP_ARTIFACT_ONTOLOGY_MODULE, type_name)
|
||||
artifact_cls = artifact.Artifact
|
||||
|
||||
if artifact_cls == artifact.Artifact:
|
||||
# Provide an empty schema if instantiating an bare-metal artifact.
|
||||
art = artifact_cls(instance_schema=artifact.DEFAULT_ARTIFACT_SCHEMA)
|
||||
else:
|
||||
art = artifact_cls()
|
||||
|
||||
art.uri = output_uris[output.name]
|
||||
result[output.name] = art
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _get_pipeline_value(value: Union[int, float, str]) -> Optional[
|
||||
pipeline_spec_pb2.Value]:
|
||||
"""Converts Python primitive value to pipeline value pb."""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
result = pipeline_spec_pb2.Value()
|
||||
if isinstance(value, int):
|
||||
result.int_value = value
|
||||
elif isinstance(value, float):
|
||||
result.double_value = value
|
||||
elif isinstance(value, str):
|
||||
result.string_value = value
|
||||
else:
|
||||
raise TypeError('Got unknown type of value: {}'.format(value))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_executor_output(
|
||||
output_artifacts: Dict[str, artifact.Artifact],
|
||||
output_params: Dict[str, Union[int, float, str]]
|
||||
) -> pipeline_spec_pb2.ExecutorOutput:
|
||||
"""Gets the output metadata message."""
|
||||
result = pipeline_spec_pb2.ExecutorOutput()
|
||||
|
||||
for name, art in output_artifacts.items():
|
||||
result.artifacts[name].CopyFrom(pipeline_spec_pb2.ArtifactList(
|
||||
artifacts=[art.runtime_artifact]
|
||||
))
|
||||
|
||||
for name, param in output_params.items():
|
||||
result.parameters[name].CopyFrom(_get_pipeline_value(param))
|
||||
|
||||
return result
|
||||
|
|
@ -1,56 +0,0 @@
|
|||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""The entrypoint binary used in KFP component."""
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
|
||||
class ParseKwargs(argparse.Action):
|
||||
"""Helper class to parse the keyword arguments.
|
||||
|
||||
This Python binary expects a set of kwargs, whose keys are not predefined.
|
||||
"""
|
||||
def __call__(
|
||||
self, parser, namespace, values, option_string=None):
|
||||
setattr(namespace, self.dest, dict())
|
||||
assert len(values) % 2 == 0, 'Each specified arg key must have a value.'
|
||||
current_key = None
|
||||
for idx, value in enumerate(values):
|
||||
if idx % 2 == 0:
|
||||
# Parse this into a key.
|
||||
current_key = value
|
||||
else:
|
||||
# Parse current value with the previous key.
|
||||
getattr(namespace, self.dest)[current_key] = value
|
||||
|
||||
|
||||
def main():
|
||||
"""The main program of KFP container entrypoint.
|
||||
|
||||
This entrypoint should be called as follows:
|
||||
python run_container.py -k key1 value1 key2 value2 ...
|
||||
|
||||
The recognized argument keys are as follows:
|
||||
- {input-parameter-name}_metadata_file
|
||||
- {input_parameter-name}_field_name
|
||||
- {}
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-k', '--kwargs', nargs='*', action=ParseKwargs)
|
||||
args = parser.parse_args(sys.argv[1:]) # Skip the file name.
|
||||
print(args.kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
@ -0,0 +1,136 @@
|
|||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for kfp.containers.entrypoint module."""
|
||||
import mock
|
||||
import unittest
|
||||
|
||||
from kfp.containers import entrypoint
|
||||
from kfp.containers import entrypoint_utils
|
||||
# Import testdata to mock entrypoint_utils.import_func_from_source function.
|
||||
from kfp.containers_tests.testdata import main
|
||||
|
||||
_OUTPUT_METADATA_JSON_LOCATION = 'executor_output_metadata.json'
|
||||
|
||||
_PRODUCER_EXECUTOR_OUTPUT = """{
|
||||
"parameters": {
|
||||
"param_output": {
|
||||
"stringValue": "hello from producer"
|
||||
}
|
||||
},
|
||||
"artifacts": {
|
||||
"artifact_output": {
|
||||
"artifacts": [
|
||||
{
|
||||
"type": {
|
||||
"instanceSchema": "properties:\\ntitle: kfp.Dataset\\ntype: object\\n"
|
||||
},
|
||||
"uri": "gs://root/producer/artifact_output"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}"""
|
||||
|
||||
_EXPECTED_EXECUTOR_OUTPUT_1 = """{
|
||||
"parameters": {
|
||||
"test_output2": {
|
||||
"stringValue": "bye world"
|
||||
}
|
||||
},
|
||||
"artifacts": {
|
||||
"test_output1": {
|
||||
"artifacts": [
|
||||
{
|
||||
"type": {
|
||||
"instanceSchema": "properties:\\ntitle: kfp.Model\\ntype: object\\n"
|
||||
},
|
||||
"uri": "gs://root/consumer/output1"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}"""
|
||||
|
||||
|
||||
class EntrypointTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Prepare mock
|
||||
self._import_func = mock.patch.object(
|
||||
entrypoint_utils,
|
||||
'import_func_from_source').start()
|
||||
self._mock_gcs_read = mock.patch(
|
||||
'kfp.containers._gcs_helper.GCSHelper.read_from_gcs_path',
|
||||
).start()
|
||||
self._mock_gcs_write = mock.patch(
|
||||
'kfp.containers._gcs_helper.GCSHelper.write_to_gcs_path',
|
||||
).start()
|
||||
|
||||
self.addCleanup(mock.patch.stopall)
|
||||
|
||||
|
||||
def testMainWithV1Producer(self):
|
||||
"""Tests the entrypoint with data passing with conventional KFP components.
|
||||
|
||||
This test case emulates the following scenario:
|
||||
- User provides a function, namely `test_func`.
|
||||
- In test function, there are an input parameter (`test_param`) and an input
|
||||
artifact (`test_artifact`). And the user code generates an output
|
||||
artifact (`test_output1`) and an output parameter (`test_output2`).
|
||||
- The specified metadata JSON file location is at
|
||||
'executor_output_metadata.json'
|
||||
- The inputs of this step are all provided by conventional KFP components.
|
||||
"""
|
||||
# Set mocked user function.
|
||||
self._import_func.return_value = main.test_func
|
||||
|
||||
entrypoint.main(
|
||||
executor_metadata_json_file=_OUTPUT_METADATA_JSON_LOCATION,
|
||||
function_name='test_func',
|
||||
test_param_input_argo_param='hello from producer',
|
||||
test_artifact_input_path='gs://root/producer/output',
|
||||
test_output1_artifact_output_path='gs://root/consumer/output1',
|
||||
test_output2_parameter_output_path='gs://root/consumer/output2'
|
||||
)
|
||||
|
||||
self._mock_gcs_write.assert_called_with(
|
||||
path=_OUTPUT_METADATA_JSON_LOCATION,
|
||||
content=_EXPECTED_EXECUTOR_OUTPUT_1)
|
||||
|
||||
def testMainWithV2Producer(self):
|
||||
"""Tests the entrypoint with data passing with new-styled KFP components.
|
||||
|
||||
This test case emulates a similar scenario as testMainWithV1Producer, except
|
||||
for that the inputs of this step are all provided by a new-styled KFP
|
||||
component.
|
||||
"""
|
||||
# Set mocked user function.
|
||||
self._import_func.return_value = main.test_func2
|
||||
# Set GFile read function
|
||||
self._mock_gcs_read.return_value = _PRODUCER_EXECUTOR_OUTPUT
|
||||
|
||||
entrypoint.main(
|
||||
executor_metadata_json_file=_OUTPUT_METADATA_JSON_LOCATION,
|
||||
function_name='test_func2',
|
||||
test_param_input_param_metadata_file='gs://root/producer/executor_output_metadata.json',
|
||||
test_param_input_field_name='param_output',
|
||||
test_artifact_input_artifact_metadata_file='gs://root/producer/executor_output_metadata.json',
|
||||
test_artifact_input_output_name='artifact_output',
|
||||
test_output1_artifact_output_path='gs://root/consumer/output1',
|
||||
test_output2_parameter_output_path='gs://root/consumer/output2'
|
||||
)
|
||||
|
||||
self._mock_gcs_write.assert_called_with(
|
||||
path=_OUTPUT_METADATA_JSON_LOCATION,
|
||||
content=_EXPECTED_EXECUTOR_OUTPUT_1)
|
||||
|
|
@ -0,0 +1,128 @@
|
|||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for kfp.containers.entrypoint_utils module."""
|
||||
from google.protobuf import json_format
|
||||
import mock
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from kfp.dsl import artifact
|
||||
from kfp import components
|
||||
from kfp.containers import entrypoint_utils
|
||||
from kfp.dsl import ontology_artifacts
|
||||
from kfp.pipeline_spec import pipeline_spec_pb2
|
||||
|
||||
|
||||
def _get_text_from_testdata(filename: str) -> str:
|
||||
"""Reads the content of a file under testdata."""
|
||||
with open(
|
||||
os.path.join(os.path.dirname(__file__), 'testdata', filename), 'r') as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def _test_function(
|
||||
a: components.InputArtifact('Dataset'),
|
||||
b: components.OutputArtifact('Model'),
|
||||
c: components.OutputArtifact('Artifact')
|
||||
):
|
||||
"""Function used to test signature parsing."""
|
||||
pass
|
||||
|
||||
|
||||
_OUTPUT_URIS = {
|
||||
'b': 'gs://root/execution/b',
|
||||
'c': 'gs://root/execution/c'
|
||||
}
|
||||
|
||||
|
||||
class EntrypointUtilsTest(unittest.TestCase):
|
||||
|
||||
@mock.patch('kfp.containers._gcs_helper.GCSHelper.read_from_gcs_path')
|
||||
def testGetParameterFromOutput(self, mock_read):
|
||||
mock_read.return_value = _get_text_from_testdata('executor_output.json')
|
||||
|
||||
self.assertEqual(entrypoint_utils.get_parameter_from_output(
|
||||
file_path=os.path.join('testdata', 'executor_output.json'),
|
||||
param_name='int_output'
|
||||
), 42)
|
||||
self.assertEqual(entrypoint_utils.get_parameter_from_output(
|
||||
file_path=os.path.join('testdata', 'executor_output.json'),
|
||||
param_name='string_output'
|
||||
), 'hello world!')
|
||||
self.assertEqual(entrypoint_utils.get_parameter_from_output(
|
||||
file_path=os.path.join('testdata', 'executor_output.json'),
|
||||
param_name='float_output'
|
||||
), 12.12)
|
||||
|
||||
@mock.patch('kfp.containers._gcs_helper.GCSHelper.read_from_gcs_path')
|
||||
def testGetArtifactFromOutput(self, mock_read):
|
||||
mock_read.return_value = _get_text_from_testdata('executor_output.json')
|
||||
|
||||
art = entrypoint_utils.get_artifact_from_output(
|
||||
file_path=os.path.join('testdata', 'executor_output.json'),
|
||||
output_name='output'
|
||||
)
|
||||
self.assertIsInstance(art, ontology_artifacts.Model)
|
||||
self.assertEqual(art.uri, 'gs://root/execution/output')
|
||||
self.assertEqual(art.name, 'test-artifact')
|
||||
self.assertEqual(art.get_string_custom_property('test_property'),
|
||||
'test value')
|
||||
|
||||
def testGetOutputArtifacts(self):
|
||||
outputs = entrypoint_utils.get_output_artifacts(
|
||||
_test_function, _OUTPUT_URIS)
|
||||
self.assertSetEqual(set(outputs.keys()), {'b', 'c'})
|
||||
self.assertIsInstance(outputs['b'], ontology_artifacts.Model)
|
||||
self.assertIsInstance(outputs['c'], artifact.Artifact)
|
||||
self.assertEqual(outputs['b'].uri, 'gs://root/execution/b')
|
||||
self.assertEqual(outputs['c'].uri, 'gs://root/execution/c')
|
||||
|
||||
def testGetExecutorOutput(self):
|
||||
model = ontology_artifacts.Model()
|
||||
model.name = 'test-artifact'
|
||||
model.uri = 'gs://root/execution/output'
|
||||
model.set_string_custom_property('test_property', 'test value')
|
||||
|
||||
executor_output = entrypoint_utils.get_executor_output(
|
||||
output_artifacts={'output': model},
|
||||
output_params={
|
||||
'int_output': 42,
|
||||
'string_output': 'hello world!',
|
||||
'float_output': 12.12
|
||||
})
|
||||
|
||||
# Renormalize the JSON proto read from testdata. Otherwise there'll be
|
||||
# mismatch in the way treating int value.
|
||||
expected_output = pipeline_spec_pb2.ExecutorOutput()
|
||||
expected_output = json_format.Parse(
|
||||
text=_get_text_from_testdata('executor_output.json'),
|
||||
message=expected_output)
|
||||
|
||||
self.assertDictEqual(
|
||||
json_format.MessageToDict(expected_output),
|
||||
json_format.MessageToDict(executor_output))
|
||||
|
||||
def testImportFuncFromSource(self):
|
||||
fn = entrypoint_utils.import_func_from_source(
|
||||
source_path=os.path.join(
|
||||
os.path.dirname(__file__), 'testdata', 'test_source.py'),
|
||||
fn_name='test_func'
|
||||
)
|
||||
self.assertEqual(fn(1, 2), 3)
|
||||
|
||||
with self.assertRaisesRegexp(ImportError, '\D+ in \D+ not found in '):
|
||||
_ = entrypoint_utils.import_func_from_source(
|
||||
source_path=os.path.join('testdata', 'test_source.py'),
|
||||
fn_name='non_existing_fn'
|
||||
)
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
{
|
||||
"artifacts": {
|
||||
"output": {
|
||||
"artifacts": [
|
||||
{
|
||||
"name": "test-artifact",
|
||||
"uri": "gs://root/execution/output",
|
||||
"type": {
|
||||
"instanceSchema": "properties:\ntitle: kfp.Model\ntype: object\n"
|
||||
},
|
||||
"customProperties": {
|
||||
"test_property": {
|
||||
"stringValue": "test value"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"parameters": {
|
||||
"int_output": {
|
||||
"intValue": 42
|
||||
},
|
||||
"string_output": {
|
||||
"stringValue": "hello world!"
|
||||
},
|
||||
"float_output": {
|
||||
"doubleValue": 12.12
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""User module under test"""
|
||||
from typing import NamedTuple
|
||||
from kfp import components
|
||||
from kfp.dsl import artifact
|
||||
from kfp.dsl import ontology_artifacts
|
||||
|
||||
|
||||
def test_func(
|
||||
test_param: str,
|
||||
test_artifact: components.InputArtifact('Dataset'),
|
||||
test_output1: components.OutputArtifact('Model')
|
||||
) -> NamedTuple('Outputs', [('test_output2', str)]):
|
||||
assert test_param == 'hello from producer'
|
||||
# In the associated test case, input artifact is produced by conventional
|
||||
# KFP components, thus no concrete artifact type can be determined.
|
||||
assert isinstance(test_artifact, artifact.Artifact)
|
||||
assert isinstance(test_output1, ontology_artifacts.Model)
|
||||
assert test_output1.uri
|
||||
from collections import namedtuple
|
||||
|
||||
Outputs = namedtuple('Outputs', 'test_output2')
|
||||
return Outputs('bye world')
|
||||
|
||||
|
||||
def test_func2(
|
||||
test_param: str,
|
||||
test_artifact: components.InputArtifact('Dataset'),
|
||||
test_output1: components.OutputArtifact('Model')
|
||||
) -> NamedTuple('Outputs', [('test_output2', str)]):
|
||||
assert test_param == 'hello from producer'
|
||||
# In the associated test case, input artifact is produced by a new-styled
|
||||
# KFP components with metadata, thus it's expected to be deserialized to
|
||||
# Dataset object.
|
||||
assert isinstance(test_artifact, ontology_artifacts.Dataset)
|
||||
assert isinstance(test_output1, ontology_artifacts.Model)
|
||||
assert test_output1.uri
|
||||
from collections import namedtuple
|
||||
|
||||
Outputs = namedtuple('Outputs', 'test_output2')
|
||||
return Outputs('bye world')
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright 2021 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Python source file under test."""
|
||||
|
||||
def test_func(a, b):
|
||||
return a + b
|
||||
|
|
@ -24,7 +24,8 @@ from kfp.pipeline_spec import pipeline_spec_pb2
|
|||
from kfp.dsl import serialization_utils
|
||||
|
||||
_KFP_ARTIFACT_TITLE_PATTERN = 'kfp.{}'
|
||||
_KFP_ARTIFACT_ONTOLOGY_MODULE = 'kfp.dsl.ontology_artifacts'
|
||||
KFP_ARTIFACT_ONTOLOGY_MODULE = 'kfp.dsl.ontology_artifacts'
|
||||
DEFAULT_ARTIFACT_SCHEMA = 'title: kfp.Artifact\ntype: object\nproperties:\n'
|
||||
|
||||
|
||||
# Enum for property types.
|
||||
|
|
@ -121,8 +122,7 @@ class Artifact(object):
|
|||
if self.__class__ == Artifact:
|
||||
if not instance_schema:
|
||||
raise ValueError(
|
||||
'The "instance_schema" argument must be passed to specify a '
|
||||
'type for this Artifact.')
|
||||
'The "instance_schema" argument must be set.')
|
||||
schema_yaml = yaml.safe_load(instance_schema)
|
||||
if 'properties' not in schema_yaml:
|
||||
raise ValueError('Invalid instance_schema, properties must be present. '
|
||||
|
|
@ -138,7 +138,7 @@ class Artifact(object):
|
|||
else:
|
||||
if instance_schema:
|
||||
raise ValueError(
|
||||
'The "mlmd_artifact_type" argument must not be passed for '
|
||||
'The "instance_schema" argument must not be passed for '
|
||||
'Artifact subclass %s.' % self.__class__)
|
||||
instance_schema = self.get_artifact_type()
|
||||
|
||||
|
|
@ -326,7 +326,7 @@ class Artifact(object):
|
|||
result = None
|
||||
try:
|
||||
artifact_cls = getattr(
|
||||
importlib.import_module(_KFP_ARTIFACT_ONTOLOGY_MODULE), type_name)
|
||||
importlib.import_module(KFP_ARTIFACT_ONTOLOGY_MODULE), type_name)
|
||||
# TODO(numerology): Add deserialization tests for first party classes.
|
||||
result = artifact_cls()
|
||||
except (AttributeError, ImportError, ValueError):
|
||||
|
|
@ -334,7 +334,7 @@ class Artifact(object):
|
|||
'Could not load artifact class %s.%s; using fallback deserialization '
|
||||
'for the relevant artifact. Please make sure that any artifact '
|
||||
'classes can be imported within your container or environment.'),
|
||||
_KFP_ARTIFACT_ONTOLOGY_MODULE, type_name)
|
||||
KFP_ARTIFACT_ONTOLOGY_MODULE, type_name)
|
||||
if not result:
|
||||
# Otherwise generate a generic Artifact object.
|
||||
result = Artifact(instance_schema=artifact.type.instance_schema)
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ REQUIRES = [
|
|||
'strip-hints',
|
||||
'docstring-parser>=0.7.3',
|
||||
'kfp-pipeline-spec>=0.1.0, <0.2.0',
|
||||
'fire>=0.3.1'
|
||||
]
|
||||
|
||||
TESTS_REQUIRE = [
|
||||
|
|
|
|||
Loading…
Reference in New Issue