Support recursions in a function (#1014)
* add a While in the ops group * deepcopy the while conditions when entering and exiting * add while condition resolution in the compiler * define graph component decorator * remove while loop related codes * fixes * remove while loop related code * fix bugs * generate a unique ops group name and being able to retrieve by name * resolve the opsgroups inputs and dependencies based on the pipelineparam in the condition * add a recursive ops_groups * fix bugs of the recursive opsgroup template name * resolve the recursive template name and arguments * add validity checks * add more comments * add usage comment in graph_component * add unit test for the graph opsgraph * refactor the opsgroup * add unit test for the graph_component decorator * exposing graph_component decorator * add recursive compiler unit tests * fix the bug of opsgroup name adjust the graph_component usage example fix index bugs use with statement in the graph_component instead of directly calling the enter/exit functions * add a todo to combine the graph_component and component decorators
This commit is contained in:
parent
1c4f9eb431
commit
8c09090985
|
|
@ -23,6 +23,7 @@ from .. import dsl
|
|||
from ._k8s_helper import K8sHelper
|
||||
from ..dsl._pipeline_param import _match_serialized_pipelineparam
|
||||
from ..dsl._metadata import TypeMeta
|
||||
from ..dsl._ops_group import OpsGroup
|
||||
|
||||
class Compiler(object):
|
||||
"""DSL Compiler.
|
||||
|
|
@ -65,6 +66,11 @@ class Compiler(object):
|
|||
def _get_op_groups_helper(current_groups, ops_to_groups):
|
||||
root_group = current_groups[-1]
|
||||
for g in root_group.groups:
|
||||
# Add recursive opsgroup in the ops_to_groups
|
||||
# such that the i/o dependency can be propagated to the ancester opsgroups
|
||||
if g.recursive_ref:
|
||||
ops_to_groups[g.name] = [x.name for x in current_groups] + [g.name]
|
||||
continue
|
||||
current_groups.append(g)
|
||||
_get_op_groups_helper(current_groups, ops_to_groups)
|
||||
del current_groups[-1]
|
||||
|
|
@ -82,7 +88,10 @@ class Compiler(object):
|
|||
def _get_groups_helper(group):
|
||||
groups = [group]
|
||||
for g in group.groups:
|
||||
groups += _get_groups_helper(g)
|
||||
# Skip the recursive opsgroup because no templates
|
||||
# need to be generated for the recursive opsgroups.
|
||||
if not g.recursive_ref:
|
||||
groups += _get_groups_helper(g)
|
||||
return groups
|
||||
|
||||
return _get_groups_helper(root_group)
|
||||
|
|
@ -145,6 +154,38 @@ class Compiler(object):
|
|||
if not op.is_exit_handler:
|
||||
for g in op_groups[op.name]:
|
||||
inputs[g].add((full_name, None))
|
||||
|
||||
# Generate the input/output for recursive opsgroups
|
||||
# It propagates the recursive opsgroups IO to their ancester opsgroups
|
||||
def _get_inputs_outputs_recursive_opsgroup(group):
|
||||
#TODO: refactor the following codes with the above
|
||||
if group.recursive_ref:
|
||||
for param in group.inputs + list(condition_params[group.name]):
|
||||
if param.value:
|
||||
continue
|
||||
full_name = self._pipelineparam_full_name(param)
|
||||
if param.op_name:
|
||||
upstream_op = pipeline.ops[param.op_name]
|
||||
upstream_groups, downstream_groups = self._get_uncommon_ancestors(
|
||||
op_groups, upstream_op, group)
|
||||
for i, g in enumerate(downstream_groups):
|
||||
if i == 0:
|
||||
inputs[g].add((full_name, upstream_groups[0]))
|
||||
else:
|
||||
inputs[g].add((full_name, None))
|
||||
for i, g in enumerate(upstream_groups):
|
||||
if i == len(upstream_groups) - 1:
|
||||
outputs[g].add((full_name, None))
|
||||
else:
|
||||
outputs[g].add((full_name, upstream_groups[i+1]))
|
||||
else:
|
||||
if not op.is_exit_handler:
|
||||
for g in op_groups[op.name]:
|
||||
inputs[g].add((full_name, None))
|
||||
for subgroup in group.groups:
|
||||
_get_inputs_outputs_recursive_opsgroup(subgroup)
|
||||
|
||||
_get_inputs_outputs_recursive_opsgroup(root_group)
|
||||
return inputs, outputs
|
||||
|
||||
def _get_condition_params_for_ops(self, root_group):
|
||||
|
|
@ -164,8 +205,13 @@ class Compiler(object):
|
|||
for param in new_current_conditions_params:
|
||||
conditions[op.name].add(param)
|
||||
for g in group.groups:
|
||||
_get_condition_params_for_ops_helper(g, new_current_conditions_params)
|
||||
|
||||
# If the subgroup is a recursive opsgroup, propagate the pipelineparams
|
||||
# in the condition expression, similar to the ops.
|
||||
if g.recursive_ref:
|
||||
for param in new_current_conditions_params:
|
||||
conditions[g.name].add(param)
|
||||
else:
|
||||
_get_condition_params_for_ops_helper(g, new_current_conditions_params)
|
||||
_get_condition_params_for_ops_helper(root_group, [])
|
||||
return conditions
|
||||
|
||||
|
|
@ -179,6 +225,8 @@ class Compiler(object):
|
|||
then G3 is dependent on G2. Basically dependency only exists in the first uncommon
|
||||
ancesters in their ancesters chain. Only sibling groups/ops can have dependencies.
|
||||
"""
|
||||
#TODO: move the condition_params out because both the _get_inputs_outputs
|
||||
# and _get_dependencies depend on it.
|
||||
condition_params = self._get_condition_params_for_ops(root_group)
|
||||
dependencies = defaultdict(set)
|
||||
for op in pipeline.ops.values():
|
||||
|
|
@ -193,6 +241,29 @@ class Compiler(object):
|
|||
upstream_groups, downstream_groups = self._get_uncommon_ancestors(
|
||||
op_groups, upstream_op, op)
|
||||
dependencies[downstream_groups[0]].add(upstream_groups[0])
|
||||
|
||||
# Generate dependencies based on the recursive opsgroups
|
||||
#TODO: refactor the following codes with the above
|
||||
def _get_dependency_opsgroup(group, dependencies):
|
||||
if group.recursive_ref:
|
||||
unstream_op_names = set()
|
||||
for param in group.inputs + list(condition_params[group.name]):
|
||||
if param.op_name:
|
||||
unstream_op_names.add(param.op_name)
|
||||
unstream_op_names |= set(group.dependencies)
|
||||
|
||||
for op_name in unstream_op_names:
|
||||
upstream_op = pipeline.ops[op_name]
|
||||
upstream_groups, downstream_groups = self._get_uncommon_ancestors(
|
||||
op_groups, upstream_op, group)
|
||||
dependencies[downstream_groups[0]].add(upstream_groups[0])
|
||||
|
||||
|
||||
for subgroup in group.groups:
|
||||
_get_dependency_opsgroup(subgroup, dependencies)
|
||||
|
||||
_get_dependency_opsgroup(root_group, dependencies)
|
||||
|
||||
return dependencies
|
||||
|
||||
def _resolve_value_or_reference(self, value_or_reference, potential_references):
|
||||
|
|
@ -364,11 +435,18 @@ class Compiler(object):
|
|||
# Generate tasks section.
|
||||
tasks = []
|
||||
for sub_group in group.groups + group.ops:
|
||||
task = {
|
||||
'name': sub_group.name,
|
||||
'template': sub_group.name,
|
||||
}
|
||||
|
||||
is_recursive_subgroup = (isinstance(sub_group, OpsGroup) and sub_group.recursive_ref)
|
||||
# Special handling for recursive subgroup: use the existing opsgroup name
|
||||
if is_recursive_subgroup:
|
||||
task = {
|
||||
'name': sub_group.recursive_ref.name,
|
||||
'template': sub_group.recursive_ref.name,
|
||||
}
|
||||
else:
|
||||
task = {
|
||||
'name': sub_group.name,
|
||||
'template': sub_group.name,
|
||||
}
|
||||
if isinstance(sub_group, dsl.OpsGroup) and sub_group.type == 'condition':
|
||||
subgroup_inputs = inputs.get(sub_group.name, [])
|
||||
condition = sub_group.condition
|
||||
|
|
@ -394,10 +472,22 @@ class Compiler(object):
|
|||
})
|
||||
else:
|
||||
# The value comes from its parent.
|
||||
arguments.append({
|
||||
'name': param_name,
|
||||
'value': '{{inputs.parameters.%s}}' % param_name
|
||||
})
|
||||
# Special handling for recursive subgroup: argument name comes from the existing opsgroup
|
||||
if is_recursive_subgroup:
|
||||
for index, input in enumerate(sub_group.inputs):
|
||||
if param_name == input.name:
|
||||
break
|
||||
referenced_input = sub_group.recursive_ref.inputs[index]
|
||||
full_name = self._pipelineparam_full_name(referenced_input)
|
||||
arguments.append({
|
||||
'name': full_name,
|
||||
'value': '{{inputs.parameters.%s}}' % param_name
|
||||
})
|
||||
else:
|
||||
arguments.append({
|
||||
'name': param_name,
|
||||
'value': '{{inputs.parameters.%s}}' % param_name
|
||||
})
|
||||
arguments.sort(key=lambda x: x['name'])
|
||||
task['arguments'] = {'parameters': arguments}
|
||||
tasks.append(task)
|
||||
|
|
@ -410,6 +500,15 @@ class Compiler(object):
|
|||
|
||||
new_root_group = pipeline.groups[0]
|
||||
|
||||
# Generate core data structures to prepare for argo yaml generation
|
||||
# op_groups: op name -> list of ancestor groups including the current op
|
||||
# inputs, outputs: group/op names -> list of tuples (param_name, producing_op_name)
|
||||
# dependencies: group/op name -> list of dependent groups/ops.
|
||||
# groups: opsgroups
|
||||
# Special Handling for the recursive opsgroup
|
||||
# op_groups also contains the recursive opsgroups
|
||||
# condition_params from _get_condition_params_for_ops also contains the recursive opsgroups
|
||||
# groups does not include the recursive opsgroups
|
||||
op_groups = self._get_groups_for_ops(new_root_group)
|
||||
inputs, outputs = self._get_inputs_outputs(pipeline, new_root_group, op_groups)
|
||||
dependencies = self._get_dependencies(pipeline, new_root_group, op_groups)
|
||||
|
|
|
|||
|
|
@ -17,5 +17,4 @@ from ._pipeline_param import PipelineParam
|
|||
from ._pipeline import Pipeline, pipeline, get_pipeline_conf
|
||||
from ._container_op import ContainerOp
|
||||
from ._ops_group import OpsGroup, ExitHandler, Condition
|
||||
from ._component import python_component, component
|
||||
#TODO: expose the component decorator when ready
|
||||
from ._component import python_component, graph_component, component
|
||||
|
|
@ -15,6 +15,7 @@
|
|||
from ._metadata import ComponentMeta, ParameterMeta, TypeMeta, _annotation_to_typemeta
|
||||
from ._pipeline_param import PipelineParam
|
||||
from .types import check_types, InconsistentTypeException
|
||||
from ._ops_group import Graph
|
||||
import kfp
|
||||
|
||||
def python_component(name, description=None, base_image=None, target_component_file: str = None):
|
||||
|
|
@ -54,7 +55,7 @@ def python_component(name, description=None, base_image=None, target_component_f
|
|||
return _python_component
|
||||
|
||||
def component(func):
|
||||
"""Decorator for component functions that use ContainerOp.
|
||||
"""Decorator for component functions that returns a ContainerOp.
|
||||
This is useful to enable type checking in the DSL compiler
|
||||
|
||||
Usage:
|
||||
|
|
@ -118,3 +119,42 @@ def component(func):
|
|||
return container_op
|
||||
|
||||
return _component
|
||||
|
||||
#TODO: combine the component and graph_component decorators into one
|
||||
def graph_component(func):
|
||||
"""Decorator for graph component functions.
|
||||
This decorator returns an ops_group.
|
||||
|
||||
Usage:
|
||||
```python
|
||||
import kfp.dsl as dsl
|
||||
@dsl.graph_component
|
||||
def flip_component(flip_result):
|
||||
print_flip = PrintOp(flip_result)
|
||||
flipA = FlipCoinOp().after(print_flip)
|
||||
with dsl.Condition(flipA.output == 'heads'):
|
||||
flip_component(flipA.output)
|
||||
return {'flip_result': flipA.output}
|
||||
"""
|
||||
from functools import wraps
|
||||
@wraps(func)
|
||||
def _graph_component(*args, **kargs):
|
||||
graph_ops_group = Graph(func.__name__)
|
||||
graph_ops_group.inputs = list(args) + list(kargs.values())
|
||||
for input in graph_ops_group.inputs:
|
||||
if not isinstance(input, PipelineParam):
|
||||
raise ValueError('arguments to ' + func.__name__ + ' should be PipelineParams.')
|
||||
|
||||
# Entering the Graph Context
|
||||
with graph_ops_group:
|
||||
# Call the function
|
||||
if not graph_ops_group.recursive_ref:
|
||||
graph_ops_group.outputs = func(*args, **kargs)
|
||||
if not isinstance(graph_ops_group.outputs, dict):
|
||||
raise ValueError(func.__name__ + ' needs to return a dictionary of string to PipelineParam.')
|
||||
for output in graph_ops_group.outputs:
|
||||
if not (isinstance(output, str) and isinstance(graph_ops_group.outputs[output], PipelineParam)):
|
||||
raise ValueError(func.__name__ + ' needs to return a dictionary of string to PipelineParam.')
|
||||
|
||||
return graph_ops_group
|
||||
return _graph_component
|
||||
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
from . import _container_op
|
||||
from . import _pipeline
|
||||
|
||||
from ._pipeline_param import ConditionOperator
|
||||
|
||||
class OpsGroup(object):
|
||||
"""Represents a logical group of ops and group of OpsGroups.
|
||||
|
|
@ -28,21 +28,51 @@ class OpsGroup(object):
|
|||
def __init__(self, group_type: str, name: str=None):
|
||||
"""Create a new instance of OpsGroup.
|
||||
Args:
|
||||
group_type: one of 'pipeline', 'exit_handler', 'condition', and 'loop'.
|
||||
group_type (str): one of 'pipeline', 'exit_handler', 'condition', and 'graph'.
|
||||
name (str): name of the opsgroup
|
||||
"""
|
||||
#TODO: declare the group_type to be strongly typed
|
||||
self.type = group_type
|
||||
self.ops = list()
|
||||
self.groups = list()
|
||||
self.name = name
|
||||
# recursive_ref points to the opsgroups with the same name if exists.
|
||||
self.recursive_ref = None
|
||||
|
||||
@staticmethod
|
||||
def _get_opsgroup_pipeline(group_type, name):
|
||||
"""retrieves the opsgroup when the pipeline already contains it.
|
||||
the opsgroup might be already in the pipeline in case of recursive calls.
|
||||
Args:
|
||||
group_type (str): one of 'pipeline', 'exit_handler', 'condition', and 'graph'.
|
||||
name (str): the name before conversion. """
|
||||
if not _pipeline.Pipeline.get_default_pipeline():
|
||||
raise ValueError('Default pipeline not defined.')
|
||||
if name is None:
|
||||
return None
|
||||
name_pattern = '^' + (group_type + '-' + name + '-').replace('_', '-') + '[\d]+$'
|
||||
for ops_group in _pipeline.Pipeline.get_default_pipeline().groups:
|
||||
import re
|
||||
if ops_group.type == group_type and re.match(name_pattern ,ops_group.name):
|
||||
return ops_group
|
||||
return None
|
||||
|
||||
def _make_name_unique(self):
|
||||
"""Generate a unique opsgroup name in the pipeline"""
|
||||
if not _pipeline.Pipeline.get_default_pipeline():
|
||||
raise ValueError('Default pipeline not defined.')
|
||||
|
||||
self.name = (self.type + '-' + ('' if self.name is None else self.name + '-') +
|
||||
str(_pipeline.Pipeline.get_default_pipeline().get_next_group_id()))
|
||||
self.name = self.name.replace('_', '-')
|
||||
|
||||
def __enter__(self):
|
||||
if not _pipeline.Pipeline.get_default_pipeline():
|
||||
raise ValueError('Default pipeline not defined.')
|
||||
|
||||
if not self.name:
|
||||
self.name = (self.type + '-' +
|
||||
str(_pipeline.Pipeline.get_default_pipeline().get_next_group_id()))
|
||||
self.name = self.name.replace('_', '-')
|
||||
self.recursive_ref = self._get_opsgroup_pipeline(self.type, self.name)
|
||||
if not self.recursive_ref:
|
||||
self._make_name_unique()
|
||||
|
||||
_pipeline.Pipeline.get_default_pipeline().push_ops_group(self)
|
||||
return self
|
||||
|
|
@ -50,7 +80,6 @@ class OpsGroup(object):
|
|||
def __exit__(self, *args):
|
||||
_pipeline.Pipeline.get_default_pipeline().pop_ops_group()
|
||||
|
||||
|
||||
class ExitHandler(OpsGroup):
|
||||
"""Represents an exit handler that is invoked upon exiting a group of ops.
|
||||
|
||||
|
|
@ -92,10 +121,25 @@ class Condition(OpsGroup):
|
|||
def __init__(self, condition):
|
||||
"""Create a new instance of ExitHandler.
|
||||
Args:
|
||||
exit_op: an operator invoked at exiting a group of ops.
|
||||
condition (ConditionOperator): the condition.
|
||||
|
||||
Raises:
|
||||
ValueError is the exit_op is invalid.
|
||||
"""
|
||||
super(Condition, self).__init__('condition')
|
||||
self.condition = condition
|
||||
|
||||
class Graph(OpsGroup):
|
||||
"""Graph DAG with inputs, recursive_inputs, and outputs.
|
||||
This is not used directly by the users but auto generated when the graph_component decoration exists
|
||||
"""
|
||||
def __init__(self, name):
|
||||
super(Graph, self).__init__(group_type='graph', name=name)
|
||||
self.inputs = []
|
||||
self.outputs = {}
|
||||
self.dependencies = []
|
||||
|
||||
def after(self, dependency):
|
||||
"""Specify explicit dependency on another op."""
|
||||
self.dependencies.append(dependency)
|
||||
return self
|
||||
|
|
@ -86,6 +86,7 @@ def get_pipeline_conf():
|
|||
"""
|
||||
return Pipeline.get_default_pipeline().conf
|
||||
|
||||
#TODO: Pipeline is in fact an opsgroup, refactor the code.
|
||||
class Pipeline():
|
||||
"""A pipeline contains a list of operators.
|
||||
|
||||
|
|
|
|||
|
|
@ -241,6 +241,10 @@ class TestCompiler(unittest.TestCase):
|
|||
"""Test pipeline imagepullsecret."""
|
||||
self._test_py_compile('imagepullsecret')
|
||||
|
||||
def test_py_recursive(self):
|
||||
"""Test pipeline recursive."""
|
||||
self._test_py_compile('recursive')
|
||||
|
||||
def test_type_checking_with_consistent_types(self):
|
||||
"""Test type check pipeline parameters against component metadata."""
|
||||
@component
|
||||
|
|
|
|||
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import kfp.dsl as dsl
|
||||
from kfp.dsl import graph_component
|
||||
|
||||
class FlipCoinOp(dsl.ContainerOp):
|
||||
"""Flip a coin and output heads or tails randomly."""
|
||||
|
||||
def __init__(self):
|
||||
super(FlipCoinOp, self).__init__(
|
||||
name='Flip',
|
||||
image='python:alpine3.6',
|
||||
command=['sh', '-c'],
|
||||
arguments=['python -c "import random; result = \'heads\' if random.randint(0,1) == 0 '
|
||||
'else \'tails\'; print(result)" | tee /tmp/output'],
|
||||
file_outputs={'output': '/tmp/output'})
|
||||
|
||||
class PrintOp(dsl.ContainerOp):
|
||||
"""Print a message."""
|
||||
|
||||
def __init__(self, msg):
|
||||
super(PrintOp, self).__init__(
|
||||
name='Print',
|
||||
image='alpine:3.6',
|
||||
command=['echo', msg],
|
||||
)
|
||||
|
||||
@graph_component
|
||||
def flip_component(flip_result):
|
||||
print_flip = PrintOp(flip_result)
|
||||
flipA = FlipCoinOp().after(print_flip)
|
||||
with dsl.Condition(flipA.output == 'heads'):
|
||||
flip_component(flipA.output)
|
||||
return {'flip_result': flipA.output}
|
||||
|
||||
@dsl.pipeline(
|
||||
name='pipeline flip coin',
|
||||
description='shows how to use graph_component.'
|
||||
)
|
||||
def recursive():
|
||||
flipA = FlipCoinOp()
|
||||
flip_loop = flip_component(flipA.output)
|
||||
PrintOp('cool, it is over. %s' % flip_loop.outputs['flip_result'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
import kfp.compiler as compiler
|
||||
compiler.Compiler().compile(recursive, __file__ + '.tar.gz')
|
||||
|
|
@ -0,0 +1,229 @@
|
|||
apiVersion: argoproj.io/v1alpha1
|
||||
kind: Workflow
|
||||
metadata:
|
||||
generateName: pipeline-flip-coin-
|
||||
spec:
|
||||
arguments:
|
||||
parameters: []
|
||||
entrypoint: pipeline-flip-coin
|
||||
serviceAccountName: pipeline-runner
|
||||
templates:
|
||||
- dag:
|
||||
tasks:
|
||||
- arguments:
|
||||
parameters:
|
||||
- name: flip-output
|
||||
value: '{{inputs.parameters.flip-2-output}}'
|
||||
name: graph-flip-component-1
|
||||
template: graph-flip-component-1
|
||||
inputs:
|
||||
parameters:
|
||||
- name: flip-2-output
|
||||
name: condition-2
|
||||
- container:
|
||||
args:
|
||||
- python -c "import random; result = 'heads' if random.randint(0,1) == 0 else
|
||||
'tails'; print(result)" | tee /tmp/output
|
||||
command:
|
||||
- sh
|
||||
- -c
|
||||
image: python:alpine3.6
|
||||
name: flip
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: mlpipeline-ui-metadata
|
||||
path: /mlpipeline-ui-metadata.json
|
||||
s3:
|
||||
accessKeySecret:
|
||||
key: accesskey
|
||||
name: mlpipeline-minio-artifact
|
||||
bucket: mlpipeline
|
||||
endpoint: minio-service.kubeflow:9000
|
||||
insecure: true
|
||||
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-ui-metadata.tgz
|
||||
secretKeySecret:
|
||||
key: secretkey
|
||||
name: mlpipeline-minio-artifact
|
||||
- name: mlpipeline-metrics
|
||||
path: /mlpipeline-metrics.json
|
||||
s3:
|
||||
accessKeySecret:
|
||||
key: accesskey
|
||||
name: mlpipeline-minio-artifact
|
||||
bucket: mlpipeline
|
||||
endpoint: minio-service.kubeflow:9000
|
||||
insecure: true
|
||||
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-metrics.tgz
|
||||
secretKeySecret:
|
||||
key: secretkey
|
||||
name: mlpipeline-minio-artifact
|
||||
parameters:
|
||||
- name: flip-output
|
||||
valueFrom:
|
||||
path: /tmp/output
|
||||
- container:
|
||||
args:
|
||||
- python -c "import random; result = 'heads' if random.randint(0,1) == 0 else
|
||||
'tails'; print(result)" | tee /tmp/output
|
||||
command:
|
||||
- sh
|
||||
- -c
|
||||
image: python:alpine3.6
|
||||
name: flip-2
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: mlpipeline-ui-metadata
|
||||
path: /mlpipeline-ui-metadata.json
|
||||
s3:
|
||||
accessKeySecret:
|
||||
key: accesskey
|
||||
name: mlpipeline-minio-artifact
|
||||
bucket: mlpipeline
|
||||
endpoint: minio-service.kubeflow:9000
|
||||
insecure: true
|
||||
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-ui-metadata.tgz
|
||||
secretKeySecret:
|
||||
key: secretkey
|
||||
name: mlpipeline-minio-artifact
|
||||
- name: mlpipeline-metrics
|
||||
path: /mlpipeline-metrics.json
|
||||
s3:
|
||||
accessKeySecret:
|
||||
key: accesskey
|
||||
name: mlpipeline-minio-artifact
|
||||
bucket: mlpipeline
|
||||
endpoint: minio-service.kubeflow:9000
|
||||
insecure: true
|
||||
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-metrics.tgz
|
||||
secretKeySecret:
|
||||
key: secretkey
|
||||
name: mlpipeline-minio-artifact
|
||||
parameters:
|
||||
- name: flip-2-output
|
||||
valueFrom:
|
||||
path: /tmp/output
|
||||
- dag:
|
||||
tasks:
|
||||
- arguments:
|
||||
parameters:
|
||||
- name: flip-2-output
|
||||
value: '{{tasks.flip-2.outputs.parameters.flip-2-output}}'
|
||||
dependencies:
|
||||
- flip-2
|
||||
name: condition-2
|
||||
template: condition-2
|
||||
when: '{{tasks.flip-2.outputs.parameters.flip-2-output}} == heads'
|
||||
- dependencies:
|
||||
- print
|
||||
name: flip-2
|
||||
template: flip-2
|
||||
- arguments:
|
||||
parameters:
|
||||
- name: flip-output
|
||||
value: '{{inputs.parameters.flip-output}}'
|
||||
name: print
|
||||
template: print
|
||||
inputs:
|
||||
parameters:
|
||||
- name: flip-output
|
||||
name: graph-flip-component-1
|
||||
outputs:
|
||||
parameters:
|
||||
- name: flip-2-output
|
||||
valueFrom:
|
||||
parameter: '{{tasks.flip-2.outputs.parameters.flip-2-output}}'
|
||||
- dag:
|
||||
tasks:
|
||||
- name: flip
|
||||
template: flip
|
||||
- arguments:
|
||||
parameters:
|
||||
- name: flip-output
|
||||
value: '{{tasks.flip.outputs.parameters.flip-output}}'
|
||||
dependencies:
|
||||
- flip
|
||||
name: graph-flip-component-1
|
||||
template: graph-flip-component-1
|
||||
- arguments:
|
||||
parameters:
|
||||
- name: flip-2-output
|
||||
value: '{{tasks.graph-flip-component-1.outputs.parameters.flip-2-output}}'
|
||||
dependencies:
|
||||
- graph-flip-component-1
|
||||
name: print-2
|
||||
template: print-2
|
||||
name: pipeline-flip-coin
|
||||
- container:
|
||||
command:
|
||||
- echo
|
||||
- '{{inputs.parameters.flip-output}}'
|
||||
image: alpine:3.6
|
||||
inputs:
|
||||
parameters:
|
||||
- name: flip-output
|
||||
name: print
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: mlpipeline-ui-metadata
|
||||
path: /mlpipeline-ui-metadata.json
|
||||
s3:
|
||||
accessKeySecret:
|
||||
key: accesskey
|
||||
name: mlpipeline-minio-artifact
|
||||
bucket: mlpipeline
|
||||
endpoint: minio-service.kubeflow:9000
|
||||
insecure: true
|
||||
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-ui-metadata.tgz
|
||||
secretKeySecret:
|
||||
key: secretkey
|
||||
name: mlpipeline-minio-artifact
|
||||
- name: mlpipeline-metrics
|
||||
path: /mlpipeline-metrics.json
|
||||
s3:
|
||||
accessKeySecret:
|
||||
key: accesskey
|
||||
name: mlpipeline-minio-artifact
|
||||
bucket: mlpipeline
|
||||
endpoint: minio-service.kubeflow:9000
|
||||
insecure: true
|
||||
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-metrics.tgz
|
||||
secretKeySecret:
|
||||
key: secretkey
|
||||
name: mlpipeline-minio-artifact
|
||||
- container:
|
||||
command:
|
||||
- echo
|
||||
- cool, it is over. {{inputs.parameters.flip-2-output}}
|
||||
image: alpine:3.6
|
||||
inputs:
|
||||
parameters:
|
||||
- name: flip-2-output
|
||||
name: print-2
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: mlpipeline-ui-metadata
|
||||
path: /mlpipeline-ui-metadata.json
|
||||
s3:
|
||||
accessKeySecret:
|
||||
key: accesskey
|
||||
name: mlpipeline-minio-artifact
|
||||
bucket: mlpipeline
|
||||
endpoint: minio-service.kubeflow:9000
|
||||
insecure: true
|
||||
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-ui-metadata.tgz
|
||||
secretKeySecret:
|
||||
key: secretkey
|
||||
name: mlpipeline-minio-artifact
|
||||
- name: mlpipeline-metrics
|
||||
path: /mlpipeline-metrics.json
|
||||
s3:
|
||||
accessKeySecret:
|
||||
key: accesskey
|
||||
name: mlpipeline-minio-artifact
|
||||
bucket: mlpipeline
|
||||
endpoint: minio-service.kubeflow:9000
|
||||
insecure: true
|
||||
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-metrics.tgz
|
||||
secretKeySecret:
|
||||
key: secretkey
|
||||
name: mlpipeline-minio-artifact
|
||||
|
|
@ -13,10 +13,11 @@
|
|||
# limitations under the License.
|
||||
|
||||
import kfp
|
||||
from kfp.dsl import component
|
||||
import kfp.dsl as dsl
|
||||
from kfp.dsl import component, graph_component
|
||||
from kfp.dsl._metadata import ComponentMeta, ParameterMeta, TypeMeta
|
||||
from kfp.dsl.types import Integer, GCSPath, InconsistentTypeException
|
||||
from kfp.dsl import ContainerOp, Pipeline
|
||||
from kfp.dsl import ContainerOp, Pipeline, PipelineParam
|
||||
import unittest
|
||||
|
||||
class TestPythonComponent(unittest.TestCase):
|
||||
|
|
@ -422,4 +423,30 @@ class TestPythonComponent(unittest.TestCase):
|
|||
a = a_op(field_l=12)
|
||||
with self.assertRaises(InconsistentTypeException):
|
||||
b = b_op(field_x=a.outputs['field_n'], field_y=a.outputs['field_o'], field_z=a.outputs['field_m'])
|
||||
b = b_op(field_x=a.outputs['field_n'].ignore_type(), field_y=a.outputs['field_o'], field_z=a.outputs['field_m'])
|
||||
b = b_op(field_x=a.outputs['field_n'].ignore_type(), field_y=a.outputs['field_o'], field_z=a.outputs['field_m'])
|
||||
|
||||
class TestGraphComponent(unittest.TestCase):
|
||||
|
||||
def test_graphcomponent_basic(self):
|
||||
"""Test graph_component decorator metadata."""
|
||||
@graph_component
|
||||
def flip_component(flip_result):
|
||||
with dsl.Condition(flip_result == 'heads'):
|
||||
flip_component(flip_result)
|
||||
return {'flip_result': flip_result}
|
||||
|
||||
with Pipeline('pipeline') as p:
|
||||
param = PipelineParam(name='param')
|
||||
flip_component(param)
|
||||
self.assertEqual(1, len(p.groups))
|
||||
self.assertEqual(1, len(p.groups[0].groups)) # pipeline
|
||||
self.assertEqual(1, len(p.groups[0].groups[0].groups)) # flip_component
|
||||
self.assertEqual(1, len(p.groups[0].groups[0].groups[0].groups)) # condition
|
||||
self.assertEqual(0, len(p.groups[0].groups[0].groups[0].groups[0].groups)) # recursive flip_component
|
||||
recursive_group = p.groups[0].groups[0].groups[0].groups[0]
|
||||
self.assertTrue(recursive_group.recursive_ref is not None)
|
||||
self.assertEqual(1, len(recursive_group.inputs))
|
||||
self.assertEqual('param', recursive_group.inputs[0].name)
|
||||
original_group = p.groups[0].groups[0]
|
||||
self.assertTrue('flip_result' in original_group.outputs)
|
||||
self.assertEqual('param', original_group.outputs['flip_result'])
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import kfp.dsl as dsl
|
||||
from kfp.dsl import Pipeline, PipelineParam, ContainerOp, ExitHandler, OpsGroup
|
||||
import unittest
|
||||
|
||||
|
|
@ -47,6 +47,40 @@ class TestOpsGroup(unittest.TestCase):
|
|||
self.assertFalse(loop_group.groups)
|
||||
self.assertCountEqual([x.name for x in loop_group.ops], ['op4'])
|
||||
|
||||
def test_basic_recursive_opsgroups(self):
|
||||
"""Test recursive opsgroups."""
|
||||
with Pipeline('somename') as p:
|
||||
self.assertEqual(1, len(p.groups))
|
||||
|
||||
# When a graph opsgraph is called.
|
||||
graph_ops_group_one = dsl._ops_group.Graph('hello')
|
||||
graph_ops_group_one.__enter__()
|
||||
self.assertFalse(graph_ops_group_one.recursive_ref)
|
||||
self.assertEqual('graph-hello-1', graph_ops_group_one.name)
|
||||
|
||||
# Another graph opsgraph is called with the same name
|
||||
# when the previous graph opsgraphs is not finished.
|
||||
graph_ops_group_two = dsl._ops_group.Graph('hello')
|
||||
graph_ops_group_two.__enter__()
|
||||
self.assertTrue(graph_ops_group_two.recursive_ref)
|
||||
self.assertEqual(graph_ops_group_one, graph_ops_group_two.recursive_ref)
|
||||
|
||||
def test_recursive_opsgroups_with_prefix_names(self):
|
||||
"""Test recursive opsgroups."""
|
||||
with Pipeline('somename') as p:
|
||||
self.assertEqual(1, len(p.groups))
|
||||
|
||||
# When a graph opsgraph is called.
|
||||
graph_ops_group_one = dsl._ops_group.Graph('foo_bar')
|
||||
graph_ops_group_one.__enter__()
|
||||
self.assertFalse(graph_ops_group_one.recursive_ref)
|
||||
self.assertEqual('graph-foo-bar-1', graph_ops_group_one.name)
|
||||
|
||||
# Another graph opsgraph is called with the name as the prefix of the ops_group_one
|
||||
# when the previous graph opsgraphs is not finished.
|
||||
graph_ops_group_two = dsl._ops_group.Graph('foo')
|
||||
graph_ops_group_two.__enter__()
|
||||
self.assertFalse(graph_ops_group_two.recursive_ref)
|
||||
|
||||
class TestExitHandler(unittest.TestCase):
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue