feat(sdk): Container entrypoint used for new styled KFP component authoring (#4978)

* skeleton

* add entrypoint utils to parse param

* wip: artifact parsing

* add input param artifacts passing and clean unused code

* wip

* add output artifact inspection

* add parameter output

* finish entrypoint implementation

* add entrypoint_utils_test.py

* add entrypoint test

* add entrypoint test

* get rid of tf

* fix test

* fix file location

* fix tests

* fix tests

* resolving comments

* Partially rollback

* resolve comments in entrypoint.py

* resolve comments
This commit is contained in:
Jiaxiao Zheng 2021-01-14 16:01:21 -08:00 committed by GitHub
parent d629397654
commit 279694ec6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 960 additions and 66 deletions

View File

@ -17,9 +17,11 @@ __all__ = [
'func_to_container_op',
'func_to_component_text',
'default_base_image_or_builder',
'InputArtifact',
'InputPath',
'InputTextFile',
'InputBinaryFile',
'OutputArtifact',
'OutputPath',
'OutputTextFile',
'OutputBinaryFile',
@ -63,6 +65,16 @@ class InputBinaryFile:
self.type = type
class InputArtifact:
"""InputArtifact function parameter annotation.
When creating component from function. InputArtifact indicates that the
associated input parameter should be tracked as an MLMD artifact.
"""
def __init__(self, type: Optional[str] = None):
self.type = type
class OutputPath:
'''When creating component from function, :class:`.OutputPath` should be used as function parameter annotation to tell the system that the function wants to output data by writing it into a file with the given path instead of returning the data from the function.'''
def __init__(self, type=None):
@ -81,6 +93,17 @@ class OutputBinaryFile:
self.type = type
class OutputArtifact:
"""OutputArtifact function parameter annotation.
When creating component from function. OutputArtifact indicates that the
associated input parameter should be treated as an MLMD artifact, whose
underlying content, together with metadata will be updated by this component
"""
def __init__(self, type: Optional[str] = None):
self.type = type
def _make_parent_dirs_and_return_path(file_path: str):
import os
os.makedirs(os.path.dirname(file_path), exist_ok=True)
@ -300,7 +323,10 @@ def _extract_component_interface(func: Callable) -> ComponentSpec:
parameter_annotation = parameter.annotation
passing_style = None
io_name = parameter.name
if isinstance(parameter_annotation, (InputPath, InputTextFile, InputBinaryFile, OutputPath, OutputTextFile, OutputBinaryFile)):
if isinstance(
parameter_annotation,
(InputArtifact, InputPath, InputTextFile, InputBinaryFile,
OutputArtifact, OutputPath, OutputTextFile, OutputBinaryFile)):
passing_style = type(parameter_annotation)
parameter_annotation = parameter_annotation.type
if parameter.default is not inspect.Parameter.empty and not (passing_style == InputPath and parameter.default is None):
@ -317,7 +343,9 @@ def _extract_component_interface(func: Callable) -> ComponentSpec:
type_struct = annotation_to_type_struct(parameter_annotation)
#TODO: Humanize the input/output names
if isinstance(parameter.annotation, (OutputPath, OutputTextFile, OutputBinaryFile)):
if isinstance(
parameter.annotation,
(OutputArtifact, OutputPath, OutputTextFile, OutputBinaryFile)):
io_name = _make_name_unique_by_adding_index(io_name, output_names, '_')
output_names.add(io_name)
output_spec = OutputSpec(

View File

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import PurePath
import os
import pathlib
import tempfile
class GCSHelper(object):
""" GCSHelper manages the connection with the GCS storage """
@ -26,7 +28,7 @@ class GCSHelper(object):
gcs_blob: gcs blob object(https://github.com/googleapis/google-cloud-python/blob/5c9bb42cb3c9250131cfeef6e0bafe8f4b7c139f/storage/google/cloud/storage/blob.py#L105)
"""
from google.cloud import storage
pure_path = PurePath(gcs_path)
pure_path = pathlib.PurePath(gcs_path)
gcs_bucket = pure_path.parts[1]
gcs_blob = '/'.join(pure_path.parts[2:])
client = storage.Client()
@ -44,6 +46,28 @@ class GCSHelper(object):
blob = GCSHelper.get_blob_from_gcs_uri(gcs_path)
blob.upload_from_filename(local_path)
@staticmethod
def write_to_gcs_path(path: str, content: str) -> None:
"""Writes serialized content to a GCS location.
Args:
path: GCS path to write to.
content: The content to be written.
"""
fd, temp_path = tempfile.mkstemp()
try:
with os.fdopen(fd, 'w') as tmp:
tmp.write(content)
if not GCSHelper.get_blob_from_gcs_uri(path):
pure_path = pathlib.PurePath(path)
gcs_bucket = pure_path.parts[1]
GCSHelper.create_gcs_bucket_if_not_exist(gcs_bucket)
GCSHelper.upload_gcs_file(temp_path, path)
finally:
os.remove(temp_path)
@staticmethod
def remove_gcs_blob(gcs_path):
"""
@ -63,6 +87,18 @@ class GCSHelper(object):
blob = GCSHelper.get_blob_from_gcs_uri(gcs_path)
blob.download_to_filename(local_path)
@staticmethod
def read_from_gcs_path(gcs_path: str) -> str:
"""Reads the content of a file hosted on GCS."""
fd, temp_path = tempfile.mkstemp()
try:
GCSHelper.download_gcs_blob(temp_path, gcs_path)
with os.fdopen(fd, 'r') as tmp:
result = tmp.read()
finally:
os.remove(temp_path)
return result
@staticmethod
def create_gcs_bucket_if_not_exist(gcs_bucket):
"""

View File

@ -0,0 +1,324 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
from absl import logging
import fire
from google.protobuf import json_format
import os
from kfp.containers import _gcs_helper
from kfp.containers import entrypoint_utils
from kfp.dsl import artifact
_FN_SOURCE = 'ml/main.py'
_FN_NAME_ARG = 'function_name'
_PARAM_METADATA_SUFFIX = '_input_param_metadata_file'
_ARTIFACT_METADATA_SUFFIX = '_input_artifact_metadata_file'
_FIELD_NAME_SUFFIX = '_input_field_name'
_ARGO_PARAM_SUFFIX = '_input_argo_param'
_INPUT_PATH_SUFFIX = '_input_path'
_OUTPUT_NAME_SUFFIX = '_input_output_name'
_OUTPUT_PARAM_PATH_SUFFIX = '_parameter_output_path'
_OUTPUT_ARTIFACT_PATH_SUFFIX = '_artifact_output_path'
_METADATA_FILE_ARG = 'executor_metadata_json_file'
class InputParam(object):
"""POD that holds an input parameter."""
def __init__(self,
value: Optional[Union[str, float, int]] = None,
metadata_file: Optional[str] = None,
field_name: Optional[str] = None):
"""Instantiates an InputParam object.
Args:
value: The actual value of the parameter.
metadata_file: The location of the metadata JSON file output by the
producer step.
field_name: The output name of the producer.
Raises:
ValueError: when neither of the following is true:
1) value is provided, and metadata_file and field_name are not; or
2) both metadata_file and field_name are provided, and value is not.
"""
if not (value is not None and not (metadata_file or field_name) or (
metadata_file and field_name and value is None)):
raise ValueError('Either value or both metadata_file and field_name '
'needs to be provided. Got value={value}, field_name='
'{field_name}, metadata_file={metadata_file}'.format(
value=value,
field_name=field_name,
metadata_file=metadata_file
))
if value is not None:
self._value = value
else:
# Parse the value by inspecting the producer's metadata JSON file.
self._value = entrypoint_utils.get_parameter_from_output(
metadata_file, field_name)
self._metadata_file = metadata_file
self._field_name = field_name
# Following properties are read-only
@property
def value(self) -> Union[float, str, int]:
return self._value
@property
def metadata_file(self) -> str:
return self._metadata_file
@property
def field_name(self) -> str:
return self._field_name
class InputArtifact(object):
"""POD that holds an input artifact."""
def __init__(self,
uri: Optional[str] = None,
metadata_file: Optional[str] = None,
output_name: Optional[str] = None
):
"""Instantiates an InputParam object.
Args:
uri: The uri holds the input artifact.
metadata_file: The location of the metadata JSON file output by the
producer step.
output_name: The output name of the artifact in producer step.
Raises:
ValueError: when neither of the following is true:
1) uri is provided, and metadata_file and output_name are not; or
2) both metadata_file and output_name are provided, and uri is not.
"""
if not ((uri and not (metadata_file or output_name) or (
metadata_file and output_name and not uri))):
raise ValueError('Either uri or both metadata_file and output_name '
'needs to be provided. Got uri={uri}, output_name='
'{output_name}, metadata_file={metadata_file}'.format(
uri=uri,
output_name=output_name,
metadata_file=metadata_file
))
self._metadata_file = metadata_file
self._output_name = output_name
if uri:
self._uri = uri
else:
self._uri = self.get_artifact().uri
# Following properties are read-only.
@property
def uri(self) -> str:
return self._uri
@property
def metadata_file(self) -> str:
return self._metadata_file
@property
def output_name(self) -> str:
return self._output_name
def get_artifact(self) -> artifact.Artifact:
"""Gets an artifact object by parsing metadata or creating one from uri."""
if self.metadata_file and self.output_name:
return entrypoint_utils.get_artifact_from_output(
self.metadata_file, self.output_name)
else:
# Provide an empty schema when returning a raw Artifact.
result = artifact.Artifact(
instance_schema=artifact.DEFAULT_ARTIFACT_SCHEMA)
result.uri = self.uri
return result
def main(**kwargs):
"""Container entrypoint used by KFP Python function based component.
This function has a dynamic signature, which will be interpreted according to
the I/O and data-passing contract of KFP Python function components. The
parameter will be received from command line interface.
For each declared parameter input of the user function, three command line
arguments will be recognized:
1. {name of the parameter}_input_param_metadata_file: The metadata JSON file
path output by the producer.
2. {name of the parameter}_input_field_name: The output name of the parameter,
by which the parameter can be found in the producer metadata JSON file.
3. {name of the parameter}_input_argo_param: The actual runtime value of the
input parameter.
When the producer is a new-styled KFP Python component, 1 and 2 will be
populated, and when it's a conventional KFP Python component, 3 will be in
use.
For each declared artifact input of the user function, three command line args
will be recognized:
1. {name of the artifact}_input_path: The actual path, or uri, of the input
artifact.
2. {name of the artifact}_input_artifact_metadata_file: The metadata JSON file
path output by the producer.
3. {name of the artifact}_input_output_name: The output name of the artifact,
by which the artifact can be found in the producer metadata JSON file.
If the producer is a new-styled KFP Python component, 2+3 will be used to give
user code access to MLMD (custom) properties associated with this artifact;
if the producer is a conventional KFP Python component, 1 will be used to
construct an Artifact with only the URI populated.
For each declared artifact or parameter output of the user function, a command
line arg, namely, `{name of the artifact|parameter}_(artifact|parameter)_output_path`,
will be passed to specify the location where the output content is written to.
In addition, `executor_metadata_json_file` specifies the location where the
output metadata JSON file will be written.
"""
if _METADATA_FILE_ARG not in kwargs:
raise RuntimeError('Must specify executor_metadata_json_file')
# Group arguments according to suffixes.
input_params_metadata = {}
input_params_field_name = {}
input_params_value = {}
input_artifacts_metadata = {}
input_artifacts_uri = {}
input_artifacts_output_name = {}
output_artifacts_uri = {}
output_params_path = {}
for k, v in kwargs.items():
if k.endswith(_PARAM_METADATA_SUFFIX):
param_name = k[:-len(_PARAM_METADATA_SUFFIX)]
input_params_metadata[param_name] = v
elif k.endswith(_FIELD_NAME_SUFFIX):
param_name = k[:-len(_FIELD_NAME_SUFFIX)]
input_params_field_name[param_name] = v
elif k.endswith(_ARGO_PARAM_SUFFIX):
param_name = k[:-len(_ARGO_PARAM_SUFFIX)]
input_params_value[param_name] = v
elif k.endswith(_ARTIFACT_METADATA_SUFFIX):
artifact_name = k[:-len(_ARTIFACT_METADATA_SUFFIX)]
input_artifacts_metadata[artifact_name] = v
elif k.endswith(_INPUT_PATH_SUFFIX):
artifact_name = k[:-len(_INPUT_PATH_SUFFIX)]
input_artifacts_uri[artifact_name] = v
elif k.endswith(_OUTPUT_NAME_SUFFIX):
artifact_name = k[:-len(_OUTPUT_NAME_SUFFIX)]
input_artifacts_output_name[artifact_name] = v
elif k.endswith(_OUTPUT_PARAM_PATH_SUFFIX):
param_name = k[:-len(_OUTPUT_PARAM_PATH_SUFFIX)]
output_params_path[param_name] = v
elif k.endswith(_OUTPUT_ARTIFACT_PATH_SUFFIX):
artifact_name = k[:-len(_OUTPUT_ARTIFACT_PATH_SUFFIX)]
output_artifacts_uri[artifact_name] = v
elif k not in (_METADATA_FILE_ARG, _FN_NAME_ARG):
logging.warning(
'Got unexpected command line argument: %s=%s Ignoring', k, v)
# Instantiate POD objects.
input_params = {}
for param_name in (
input_params_value.keys() |
input_params_field_name.keys() | input_params_metadata.keys()):
input_param = InputParam(
value=input_params_value.get(param_name),
metadata_file=input_params_metadata.get(param_name),
field_name=input_params_field_name.get(param_name))
input_params[param_name] = input_param
input_artifacts = {}
for artifact_name in (
input_artifacts_uri.keys() |
input_artifacts_metadata.keys() |
input_artifacts_output_name.keys()
):
input_artifact = InputArtifact(
uri=input_artifacts_uri.get(artifact_name),
metadata_file=input_artifacts_metadata.get(artifact_name),
output_name=input_artifacts_output_name.get(artifact_name))
input_artifacts[artifact_name] = input_artifact
# Import and invoke the user-provided function.
# Currently the actual user code is built into container as /ml/main.py
# which is specified in
# kfp.containers._component_builder.build_python_component.
# Also, determine a way to inspect the function signature to decide the type
# of output artifacts.
fn_name = kwargs[_FN_NAME_ARG]
fn = entrypoint_utils.import_func_from_source(_FN_SOURCE, fn_name)
# Get the output artifacts and combine them with the provided URIs.
output_artifacts = entrypoint_utils.get_output_artifacts(
fn, output_artifacts_uri)
invoking_kwargs = {}
for k, v in output_artifacts.items():
invoking_kwargs[k] = v
for k, v in input_params.items():
invoking_kwargs[k] = v.value
for k, v in input_artifacts.items():
invoking_kwargs[k] = v.get_artifact()
# Execute the user function. fn_res is expected to contain output parameters
# only. It's either an namedtuple or a single primitive value.
fn_res = fn(**invoking_kwargs)
if isinstance(fn_res, (int, float, str)) and len(output_params_path) != 1:
raise RuntimeError('For primitive output a single output param path is '
'expected. Got %s' % output_params_path)
if isinstance(fn_res, (int, float, str)):
output_name = list(output_params_path.keys())[0]
# Write the output to the provided path.
_gcs_helper.GCSHelper.write_to_gcs_path(
path=output_params_path[output_name],
content=str(fn_res))
else:
# When multiple outputs, we'll need to match each field to the output paths.
for idx, output_name in enumerate(fn_res._fields):
path = output_params_path[output_name]
_gcs_helper.GCSHelper.write_to_gcs_path(
path=path,
content=str(fn_res[idx]))
# Write output metadata JSON file.
output_parameters = {}
if isinstance(fn_res, (int, float, str)):
output_parameters['output'] = fn_res
else:
for idx, output_name in enumerate(fn_res._fields):
output_parameters[output_name] = fn_res[idx]
executor_output = entrypoint_utils.get_executor_output(
output_artifacts=output_artifacts,
output_params=output_parameters)
_gcs_helper.GCSHelper.write_to_gcs_path(
path=kwargs[_METADATA_FILE_ARG],
content=json_format.MessageToJson(executor_output))
if __name__ == '__main__':
fire.Fire(main)

View File

@ -0,0 +1,183 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl import logging
import importlib
import sys
from typing import Callable, Dict, Optional, Union
from google.protobuf import json_format
from kfp.components import _python_op
from kfp.containers import _gcs_helper
from kfp.pipeline_spec import pipeline_spec_pb2
from kfp.dsl import artifact
# If path starts with one of those, consider files are in remote filesystem.
_REMOTE_FS_PREFIX = ['gs://', 'hdfs://', 's3://']
# Constant user module name when importing the function from a Python file.
_USER_MODULE = 'user_module'
def get_parameter_from_output(file_path: str, param_name: str):
"""Gets a parameter value by its name from output metadata JSON."""
output = pipeline_spec_pb2.ExecutorOutput()
json_format.Parse(
text=_gcs_helper.GCSHelper.read_from_gcs_path(file_path),
message=output)
value = output.parameters[param_name]
return getattr(value, value.WhichOneof('value'))
def get_artifact_from_output(
file_path: str, output_name: str) -> artifact.Artifact:
"""Gets an artifact object from output metadata JSON."""
output = pipeline_spec_pb2.ExecutorOutput()
json_format.Parse(
text=_gcs_helper.GCSHelper.read_from_gcs_path(file_path),
message=output
)
# Currently we bear the assumption that each output contains only one artifact
json_str = json_format.MessageToJson(
output.artifacts[output_name].artifacts[0], sort_keys=True)
# Convert runtime_artifact to Python artifact
return artifact.Artifact.deserialize(json_str)
def import_func_from_source(source_path: str, fn_name: str) -> Callable:
"""Imports a function from a Python file.
The implementation is borrowed from
https://github.com/tensorflow/tfx/blob/8f25a4d1cc92dfc8c3a684dfc8b82699513cafb5/tfx/utils/import_utils.py#L50
Args:
source_path: The local path to the Python source file.
fn_name: The function name, which can be found in the source file.
Return: A Python function object.
Raises:
ImportError when failed to load the source file or cannot find the function
with the given name.
"""
if any([source_path.startswith(prefix) for prefix in _REMOTE_FS_PREFIX]):
raise RuntimeError('Only local source file can be imported. Please make '
'sure the user code is built into executor container. '
'Got file path: %s' % source_path)
try:
loader = importlib.machinery.SourceFileLoader(
fullname=_USER_MODULE,
path=source_path,
)
spec = importlib.util.spec_from_loader(
loader.name, loader, origin=source_path)
module = importlib.util.module_from_spec(spec)
sys.modules[loader.name] = module
loader.exec_module(module)
except IOError:
raise ImportError('{} in {} not found in import_func_from_source()'.format(
fn_name, source_path
))
try:
return getattr(module, fn_name)
except AttributeError:
raise ImportError('{} in {} not found in import_func_from_source()'.format(
fn_name, source_path
))
def get_output_artifacts(
fn: Callable, output_uris: Dict[str, str]) -> Dict[str, artifact.Artifact]:
"""Gets the output artifacts from function signature and provided URIs.
Args:
fn: A user-provided function, whose signature annotates the type of output
artifacts.
output_uris: The mapping from output artifact name to its URI.
Returns:
A mapping from output artifact name to Python artifact objects.
"""
# Inspect the function signature to determine the set of output artifact.
spec = _python_op._extract_component_interface(fn)
result = {} # Mapping from output name to artifacts.
for output in spec.outputs:
if (getattr(output, '_passing_style', None) == _python_op.OutputArtifact):
# Creates an artifact according to its name
type_name = getattr(output, 'type', None)
if not type_name:
continue
try:
artifact_cls = getattr(
importlib.import_module(artifact.KFP_ARTIFACT_ONTOLOGY_MODULE),
type_name)
except (AttributeError, ImportError, ValueError):
logging.warning((
'Could not load artifact class %s.%s; using fallback deserialization'
' for the relevant artifact. Please make sure that any artifact '
'classes can be imported within your container or environment.'),
artifact.KFP_ARTIFACT_ONTOLOGY_MODULE, type_name)
artifact_cls = artifact.Artifact
if artifact_cls == artifact.Artifact:
# Provide an empty schema if instantiating an bare-metal artifact.
art = artifact_cls(instance_schema=artifact.DEFAULT_ARTIFACT_SCHEMA)
else:
art = artifact_cls()
art.uri = output_uris[output.name]
result[output.name] = art
return result
def _get_pipeline_value(value: Union[int, float, str]) -> Optional[
pipeline_spec_pb2.Value]:
"""Converts Python primitive value to pipeline value pb."""
if value is None:
return None
result = pipeline_spec_pb2.Value()
if isinstance(value, int):
result.int_value = value
elif isinstance(value, float):
result.double_value = value
elif isinstance(value, str):
result.string_value = value
else:
raise TypeError('Got unknown type of value: {}'.format(value))
return result
def get_executor_output(
output_artifacts: Dict[str, artifact.Artifact],
output_params: Dict[str, Union[int, float, str]]
) -> pipeline_spec_pb2.ExecutorOutput:
"""Gets the output metadata message."""
result = pipeline_spec_pb2.ExecutorOutput()
for name, art in output_artifacts.items():
result.artifacts[name].CopyFrom(pipeline_spec_pb2.ArtifactList(
artifacts=[art.runtime_artifact]
))
for name, param in output_params.items():
result.parameters[name].CopyFrom(_get_pipeline_value(param))
return result

View File

@ -1,56 +0,0 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The entrypoint binary used in KFP component."""
import argparse
import sys
class ParseKwargs(argparse.Action):
"""Helper class to parse the keyword arguments.
This Python binary expects a set of kwargs, whose keys are not predefined.
"""
def __call__(
self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, dict())
assert len(values) % 2 == 0, 'Each specified arg key must have a value.'
current_key = None
for idx, value in enumerate(values):
if idx % 2 == 0:
# Parse this into a key.
current_key = value
else:
# Parse current value with the previous key.
getattr(namespace, self.dest)[current_key] = value
def main():
"""The main program of KFP container entrypoint.
This entrypoint should be called as follows:
python run_container.py -k key1 value1 key2 value2 ...
The recognized argument keys are as follows:
- {input-parameter-name}_metadata_file
- {input_parameter-name}_field_name
- {}
"""
parser = argparse.ArgumentParser()
parser.add_argument('-k', '--kwargs', nargs='*', action=ParseKwargs)
args = parser.parse_args(sys.argv[1:]) # Skip the file name.
print(args.kwargs)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,136 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for kfp.containers.entrypoint module."""
import mock
import unittest
from kfp.containers import entrypoint
from kfp.containers import entrypoint_utils
# Import testdata to mock entrypoint_utils.import_func_from_source function.
from kfp.containers_tests.testdata import main
_OUTPUT_METADATA_JSON_LOCATION = 'executor_output_metadata.json'
_PRODUCER_EXECUTOR_OUTPUT = """{
"parameters": {
"param_output": {
"stringValue": "hello from producer"
}
},
"artifacts": {
"artifact_output": {
"artifacts": [
{
"type": {
"instanceSchema": "properties:\\ntitle: kfp.Dataset\\ntype: object\\n"
},
"uri": "gs://root/producer/artifact_output"
}
]
}
}
}"""
_EXPECTED_EXECUTOR_OUTPUT_1 = """{
"parameters": {
"test_output2": {
"stringValue": "bye world"
}
},
"artifacts": {
"test_output1": {
"artifacts": [
{
"type": {
"instanceSchema": "properties:\\ntitle: kfp.Model\\ntype: object\\n"
},
"uri": "gs://root/consumer/output1"
}
]
}
}
}"""
class EntrypointTest(unittest.TestCase):
def setUp(self):
# Prepare mock
self._import_func = mock.patch.object(
entrypoint_utils,
'import_func_from_source').start()
self._mock_gcs_read = mock.patch(
'kfp.containers._gcs_helper.GCSHelper.read_from_gcs_path',
).start()
self._mock_gcs_write = mock.patch(
'kfp.containers._gcs_helper.GCSHelper.write_to_gcs_path',
).start()
self.addCleanup(mock.patch.stopall)
def testMainWithV1Producer(self):
"""Tests the entrypoint with data passing with conventional KFP components.
This test case emulates the following scenario:
- User provides a function, namely `test_func`.
- In test function, there are an input parameter (`test_param`) and an input
artifact (`test_artifact`). And the user code generates an output
artifact (`test_output1`) and an output parameter (`test_output2`).
- The specified metadata JSON file location is at
'executor_output_metadata.json'
- The inputs of this step are all provided by conventional KFP components.
"""
# Set mocked user function.
self._import_func.return_value = main.test_func
entrypoint.main(
executor_metadata_json_file=_OUTPUT_METADATA_JSON_LOCATION,
function_name='test_func',
test_param_input_argo_param='hello from producer',
test_artifact_input_path='gs://root/producer/output',
test_output1_artifact_output_path='gs://root/consumer/output1',
test_output2_parameter_output_path='gs://root/consumer/output2'
)
self._mock_gcs_write.assert_called_with(
path=_OUTPUT_METADATA_JSON_LOCATION,
content=_EXPECTED_EXECUTOR_OUTPUT_1)
def testMainWithV2Producer(self):
"""Tests the entrypoint with data passing with new-styled KFP components.
This test case emulates a similar scenario as testMainWithV1Producer, except
for that the inputs of this step are all provided by a new-styled KFP
component.
"""
# Set mocked user function.
self._import_func.return_value = main.test_func2
# Set GFile read function
self._mock_gcs_read.return_value = _PRODUCER_EXECUTOR_OUTPUT
entrypoint.main(
executor_metadata_json_file=_OUTPUT_METADATA_JSON_LOCATION,
function_name='test_func2',
test_param_input_param_metadata_file='gs://root/producer/executor_output_metadata.json',
test_param_input_field_name='param_output',
test_artifact_input_artifact_metadata_file='gs://root/producer/executor_output_metadata.json',
test_artifact_input_output_name='artifact_output',
test_output1_artifact_output_path='gs://root/consumer/output1',
test_output2_parameter_output_path='gs://root/consumer/output2'
)
self._mock_gcs_write.assert_called_with(
path=_OUTPUT_METADATA_JSON_LOCATION,
content=_EXPECTED_EXECUTOR_OUTPUT_1)

View File

@ -0,0 +1,128 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for kfp.containers.entrypoint_utils module."""
from google.protobuf import json_format
import mock
import os
import unittest
from kfp.dsl import artifact
from kfp import components
from kfp.containers import entrypoint_utils
from kfp.dsl import ontology_artifacts
from kfp.pipeline_spec import pipeline_spec_pb2
def _get_text_from_testdata(filename: str) -> str:
"""Reads the content of a file under testdata."""
with open(
os.path.join(os.path.dirname(__file__), 'testdata', filename), 'r') as f:
return f.read()
def _test_function(
a: components.InputArtifact('Dataset'),
b: components.OutputArtifact('Model'),
c: components.OutputArtifact('Artifact')
):
"""Function used to test signature parsing."""
pass
_OUTPUT_URIS = {
'b': 'gs://root/execution/b',
'c': 'gs://root/execution/c'
}
class EntrypointUtilsTest(unittest.TestCase):
@mock.patch('kfp.containers._gcs_helper.GCSHelper.read_from_gcs_path')
def testGetParameterFromOutput(self, mock_read):
mock_read.return_value = _get_text_from_testdata('executor_output.json')
self.assertEqual(entrypoint_utils.get_parameter_from_output(
file_path=os.path.join('testdata', 'executor_output.json'),
param_name='int_output'
), 42)
self.assertEqual(entrypoint_utils.get_parameter_from_output(
file_path=os.path.join('testdata', 'executor_output.json'),
param_name='string_output'
), 'hello world!')
self.assertEqual(entrypoint_utils.get_parameter_from_output(
file_path=os.path.join('testdata', 'executor_output.json'),
param_name='float_output'
), 12.12)
@mock.patch('kfp.containers._gcs_helper.GCSHelper.read_from_gcs_path')
def testGetArtifactFromOutput(self, mock_read):
mock_read.return_value = _get_text_from_testdata('executor_output.json')
art = entrypoint_utils.get_artifact_from_output(
file_path=os.path.join('testdata', 'executor_output.json'),
output_name='output'
)
self.assertIsInstance(art, ontology_artifacts.Model)
self.assertEqual(art.uri, 'gs://root/execution/output')
self.assertEqual(art.name, 'test-artifact')
self.assertEqual(art.get_string_custom_property('test_property'),
'test value')
def testGetOutputArtifacts(self):
outputs = entrypoint_utils.get_output_artifacts(
_test_function, _OUTPUT_URIS)
self.assertSetEqual(set(outputs.keys()), {'b', 'c'})
self.assertIsInstance(outputs['b'], ontology_artifacts.Model)
self.assertIsInstance(outputs['c'], artifact.Artifact)
self.assertEqual(outputs['b'].uri, 'gs://root/execution/b')
self.assertEqual(outputs['c'].uri, 'gs://root/execution/c')
def testGetExecutorOutput(self):
model = ontology_artifacts.Model()
model.name = 'test-artifact'
model.uri = 'gs://root/execution/output'
model.set_string_custom_property('test_property', 'test value')
executor_output = entrypoint_utils.get_executor_output(
output_artifacts={'output': model},
output_params={
'int_output': 42,
'string_output': 'hello world!',
'float_output': 12.12
})
# Renormalize the JSON proto read from testdata. Otherwise there'll be
# mismatch in the way treating int value.
expected_output = pipeline_spec_pb2.ExecutorOutput()
expected_output = json_format.Parse(
text=_get_text_from_testdata('executor_output.json'),
message=expected_output)
self.assertDictEqual(
json_format.MessageToDict(expected_output),
json_format.MessageToDict(executor_output))
def testImportFuncFromSource(self):
fn = entrypoint_utils.import_func_from_source(
source_path=os.path.join(
os.path.dirname(__file__), 'testdata', 'test_source.py'),
fn_name='test_func'
)
self.assertEqual(fn(1, 2), 3)
with self.assertRaisesRegexp(ImportError, '\D+ in \D+ not found in '):
_ = entrypoint_utils.import_func_from_source(
source_path=os.path.join('testdata', 'test_source.py'),
fn_name='non_existing_fn'
)

