107 lines
3.8 KiB
Python
107 lines
3.8 KiB
Python
# Copyright 2020 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""KFP v2 DSL compiler utility functions."""
|
|
|
|
import re
|
|
from typing import Any, Mapping, Optional, Union
|
|
|
|
from kfp.containers import _component_builder
|
|
from kfp.dsl import _container_op
|
|
from kfp.pipeline_spec import pipeline_spec_pb2
|
|
|
|
# Alias for PipelineContainerSpec
|
|
PipelineContainerSpec = pipeline_spec_pb2.PipelineDeploymentConfig.PipelineContainerSpec
|
|
|
|
|
|
def build_runtime_config_spec(
|
|
output_directory: str,
|
|
pipeline_parameters: Optional[Mapping[str, Any]] = None,
|
|
) -> pipeline_spec_pb2.PipelineJob.RuntimeConfig:
|
|
"""Converts pipeine parameters to runtime parameters mapping.
|
|
|
|
Args:
|
|
output_directory: The root of pipeline outputs.
|
|
pipeline_parameters: The mapping from parameter names to values. Optional.
|
|
|
|
Returns:
|
|
A pipeline job RuntimeConfig object.
|
|
"""
|
|
|
|
def _get_value(value: Union[int, float, str]) -> pipeline_spec_pb2.Value:
|
|
assert value is not None, 'None values should be filterd out.'
|
|
|
|
result = pipeline_spec_pb2.Value()
|
|
if isinstance(value, int):
|
|
result.int_value = value
|
|
elif isinstance(value, float):
|
|
result.double_value = value
|
|
elif isinstance(value, str):
|
|
result.string_value = value
|
|
else:
|
|
raise TypeError('Got unknown type of value: {}'.format(value))
|
|
|
|
return result
|
|
|
|
parameter_values = pipeline_parameters or {}
|
|
return pipeline_spec_pb2.PipelineJob.RuntimeConfig(
|
|
gcs_output_directory=output_directory,
|
|
parameters={
|
|
k: _get_value(v) for k, v in parameter_values.items() if v is not None
|
|
})
|
|
|
|
|
|
def validate_pipeline_name(name: str) -> None:
|
|
"""Validate pipeline name.
|
|
|
|
A valid pipeline name should match ^[a-z0-9][a-z0-9-]{0,127}$.
|
|
|
|
Args:
|
|
name: The pipeline name.
|
|
|
|
Raises:
|
|
ValueError if the pipeline name doesn't conform to the regular expression.
|
|
"""
|
|
pattern = re.compile(r'^[a-z0-9][a-z0-9-]{0,127}$')
|
|
if not pattern.match(name):
|
|
raise ValueError('Invalid pipeline name: %s.\n'
|
|
'Please specify a pipeline name that matches the regular '
|
|
'expression "^[a-z0-9][a-z0-9-]{0,127}$" using '
|
|
'`dsl.pipeline(name=...)` decorator.' % name)
|
|
|
|
def is_v2_component(op: _container_op.ContainerOp) -> bool:
|
|
"""Determines whether a component is a KFP v2 component."""
|
|
if not op._metadata or not op._metadata.metadata:
|
|
return False
|
|
if not (op._metadata.metadata.annotations
|
|
or _component_builder.V2_COMPONENT_ANNOTATION
|
|
not in op._metadata.metadata.annotations):
|
|
return False
|
|
return bool(
|
|
op._metadata.metadata.annotations[
|
|
_component_builder.V2_COMPONENT_ANNOTATION])
|
|
|
|
|
|
def refactor_v2_container_spec(
|
|
container_spec: PipelineContainerSpec) -> None:
|
|
"""Refactor the container spec for a v2 component."""
|
|
if not '--function_name' in container_spec.args:
|
|
raise RuntimeError('V2 component is expected to have function_name as a '
|
|
'command line arg.')
|
|
fn_name_idx = list(container_spec.args).index('--function_name') + 1
|
|
fn_name = container_spec.args[fn_name_idx]
|
|
container_spec.ClearField('command')
|
|
container_spec.ClearField('args')
|
|
container_spec.command.extend(['python', '-m', 'kfp.container.entrypoint'])
|
|
container_spec.args.extend(
|
|
['--executor_input_str','{{$}}', '--function_name', fn_name]) |