Expose an API for appending params/names/descriptions in a programmable way. (#2082)

* Refactor. Expose a public API to append pipeline param without interacting with dsl.Pipeline obj.

* Add unit test and fix.

* Fix docstring.

* Fix test

* Fix test

* Fix two nit problems

* Refactor
This commit is contained in:
Jiaxiao Zheng 2019-09-10 17:58:47 -07:00 committed by Kubernetes Prow Robot
parent a4fa1edb42
commit 497d016e85
4 changed files with 292 additions and 27 deletions

View File

@ -17,7 +17,7 @@ from collections import defaultdict
import inspect import inspect
import tarfile import tarfile
import zipfile import zipfile
from typing import Set, List, Text, Dict from typing import Any, Callable, Set, List, Text, Dict
import yaml import yaml
from kfp.dsl import _container_op, _for_loop from kfp.dsl import _container_op, _for_loop
@ -27,7 +27,7 @@ from ._k8s_helper import K8sHelper
from ._op_to_template import _op_to_template from ._op_to_template import _op_to_template
from ._default_transformers import add_pod_env from ._default_transformers import add_pod_env
from ..dsl._metadata import _extract_pipeline_metadata from ..dsl._metadata import ParameterMeta, _extract_pipeline_metadata
from ..dsl._ops_group import OpsGroup from ..dsl._ops_group import OpsGroup
@ -714,51 +714,84 @@ class Compiler(object):
sanitized_ops[sanitized_name] = op sanitized_ops[sanitized_name] = op
pipeline.ops = sanitized_ops pipeline.ops = sanitized_ops
def _compile(self, pipeline_func): def create_workflow(self,
"""Compile the given pipeline function into workflow.""" pipeline_func: Callable,
# Step 1: extract param value, name and description from pipeline_func signature and decoration. pipeline_name: Text=None,
pipeline_description: Text=None,
params_list: List[dsl.PipelineParam]=()) -> Dict[Text, Any]:
""" Create workflow spec from pipeline function and specified pipeline
params/metadata. Currently, the pipeline params are either specified in
the signature of the pipeline function or by passing a list of
dsl.PipelineParam. Conflict will cause ValueError.
:param pipeline_func: pipeline function where ContainerOps are invoked.
:param pipeline_name:
:param pipeline_description:
:param params_list: list of pipeline params to append to the pipeline.
:return: workflow dict.
"""
argspec = inspect.getfullargspec(pipeline_func) argspec = inspect.getfullargspec(pipeline_func)
# Create the arg list with no default values and call pipeline function. # Create the arg list with no default values and call pipeline function.
# Assign type information to the PipelineParam # Assign type information to the PipelineParam
pipeline_meta = _extract_pipeline_metadata(pipeline_func) pipeline_meta = _extract_pipeline_metadata(pipeline_func)
pipeline_meta.name = pipeline_name or pipeline_meta.name
pipeline_meta.description = pipeline_description or pipeline_meta.description
pipeline_name = K8sHelper.sanitize_k8s_name(pipeline_meta.name) pipeline_name = K8sHelper.sanitize_k8s_name(pipeline_meta.name)
args_list = [] # Currently only allow specifying pipeline params at one place.
for arg_name in argspec.args: if params_list and pipeline_meta.inputs:
arg_type = None raise ValueError('Either specify pipeline params in the pipeline function, or in "params_list", but not both.')
for input in pipeline_meta.inputs:
if arg_name == input.name:
arg_type = input.param_type
break
args_list.append(dsl.PipelineParam(K8sHelper.sanitize_k8s_name(arg_name), param_type=arg_type))
# Step 2: Inflate pipeline obj with ContainerOps. args_list = []
with dsl.Pipeline(pipeline_name) as p: if pipeline_meta.inputs:
input_types = {
input.name : input.param_type for input in pipeline_meta.inputs }
for arg_name in argspec.args:
arg_type = input_types.get(arg_name, None)
args_list.append(dsl.PipelineParam(K8sHelper.sanitize_k8s_name(arg_name), param_type=arg_type))
with dsl.Pipeline(pipeline_name) as dsl_pipeline:
pipeline_func(*args_list) pipeline_func(*args_list)
# Step 3: post process. Fill in the default value for pipeline params. self._validate_exit_handler(dsl_pipeline)
# Remove when argo supports local exit handler. self._sanitize_and_inject_artifact(dsl_pipeline)
self._validate_exit_handler(p)
# Fill in the default values. # Fill in the default values.
args_list_with_defaults = [dsl.PipelineParam(K8sHelper.sanitize_k8s_name(arg_name)) if pipeline_meta.inputs:
for arg_name in argspec.args] args_list_with_defaults = [dsl.PipelineParam(K8sHelper.sanitize_k8s_name(arg_name))
if argspec.defaults: for arg_name in argspec.args]
for arg, default in zip(reversed(args_list_with_defaults), reversed(argspec.defaults)): if argspec.defaults:
arg.value = default.value if isinstance(default, dsl.PipelineParam) else default for arg, default in zip(reversed(args_list_with_defaults), reversed(argspec.defaults)):
arg.value = default.value if isinstance(default, dsl.PipelineParam) else default
self._sanitize_and_inject_artifact(p) else:
# Or, if args are provided by params_list, fill in pipeline_meta.
args_list_with_defaults = params_list
pipeline_meta.inputs = [
ParameterMeta(
name=param.name,
description='',
param_type=param.param_type,
default=param.value) for param in params_list]
op_transformers = [add_pod_env] op_transformers = [add_pod_env]
op_transformers.extend(p.conf.op_transformers) op_transformers.extend(dsl_pipeline.conf.op_transformers)
workflow = self._create_pipeline_workflow(args_list_with_defaults, p, op_transformers)
workflow = self._create_pipeline_workflow(
args_list_with_defaults,
dsl_pipeline,
op_transformers)
import json import json
workflow.setdefault('metadata', {}).setdefault('annotations', {})['pipelines.kubeflow.org/pipeline_spec'] = json.dumps(pipeline_meta.to_dict(), sort_keys=True) workflow.setdefault('metadata', {}).setdefault('annotations', {})['pipelines.kubeflow.org/pipeline_spec'] = json.dumps(pipeline_meta.to_dict(), sort_keys=True)
return workflow return workflow
def _compile(self, pipeline_func):
"""Compile the given pipeline function into workflow."""
return self.create_workflow(pipeline_func, [])
def compile(self, pipeline_func, package_path, type_check=True): def compile(self, pipeline_func, package_path, type_check=True):
"""Compile the given pipeline function into workflow yaml. """Compile the given pipeline function into workflow yaml.

View File

@ -179,6 +179,28 @@ class TestCompiler(unittest.TestCase):
shutil.rmtree(tmpdir) shutil.rmtree(tmpdir)
# print(tmpdir) # print(tmpdir)
def test_basic_workflow_without_decorator(self):
"""Test compiling a workflow and appending pipeline params."""
test_data_dir = os.path.join(os.path.dirname(__file__), 'testdata')
sys.path.append(test_data_dir)
import basic_no_decorator
tmpdir = tempfile.mkdtemp()
try:
compiled_workflow = compiler.Compiler().create_workflow(
basic_no_decorator.save_most_frequent_word,
'Save Most Frequent',
'Get Most Frequent Word and Save to GCS',
[
basic_no_decorator.message_param,
basic_no_decorator.output_path_param
])
with open(os.path.join(test_data_dir, 'basic_no_decorator.yaml'), 'r') as f:
golden = yaml.safe_load(f)
self.assertEqual(golden, compiled_workflow)
finally:
shutil.rmtree(tmpdir)
def test_composing_workflow(self): def test_composing_workflow(self):
"""Test compiling a simple workflow, and a bigger one composed from the simple one.""" """Test compiling a simple workflow, and a bigger one composed from the simple one."""

View File

@ -0,0 +1,91 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import kfp.dsl as dsl
import kfp.gcp as gcp
message_param = dsl.PipelineParam(name='message')
output_path_param = dsl.PipelineParam(name='outputpath')
class GetFrequentWordOp(dsl.ContainerOp):
"""A get frequent word class representing a component in ML Pipelines.
The class provides a nice interface to users by hiding details such as container,
command, arguments.
"""
def __init__(self, name, message):
"""Args:
name: An identifier of the step which needs to be unique within a pipeline.
message: a dsl.PipelineParam object representing an input message.
"""
super(GetFrequentWordOp, self).__init__(
name=name,
image='python:3.5-jessie',
command=['sh', '-c'],
arguments=['python -c "from collections import Counter; '
'words = Counter(\'%s\'.split()); print(max(words, key=words.get))" '
'| tee /tmp/message.txt' % message],
file_outputs={'word': '/tmp/message.txt'})
class SaveMessageOp(dsl.ContainerOp):
"""A class representing a component in ML Pipelines.
It saves a message to a given output_path.
"""
def __init__(self, name, message, output_path):
"""Args:
name: An identifier of the step which needs to be unique within a pipeline.
message: a dsl.PipelineParam object representing the message to be saved.
output_path: a dsl.PipelineParam object representing the GCS path for output file.
"""
super(SaveMessageOp, self).__init__(
name=name,
image='google/cloud-sdk',
command=['sh', '-c'],
arguments=['echo %s | tee /tmp/results.txt | gsutil cp /tmp/results.txt %s'
% (message, output_path)])
class ExitHandlerOp(dsl.ContainerOp):
"""A class representing a component in ML Pipelines.
"""
def __init__(self, name):
super(ExitHandlerOp, self).__init__(
name=name,
image='python:3.5-jessie',
command=['sh', '-c'],
arguments=['echo exit!'])
def save_most_frequent_word():
exit_op = ExitHandlerOp('exiting')
with dsl.ExitHandler(exit_op):
counter = GetFrequentWordOp(
name='get-Frequent',
message=message_param)
counter.set_memory_request('200M')
saver = SaveMessageOp(
name='save',
message=counter.output,
output_path=output_path_param)
saver.set_cpu_limit('0.5')
saver.set_gpu_limit('2')
saver.add_node_selector_constraint(
'cloud.google.com/gke-accelerator',
'nvidia-tesla-k80')
saver.apply(
gcp.use_tpu(tpu_cores = 8, tpu_resource = 'v2', tf_version = '1.12'))

View File

@ -0,0 +1,119 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
annotations:
pipelines.kubeflow.org/pipeline_spec: '{"description": "Get Most Frequent Word and Save to GCS", "inputs": [{"name": "message"}, {"name": "outputpath"}], "name": "Save Most Frequent"}'
generateName: save-most-frequent-
spec:
arguments:
parameters:
- name: message
- name: outputpath
entrypoint: save-most-frequent
serviceAccountName: pipeline-runner
onExit: exiting
templates:
- dag:
tasks:
- arguments:
parameters:
- name: message
value: '{{inputs.parameters.message}}'
name: get-frequent
template: get-frequent
- arguments:
parameters:
- name: get-frequent-word
value: '{{tasks.get-frequent.outputs.parameters.get-frequent-word}}'
- name: outputpath
value: '{{inputs.parameters.outputpath}}'
dependencies:
- get-frequent
name: save
template: save
inputs:
parameters:
- name: message
- name: outputpath
name: exit-handler-1
- container:
args:
- echo exit!
command:
- sh
- -c
image: python:3.5-jessie
name: exiting
- container:
args:
- python -c "from collections import Counter; words = Counter('{{inputs.parameters.message}}'.split());
print(max(words, key=words.get))" | tee /tmp/message.txt
command:
- sh
- -c
image: python:3.5-jessie
resources:
requests:
memory: 200M
inputs:
parameters:
- name: message
name: get-frequent
outputs:
parameters:
- name: get-frequent-word
valueFrom:
path: /tmp/message.txt
- container:
args:
- echo {{inputs.parameters.get-frequent-word}} | tee /tmp/results.txt | gsutil
cp /tmp/results.txt {{inputs.parameters.outputpath}}
command:
- sh
- -c
image: google/cloud-sdk
resources:
limits:
cloud-tpus.google.com/v2: "8"
cpu: "0.5"
nvidia.com/gpu: "2"
inputs:
parameters:
- name: get-frequent-word
- name: outputpath
metadata:
annotations:
tf-version.cloud-tpus.google.com: "1.12"
name: save
nodeSelector:
cloud.google.com/gke-accelerator: nvidia-tesla-k80
- dag:
tasks:
- arguments:
parameters:
- name: message
value: '{{inputs.parameters.message}}'
- name: outputpath
value: '{{inputs.parameters.outputpath}}'
name: exit-handler-1
template: exit-handler-1
- name: exiting
template: exiting
inputs:
parameters:
- name: message
- name: outputpath
name: save-most-frequent