Feature: sidecar for ContainerOp (#879)
* Feature: sidecar for ContainerOp * replace f-string with string format for compatibility with py3.5 * ContainerOp now can be updated with any k8s V1Container attributes as well as sidecars with Sidecar class. ContainerOp accepts PipelineParam in any valid k8 properties. * WIP: fix conflicts and bugs with recent master. TODO: more complex template with pipeline params * fix proxy args * Fixed to work with latest master head * Added container_kwargs to ContainerOp to pass in k8s container kwargs * Fix comment bug, updated with example in ContainerOp docstring * fix copyright year * expose match_serialized_pipelineparam as public for compiler to process serialized pipeline params * fixed pydoc example and removed unnecessary ContainerOp.container.parent * Fix conflicts in compiler tests
This commit is contained in:
parent
4b56e7b425
commit
825f64d672
|
|
@ -19,6 +19,9 @@ import time
|
|||
import logging
|
||||
import re
|
||||
|
||||
from .. import dsl
|
||||
|
||||
|
||||
class K8sHelper(object):
|
||||
""" Kubernetes Helper """
|
||||
|
||||
|
|
@ -159,7 +162,11 @@ class K8sHelper(object):
|
|||
for sub_obj in obj)
|
||||
elif isinstance(k8s_obj, (datetime, date)):
|
||||
return k8s_obj.isoformat()
|
||||
|
||||
elif isinstance(k8s_obj, dsl.PipelineParam):
|
||||
if isinstance(k8s_obj.value, str):
|
||||
return k8s_obj.value
|
||||
return '{{inputs.parameters.%s}}' % k8s_obj.full_name
|
||||
|
||||
if isinstance(k8s_obj, dict):
|
||||
obj_dict = k8s_obj
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,218 @@
|
|||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import Union, List, Any, Callable, TypeVar, Dict
|
||||
|
||||
from ._k8s_helper import K8sHelper
|
||||
from .. import dsl
|
||||
|
||||
# generics
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def _process_obj(obj: Any, map_to_tmpl_var: dict):
|
||||
"""Recursively sanitize and replace any PipelineParam (instances and serialized strings)
|
||||
in the object with the corresponding template variables
|
||||
(i.e. '{{inputs.parameters.<PipelineParam.full_name>}}').
|
||||
|
||||
Args:
|
||||
obj: any obj that may have PipelineParam
|
||||
map_to_tmpl_var: a dict that maps an unsanitized pipeline
|
||||
params signature into a template var
|
||||
"""
|
||||
# serialized str might be unsanitized
|
||||
if isinstance(obj, str):
|
||||
# get signature
|
||||
param_tuples = dsl.match_serialized_pipelineparam(obj)
|
||||
if not param_tuples:
|
||||
return obj
|
||||
# replace all unsanitized signature with template var
|
||||
for param_tuple in param_tuples:
|
||||
obj = re.sub(param_tuple.pattern, map_to_tmpl_var[param_tuple.pattern], obj)
|
||||
|
||||
# list
|
||||
if isinstance(obj, list):
|
||||
return [_process_obj(item, map_to_tmpl_var) for item in obj]
|
||||
|
||||
# tuple
|
||||
if isinstance(obj, tuple):
|
||||
return tuple((_process_obj(item, map_to_tmpl_var) for item in obj))
|
||||
|
||||
# dict
|
||||
if isinstance(obj, dict):
|
||||
return {
|
||||
key: _process_obj(value, map_to_tmpl_var)
|
||||
for key, value in obj.items()
|
||||
}
|
||||
|
||||
# pipelineparam
|
||||
if isinstance(obj, dsl.PipelineParam):
|
||||
# if not found in unsanitized map, then likely to be sanitized
|
||||
return map_to_tmpl_var.get(
|
||||
str(obj), '{{inputs.parameters.%s}}' % obj.full_name)
|
||||
|
||||
# k8s_obj
|
||||
if hasattr(obj, 'swagger_types') and isinstance(obj.swagger_types, dict):
|
||||
# process everything inside recursively
|
||||
for key in obj.swagger_types.keys():
|
||||
setattr(obj, key, _process_obj(getattr(obj, key), map_to_tmpl_var))
|
||||
# return json representation of the k8s obj
|
||||
return K8sHelper.convert_k8s_obj_to_json(obj)
|
||||
|
||||
# do nothing
|
||||
return obj
|
||||
|
||||
|
||||
def _process_container_ops(op: dsl.ContainerOp):
|
||||
"""Recursively go through the attrs listed in `attrs_with_pipelineparams`
|
||||
and sanitize and replace pipeline params with template var string.
|
||||
|
||||
Returns a processed `ContainerOp`.
|
||||
|
||||
NOTE this is an in-place update to `ContainerOp`'s attributes (i.e. other than
|
||||
`file_outputs`, and `outputs`, all `PipelineParam` are replaced with the
|
||||
corresponding template variable strings).
|
||||
|
||||
Args:
|
||||
op {dsl.ContainerOp}: class that inherits from ds.ContainerOp
|
||||
|
||||
Returns:
|
||||
dsl.ContainerOp
|
||||
"""
|
||||
|
||||
# map param's (unsanitized pattern or serialized str pattern) -> input param var str
|
||||
map_to_tmpl_var = {
|
||||
(param.pattern or str(param)): '{{inputs.parameters.%s}}' % param.full_name
|
||||
for param in op.inputs
|
||||
}
|
||||
|
||||
# process all attr with pipelineParams except inputs and outputs parameters
|
||||
for key in op.attrs_with_pipelineparams:
|
||||
setattr(op, key, _process_obj(getattr(op, key), map_to_tmpl_var))
|
||||
|
||||
return op
|
||||
|
||||
|
||||
def _parameters_to_json(params: List[dsl.PipelineParam]):
|
||||
"""Converts a list of PipelineParam into an argo `parameter` JSON obj."""
|
||||
_to_json = (lambda param: dict(name=param.full_name, value=param.value)
|
||||
if param.value else dict(name=param.full_name))
|
||||
params = [_to_json(param) for param in params]
|
||||
# Sort to make the results deterministic.
|
||||
params.sort(key=lambda x: x['name'])
|
||||
return params
|
||||
|
||||
|
||||
# TODO: artifacts?
|
||||
def _inputs_to_json(inputs_params: List[dsl.PipelineParam], _artifacts=None):
|
||||
"""Converts a list of PipelineParam into an argo `inputs` JSON obj."""
|
||||
parameters = _parameters_to_json(inputs_params)
|
||||
return {'parameters': parameters} if parameters else None
|
||||
|
||||
|
||||
def _outputs_to_json(outputs: Dict[str, dsl.PipelineParam],
|
||||
file_outputs: Dict[str, str],
|
||||
output_artifacts: List[dict]):
|
||||
"""Creates an argo `outputs` JSON obj."""
|
||||
output_parameters = []
|
||||
for param in outputs.values():
|
||||
output_parameters.append({
|
||||
'name': param.full_name,
|
||||
'valueFrom': {
|
||||
'path': file_outputs[param.name]
|
||||
}
|
||||
})
|
||||
output_parameters.sort(key=lambda x: x['name'])
|
||||
ret = {}
|
||||
if output_parameters:
|
||||
ret['parameters'] = output_parameters
|
||||
if output_artifacts:
|
||||
ret['artifacts'] = output_artifacts
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def _build_conventional_artifact(name):
|
||||
return {
|
||||
'name': name,
|
||||
'path': '/' + name + '.json',
|
||||
's3': {
|
||||
# TODO: parameterize namespace for minio service
|
||||
'endpoint': 'minio-service.kubeflow:9000',
|
||||
'bucket': 'mlpipeline',
|
||||
'key': 'runs/{{workflow.uid}}/{{pod.name}}/' + name + '.tgz',
|
||||
'insecure': True,
|
||||
'accessKeySecret': {
|
||||
'name': 'mlpipeline-minio-artifact',
|
||||
'key': 'accesskey',
|
||||
},
|
||||
'secretKeySecret': {
|
||||
'name': 'mlpipeline-minio-artifact',
|
||||
'key': 'secretkey'
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# TODO: generate argo python classes from swagger and use convert_k8s_obj_to_json??
|
||||
def _op_to_template(op: dsl.ContainerOp):
|
||||
"""Generate template given an operator inherited from dsl.ContainerOp."""
|
||||
|
||||
# NOTE in-place update to ContainerOp
|
||||
# replace all PipelineParams (except in `file_outputs`, `outputs`, `inputs`)
|
||||
# with template var strings
|
||||
processed_op = _process_container_ops(op)
|
||||
|
||||
# default output artifacts
|
||||
output_artifacts = [
|
||||
_build_conventional_artifact(name)
|
||||
for name in ['mlpipeline-ui-metadata', 'mlpipeline-metrics']
|
||||
]
|
||||
|
||||
# workflow template
|
||||
template = {
|
||||
'name': op.name,
|
||||
'container': K8sHelper.convert_k8s_obj_to_json(op.container)
|
||||
}
|
||||
|
||||
# inputs
|
||||
inputs = _inputs_to_json(processed_op.inputs)
|
||||
if inputs:
|
||||
template['inputs'] = inputs
|
||||
|
||||
# outputs
|
||||
template['outputs'] = _outputs_to_json(op.outputs, op.file_outputs,
|
||||
output_artifacts)
|
||||
|
||||
# node selector
|
||||
if processed_op.node_selector:
|
||||
template['nodeSelector'] = processed_op.node_selector
|
||||
|
||||
# metadata
|
||||
if processed_op.pod_annotations or processed_op.pod_labels:
|
||||
template['metadata'] = {}
|
||||
if processed_op.pod_annotations:
|
||||
template['metadata']['annotations'] = processed_op.pod_annotations
|
||||
if processed_op.pod_labels:
|
||||
template['metadata']['labels'] = processed_op.pod_labels
|
||||
# retries
|
||||
if processed_op.num_retries:
|
||||
template['retryStrategy'] = {'limit': processed_op.num_retries}
|
||||
|
||||
# sidecars
|
||||
if processed_op.sidecars:
|
||||
template['sidecars'] = processed_op.sidecars
|
||||
|
||||
return template
|
||||
|
|
@ -22,7 +22,8 @@ import yaml
|
|||
|
||||
from .. import dsl
|
||||
from ._k8s_helper import K8sHelper
|
||||
from ..dsl._pipeline_param import _match_serialized_pipelineparam
|
||||
from ._op_to_template import _op_to_template
|
||||
|
||||
from ..dsl._metadata import TypeMeta
|
||||
from ..dsl._ops_group import OpsGroup
|
||||
|
||||
|
|
@ -129,7 +130,6 @@ class Compiler(object):
|
|||
# it as input for its parent groups.
|
||||
if param.value:
|
||||
continue
|
||||
|
||||
full_name = self._pipelineparam_full_name(param)
|
||||
if param.op_name:
|
||||
upstream_op = pipeline.ops[param.op_name]
|
||||
|
|
@ -285,125 +285,8 @@ class Compiler(object):
|
|||
else:
|
||||
return str(value_or_reference)
|
||||
|
||||
def _process_args(self, raw_args, argument_inputs):
|
||||
if not raw_args:
|
||||
return []
|
||||
processed_args = list(map(str, raw_args))
|
||||
for i, _ in enumerate(processed_args):
|
||||
# unsanitized_argument_inputs stores a dict: string of sanitized param -> string of unsanitized param
|
||||
param_tuples = []
|
||||
param_tuples += _match_serialized_pipelineparam(str(processed_args[i]))
|
||||
unsanitized_argument_inputs = {}
|
||||
for param_tuple in list(set(param_tuples)):
|
||||
sanitized_str = str(dsl.PipelineParam(K8sHelper.sanitize_k8s_name(param_tuple.name), K8sHelper.sanitize_k8s_name(param_tuple.op), param_tuple.value, TypeMeta.deserialize(param_tuple.type)))
|
||||
unsanitized_argument_inputs[sanitized_str] = str(dsl.PipelineParam(param_tuple.name, param_tuple.op, param_tuple.value, TypeMeta.deserialize(param_tuple.type)))
|
||||
if argument_inputs:
|
||||
for param in argument_inputs:
|
||||
if str(param) in unsanitized_argument_inputs:
|
||||
full_name = self._pipelineparam_full_name(param)
|
||||
processed_args[i] = re.sub(unsanitized_argument_inputs[str(param)], '{{inputs.parameters.%s}}' % full_name,
|
||||
processed_args[i])
|
||||
return processed_args
|
||||
|
||||
def _op_to_template(self, op):
|
||||
"""Generate template given an operator inherited from dsl.ContainerOp."""
|
||||
|
||||
def _build_conventional_artifact(name, path):
|
||||
return {
|
||||
'name': name,
|
||||
'path': path,
|
||||
's3': {
|
||||
# TODO: parameterize namespace for minio service
|
||||
'endpoint': 'minio-service.kubeflow:9000',
|
||||
'bucket': 'mlpipeline',
|
||||
'key': 'runs/{{workflow.uid}}/{{pod.name}}/' + name + '.tgz',
|
||||
'insecure': True,
|
||||
'accessKeySecret': {
|
||||
'name': 'mlpipeline-minio-artifact',
|
||||
'key': 'accesskey',
|
||||
},
|
||||
'secretKeySecret': {
|
||||
'name': 'mlpipeline-minio-artifact',
|
||||
'key': 'secretkey'
|
||||
}
|
||||
},
|
||||
}
|
||||
processed_arguments = self._process_args(op.arguments, op.argument_inputs)
|
||||
processed_command = self._process_args(op.command, op.argument_inputs)
|
||||
|
||||
input_parameters = []
|
||||
for param in op.inputs:
|
||||
one_parameter = {'name': self._pipelineparam_full_name(param)}
|
||||
if param.value:
|
||||
one_parameter['value'] = str(param.value)
|
||||
input_parameters.append(one_parameter)
|
||||
# Sort to make the results deterministic.
|
||||
input_parameters.sort(key=lambda x: x['name'])
|
||||
|
||||
output_parameters = []
|
||||
for param in op.outputs.values():
|
||||
output_parameters.append({
|
||||
'name': self._pipelineparam_full_name(param),
|
||||
'valueFrom': {'path': op.file_outputs[param.name]}
|
||||
})
|
||||
output_parameters.sort(key=lambda x: x['name'])
|
||||
|
||||
template = {
|
||||
'name': op.name,
|
||||
'container': {
|
||||
'image': op.image,
|
||||
}
|
||||
}
|
||||
if processed_arguments:
|
||||
template['container']['args'] = processed_arguments
|
||||
if processed_command:
|
||||
template['container']['command'] = processed_command
|
||||
if input_parameters:
|
||||
template['inputs'] = {'parameters': input_parameters}
|
||||
|
||||
template['outputs'] = {}
|
||||
if output_parameters:
|
||||
template['outputs'] = {'parameters': output_parameters}
|
||||
|
||||
# Generate artifact for metadata output
|
||||
# The motivation of appending the minio info in the yaml
|
||||
# is to specify a unique path for the metadata.
|
||||
# TODO: after argo addresses the issue that configures a unique path
|
||||
# for the artifact output when default artifact repository is configured,
|
||||
# this part needs to be updated to use the default artifact repository.
|
||||
output_artifacts = []
|
||||
output_artifacts.append(_build_conventional_artifact('mlpipeline-ui-metadata', '/mlpipeline-ui-metadata.json'))
|
||||
output_artifacts.append(_build_conventional_artifact('mlpipeline-metrics', '/mlpipeline-metrics.json'))
|
||||
template['outputs']['artifacts'] = output_artifacts
|
||||
|
||||
# Set resources.
|
||||
if op.resource_limits or op.resource_requests:
|
||||
template['container']['resources'] = {}
|
||||
if op.resource_limits:
|
||||
template['container']['resources']['limits'] = op.resource_limits
|
||||
if op.resource_requests:
|
||||
template['container']['resources']['requests'] = op.resource_requests
|
||||
|
||||
# Set nodeSelector.
|
||||
if op.node_selector:
|
||||
template['nodeSelector'] = op.node_selector
|
||||
|
||||
if op.env_variables:
|
||||
template['container']['env'] = list(map(K8sHelper.convert_k8s_obj_to_json, op.env_variables))
|
||||
if op.volume_mounts:
|
||||
template['container']['volumeMounts'] = list(map(K8sHelper.convert_k8s_obj_to_json, op.volume_mounts))
|
||||
|
||||
if op.pod_annotations or op.pod_labels:
|
||||
template['metadata'] = {}
|
||||
if op.pod_annotations:
|
||||
template['metadata']['annotations'] = op.pod_annotations
|
||||
if op.pod_labels:
|
||||
template['metadata']['labels'] = op.pod_labels
|
||||
|
||||
if op.num_retries:
|
||||
template['retryStrategy'] = {'limit': op.num_retries}
|
||||
|
||||
return template
|
||||
return _op_to_template(op)
|
||||
|
||||
def _group_to_template(self, group, inputs, outputs, dependencies):
|
||||
"""Generate template given an OpsGroup.
|
||||
|
|
@ -419,15 +302,14 @@ class Compiler(object):
|
|||
template['inputs'] = {
|
||||
'parameters': template_inputs
|
||||
}
|
||||
|
||||
# Generate outputs section.
|
||||
if outputs.get(group.name, None):
|
||||
template_outputs = []
|
||||
for param_name, depentent_name in outputs[group.name]:
|
||||
for param_name, dependent_name in outputs[group.name]:
|
||||
template_outputs.append({
|
||||
'name': param_name,
|
||||
'valueFrom': {
|
||||
'parameter': '{{tasks.%s.outputs.parameters.%s}}' % (depentent_name, param_name)
|
||||
'parameter': '{{tasks.%s.outputs.parameters.%s}}' % (dependent_name, param_name)
|
||||
}
|
||||
})
|
||||
template_outputs.sort(key=lambda x: x['name'])
|
||||
|
|
@ -520,7 +402,7 @@ class Compiler(object):
|
|||
templates.append(self._group_to_template(g, inputs, outputs, dependencies))
|
||||
|
||||
for op in pipeline.ops.values():
|
||||
templates.append(self._op_to_template(op))
|
||||
templates.append(_op_to_template(op))
|
||||
return templates
|
||||
|
||||
def _create_volumes(self, pipeline):
|
||||
|
|
@ -532,9 +414,9 @@ class Compiler(object):
|
|||
for v in op.volumes:
|
||||
# Remove volume duplicates which have the same name
|
||||
#TODO: check for duplicity based on the serialized volumes instead of just name.
|
||||
if v.name not in volume_name_set:
|
||||
volume_name_set.add(v.name)
|
||||
volumes.append(K8sHelper.convert_k8s_obj_to_json(v))
|
||||
if v['name'] not in volume_name_set:
|
||||
volume_name_set.add(v['name'])
|
||||
volumes.append(v)
|
||||
volumes.sort(key=lambda x: x['name'])
|
||||
return volumes
|
||||
|
||||
|
|
@ -649,10 +531,6 @@ class Compiler(object):
|
|||
for op in p.ops.values():
|
||||
sanitized_name = K8sHelper.sanitize_k8s_name(op.name)
|
||||
op.name = sanitized_name
|
||||
for param in op.inputs + op.argument_inputs:
|
||||
param.name = K8sHelper.sanitize_k8s_name(param.name)
|
||||
if param.op_name:
|
||||
param.op_name = K8sHelper.sanitize_k8s_name(param.op_name)
|
||||
for param in op.outputs.values():
|
||||
param.name = K8sHelper.sanitize_k8s_name(param.name)
|
||||
if param.op_name:
|
||||
|
|
@ -669,7 +547,6 @@ class Compiler(object):
|
|||
op.file_outputs = sanitized_file_outputs
|
||||
sanitized_ops[sanitized_name] = op
|
||||
p.ops = sanitized_ops
|
||||
|
||||
workflow = self._create_pipeline_workflow(args_list_with_defaults, p)
|
||||
return workflow
|
||||
|
||||
|
|
|
|||
|
|
@ -169,7 +169,7 @@ def _create_container_op_from_resolved_task(name:str, container_image:str, comma
|
|||
if env:
|
||||
from kubernetes import client as k8s_client
|
||||
for name, value in env.items():
|
||||
task.add_env_variable(k8s_client.V1EnvVar(name=name, value=value))
|
||||
task.container.add_env_variable(k8s_client.V1EnvVar(name=name, value=value))
|
||||
|
||||
if need_dummy:
|
||||
_dummy_pipeline.__exit__()
|
||||
|
|
|
|||
|
|
@ -13,8 +13,8 @@
|
|||
# limitations under the License.
|
||||
|
||||
|
||||
from ._pipeline_param import PipelineParam
|
||||
from ._pipeline_param import PipelineParam, match_serialized_pipelineparam
|
||||
from ._pipeline import Pipeline, pipeline, get_pipeline_conf
|
||||
from ._container_op import ContainerOp
|
||||
from ._container_op import ContainerOp, Sidecar
|
||||
from ._ops_group import OpsGroup, ExitHandler, Condition
|
||||
from ._component import python_component, graph_component, component
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -15,16 +15,25 @@
|
|||
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from typing import List
|
||||
from ._metadata import TypeMeta
|
||||
|
||||
|
||||
# TODO: Move this to a separate class
|
||||
# For now, this identifies a condition with only "==" operator supported.
|
||||
ConditionOperator = namedtuple('ConditionOperator', 'operator operand1 operand2')
|
||||
PipelineParamTuple = namedtuple('PipelineParamTuple', 'name op value type')
|
||||
PipelineParamTuple = namedtuple('PipelineParamTuple', 'name op value type pattern')
|
||||
|
||||
def _match_serialized_pipelineparam(payload: str):
|
||||
"""_match_serialized_pipelineparam matches the serialized pipelineparam.
|
||||
|
||||
def sanitize_k8s_name(name):
|
||||
"""From _make_kubernetes_name
|
||||
sanitize_k8s_name cleans and converts the names in the workflow.
|
||||
"""
|
||||
return re.sub('-+', '-', re.sub('[^-0-9a-z]+', '-', name.lower())).lstrip('-').rstrip('-')
|
||||
|
||||
|
||||
def match_serialized_pipelineparam(payload: str):
|
||||
"""match_serialized_pipelineparam matches the serialized pipelineparam.
|
||||
Args:
|
||||
payloads (str): a string that contains the serialized pipelineparam.
|
||||
|
||||
|
|
@ -37,12 +46,24 @@ def _match_serialized_pipelineparam(payload: str):
|
|||
param_tuples = []
|
||||
for match in matches:
|
||||
if len(match) == 3:
|
||||
param_tuples.append(PipelineParamTuple(name=match[1], op=match[0], value=match[2], type=''))
|
||||
pattern = '{{pipelineparam:op=%s;name=%s;value=%s}}' % (match[0], match[1], match[2])
|
||||
param_tuples.append(PipelineParamTuple(
|
||||
name=sanitize_k8s_name(match[1]),
|
||||
op=sanitize_k8s_name(match[0]),
|
||||
value=match[2],
|
||||
type='',
|
||||
pattern=pattern))
|
||||
elif len(match) == 4:
|
||||
param_tuples.append(PipelineParamTuple(name=match[1], op=match[0], value=match[2], type=match[3]))
|
||||
pattern = '{{pipelineparam:op=%s;name=%s;value=%s;type=%s;}}' % (match[0], match[1], match[2], match[3])
|
||||
param_tuples.append(PipelineParamTuple(
|
||||
name=sanitize_k8s_name(match[1]),
|
||||
op=sanitize_k8s_name(match[0]),
|
||||
value=match[2],
|
||||
type=match[3],
|
||||
pattern=pattern))
|
||||
return param_tuples
|
||||
|
||||
def _extract_pipelineparams(payloads: str or list[str]):
|
||||
def _extract_pipelineparams(payloads: str or List[str]):
|
||||
"""_extract_pipelineparam extract a list of PipelineParam instances from the payload string.
|
||||
Note: this function removes all duplicate matches.
|
||||
|
||||
|
|
@ -55,12 +76,64 @@ def _extract_pipelineparams(payloads: str or list[str]):
|
|||
payloads = [payloads]
|
||||
param_tuples = []
|
||||
for payload in payloads:
|
||||
param_tuples += _match_serialized_pipelineparam(payload)
|
||||
param_tuples += match_serialized_pipelineparam(payload)
|
||||
pipeline_params = []
|
||||
for param_tuple in list(set(param_tuples)):
|
||||
pipeline_params.append(PipelineParam(param_tuple.name, param_tuple.op, param_tuple.value, TypeMeta.deserialize(param_tuple.type)))
|
||||
pipeline_params.append(PipelineParam(param_tuple.name,
|
||||
param_tuple.op,
|
||||
param_tuple.value,
|
||||
TypeMeta.deserialize(param_tuple.type),
|
||||
pattern=param_tuple.pattern))
|
||||
return pipeline_params
|
||||
|
||||
|
||||
def extract_pipelineparams_from_any(payload) -> List['PipelineParam']:
|
||||
"""Recursively extract PipelineParam instances or serialized string from any object or list of objects.
|
||||
|
||||
Args:
|
||||
payload (str or k8_obj or list[str or k8_obj]): a string/a list
|
||||
of strings that contains serialized pipelineparams or a k8 definition
|
||||
object.
|
||||
Return:
|
||||
List[PipelineParam]
|
||||
"""
|
||||
if not payload:
|
||||
return []
|
||||
|
||||
# PipelineParam
|
||||
if isinstance(payload, PipelineParam):
|
||||
return [payload]
|
||||
|
||||
# str
|
||||
if isinstance(payload, str):
|
||||
return list(set(_extract_pipelineparams(payload)))
|
||||
|
||||
# list or tuple
|
||||
if isinstance(payload, list) or isinstance(payload, tuple):
|
||||
pipeline_params = []
|
||||
for item in payload:
|
||||
pipeline_params += extract_pipelineparams_from_any(item)
|
||||
return list(set(pipeline_params))
|
||||
|
||||
# dict
|
||||
if isinstance(payload, dict):
|
||||
pipeline_params = []
|
||||
for item in payload.values():
|
||||
pipeline_params += extract_pipelineparams_from_any(item)
|
||||
return list(set(pipeline_params))
|
||||
|
||||
# k8s object
|
||||
if hasattr(payload, 'swagger_types') and isinstance(payload.swagger_types, dict):
|
||||
pipeline_params = []
|
||||
for key in payload.swagger_types.keys():
|
||||
pipeline_params += extract_pipelineparams_from_any(getattr(payload, key))
|
||||
|
||||
return list(set(pipeline_params))
|
||||
|
||||
# return empty list
|
||||
return []
|
||||
|
||||
|
||||
class PipelineParam(object):
|
||||
"""Representing a future value that is passed between pipeline components.
|
||||
|
||||
|
|
@ -69,7 +142,7 @@ class PipelineParam(object):
|
|||
value passed between components.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, op_name: str=None, value: str=None, param_type: TypeMeta=TypeMeta()):
|
||||
def __init__(self, name: str, op_name: str=None, value: str=None, param_type: TypeMeta=TypeMeta(), pattern: str=None):
|
||||
"""Create a new instance of PipelineParam.
|
||||
Args:
|
||||
name: name of the pipeline parameter.
|
||||
|
|
@ -80,6 +153,7 @@ class PipelineParam(object):
|
|||
value: The actual value of the PipelineParam. If provided, the PipelineParam is
|
||||
"resolved" immediately. For now, we support string only.
|
||||
param_type: the type of the PipelineParam.
|
||||
pattern: the serialized string regex pattern this pipeline parameter created from.
|
||||
Raises: ValueError in name or op_name contains invalid characters, or both op_name
|
||||
and value are set.
|
||||
"""
|
||||
|
|
@ -91,10 +165,21 @@ class PipelineParam(object):
|
|||
if op_name and value:
|
||||
raise ValueError('op_name and value cannot be both set.')
|
||||
|
||||
self.op_name = op_name
|
||||
self.name = name
|
||||
self.value = value
|
||||
# ensure value is None even if empty string or empty list
|
||||
# so that serialization and unserialization remain consistent
|
||||
# (i.e. None => '' => None)
|
||||
self.op_name = op_name if op_name else None
|
||||
self.value = value if value else None
|
||||
self.param_type = param_type
|
||||
self.pattern = pattern
|
||||
|
||||
@property
|
||||
def full_name(self):
|
||||
"""Unique name in the argo yaml for the PipelineParam"""
|
||||
if self.op_name:
|
||||
return self.op_name + '-' + self.name
|
||||
return self.name
|
||||
|
||||
def __str__(self):
|
||||
"""String representation.
|
||||
|
|
|
|||
|
|
@ -230,7 +230,6 @@ class TestCompiler(unittest.TestCase):
|
|||
with open(os.path.join(test_data_dir, file_base_name + '.yaml'), 'r') as f:
|
||||
golden = yaml.load(f)
|
||||
compiled = self._get_yaml_from_tar(target_tar)
|
||||
|
||||
self.maxDiff = None
|
||||
self.assertEqual(golden, compiled)
|
||||
finally:
|
||||
|
|
@ -259,6 +258,14 @@ class TestCompiler(unittest.TestCase):
|
|||
"""Test basic sequential pipeline."""
|
||||
self._test_py_compile_zip('basic')
|
||||
|
||||
def test_py_compile_with_sidecar(self):
|
||||
"""Test pipeline with sidecar."""
|
||||
self._test_py_compile_yaml('sidecar')
|
||||
|
||||
def test_py_compile_with_pipelineparams(self):
|
||||
"""Test pipeline with multiple pipeline params."""
|
||||
self._test_py_compile_yaml('pipelineparams')
|
||||
|
||||
def test_py_compile_condition(self):
|
||||
"""Test a pipeline with conditions."""
|
||||
self._test_py_compile_zip('coin')
|
||||
|
|
|
|||
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import kfp.dsl as dsl
|
||||
from kubernetes import client as k8s_client
|
||||
from kubernetes.client.models import V1EnvVar
|
||||
|
||||
|
||||
@dsl.pipeline(name='PipelineParams', description='A pipeline with multiple pipeline params.')
|
||||
def pipelineparams_pipeline(tag: str = 'latest', sleep_ms: int = 10):
|
||||
|
||||
echo = dsl.Sidecar(
|
||||
name='echo',
|
||||
image='hashicorp/http-echo:%s' % tag,
|
||||
args=['-text="hello world"'])
|
||||
|
||||
op1 = dsl.ContainerOp(
|
||||
name='download',
|
||||
image='busybox:%s' % tag,
|
||||
command=['sh', '-c'],
|
||||
arguments=['sleep %s; wget localhost:5678 -O /tmp/results.txt' % sleep_ms],
|
||||
sidecars=[echo],
|
||||
file_outputs={'downloaded': '/tmp/results.txt'})
|
||||
|
||||
op2 = dsl.ContainerOp(
|
||||
name='echo',
|
||||
image='library/bash',
|
||||
command=['sh', '-c'],
|
||||
arguments=['echo $MSG %s' % op1.output])
|
||||
|
||||
op2.container.add_env_variable(V1EnvVar(name='MSG', value='pipelineParams: '))
|
||||
|
|
@ -0,0 +1,142 @@
|
|||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
---
|
||||
apiVersion: argoproj.io/v1alpha1
|
||||
metadata:
|
||||
generateName: pipelineparams-
|
||||
spec:
|
||||
entrypoint: pipelineparams
|
||||
arguments:
|
||||
parameters:
|
||||
- name: tag
|
||||
value: latest
|
||||
- name: sleep-ms
|
||||
value: '10'
|
||||
templates:
|
||||
- name: download
|
||||
inputs:
|
||||
parameters:
|
||||
- name: sleep-ms
|
||||
- name: tag
|
||||
container:
|
||||
image: busybox:{{inputs.parameters.tag}}
|
||||
args:
|
||||
- sleep {{inputs.parameters.sleep-ms}}; wget localhost:5678 -O /tmp/results.txt
|
||||
command:
|
||||
- sh
|
||||
- "-c"
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: mlpipeline-ui-metadata
|
||||
path: "/mlpipeline-ui-metadata.json"
|
||||
s3:
|
||||
endpoint: minio-service.kubeflow:9000
|
||||
insecure: true
|
||||
bucket: mlpipeline
|
||||
accessKeySecret:
|
||||
name: mlpipeline-minio-artifact
|
||||
key: accesskey
|
||||
secretKeySecret:
|
||||
name: mlpipeline-minio-artifact
|
||||
key: secretkey
|
||||
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-ui-metadata.tgz
|
||||
- name: mlpipeline-metrics
|
||||
path: "/mlpipeline-metrics.json"
|
||||
s3:
|
||||
endpoint: minio-service.kubeflow:9000
|
||||
insecure: true
|
||||
bucket: mlpipeline
|
||||
accessKeySecret:
|
||||
name: mlpipeline-minio-artifact
|
||||
key: accesskey
|
||||
secretKeySecret:
|
||||
name: mlpipeline-minio-artifact
|
||||
key: secretkey
|
||||
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-metrics.tgz
|
||||
parameters:
|
||||
- name: download-downloaded
|
||||
valueFrom:
|
||||
path: "/tmp/results.txt"
|
||||
sidecars:
|
||||
- image: hashicorp/http-echo:{{inputs.parameters.tag}}
|
||||
name: echo
|
||||
args:
|
||||
- -text="hello world"
|
||||
- name: echo
|
||||
inputs:
|
||||
parameters:
|
||||
- name: download-downloaded
|
||||
container:
|
||||
image: library/bash
|
||||
args:
|
||||
- echo $MSG {{inputs.parameters.download-downloaded}}
|
||||
command:
|
||||
- sh
|
||||
- "-c"
|
||||
env:
|
||||
- name: MSG
|
||||
value: 'pipelineParams: '
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: mlpipeline-ui-metadata
|
||||
path: "/mlpipeline-ui-metadata.json"
|
||||
s3:
|
||||
endpoint: minio-service.kubeflow:9000
|
||||
insecure: true
|
||||
bucket: mlpipeline
|
||||
accessKeySecret:
|
||||
name: mlpipeline-minio-artifact
|
||||
key: accesskey
|
||||
secretKeySecret:
|
||||
name: mlpipeline-minio-artifact
|
||||
key: secretkey
|
||||
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-ui-metadata.tgz
|
||||
- name: mlpipeline-metrics
|
||||
path: "/mlpipeline-metrics.json"
|
||||
s3:
|
||||
endpoint: minio-service.kubeflow:9000
|
||||
insecure: true
|
||||
bucket: mlpipeline
|
||||
accessKeySecret:
|
||||
name: mlpipeline-minio-artifact
|
||||
key: accesskey
|
||||
secretKeySecret:
|
||||
name: mlpipeline-minio-artifact
|
||||
key: secretkey
|
||||
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-metrics.tgz
|
||||
- name: pipelineparams
|
||||
inputs:
|
||||
parameters:
|
||||
- name: sleep-ms
|
||||
- name: tag
|
||||
dag:
|
||||
tasks:
|
||||
- name: download
|
||||
arguments:
|
||||
parameters:
|
||||
- name: sleep-ms
|
||||
value: "{{inputs.parameters.sleep-ms}}"
|
||||
- name: tag
|
||||
value: "{{inputs.parameters.tag}}"
|
||||
template: download
|
||||
- dependencies:
|
||||
- download
|
||||
arguments:
|
||||
parameters:
|
||||
- name: download-downloaded
|
||||
value: "{{tasks.download.outputs.parameters.download-downloaded}}"
|
||||
name: echo
|
||||
template: echo
|
||||
serviceAccountName: pipeline-runner
|
||||
kind: Workflow
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import kfp.dsl as dsl
|
||||
from kubernetes import client as k8s_client
|
||||
|
||||
|
||||
@dsl.pipeline(name='Sidecar', description='A pipeline with sidecars.')
|
||||
def sidecar_pipeline():
|
||||
|
||||
echo = dsl.Sidecar(
|
||||
name='echo',
|
||||
image='hashicorp/http-echo',
|
||||
args=['-text="hello world"'])
|
||||
|
||||
op1 = dsl.ContainerOp(
|
||||
name='download',
|
||||
image='busybox',
|
||||
command=['sh', '-c'],
|
||||
arguments=['sleep 10; wget localhost:5678 -O /tmp/results.txt'],
|
||||
sidecars=[echo],
|
||||
file_outputs={'downloaded': '/tmp/results.txt'})
|
||||
|
||||
op2 = dsl.ContainerOp(
|
||||
name='echo',
|
||||
image='library/bash',
|
||||
command=['sh', '-c'],
|
||||
arguments=['echo %s' % op1.output])
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
kind: Workflow
|
||||
metadata:
|
||||
generateName: sidecar-
|
||||
apiVersion: argoproj.io/v1alpha1
|
||||
spec:
|
||||
arguments:
|
||||
parameters: []
|
||||
templates:
|
||||
- outputs:
|
||||
artifacts:
|
||||
- name: mlpipeline-ui-metadata
|
||||
path: "/mlpipeline-ui-metadata.json"
|
||||
s3:
|
||||
endpoint: minio-service.kubeflow:9000
|
||||
secretKeySecret:
|
||||
name: mlpipeline-minio-artifact
|
||||
key: secretkey
|
||||
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-ui-metadata.tgz
|
||||
bucket: mlpipeline
|
||||
accessKeySecret:
|
||||
name: mlpipeline-minio-artifact
|
||||
key: accesskey
|
||||
insecure: true
|
||||
- name: mlpipeline-metrics
|
||||
path: "/mlpipeline-metrics.json"
|
||||
s3:
|
||||
endpoint: minio-service.kubeflow:9000
|
||||
secretKeySecret:
|
||||
name: mlpipeline-minio-artifact
|
||||
key: secretkey
|
||||
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-metrics.tgz
|
||||
bucket: mlpipeline
|
||||
accessKeySecret:
|
||||
name: mlpipeline-minio-artifact
|
||||
key: accesskey
|
||||
insecure: true
|
||||
parameters:
|
||||
- name: download-downloaded
|
||||
valueFrom:
|
||||
path: "/tmp/results.txt"
|
||||
name: download
|
||||
sidecars:
|
||||
- image: hashicorp/http-echo
|
||||
name: echo
|
||||
args:
|
||||
- -text="hello world"
|
||||
container:
|
||||
image: busybox
|
||||
args:
|
||||
- sleep 10; wget localhost:5678 -O /tmp/results.txt
|
||||
command:
|
||||
- sh
|
||||
- "-c"
|
||||
- outputs:
|
||||
artifacts:
|
||||
- name: mlpipeline-ui-metadata
|
||||
path: "/mlpipeline-ui-metadata.json"
|
||||
s3:
|
||||
endpoint: minio-service.kubeflow:9000
|
||||
secretKeySecret:
|
||||
name: mlpipeline-minio-artifact
|
||||
key: secretkey
|
||||
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-ui-metadata.tgz
|
||||
bucket: mlpipeline
|
||||
accessKeySecret:
|
||||
name: mlpipeline-minio-artifact
|
||||
key: accesskey
|
||||
insecure: true
|
||||
- name: mlpipeline-metrics
|
||||
path: "/mlpipeline-metrics.json"
|
||||
s3:
|
||||
endpoint: minio-service.kubeflow:9000
|
||||
secretKeySecret:
|
||||
name: mlpipeline-minio-artifact
|
||||
key: secretkey
|
||||
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-metrics.tgz
|
||||
bucket: mlpipeline
|
||||
accessKeySecret:
|
||||
name: mlpipeline-minio-artifact
|
||||
key: accesskey
|
||||
insecure: true
|
||||
name: echo
|
||||
inputs:
|
||||
parameters:
|
||||
- name: download-downloaded
|
||||
container:
|
||||
image: library/bash
|
||||
args:
|
||||
- echo {{inputs.parameters.download-downloaded}}
|
||||
command:
|
||||
- sh
|
||||
- "-c"
|
||||
- name: sidecar
|
||||
dag:
|
||||
tasks:
|
||||
- name: download
|
||||
template: download
|
||||
- arguments:
|
||||
parameters:
|
||||
- name: download-downloaded
|
||||
value: "{{tasks.download.outputs.parameters.download-downloaded}}"
|
||||
name: echo
|
||||
dependencies:
|
||||
- download
|
||||
template: echo
|
||||
serviceAccountName: pipeline-runner
|
||||
entrypoint: sidecar
|
||||
|
|
@ -34,9 +34,9 @@ class LoadComponentTestCase(unittest.TestCase):
|
|||
|
||||
self.assertEqual(task1.human_name, 'Add')
|
||||
self.assertEqual(task_factory1.__doc__.strip(), 'Add\nReturns sum of two arguments')
|
||||
self.assertEqual(task1.image, 'python:3.5')
|
||||
self.assertEqual(task1.arguments[0], str(arg1))
|
||||
self.assertEqual(task1.arguments[1], str(arg2))
|
||||
self.assertEqual(task1.container.image, 'python:3.5')
|
||||
self.assertEqual(task1.container.args[0], str(arg1))
|
||||
self.assertEqual(task1.container.args[1], str(arg2))
|
||||
|
||||
def test_load_component_from_yaml_file(self):
|
||||
_this_file = Path(__file__).resolve()
|
||||
|
|
@ -68,7 +68,7 @@ class LoadComponentTestCase(unittest.TestCase):
|
|||
arg2 = 5
|
||||
task1 = task_factory1(arg1, arg2)
|
||||
assert task1.human_name == component_dict['name']
|
||||
assert task1.image == component_dict['implementation']['container']['image']
|
||||
assert task1.container.image == component_dict['implementation']['container']['image']
|
||||
|
||||
assert task1.arguments[0] == str(arg1)
|
||||
assert task1.arguments[1] == str(arg2)
|
||||
|
|
@ -83,7 +83,7 @@ implementation:
|
|||
task_factory1 = comp.load_component(text=component_text)
|
||||
|
||||
task1 = task_factory1()
|
||||
assert task1.image == component_dict['implementation']['container']['image']
|
||||
assert task1.container.image == component_dict['implementation']['container']['image']
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_fail_on_duplicate_input_names(self):
|
||||
|
|
@ -525,7 +525,7 @@ implementation:
|
|||
import kfp
|
||||
with kfp.dsl.Pipeline('Dummy'): #Forcing the TaskSpec conversion to ContainerOp
|
||||
task1 = task_factory1()
|
||||
actual_env = {env_var.name: env_var.value for env_var in task1.env_variables}
|
||||
actual_env = {env_var.name: env_var.value for env_var in task1.container.env}
|
||||
expected_env = {'key1': 'value 1', 'key2': 'value 2'}
|
||||
self.assertDictEqual(expected_env, actual_env)
|
||||
|
||||
|
|
|
|||
|
|
@ -49,7 +49,6 @@ class PythonOpTestCase(unittest.TestCase):
|
|||
task = op(arg1, arg2)
|
||||
|
||||
full_command = task.command + task.arguments
|
||||
|
||||
process = subprocess.run(full_command)
|
||||
|
||||
output_path = list(task.file_outputs.values())[0]
|
||||
|
|
|
|||
|
|
@ -13,8 +13,12 @@
|
|||
# limitations under the License.
|
||||
|
||||
|
||||
from kfp.dsl import Pipeline, PipelineParam, ContainerOp
|
||||
import warnings
|
||||
import unittest
|
||||
from kubernetes.client.models import V1EnvVar, V1VolumeMount
|
||||
|
||||
from kfp.dsl import Pipeline, PipelineParam, ContainerOp, Sidecar
|
||||
|
||||
|
||||
class TestContainerOp(unittest.TestCase):
|
||||
|
||||
|
|
@ -23,14 +27,21 @@ class TestContainerOp(unittest.TestCase):
|
|||
with Pipeline('somename') as p:
|
||||
param1 = PipelineParam('param1')
|
||||
param2 = PipelineParam('param2')
|
||||
op1 = ContainerOp(name='op1', image='image',
|
||||
op1 = (ContainerOp(name='op1', image='image',
|
||||
arguments=['%s hello %s %s' % (param1, param2, param1)],
|
||||
sidecars=[Sidecar(name='sidecar0', image='image0')],
|
||||
container_kwargs={'env': [V1EnvVar(name='env1', value='value1')]},
|
||||
file_outputs={'out1': '/tmp/b'})
|
||||
.add_sidecar(Sidecar(name='sidecar1', image='image1'))
|
||||
.add_sidecar(Sidecar(name='sidecar2', image='image2')))
|
||||
|
||||
self.assertCountEqual([x.name for x in op1.inputs], ['param1', 'param2'])
|
||||
self.assertCountEqual(list(op1.outputs.keys()), ['out1'])
|
||||
self.assertCountEqual([x.op_name for x in op1.outputs.values()], ['op1'])
|
||||
self.assertEqual(op1.output.name, 'out1')
|
||||
self.assertCountEqual([sidecar.name for sidecar in op1.sidecars], ['sidecar0', 'sidecar1', 'sidecar2'])
|
||||
self.assertCountEqual([sidecar.image for sidecar in op1.sidecars], ['image0', 'image1', 'image2'])
|
||||
self.assertCountEqual([env.name for env in op1.container.env], ['env1'])
|
||||
|
||||
def test_after_op(self):
|
||||
"""Test duplicate ops."""
|
||||
|
|
@ -39,3 +50,38 @@ class TestContainerOp(unittest.TestCase):
|
|||
op2 = ContainerOp(name='op2', image='image')
|
||||
op2.after(op1)
|
||||
self.assertCountEqual(op2.dependent_op_names, [op1.name])
|
||||
|
||||
|
||||
def test_deprecation_warnings(self):
|
||||
"""Test deprecation warnings."""
|
||||
with Pipeline('somename') as p:
|
||||
op = ContainerOp(name='op1', image='image')
|
||||
|
||||
with self.assertWarns(PendingDeprecationWarning):
|
||||
op.env_variables = [V1EnvVar(name="foo", value="bar")]
|
||||
|
||||
with self.assertWarns(PendingDeprecationWarning):
|
||||
op.image = 'image2'
|
||||
|
||||
with self.assertWarns(PendingDeprecationWarning):
|
||||
op.set_memory_request('10M')
|
||||
|
||||
with self.assertWarns(PendingDeprecationWarning):
|
||||
op.set_memory_limit('10M')
|
||||
|
||||
with self.assertWarns(PendingDeprecationWarning):
|
||||
op.set_cpu_request('100m')
|
||||
|
||||
with self.assertWarns(PendingDeprecationWarning):
|
||||
op.set_cpu_limit('1')
|
||||
|
||||
with self.assertWarns(PendingDeprecationWarning):
|
||||
op.set_gpu_limit('1')
|
||||
|
||||
with self.assertWarns(PendingDeprecationWarning):
|
||||
op.add_env_variable(V1EnvVar(name="foo", value="bar"))
|
||||
|
||||
with self.assertWarns(PendingDeprecationWarning):
|
||||
op.add_volume_mount(V1VolumeMount(
|
||||
mount_path='/secret/gcp-credentials',
|
||||
name='gcp-credentials'))
|
||||
|
|
@ -12,9 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from kubernetes.client.models import V1Container, V1EnvVar
|
||||
from kfp.dsl import PipelineParam
|
||||
from kfp.dsl._pipeline_param import _extract_pipelineparams
|
||||
from kfp.dsl._pipeline_param import _extract_pipelineparams, extract_pipelineparams_from_any
|
||||
from kfp.dsl._metadata import TypeMeta
|
||||
import unittest
|
||||
|
||||
|
|
@ -38,8 +38,8 @@ class TestPipelineParam(unittest.TestCase):
|
|||
p = PipelineParam(name='param3', value='value3')
|
||||
self.assertEqual('{{pipelineparam:op=;name=param3;value=value3;type=;}}', str(p))
|
||||
|
||||
def test_extract_pipelineparam(self):
|
||||
"""Test _extract_pipeleineparam."""
|
||||
def test_extract_pipelineparams(self):
|
||||
"""Test _extract_pipeleineparams."""
|
||||
|
||||
p1 = PipelineParam(name='param1', op_name='op1')
|
||||
p2 = PipelineParam(name='param2')
|
||||
|
|
@ -52,6 +52,21 @@ class TestPipelineParam(unittest.TestCase):
|
|||
params = _extract_pipelineparams(payload)
|
||||
self.assertListEqual([p1, p2, p3], params)
|
||||
|
||||
def test_extract_pipelineparams_from_any(self):
|
||||
"""Test extract_pipeleineparams."""
|
||||
p1 = PipelineParam(name='param1', op_name='op1')
|
||||
p2 = PipelineParam(name='param2')
|
||||
p3 = PipelineParam(name='param3', value='value3')
|
||||
stuff_chars = ' between '
|
||||
payload = str(p1) + stuff_chars + str(p2) + stuff_chars + str(p3)
|
||||
|
||||
container = V1Container(name=p1,
|
||||
image=p2,
|
||||
env=[V1EnvVar(name="foo", value=payload)])
|
||||
|
||||
params = extract_pipelineparams_from_any(container)
|
||||
self.assertListEqual(sorted([p1, p2, p3]), sorted(params))
|
||||
|
||||
def test_extract_pipelineparam_with_types(self):
|
||||
"""Test _extract_pipelineparams. """
|
||||
p1 = PipelineParam(name='param1', op_name='op1', param_type=TypeMeta(name='customized_type_a', properties={'property_a': 'value_a'}))
|
||||
|
|
@ -64,4 +79,4 @@ class TestPipelineParam(unittest.TestCase):
|
|||
# Expecting the _extract_pipelineparam to dedup the pipelineparams among all the payloads.
|
||||
payload = [str(p1) + stuff_chars + str(p2), str(p2) + stuff_chars + str(p3)]
|
||||
params = _extract_pipelineparams(payload)
|
||||
self.assertListEqual([p1, p2, p3], params)
|
||||
self.assertListEqual([p1, p2, p3], params)
|
||||
|
|
|
|||
Loading…
Reference in New Issue