View File

@ -0,0 +1,13 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,31 @@
{
"artifacts": {
"output": {
"artifacts": [
{
"name": "test-artifact",
"uri": "gs://root/execution/output",
"type": {
"instanceSchema": "properties:\ntitle: kfp.Model\ntype: object\n"
},
"customProperties": {
"test_property": {
"stringValue": "test value"
}
}
}
]
}
},
"parameters": {
"int_output": {
"intValue": 42
},
"string_output": {
"stringValue": "hello world!"
},
"float_output": {
"doubleValue": 12.12
}
}
}

View File

@ -0,0 +1,53 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""User module under test"""
from typing import NamedTuple
from kfp import components
from kfp.dsl import artifact
from kfp.dsl import ontology_artifacts
def test_func(
test_param: str,
test_artifact: components.InputArtifact('Dataset'),
test_output1: components.OutputArtifact('Model')
) -> NamedTuple('Outputs', [('test_output2', str)]):
assert test_param == 'hello from producer'
# In the associated test case, input artifact is produced by conventional
# KFP components, thus no concrete artifact type can be determined.
assert isinstance(test_artifact, artifact.Artifact)
assert isinstance(test_output1, ontology_artifacts.Model)
assert test_output1.uri
from collections import namedtuple
Outputs = namedtuple('Outputs', 'test_output2')
return Outputs('bye world')
def test_func2(
test_param: str,
test_artifact: components.InputArtifact('Dataset'),
test_output1: components.OutputArtifact('Model')
) -> NamedTuple('Outputs', [('test_output2', str)]):
assert test_param == 'hello from producer'
# In the associated test case, input artifact is produced by a new-styled
# KFP components with metadata, thus it's expected to be deserialized to
# Dataset object.
assert isinstance(test_artifact, ontology_artifacts.Dataset)
assert isinstance(test_output1, ontology_artifacts.Model)
assert test_output1.uri
from collections import namedtuple
Outputs = namedtuple('Outputs', 'test_output2')
return Outputs('bye world')

View File

@ -0,0 +1,17 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Python source file under test."""
def test_func(a, b):
return a + b

