Expose an API for appending params/names/descriptions in a programmable way. (#2082)
* Refactor. Expose a public API to append pipeline param without interacting with dsl.Pipeline obj. * Add unit test and fix. * Fix docstring. * Fix test * Fix test * Fix two nit problems * Refactor
This commit is contained in:
parent
a4fa1edb42
commit
497d016e85
|
|
@ -17,7 +17,7 @@ from collections import defaultdict
|
|||
import inspect
|
||||
import tarfile
|
||||
import zipfile
|
||||
from typing import Set, List, Text, Dict
|
||||
from typing import Any, Callable, Set, List, Text, Dict
|
||||
|
||||
import yaml
|
||||
from kfp.dsl import _container_op, _for_loop
|
||||
|
|
@ -27,7 +27,7 @@ from ._k8s_helper import K8sHelper
|
|||
from ._op_to_template import _op_to_template
|
||||
from ._default_transformers import add_pod_env
|
||||
|
||||
from ..dsl._metadata import _extract_pipeline_metadata
|
||||
from ..dsl._metadata import ParameterMeta, _extract_pipeline_metadata
|
||||
from ..dsl._ops_group import OpsGroup
|
||||
|
||||
|
||||
|
|
@ -714,51 +714,84 @@ class Compiler(object):
|
|||
sanitized_ops[sanitized_name] = op
|
||||
pipeline.ops = sanitized_ops
|
||||
|
||||
def _compile(self, pipeline_func):
|
||||
"""Compile the given pipeline function into workflow."""
|
||||
# Step 1: extract param value, name and description from pipeline_func signature and decoration.
|
||||
def create_workflow(self,
|
||||
pipeline_func: Callable,
|
||||
pipeline_name: Text=None,
|
||||
pipeline_description: Text=None,
|
||||
params_list: List[dsl.PipelineParam]=()) -> Dict[Text, Any]:
|
||||
""" Create workflow spec from pipeline function and specified pipeline
|
||||
params/metadata. Currently, the pipeline params are either specified in
|
||||
the signature of the pipeline function or by passing a list of
|
||||
dsl.PipelineParam. Conflict will cause ValueError.
|
||||
|
||||
:param pipeline_func: pipeline function where ContainerOps are invoked.
|
||||
:param pipeline_name:
|
||||
:param pipeline_description:
|
||||
:param params_list: list of pipeline params to append to the pipeline.
|
||||
:return: workflow dict.
|
||||
"""
|
||||
argspec = inspect.getfullargspec(pipeline_func)
|
||||
|
||||
# Create the arg list with no default values and call pipeline function.
|
||||
# Assign type information to the PipelineParam
|
||||
pipeline_meta = _extract_pipeline_metadata(pipeline_func)
|
||||
pipeline_meta.name = pipeline_name or pipeline_meta.name
|
||||
pipeline_meta.description = pipeline_description or pipeline_meta.description
|
||||
pipeline_name = K8sHelper.sanitize_k8s_name(pipeline_meta.name)
|
||||
|
||||
# Currently only allow specifying pipeline params at one place.
|
||||
if params_list and pipeline_meta.inputs:
|
||||
raise ValueError('Either specify pipeline params in the pipeline function, or in "params_list", but not both.')
|
||||
|
||||
args_list = []
|
||||
if pipeline_meta.inputs:
|
||||
input_types = {
|
||||
input.name : input.param_type for input in pipeline_meta.inputs }
|
||||
|
||||
for arg_name in argspec.args:
|
||||
arg_type = None
|
||||
for input in pipeline_meta.inputs:
|
||||
if arg_name == input.name:
|
||||
arg_type = input.param_type
|
||||
break
|
||||
arg_type = input_types.get(arg_name, None)
|
||||
args_list.append(dsl.PipelineParam(K8sHelper.sanitize_k8s_name(arg_name), param_type=arg_type))
|
||||
|
||||
# Step 2: Inflate pipeline obj with ContainerOps.
|
||||
with dsl.Pipeline(pipeline_name) as p:
|
||||
with dsl.Pipeline(pipeline_name) as dsl_pipeline:
|
||||
pipeline_func(*args_list)
|
||||
|
||||
# Step 3: post process. Fill in the default value for pipeline params.
|
||||
# Remove when argo supports local exit handler.
|
||||
self._validate_exit_handler(p)
|
||||
self._validate_exit_handler(dsl_pipeline)
|
||||
self._sanitize_and_inject_artifact(dsl_pipeline)
|
||||
|
||||
# Fill in the default values.
|
||||
if pipeline_meta.inputs:
|
||||
args_list_with_defaults = [dsl.PipelineParam(K8sHelper.sanitize_k8s_name(arg_name))
|
||||
for arg_name in argspec.args]
|
||||
if argspec.defaults:
|
||||
for arg, default in zip(reversed(args_list_with_defaults), reversed(argspec.defaults)):
|
||||
arg.value = default.value if isinstance(default, dsl.PipelineParam) else default
|
||||
|
||||
self._sanitize_and_inject_artifact(p)
|
||||
else:
|
||||
# Or, if args are provided by params_list, fill in pipeline_meta.
|
||||
args_list_with_defaults = params_list
|
||||
pipeline_meta.inputs = [
|
||||
ParameterMeta(
|
||||
name=param.name,
|
||||
description='',
|
||||
param_type=param.param_type,
|
||||
default=param.value) for param in params_list]
|
||||
|
||||
op_transformers = [add_pod_env]
|
||||
op_transformers.extend(p.conf.op_transformers)
|
||||
workflow = self._create_pipeline_workflow(args_list_with_defaults, p, op_transformers)
|
||||
op_transformers.extend(dsl_pipeline.conf.op_transformers)
|
||||
|
||||
workflow = self._create_pipeline_workflow(
|
||||
args_list_with_defaults,
|
||||
dsl_pipeline,
|
||||
op_transformers)
|
||||
|
||||
import json
|
||||
workflow.setdefault('metadata', {}).setdefault('annotations', {})['pipelines.kubeflow.org/pipeline_spec'] = json.dumps(pipeline_meta.to_dict(), sort_keys=True)
|
||||
|
||||
return workflow
|
||||
|
||||
def _compile(self, pipeline_func):
|
||||
"""Compile the given pipeline function into workflow."""
|
||||
return self.create_workflow(pipeline_func, [])
|
||||
|
||||
def compile(self, pipeline_func, package_path, type_check=True):
|
||||
"""Compile the given pipeline function into workflow yaml.
|
||||
|
||||
|
|
|
|||
|
|
@ -179,6 +179,28 @@ class TestCompiler(unittest.TestCase):
|
|||
shutil.rmtree(tmpdir)
|
||||
# print(tmpdir)
|
||||
|
||||
def test_basic_workflow_without_decorator(self):
|
||||
"""Test compiling a workflow and appending pipeline params."""
|
||||
test_data_dir = os.path.join(os.path.dirname(__file__), 'testdata')
|
||||
sys.path.append(test_data_dir)
|
||||
import basic_no_decorator
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
try:
|
||||
compiled_workflow = compiler.Compiler().create_workflow(
|
||||
basic_no_decorator.save_most_frequent_word,
|
||||
'Save Most Frequent',
|
||||
'Get Most Frequent Word and Save to GCS',
|
||||
[
|
||||
basic_no_decorator.message_param,
|
||||
basic_no_decorator.output_path_param
|
||||
])
|
||||
with open(os.path.join(test_data_dir, 'basic_no_decorator.yaml'), 'r') as f:
|
||||
golden = yaml.safe_load(f)
|
||||
|
||||
self.assertEqual(golden, compiled_workflow)
|
||||
finally:
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def test_composing_workflow(self):
|
||||
"""Test compiling a simple workflow, and a bigger one composed from the simple one."""
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import kfp.dsl as dsl
|
||||
import kfp.gcp as gcp
|
||||
|
||||
|
||||
message_param = dsl.PipelineParam(name='message')
|
||||
output_path_param = dsl.PipelineParam(name='outputpath')
|
||||
|
||||
class GetFrequentWordOp(dsl.ContainerOp):
|
||||
"""A get frequent word class representing a component in ML Pipelines.
|
||||
|
||||
The class provides a nice interface to users by hiding details such as container,
|
||||
command, arguments.
|
||||
"""
|
||||
def __init__(self, name, message):
|
||||
"""Args:
|
||||
name: An identifier of the step which needs to be unique within a pipeline.
|
||||
message: a dsl.PipelineParam object representing an input message.
|
||||
"""
|
||||
super(GetFrequentWordOp, self).__init__(
|
||||
name=name,
|
||||
image='python:3.5-jessie',
|
||||
command=['sh', '-c'],
|
||||
arguments=['python -c "from collections import Counter; '
|
||||
'words = Counter(\'%s\'.split()); print(max(words, key=words.get))" '
|
||||
'| tee /tmp/message.txt' % message],
|
||||
file_outputs={'word': '/tmp/message.txt'})
|
||||
|
||||
|
||||
class SaveMessageOp(dsl.ContainerOp):
|
||||
"""A class representing a component in ML Pipelines.
|
||||
|
||||
It saves a message to a given output_path.
|
||||
"""
|
||||
def __init__(self, name, message, output_path):
|
||||
"""Args:
|
||||
name: An identifier of the step which needs to be unique within a pipeline.
|
||||
message: a dsl.PipelineParam object representing the message to be saved.
|
||||
output_path: a dsl.PipelineParam object representing the GCS path for output file.
|
||||
"""
|
||||
super(SaveMessageOp, self).__init__(
|
||||
name=name,
|
||||
image='google/cloud-sdk',
|
||||
command=['sh', '-c'],
|
||||
arguments=['echo %s | tee /tmp/results.txt | gsutil cp /tmp/results.txt %s'
|
||||
% (message, output_path)])
|
||||
|
||||
|
||||
class ExitHandlerOp(dsl.ContainerOp):
|
||||
"""A class representing a component in ML Pipelines.
|
||||
"""
|
||||
def __init__(self, name):
|
||||
super(ExitHandlerOp, self).__init__(
|
||||
name=name,
|
||||
image='python:3.5-jessie',
|
||||
command=['sh', '-c'],
|
||||
arguments=['echo exit!'])
|
||||
|
||||
def save_most_frequent_word():
|
||||
exit_op = ExitHandlerOp('exiting')
|
||||
with dsl.ExitHandler(exit_op):
|
||||
counter = GetFrequentWordOp(
|
||||
name='get-Frequent',
|
||||
message=message_param)
|
||||
counter.set_memory_request('200M')
|
||||
|
||||
saver = SaveMessageOp(
|
||||
name='save',
|
||||
message=counter.output,
|
||||
output_path=output_path_param)
|
||||
saver.set_cpu_limit('0.5')
|
||||
saver.set_gpu_limit('2')
|
||||
saver.add_node_selector_constraint(
|
||||
'cloud.google.com/gke-accelerator',
|
||||
'nvidia-tesla-k80')
|
||||
saver.apply(
|
||||
gcp.use_tpu(tpu_cores = 8, tpu_resource = 'v2', tf_version = '1.12'))
|
||||
|
|
@ -0,0 +1,119 @@
|
|||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
apiVersion: argoproj.io/v1alpha1
|
||||
kind: Workflow
|
||||
metadata:
|
||||
annotations:
|
||||
pipelines.kubeflow.org/pipeline_spec: '{"description": "Get Most Frequent Word and Save to GCS", "inputs": [{"name": "message"}, {"name": "outputpath"}], "name": "Save Most Frequent"}'
|
||||
generateName: save-most-frequent-
|
||||
spec:
|
||||
arguments:
|
||||
parameters:
|
||||
- name: message
|
||||
- name: outputpath
|
||||
entrypoint: save-most-frequent
|
||||
serviceAccountName: pipeline-runner
|
||||
onExit: exiting
|
||||
templates:
|
||||
- dag:
|
||||
tasks:
|
||||
- arguments:
|
||||
parameters:
|
||||
- name: message
|
||||
value: '{{inputs.parameters.message}}'
|
||||
name: get-frequent
|
||||
template: get-frequent
|
||||
- arguments:
|
||||
parameters:
|
||||
- name: get-frequent-word
|
||||
value: '{{tasks.get-frequent.outputs.parameters.get-frequent-word}}'
|
||||
- name: outputpath
|
||||
value: '{{inputs.parameters.outputpath}}'
|
||||
dependencies:
|
||||
- get-frequent
|
||||
name: save
|
||||
template: save
|
||||
inputs:
|
||||
parameters:
|
||||
- name: message
|
||||
- name: outputpath
|
||||
name: exit-handler-1
|
||||
- container:
|
||||
args:
|
||||
- echo exit!
|
||||
command:
|
||||
- sh
|
||||
- -c
|
||||
image: python:3.5-jessie
|
||||
name: exiting
|
||||
- container:
|
||||
args:
|
||||
- python -c "from collections import Counter; words = Counter('{{inputs.parameters.message}}'.split());
|
||||
print(max(words, key=words.get))" | tee /tmp/message.txt
|
||||
command:
|
||||
- sh
|
||||
- -c
|
||||
image: python:3.5-jessie
|
||||
resources:
|
||||
requests:
|
||||
memory: 200M
|
||||
inputs:
|
||||
parameters:
|
||||
- name: message
|
||||
name: get-frequent
|
||||
outputs:
|
||||
parameters:
|
||||
- name: get-frequent-word
|
||||
valueFrom:
|
||||
path: /tmp/message.txt
|
||||
- container:
|
||||
args:
|
||||
- echo {{inputs.parameters.get-frequent-word}} | tee /tmp/results.txt | gsutil
|
||||
cp /tmp/results.txt {{inputs.parameters.outputpath}}
|
||||
command:
|
||||
- sh
|
||||
- -c
|
||||
image: google/cloud-sdk
|
||||
resources:
|
||||
limits:
|
||||
cloud-tpus.google.com/v2: "8"
|
||||
cpu: "0.5"
|
||||
nvidia.com/gpu: "2"
|
||||
inputs:
|
||||
parameters:
|
||||
- name: get-frequent-word
|
||||
- name: outputpath
|
||||
metadata:
|
||||
annotations:
|
||||
tf-version.cloud-tpus.google.com: "1.12"
|
||||
name: save
|
||||
nodeSelector:
|
||||
cloud.google.com/gke-accelerator: nvidia-tesla-k80
|
||||
- dag:
|
||||
tasks:
|
||||
- arguments:
|
||||
parameters:
|
||||
- name: message
|
||||
value: '{{inputs.parameters.message}}'
|
||||
- name: outputpath
|
||||
value: '{{inputs.parameters.outputpath}}'
|
||||
name: exit-handler-1
|
||||
template: exit-handler-1
|
||||
- name: exiting
|
||||
template: exiting
|
||||
inputs:
|
||||
parameters:
|
||||
- name: message
|
||||
- name: outputpath
|
||||
name: save-most-frequent
|
||||
Loading…
Reference in New Issue