Expose an API for appending params/names/descriptions in a programmable way. (#2082)
* Refactor. Expose a public API to append pipeline param without interacting with dsl.Pipeline obj. * Add unit test and fix. * Fix docstring. * Fix test * Fix test * Fix two nit problems * Refactor
This commit is contained in:
parent
a4fa1edb42
commit
497d016e85
|
|
@ -17,7 +17,7 @@ from collections import defaultdict
|
||||||
import inspect
|
import inspect
|
||||||
import tarfile
|
import tarfile
|
||||||
import zipfile
|
import zipfile
|
||||||
from typing import Set, List, Text, Dict
|
from typing import Any, Callable, Set, List, Text, Dict
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from kfp.dsl import _container_op, _for_loop
|
from kfp.dsl import _container_op, _for_loop
|
||||||
|
|
@ -27,7 +27,7 @@ from ._k8s_helper import K8sHelper
|
||||||
from ._op_to_template import _op_to_template
|
from ._op_to_template import _op_to_template
|
||||||
from ._default_transformers import add_pod_env
|
from ._default_transformers import add_pod_env
|
||||||
|
|
||||||
from ..dsl._metadata import _extract_pipeline_metadata
|
from ..dsl._metadata import ParameterMeta, _extract_pipeline_metadata
|
||||||
from ..dsl._ops_group import OpsGroup
|
from ..dsl._ops_group import OpsGroup
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -714,51 +714,84 @@ class Compiler(object):
|
||||||
sanitized_ops[sanitized_name] = op
|
sanitized_ops[sanitized_name] = op
|
||||||
pipeline.ops = sanitized_ops
|
pipeline.ops = sanitized_ops
|
||||||
|
|
||||||
def _compile(self, pipeline_func):
|
def create_workflow(self,
|
||||||
"""Compile the given pipeline function into workflow."""
|
pipeline_func: Callable,
|
||||||
# Step 1: extract param value, name and description from pipeline_func signature and decoration.
|
pipeline_name: Text=None,
|
||||||
|
pipeline_description: Text=None,
|
||||||
|
params_list: List[dsl.PipelineParam]=()) -> Dict[Text, Any]:
|
||||||
|
""" Create workflow spec from pipeline function and specified pipeline
|
||||||
|
params/metadata. Currently, the pipeline params are either specified in
|
||||||
|
the signature of the pipeline function or by passing a list of
|
||||||
|
dsl.PipelineParam. Conflict will cause ValueError.
|
||||||
|
|
||||||
|
:param pipeline_func: pipeline function where ContainerOps are invoked.
|
||||||
|
:param pipeline_name:
|
||||||
|
:param pipeline_description:
|
||||||
|
:param params_list: list of pipeline params to append to the pipeline.
|
||||||
|
:return: workflow dict.
|
||||||
|
"""
|
||||||
argspec = inspect.getfullargspec(pipeline_func)
|
argspec = inspect.getfullargspec(pipeline_func)
|
||||||
|
|
||||||
# Create the arg list with no default values and call pipeline function.
|
# Create the arg list with no default values and call pipeline function.
|
||||||
# Assign type information to the PipelineParam
|
# Assign type information to the PipelineParam
|
||||||
pipeline_meta = _extract_pipeline_metadata(pipeline_func)
|
pipeline_meta = _extract_pipeline_metadata(pipeline_func)
|
||||||
|
pipeline_meta.name = pipeline_name or pipeline_meta.name
|
||||||
|
pipeline_meta.description = pipeline_description or pipeline_meta.description
|
||||||
pipeline_name = K8sHelper.sanitize_k8s_name(pipeline_meta.name)
|
pipeline_name = K8sHelper.sanitize_k8s_name(pipeline_meta.name)
|
||||||
|
|
||||||
args_list = []
|
# Currently only allow specifying pipeline params at one place.
|
||||||
for arg_name in argspec.args:
|
if params_list and pipeline_meta.inputs:
|
||||||
arg_type = None
|
raise ValueError('Either specify pipeline params in the pipeline function, or in "params_list", but not both.')
|
||||||
for input in pipeline_meta.inputs:
|
|
||||||
if arg_name == input.name:
|
|
||||||
arg_type = input.param_type
|
|
||||||
break
|
|
||||||
args_list.append(dsl.PipelineParam(K8sHelper.sanitize_k8s_name(arg_name), param_type=arg_type))
|
|
||||||
|
|
||||||
# Step 2: Inflate pipeline obj with ContainerOps.
|
args_list = []
|
||||||
with dsl.Pipeline(pipeline_name) as p:
|
if pipeline_meta.inputs:
|
||||||
|
input_types = {
|
||||||
|
input.name : input.param_type for input in pipeline_meta.inputs }
|
||||||
|
|
||||||
|
for arg_name in argspec.args:
|
||||||
|
arg_type = input_types.get(arg_name, None)
|
||||||
|
args_list.append(dsl.PipelineParam(K8sHelper.sanitize_k8s_name(arg_name), param_type=arg_type))
|
||||||
|
|
||||||
|
with dsl.Pipeline(pipeline_name) as dsl_pipeline:
|
||||||
pipeline_func(*args_list)
|
pipeline_func(*args_list)
|
||||||
|
|
||||||
# Step 3: post process. Fill in the default value for pipeline params.
|
self._validate_exit_handler(dsl_pipeline)
|
||||||
# Remove when argo supports local exit handler.
|
self._sanitize_and_inject_artifact(dsl_pipeline)
|
||||||
self._validate_exit_handler(p)
|
|
||||||
|
|
||||||
# Fill in the default values.
|
# Fill in the default values.
|
||||||
args_list_with_defaults = [dsl.PipelineParam(K8sHelper.sanitize_k8s_name(arg_name))
|
if pipeline_meta.inputs:
|
||||||
for arg_name in argspec.args]
|
args_list_with_defaults = [dsl.PipelineParam(K8sHelper.sanitize_k8s_name(arg_name))
|
||||||
if argspec.defaults:
|
for arg_name in argspec.args]
|
||||||
for arg, default in zip(reversed(args_list_with_defaults), reversed(argspec.defaults)):
|
if argspec.defaults:
|
||||||
arg.value = default.value if isinstance(default, dsl.PipelineParam) else default
|
for arg, default in zip(reversed(args_list_with_defaults), reversed(argspec.defaults)):
|
||||||
|
arg.value = default.value if isinstance(default, dsl.PipelineParam) else default
|
||||||
self._sanitize_and_inject_artifact(p)
|
else:
|
||||||
|
# Or, if args are provided by params_list, fill in pipeline_meta.
|
||||||
|
args_list_with_defaults = params_list
|
||||||
|
pipeline_meta.inputs = [
|
||||||
|
ParameterMeta(
|
||||||
|
name=param.name,
|
||||||
|
description='',
|
||||||
|
param_type=param.param_type,
|
||||||
|
default=param.value) for param in params_list]
|
||||||
|
|
||||||
op_transformers = [add_pod_env]
|
op_transformers = [add_pod_env]
|
||||||
op_transformers.extend(p.conf.op_transformers)
|
op_transformers.extend(dsl_pipeline.conf.op_transformers)
|
||||||
workflow = self._create_pipeline_workflow(args_list_with_defaults, p, op_transformers)
|
|
||||||
|
workflow = self._create_pipeline_workflow(
|
||||||
|
args_list_with_defaults,
|
||||||
|
dsl_pipeline,
|
||||||
|
op_transformers)
|
||||||
|
|
||||||
import json
|
import json
|
||||||
workflow.setdefault('metadata', {}).setdefault('annotations', {})['pipelines.kubeflow.org/pipeline_spec'] = json.dumps(pipeline_meta.to_dict(), sort_keys=True)
|
workflow.setdefault('metadata', {}).setdefault('annotations', {})['pipelines.kubeflow.org/pipeline_spec'] = json.dumps(pipeline_meta.to_dict(), sort_keys=True)
|
||||||
|
|
||||||
return workflow
|
return workflow
|
||||||
|
|
||||||
|
def _compile(self, pipeline_func):
|
||||||
|
"""Compile the given pipeline function into workflow."""
|
||||||
|
return self.create_workflow(pipeline_func, [])
|
||||||
|
|
||||||
def compile(self, pipeline_func, package_path, type_check=True):
|
def compile(self, pipeline_func, package_path, type_check=True):
|
||||||
"""Compile the given pipeline function into workflow yaml.
|
"""Compile the given pipeline function into workflow yaml.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -179,6 +179,28 @@ class TestCompiler(unittest.TestCase):
|
||||||
shutil.rmtree(tmpdir)
|
shutil.rmtree(tmpdir)
|
||||||
# print(tmpdir)
|
# print(tmpdir)
|
||||||
|
|
||||||
|
def test_basic_workflow_without_decorator(self):
|
||||||
|
"""Test compiling a workflow and appending pipeline params."""
|
||||||
|
test_data_dir = os.path.join(os.path.dirname(__file__), 'testdata')
|
||||||
|
sys.path.append(test_data_dir)
|
||||||
|
import basic_no_decorator
|
||||||
|
tmpdir = tempfile.mkdtemp()
|
||||||
|
try:
|
||||||
|
compiled_workflow = compiler.Compiler().create_workflow(
|
||||||
|
basic_no_decorator.save_most_frequent_word,
|
||||||
|
'Save Most Frequent',
|
||||||
|
'Get Most Frequent Word and Save to GCS',
|
||||||
|
[
|
||||||
|
basic_no_decorator.message_param,
|
||||||
|
basic_no_decorator.output_path_param
|
||||||
|
])
|
||||||
|
with open(os.path.join(test_data_dir, 'basic_no_decorator.yaml'), 'r') as f:
|
||||||
|
golden = yaml.safe_load(f)
|
||||||
|
|
||||||
|
self.assertEqual(golden, compiled_workflow)
|
||||||
|
finally:
|
||||||
|
shutil.rmtree(tmpdir)
|
||||||
|
|
||||||
def test_composing_workflow(self):
|
def test_composing_workflow(self):
|
||||||
"""Test compiling a simple workflow, and a bigger one composed from the simple one."""
|
"""Test compiling a simple workflow, and a bigger one composed from the simple one."""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,91 @@
|
||||||
|
# Copyright 2019 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import kfp.dsl as dsl
|
||||||
|
import kfp.gcp as gcp
|
||||||
|
|
||||||
|
|
||||||
|
message_param = dsl.PipelineParam(name='message')
|
||||||
|
output_path_param = dsl.PipelineParam(name='outputpath')
|
||||||
|
|
||||||
|
class GetFrequentWordOp(dsl.ContainerOp):
|
||||||
|
"""A get frequent word class representing a component in ML Pipelines.
|
||||||
|
|
||||||
|
The class provides a nice interface to users by hiding details such as container,
|
||||||
|
command, arguments.
|
||||||
|
"""
|
||||||
|
def __init__(self, name, message):
|
||||||
|
"""Args:
|
||||||
|
name: An identifier of the step which needs to be unique within a pipeline.
|
||||||
|
message: a dsl.PipelineParam object representing an input message.
|
||||||
|
"""
|
||||||
|
super(GetFrequentWordOp, self).__init__(
|
||||||
|
name=name,
|
||||||
|
image='python:3.5-jessie',
|
||||||
|
command=['sh', '-c'],
|
||||||
|
arguments=['python -c "from collections import Counter; '
|
||||||
|
'words = Counter(\'%s\'.split()); print(max(words, key=words.get))" '
|
||||||
|
'| tee /tmp/message.txt' % message],
|
||||||
|
file_outputs={'word': '/tmp/message.txt'})
|
||||||
|
|
||||||
|
|
||||||
|
class SaveMessageOp(dsl.ContainerOp):
|
||||||
|
"""A class representing a component in ML Pipelines.
|
||||||
|
|
||||||
|
It saves a message to a given output_path.
|
||||||
|
"""
|
||||||
|
def __init__(self, name, message, output_path):
|
||||||
|
"""Args:
|
||||||
|
name: An identifier of the step which needs to be unique within a pipeline.
|
||||||
|
message: a dsl.PipelineParam object representing the message to be saved.
|
||||||
|
output_path: a dsl.PipelineParam object representing the GCS path for output file.
|
||||||
|
"""
|
||||||
|
super(SaveMessageOp, self).__init__(
|
||||||
|
name=name,
|
||||||
|
image='google/cloud-sdk',
|
||||||
|
command=['sh', '-c'],
|
||||||
|
arguments=['echo %s | tee /tmp/results.txt | gsutil cp /tmp/results.txt %s'
|
||||||
|
% (message, output_path)])
|
||||||
|
|
||||||
|
|
||||||
|
class ExitHandlerOp(dsl.ContainerOp):
|
||||||
|
"""A class representing a component in ML Pipelines.
|
||||||
|
"""
|
||||||
|
def __init__(self, name):
|
||||||
|
super(ExitHandlerOp, self).__init__(
|
||||||
|
name=name,
|
||||||
|
image='python:3.5-jessie',
|
||||||
|
command=['sh', '-c'],
|
||||||
|
arguments=['echo exit!'])
|
||||||
|
|
||||||
|
def save_most_frequent_word():
|
||||||
|
exit_op = ExitHandlerOp('exiting')
|
||||||
|
with dsl.ExitHandler(exit_op):
|
||||||
|
counter = GetFrequentWordOp(
|
||||||
|
name='get-Frequent',
|
||||||
|
message=message_param)
|
||||||
|
counter.set_memory_request('200M')
|
||||||
|
|
||||||
|
saver = SaveMessageOp(
|
||||||
|
name='save',
|
||||||
|
message=counter.output,
|
||||||
|
output_path=output_path_param)
|
||||||
|
saver.set_cpu_limit('0.5')
|
||||||
|
saver.set_gpu_limit('2')
|
||||||
|
saver.add_node_selector_constraint(
|
||||||
|
'cloud.google.com/gke-accelerator',
|
||||||
|
'nvidia-tesla-k80')
|
||||||
|
saver.apply(
|
||||||
|
gcp.use_tpu(tpu_cores = 8, tpu_resource = 'v2', tf_version = '1.12'))
|
||||||
|
|
@ -0,0 +1,119 @@
|
||||||
|
# Copyright 2018 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
apiVersion: argoproj.io/v1alpha1
|
||||||
|
kind: Workflow
|
||||||
|
metadata:
|
||||||
|
annotations:
|
||||||
|
pipelines.kubeflow.org/pipeline_spec: '{"description": "Get Most Frequent Word and Save to GCS", "inputs": [{"name": "message"}, {"name": "outputpath"}], "name": "Save Most Frequent"}'
|
||||||
|
generateName: save-most-frequent-
|
||||||
|
spec:
|
||||||
|
arguments:
|
||||||
|
parameters:
|
||||||
|
- name: message
|
||||||
|
- name: outputpath
|
||||||
|
entrypoint: save-most-frequent
|
||||||
|
serviceAccountName: pipeline-runner
|
||||||
|
onExit: exiting
|
||||||
|
templates:
|
||||||
|
- dag:
|
||||||
|
tasks:
|
||||||
|
- arguments:
|
||||||
|
parameters:
|
||||||
|
- name: message
|
||||||
|
value: '{{inputs.parameters.message}}'
|
||||||
|
name: get-frequent
|
||||||
|
template: get-frequent
|
||||||
|
- arguments:
|
||||||
|
parameters:
|
||||||
|
- name: get-frequent-word
|
||||||
|
value: '{{tasks.get-frequent.outputs.parameters.get-frequent-word}}'
|
||||||
|
- name: outputpath
|
||||||
|
value: '{{inputs.parameters.outputpath}}'
|
||||||
|
dependencies:
|
||||||
|
- get-frequent
|
||||||
|
name: save
|
||||||
|
template: save
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: message
|
||||||
|
- name: outputpath
|
||||||
|
name: exit-handler-1
|
||||||
|
- container:
|
||||||
|
args:
|
||||||
|
- echo exit!
|
||||||
|
command:
|
||||||
|
- sh
|
||||||
|
- -c
|
||||||
|
image: python:3.5-jessie
|
||||||
|
name: exiting
|
||||||
|
- container:
|
||||||
|
args:
|
||||||
|
- python -c "from collections import Counter; words = Counter('{{inputs.parameters.message}}'.split());
|
||||||
|
print(max(words, key=words.get))" | tee /tmp/message.txt
|
||||||
|
command:
|
||||||
|
- sh
|
||||||
|
- -c
|
||||||
|
image: python:3.5-jessie
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
memory: 200M
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: message
|
||||||
|
name: get-frequent
|
||||||
|
outputs:
|
||||||
|
parameters:
|
||||||
|
- name: get-frequent-word
|
||||||
|
valueFrom:
|
||||||
|
path: /tmp/message.txt
|
||||||
|
- container:
|
||||||
|
args:
|
||||||
|
- echo {{inputs.parameters.get-frequent-word}} | tee /tmp/results.txt | gsutil
|
||||||
|
cp /tmp/results.txt {{inputs.parameters.outputpath}}
|
||||||
|
command:
|
||||||
|
- sh
|
||||||
|
- -c
|
||||||
|
image: google/cloud-sdk
|
||||||
|
resources:
|
||||||
|
limits:
|
||||||
|
cloud-tpus.google.com/v2: "8"
|
||||||
|
cpu: "0.5"
|
||||||
|
nvidia.com/gpu: "2"
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: get-frequent-word
|
||||||
|
- name: outputpath
|
||||||
|
metadata:
|
||||||
|
annotations:
|
||||||
|
tf-version.cloud-tpus.google.com: "1.12"
|
||||||
|
name: save
|
||||||
|
nodeSelector:
|
||||||
|
cloud.google.com/gke-accelerator: nvidia-tesla-k80
|
||||||
|
- dag:
|
||||||
|
tasks:
|
||||||
|
- arguments:
|
||||||
|
parameters:
|
||||||
|
- name: message
|
||||||
|
value: '{{inputs.parameters.message}}'
|
||||||
|
- name: outputpath
|
||||||
|
value: '{{inputs.parameters.outputpath}}'
|
||||||
|
name: exit-handler-1
|
||||||
|
template: exit-handler-1
|
||||||
|
- name: exiting
|
||||||
|
template: exiting
|
||||||
|
inputs:
|
||||||
|
parameters:
|
||||||
|
- name: message
|
||||||
|
- name: outputpath
|
||||||
|
name: save-most-frequent
|
||||||
Loading…
Reference in New Issue