View File

@ -24,7 +24,8 @@ from kfp.pipeline_spec import pipeline_spec_pb2
from kfp.dsl import serialization_utils
_KFP_ARTIFACT_TITLE_PATTERN = 'kfp.{}'
_KFP_ARTIFACT_ONTOLOGY_MODULE = 'kfp.dsl.ontology_artifacts'
KFP_ARTIFACT_ONTOLOGY_MODULE = 'kfp.dsl.ontology_artifacts'
DEFAULT_ARTIFACT_SCHEMA = 'title: kfp.Artifact\ntype: object\nproperties:\n'
# Enum for property types.
@ -121,8 +122,7 @@ class Artifact(object):
if self.__class__ == Artifact:
if not instance_schema:
raise ValueError(
'The "instance_schema" argument must be passed to specify a '
'type for this Artifact.')
'The "instance_schema" argument must be set.')
schema_yaml = yaml.safe_load(instance_schema)
if 'properties' not in schema_yaml:
raise ValueError('Invalid instance_schema, properties must be present. '
@ -138,7 +138,7 @@ class Artifact(object):
else:
if instance_schema:
raise ValueError(
'The "mlmd_artifact_type" argument must not be passed for '
'The "instance_schema" argument must not be passed for '
'Artifact subclass %s.' % self.__class__)
instance_schema = self.get_artifact_type()
@ -326,7 +326,7 @@ class Artifact(object):
result = None
try:
artifact_cls = getattr(
importlib.import_module(_KFP_ARTIFACT_ONTOLOGY_MODULE), type_name)
importlib.import_module(KFP_ARTIFACT_ONTOLOGY_MODULE), type_name)
# TODO(numerology): Add deserialization tests for first party classes.
result = artifact_cls()
except (AttributeError, ImportError, ValueError):
@ -334,7 +334,7 @@ class Artifact(object):
'Could not load artifact class %s.%s; using fallback deserialization '
'for the relevant artifact. Please make sure that any artifact '
'classes can be imported within your container or environment.'),
_KFP_ARTIFACT_ONTOLOGY_MODULE, type_name)
KFP_ARTIFACT_ONTOLOGY_MODULE, type_name)
if not result:
# Otherwise generate a generic Artifact object.
result = Artifact(instance_schema=artifact.type.instance_schema)

View File

@ -39,6 +39,7 @@ REQUIRES = [
'strip-hints',
'docstring-parser>=0.7.3',
'kfp-pipeline-spec>=0.1.0, <0.2.0',
'fire>=0.3.1'
]
TESTS_REQUIRE = [