332 lines
11 KiB
Python
332 lines
11 KiB
Python
# Copyright 2023 The Kubeflow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Private utilities for component authoring."""
|
|
|
|
import copy
|
|
import json
|
|
import re
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
from google_cloud_pipeline_components import _image
|
|
from kfp import components
|
|
from kfp import dsl
|
|
# do not follow this pattern!
|
|
# we should not depend on non-public modules of the KFP SDK!
|
|
from kfp.components import placeholders
|
|
|
|
from google.protobuf import json_format
|
|
|
|
# note: this is a slight dependency on KFP SDK implementation details
|
|
# other code should not similarly depend on the stability of kfp.placeholders
|
|
DOCS_INTEGRATED_OUTPUT_RENAMING_PREFIX = "output__"
|
|
|
|
|
|
def build_serverless_customjob_container_spec(
|
|
*,
|
|
project: str,
|
|
location: str,
|
|
custom_job_payload: Dict[str, Any],
|
|
gcp_resources: dsl.OutputPath(str), # pytype: disable=invalid-annotation
|
|
) -> dsl.ContainerSpec:
|
|
"""Builds a container spec that launches a custom job.
|
|
|
|
Args:
|
|
project: Project to run the job in.
|
|
location: Location to run the job in.
|
|
custom_job_payload: Payload to pass to the custom job. This dictionary is
|
|
serialized and passed as the custom job `--payload`.
|
|
gcp_resources: GCP resources that can be used to track the job.
|
|
|
|
Returns:
|
|
Container spec that launches a custom job with the specified payload.
|
|
"""
|
|
return dsl.ContainerSpec(
|
|
image=_image.GCPC_IMAGE_TAG,
|
|
command=[
|
|
"python3",
|
|
"-u",
|
|
"-m",
|
|
"google_cloud_pipeline_components.container.v1.custom_job.launcher",
|
|
],
|
|
args=[
|
|
"--type",
|
|
"CustomJob",
|
|
"--payload",
|
|
container_component_dumps(custom_job_payload),
|
|
"--project",
|
|
project,
|
|
"--location",
|
|
location,
|
|
"--gcp_resources",
|
|
gcp_resources,
|
|
],
|
|
)
|
|
|
|
|
|
def container_component_dumps(obj: Any) -> Any:
|
|
"""Dump object to JSON string with KFP SDK placeholders included and, if the placeholder does not correspond to a runtime string, quotes escaped.
|
|
|
|
Limitations:
|
|
- Cannot handle placeholders as dictionary keys
|
|
|
|
Example usage:
|
|
|
|
@dsl.container_component
|
|
def comp(val: str, other_val: int):
|
|
return dsl.ContainerSpec(
|
|
image='alpine',
|
|
command=['echo'],
|
|
args=[utils.container_component_dumps({'key': val, 'other_key':
|
|
other_val})],
|
|
)
|
|
|
|
Args:
|
|
obj: JSON serializable object, which may container KFP SDK placeholder
|
|
objects.
|
|
|
|
Returns:
|
|
JSON string, possibly with placeholder strings.
|
|
"""
|
|
|
|
def collect_string_fields(obj: Any) -> List[str]:
|
|
non_str_fields = []
|
|
|
|
def inner_func(obj):
|
|
if isinstance(obj, list):
|
|
for e in obj:
|
|
inner_func(e)
|
|
elif isinstance(obj, dict):
|
|
for _, v in obj.items():
|
|
# InputValuePlaceholder keys will be caught at dict construction as `TypeError: unhashable type: InputValuePlaceholder`
|
|
inner_func(v)
|
|
elif (
|
|
isinstance(obj, placeholders.InputValuePlaceholder)
|
|
and obj._ir_type != "STRING" # pylint: disable=protected-access
|
|
):
|
|
non_str_fields.append(obj.input_name)
|
|
|
|
inner_func(obj)
|
|
|
|
return non_str_fields
|
|
|
|
def custom_placeholder_encoder(obj: Any) -> str:
|
|
if isinstance(obj, placeholders.Placeholder):
|
|
return str(obj)
|
|
raise TypeError(
|
|
f"Object of type {obj.__class__.__name__!r} is not JSON serializable."
|
|
)
|
|
|
|
def unquote_nonstring_placeholders(
|
|
json_string: str, strip_quotes_fields: List[str]
|
|
) -> str:
|
|
for key in strip_quotes_fields:
|
|
pattern = rf"\"\{{\{{\$\.inputs\.parameters\[(?:'|\"|\"){key}(?:'|\"|\")]\}}\}}\""
|
|
repl = f"{{{{$.inputs.parameters['{key}']}}}}"
|
|
json_string = re.sub(pattern, repl, json_string)
|
|
return json_string
|
|
|
|
string_fields = collect_string_fields(obj)
|
|
json_string = json.dumps(obj, default=custom_placeholder_encoder)
|
|
return unquote_nonstring_placeholders(json_string, string_fields)
|
|
|
|
|
|
def gcpc_output_name_converter(
|
|
new_name: str,
|
|
original_name: Optional[str] = None,
|
|
) -> Callable[["BaseComponent"], "BaseComponent"]: # pytype: disable=name-error
|
|
"""Replace the output with original_name with a new_name in a component decorated with an @dsl.container_component decorator.
|
|
|
|
Enables authoring components that have an input and output with the same
|
|
key/name.
|
|
|
|
Args:
|
|
new_name: The new name for the output.
|
|
original_name: The original name of the output.
|
|
|
|
Returns:
|
|
A decorator that takes modifies a component in place.
|
|
|
|
Example usage:
|
|
|
|
@utils.gcpc_output_name_converter('output__gcp_resources', 'gcp_resources')
|
|
@dsl.container_component
|
|
def my_component(
|
|
param: str,
|
|
output__param: dsl.OutputPath(str),
|
|
):
|
|
'''Has an input `param` and creates an output `param`'''
|
|
return dsl.ContainerSpec(
|
|
image='alpine',
|
|
command=['echo'],
|
|
args=[output__param],
|
|
)
|
|
"""
|
|
original_name = (
|
|
original_name
|
|
if original_name is not None
|
|
else DOCS_INTEGRATED_OUTPUT_RENAMING_PREFIX + new_name
|
|
)
|
|
|
|
def converter(comp):
|
|
def get_modified_pipeline_spec(
|
|
pipeline_spec,
|
|
original_name: str,
|
|
new_name: str,
|
|
):
|
|
root_component_spec = pipeline_spec.root
|
|
num_components = len(pipeline_spec.components)
|
|
component_spec_key, _ = dict(pipeline_spec.components).popitem()
|
|
inner_component_spec = pipeline_spec.components[component_spec_key]
|
|
is_primitive_component = (
|
|
num_components == 1
|
|
and "comp-" + pipeline_spec.pipeline_info.name == component_spec_key
|
|
and root_component_spec.input_definitions
|
|
== inner_component_spec.input_definitions
|
|
and root_component_spec.output_definitions
|
|
== inner_component_spec.output_definitions
|
|
)
|
|
if not is_primitive_component:
|
|
raise ValueError(
|
|
f"The {gcpc_output_name_converter.__name__!r} decorator can only be"
|
|
" used on primitive container components. You are trying to use it"
|
|
" on a pipeline."
|
|
)
|
|
|
|
executor_key, _ = dict(
|
|
pipeline_spec.deployment_spec["executors"]
|
|
).popitem()
|
|
container_spec = pipeline_spec.deployment_spec["executors"][executor_key][
|
|
"container"
|
|
]
|
|
command = container_spec.get_or_create_list("command")
|
|
args = container_spec.get_or_create_list("args")
|
|
if "--executor_input" in args and "--function_to_execute" in args:
|
|
raise ValueError(
|
|
f"The {gcpc_output_name_converter.__name__!r} decorator can only be"
|
|
" used on primitive container components. You are trying to use it"
|
|
" on a Python component."
|
|
)
|
|
|
|
def replace_output_name_in_componentspec_interface(
|
|
component_spec,
|
|
original_name: str,
|
|
new_name: str,
|
|
):
|
|
# copy so that iterable doesn't change size on iteration
|
|
for output_name in copy.copy(
|
|
list(component_spec.output_definitions.parameters.keys())
|
|
):
|
|
if output_name == original_name:
|
|
component_spec.output_definitions.parameters[new_name].CopyFrom(
|
|
component_spec.output_definitions.parameters.pop(original_name)
|
|
)
|
|
|
|
# copy so that iterable doesn't change size on iteration
|
|
for output_name in copy.copy(
|
|
list(component_spec.output_definitions.artifacts.keys())
|
|
):
|
|
if output_name == original_name:
|
|
component_spec.output_definitions.artifacts[new_name].CopyFrom(
|
|
component_spec.output_definitions.artifacts.pop(original_name)
|
|
)
|
|
|
|
def replace_output_name_in_dag_outputs(
|
|
component_spec,
|
|
original_name: str,
|
|
new_name: str,
|
|
):
|
|
# copy so that iterable doesn't change size on iteration
|
|
for parameter_name in copy.copy(
|
|
list(component_spec.dag.outputs.parameters.keys())
|
|
):
|
|
if parameter_name == original_name:
|
|
modified_dag_output_parameter_spec = (
|
|
component_spec.dag.outputs.parameters.pop(original_name)
|
|
)
|
|
modified_dag_output_parameter_spec.value_from_parameter.output_parameter_key = (
|
|
new_name
|
|
)
|
|
component_spec.dag.outputs.parameters[new_name].CopyFrom(
|
|
modified_dag_output_parameter_spec
|
|
)
|
|
|
|
def replace_output_name_in_executor(
|
|
command: List[str],
|
|
args: List[str],
|
|
original_name: str,
|
|
new_name: str,
|
|
):
|
|
def placeholder_replacer(string: str) -> str:
|
|
param_pattern = rf"\{{\{{\$\.outputs\.parameters\[(?:''|'|\")({original_name})(?:''|'|\")]"
|
|
param_replacement = f"{{{{$.outputs.parameters['{new_name}']"
|
|
|
|
artifact_pattern = rf"\{{\{{\$\.outputs\.artifacts\[(?:''|'|\")({original_name})(?:''|'|\")]"
|
|
artifact_replacement = f"{{{{$.outputs.artifacts['{new_name}']"
|
|
|
|
string = re.sub(
|
|
param_pattern,
|
|
param_replacement,
|
|
string,
|
|
)
|
|
return re.sub(
|
|
artifact_pattern,
|
|
artifact_replacement,
|
|
string,
|
|
)
|
|
|
|
for i, s in enumerate(command):
|
|
command[i] = placeholder_replacer(s)
|
|
|
|
for i, s in enumerate(args):
|
|
args[i] = placeholder_replacer(s)
|
|
|
|
replace_output_name_in_componentspec_interface(
|
|
root_component_spec,
|
|
original_name,
|
|
new_name,
|
|
)
|
|
replace_output_name_in_componentspec_interface(
|
|
inner_component_spec,
|
|
original_name,
|
|
new_name,
|
|
)
|
|
replace_output_name_in_dag_outputs(
|
|
root_component_spec,
|
|
original_name,
|
|
new_name,
|
|
)
|
|
replace_output_name_in_executor(
|
|
command,
|
|
args,
|
|
original_name,
|
|
new_name,
|
|
)
|
|
|
|
return pipeline_spec
|
|
|
|
reloaded_component = components.load_component_from_text(
|
|
json_format.MessageToJson(
|
|
get_modified_pipeline_spec(
|
|
comp.pipeline_spec,
|
|
original_name,
|
|
new_name,
|
|
)
|
|
)
|
|
)
|
|
reloaded_component.__doc__ = comp.pipeline_func.__doc__
|
|
reloaded_component.__annotations__ = comp.pipeline_func.__annotations__
|
|
return reloaded_component
|
|
|
|
return converter
|