Support recursions in a function (#1014)

* add a While in the ops group

* deepcopy the while conditions when entering and exiting

* add while condition resolution in the compiler

* define graph component decorator

* remove while loop related codes

* fixes

* remove while loop related code

* fix bugs

* generate a unique ops group name and being able to retrieve by name

* resolve the opsgroups inputs and dependencies based on the pipelineparam in the condition

* add a recursive ops_groups

* fix bugs of the recursive opsgroup template name

* resolve the recursive template name and arguments

* add validity checks

* add more comments

* add usage comment in graph_component

* add unit test for the graph opsgraph

* refactor the opsgroup

* add unit test for the graph_component decorator

* exposing graph_component decorator

* add recursive compiler unit tests

* fix the bug of opsgroup name
adjust the graph_component usage example
fix index bugs
use with statement in the graph_component instead of directly calling
the enter/exit functions

* add a todo to combine the graph_component and component decorators
This commit is contained in:
Ning 2019-03-26 14:17:18 -07:00 committed by Kubernetes Prow Robot
parent 1c4f9eb431
commit 8c09090985
10 changed files with 564 additions and 27 deletions

View File

@ -23,6 +23,7 @@ from .. import dsl
from ._k8s_helper import K8sHelper
from ..dsl._pipeline_param import _match_serialized_pipelineparam
from ..dsl._metadata import TypeMeta
from ..dsl._ops_group import OpsGroup
class Compiler(object):
"""DSL Compiler.
@ -65,6 +66,11 @@ class Compiler(object):
def _get_op_groups_helper(current_groups, ops_to_groups):
root_group = current_groups[-1]
for g in root_group.groups:
# Add recursive opsgroup in the ops_to_groups
# such that the i/o dependency can be propagated to the ancester opsgroups
if g.recursive_ref:
ops_to_groups[g.name] = [x.name for x in current_groups] + [g.name]
continue
current_groups.append(g)
_get_op_groups_helper(current_groups, ops_to_groups)
del current_groups[-1]
@ -82,7 +88,10 @@ class Compiler(object):
def _get_groups_helper(group):
groups = [group]
for g in group.groups:
groups += _get_groups_helper(g)
# Skip the recursive opsgroup because no templates
# need to be generated for the recursive opsgroups.
if not g.recursive_ref:
groups += _get_groups_helper(g)
return groups
return _get_groups_helper(root_group)
@ -145,6 +154,38 @@ class Compiler(object):
if not op.is_exit_handler:
for g in op_groups[op.name]:
inputs[g].add((full_name, None))
# Generate the input/output for recursive opsgroups
# It propagates the recursive opsgroups IO to their ancester opsgroups
def _get_inputs_outputs_recursive_opsgroup(group):
#TODO: refactor the following codes with the above
if group.recursive_ref:
for param in group.inputs + list(condition_params[group.name]):
if param.value:
continue
full_name = self._pipelineparam_full_name(param)
if param.op_name:
upstream_op = pipeline.ops[param.op_name]
upstream_groups, downstream_groups = self._get_uncommon_ancestors(
op_groups, upstream_op, group)
for i, g in enumerate(downstream_groups):
if i == 0:
inputs[g].add((full_name, upstream_groups[0]))
else:
inputs[g].add((full_name, None))
for i, g in enumerate(upstream_groups):
if i == len(upstream_groups) - 1:
outputs[g].add((full_name, None))
else:
outputs[g].add((full_name, upstream_groups[i+1]))
else:
if not op.is_exit_handler:
for g in op_groups[op.name]:
inputs[g].add((full_name, None))
for subgroup in group.groups:
_get_inputs_outputs_recursive_opsgroup(subgroup)
_get_inputs_outputs_recursive_opsgroup(root_group)
return inputs, outputs
def _get_condition_params_for_ops(self, root_group):
@ -164,8 +205,13 @@ class Compiler(object):
for param in new_current_conditions_params:
conditions[op.name].add(param)
for g in group.groups:
_get_condition_params_for_ops_helper(g, new_current_conditions_params)
# If the subgroup is a recursive opsgroup, propagate the pipelineparams
# in the condition expression, similar to the ops.
if g.recursive_ref:
for param in new_current_conditions_params:
conditions[g.name].add(param)
else:
_get_condition_params_for_ops_helper(g, new_current_conditions_params)
_get_condition_params_for_ops_helper(root_group, [])
return conditions
@ -179,6 +225,8 @@ class Compiler(object):
then G3 is dependent on G2. Basically dependency only exists in the first uncommon
ancesters in their ancesters chain. Only sibling groups/ops can have dependencies.
"""
#TODO: move the condition_params out because both the _get_inputs_outputs
# and _get_dependencies depend on it.
condition_params = self._get_condition_params_for_ops(root_group)
dependencies = defaultdict(set)
for op in pipeline.ops.values():
@ -193,6 +241,29 @@ class Compiler(object):
upstream_groups, downstream_groups = self._get_uncommon_ancestors(
op_groups, upstream_op, op)
dependencies[downstream_groups[0]].add(upstream_groups[0])
# Generate dependencies based on the recursive opsgroups
#TODO: refactor the following codes with the above
def _get_dependency_opsgroup(group, dependencies):
if group.recursive_ref:
unstream_op_names = set()
for param in group.inputs + list(condition_params[group.name]):
if param.op_name:
unstream_op_names.add(param.op_name)
unstream_op_names |= set(group.dependencies)
for op_name in unstream_op_names:
upstream_op = pipeline.ops[op_name]
upstream_groups, downstream_groups = self._get_uncommon_ancestors(
op_groups, upstream_op, group)
dependencies[downstream_groups[0]].add(upstream_groups[0])
for subgroup in group.groups:
_get_dependency_opsgroup(subgroup, dependencies)
_get_dependency_opsgroup(root_group, dependencies)
return dependencies
def _resolve_value_or_reference(self, value_or_reference, potential_references):
@ -364,11 +435,18 @@ class Compiler(object):
# Generate tasks section.
tasks = []
for sub_group in group.groups + group.ops:
task = {
'name': sub_group.name,
'template': sub_group.name,
}
is_recursive_subgroup = (isinstance(sub_group, OpsGroup) and sub_group.recursive_ref)
# Special handling for recursive subgroup: use the existing opsgroup name
if is_recursive_subgroup:
task = {
'name': sub_group.recursive_ref.name,
'template': sub_group.recursive_ref.name,
}
else:
task = {
'name': sub_group.name,
'template': sub_group.name,
}
if isinstance(sub_group, dsl.OpsGroup) and sub_group.type == 'condition':
subgroup_inputs = inputs.get(sub_group.name, [])
condition = sub_group.condition
@ -394,10 +472,22 @@ class Compiler(object):
})
else:
# The value comes from its parent.
arguments.append({
'name': param_name,
'value': '{{inputs.parameters.%s}}' % param_name
})
# Special handling for recursive subgroup: argument name comes from the existing opsgroup
if is_recursive_subgroup:
for index, input in enumerate(sub_group.inputs):
if param_name == input.name:
break
referenced_input = sub_group.recursive_ref.inputs[index]
full_name = self._pipelineparam_full_name(referenced_input)
arguments.append({
'name': full_name,
'value': '{{inputs.parameters.%s}}' % param_name
})
else:
arguments.append({
'name': param_name,
'value': '{{inputs.parameters.%s}}' % param_name
})
arguments.sort(key=lambda x: x['name'])
task['arguments'] = {'parameters': arguments}
tasks.append(task)
@ -410,6 +500,15 @@ class Compiler(object):
new_root_group = pipeline.groups[0]
# Generate core data structures to prepare for argo yaml generation
# op_groups: op name -> list of ancestor groups including the current op
# inputs, outputs: group/op names -> list of tuples (param_name, producing_op_name)
# dependencies: group/op name -> list of dependent groups/ops.
# groups: opsgroups
# Special Handling for the recursive opsgroup
# op_groups also contains the recursive opsgroups
# condition_params from _get_condition_params_for_ops also contains the recursive opsgroups
# groups does not include the recursive opsgroups
op_groups = self._get_groups_for_ops(new_root_group)
inputs, outputs = self._get_inputs_outputs(pipeline, new_root_group, op_groups)
dependencies = self._get_dependencies(pipeline, new_root_group, op_groups)

View File

@ -17,5 +17,4 @@ from ._pipeline_param import PipelineParam
from ._pipeline import Pipeline, pipeline, get_pipeline_conf
from ._container_op import ContainerOp
from ._ops_group import OpsGroup, ExitHandler, Condition
from ._component import python_component, component
#TODO: expose the component decorator when ready
from ._component import python_component, graph_component, component

View File

@ -15,6 +15,7 @@
from ._metadata import ComponentMeta, ParameterMeta, TypeMeta, _annotation_to_typemeta
from ._pipeline_param import PipelineParam
from .types import check_types, InconsistentTypeException
from ._ops_group import Graph
import kfp
def python_component(name, description=None, base_image=None, target_component_file: str = None):
@ -54,7 +55,7 @@ def python_component(name, description=None, base_image=None, target_component_f
return _python_component
def component(func):
"""Decorator for component functions that use ContainerOp.
"""Decorator for component functions that returns a ContainerOp.
This is useful to enable type checking in the DSL compiler
Usage:
@ -118,3 +119,42 @@ def component(func):
return container_op
return _component
#TODO: combine the component and graph_component decorators into one
def graph_component(func):
"""Decorator for graph component functions.
This decorator returns an ops_group.
Usage:
```python
import kfp.dsl as dsl
@dsl.graph_component
def flip_component(flip_result):
print_flip = PrintOp(flip_result)
flipA = FlipCoinOp().after(print_flip)
with dsl.Condition(flipA.output == 'heads'):
flip_component(flipA.output)
return {'flip_result': flipA.output}
"""
from functools import wraps
@wraps(func)
def _graph_component(*args, **kargs):
graph_ops_group = Graph(func.__name__)
graph_ops_group.inputs = list(args) + list(kargs.values())
for input in graph_ops_group.inputs:
if not isinstance(input, PipelineParam):
raise ValueError('arguments to ' + func.__name__ + ' should be PipelineParams.')
# Entering the Graph Context
with graph_ops_group:
# Call the function
if not graph_ops_group.recursive_ref:
graph_ops_group.outputs = func(*args, **kargs)
if not isinstance(graph_ops_group.outputs, dict):
raise ValueError(func.__name__ + ' needs to return a dictionary of string to PipelineParam.')
for output in graph_ops_group.outputs:
if not (isinstance(output, str) and isinstance(graph_ops_group.outputs[output], PipelineParam)):
raise ValueError(func.__name__ + ' needs to return a dictionary of string to PipelineParam.')
return graph_ops_group
return _graph_component

View File

@ -15,7 +15,7 @@
from . import _container_op
from . import _pipeline
from ._pipeline_param import ConditionOperator
class OpsGroup(object):
"""Represents a logical group of ops and group of OpsGroups.
@ -28,21 +28,51 @@ class OpsGroup(object):
def __init__(self, group_type: str, name: str=None):
"""Create a new instance of OpsGroup.
Args:
group_type: one of 'pipeline', 'exit_handler', 'condition', and 'loop'.
group_type (str): one of 'pipeline', 'exit_handler', 'condition', and 'graph'.
name (str): name of the opsgroup
"""
#TODO: declare the group_type to be strongly typed
self.type = group_type
self.ops = list()
self.groups = list()
self.name = name
# recursive_ref points to the opsgroups with the same name if exists.
self.recursive_ref = None
@staticmethod
def _get_opsgroup_pipeline(group_type, name):
"""retrieves the opsgroup when the pipeline already contains it.
the opsgroup might be already in the pipeline in case of recursive calls.
Args:
group_type (str): one of 'pipeline', 'exit_handler', 'condition', and 'graph'.
name (str): the name before conversion. """
if not _pipeline.Pipeline.get_default_pipeline():
raise ValueError('Default pipeline not defined.')
if name is None:
return None
name_pattern = '^' + (group_type + '-' + name + '-').replace('_', '-') + '[\d]+$'
for ops_group in _pipeline.Pipeline.get_default_pipeline().groups:
import re
if ops_group.type == group_type and re.match(name_pattern ,ops_group.name):
return ops_group
return None
def _make_name_unique(self):
"""Generate a unique opsgroup name in the pipeline"""
if not _pipeline.Pipeline.get_default_pipeline():
raise ValueError('Default pipeline not defined.')
self.name = (self.type + '-' + ('' if self.name is None else self.name + '-') +
str(_pipeline.Pipeline.get_default_pipeline().get_next_group_id()))
self.name = self.name.replace('_', '-')
def __enter__(self):
if not _pipeline.Pipeline.get_default_pipeline():
raise ValueError('Default pipeline not defined.')
if not self.name:
self.name = (self.type + '-' +
str(_pipeline.Pipeline.get_default_pipeline().get_next_group_id()))
self.name = self.name.replace('_', '-')
self.recursive_ref = self._get_opsgroup_pipeline(self.type, self.name)
if not self.recursive_ref:
self._make_name_unique()
_pipeline.Pipeline.get_default_pipeline().push_ops_group(self)
return self
@ -50,7 +80,6 @@ class OpsGroup(object):
def __exit__(self, *args):
_pipeline.Pipeline.get_default_pipeline().pop_ops_group()
class ExitHandler(OpsGroup):
"""Represents an exit handler that is invoked upon exiting a group of ops.
@ -92,10 +121,25 @@ class Condition(OpsGroup):
def __init__(self, condition):
"""Create a new instance of ExitHandler.
Args:
exit_op: an operator invoked at exiting a group of ops.
condition (ConditionOperator): the condition.
Raises:
ValueError is the exit_op is invalid.
"""
super(Condition, self).__init__('condition')
self.condition = condition
class Graph(OpsGroup):
"""Graph DAG with inputs, recursive_inputs, and outputs.
This is not used directly by the users but auto generated when the graph_component decoration exists
"""
def __init__(self, name):
super(Graph, self).__init__(group_type='graph', name=name)
self.inputs = []
self.outputs = {}
self.dependencies = []
def after(self, dependency):
"""Specify explicit dependency on another op."""
self.dependencies.append(dependency)
return self

View File

@ -86,6 +86,7 @@ def get_pipeline_conf():
"""
return Pipeline.get_default_pipeline().conf
#TODO: Pipeline is in fact an opsgroup, refactor the code.
class Pipeline():
"""A pipeline contains a list of operators.

View File

@ -241,6 +241,10 @@ class TestCompiler(unittest.TestCase):
"""Test pipeline imagepullsecret."""
self._test_py_compile('imagepullsecret')
def test_py_recursive(self):
"""Test pipeline recursive."""
self._test_py_compile('recursive')
def test_type_checking_with_consistent_types(self):
"""Test type check pipeline parameters against component metadata."""
@component

View File

@ -0,0 +1,60 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import kfp.dsl as dsl
from kfp.dsl import graph_component
class FlipCoinOp(dsl.ContainerOp):
"""Flip a coin and output heads or tails randomly."""
def __init__(self):
super(FlipCoinOp, self).__init__(
name='Flip',
image='python:alpine3.6',
command=['sh', '-c'],
arguments=['python -c "import random; result = \'heads\' if random.randint(0,1) == 0 '
'else \'tails\'; print(result)" | tee /tmp/output'],
file_outputs={'output': '/tmp/output'})
class PrintOp(dsl.ContainerOp):
"""Print a message."""
def __init__(self, msg):
super(PrintOp, self).__init__(
name='Print',
image='alpine:3.6',
command=['echo', msg],
)
@graph_component
def flip_component(flip_result):
print_flip = PrintOp(flip_result)
flipA = FlipCoinOp().after(print_flip)
with dsl.Condition(flipA.output == 'heads'):
flip_component(flipA.output)
return {'flip_result': flipA.output}
@dsl.pipeline(
name='pipeline flip coin',
description='shows how to use graph_component.'
)
def recursive():
flipA = FlipCoinOp()
flip_loop = flip_component(flipA.output)
PrintOp('cool, it is over. %s' % flip_loop.outputs['flip_result'])
if __name__ == '__main__':
import kfp.compiler as compiler
compiler.Compiler().compile(recursive, __file__ + '.tar.gz')

View File

@ -0,0 +1,229 @@
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: pipeline-flip-coin-
spec:
arguments:
parameters: []
entrypoint: pipeline-flip-coin
serviceAccountName: pipeline-runner
templates:
- dag:
tasks:
- arguments:
parameters:
- name: flip-output
value: '{{inputs.parameters.flip-2-output}}'
name: graph-flip-component-1
template: graph-flip-component-1
inputs:
parameters:
- name: flip-2-output
name: condition-2
- container:
args:
- python -c "import random; result = 'heads' if random.randint(0,1) == 0 else
'tails'; print(result)" | tee /tmp/output
command:
- sh
- -c
image: python:alpine3.6
name: flip
outputs:
artifacts:
- name: mlpipeline-ui-metadata
path: /mlpipeline-ui-metadata.json
s3:
accessKeySecret:
key: accesskey
name: mlpipeline-minio-artifact
bucket: mlpipeline
endpoint: minio-service.kubeflow:9000
insecure: true
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-ui-metadata.tgz
secretKeySecret:
key: secretkey
name: mlpipeline-minio-artifact
- name: mlpipeline-metrics
path: /mlpipeline-metrics.json
s3:
accessKeySecret:
key: accesskey
name: mlpipeline-minio-artifact
bucket: mlpipeline
endpoint: minio-service.kubeflow:9000
insecure: true
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-metrics.tgz
secretKeySecret:
key: secretkey
name: mlpipeline-minio-artifact
parameters:
- name: flip-output
valueFrom:
path: /tmp/output
- container:
args:
- python -c "import random; result = 'heads' if random.randint(0,1) == 0 else
'tails'; print(result)" | tee /tmp/output
command:
- sh
- -c
image: python:alpine3.6
name: flip-2
outputs:
artifacts:
- name: mlpipeline-ui-metadata
path: /mlpipeline-ui-metadata.json
s3:
accessKeySecret:
key: accesskey
name: mlpipeline-minio-artifact
bucket: mlpipeline
endpoint: minio-service.kubeflow:9000
insecure: true
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-ui-metadata.tgz
secretKeySecret:
key: secretkey
name: mlpipeline-minio-artifact
- name: mlpipeline-metrics
path: /mlpipeline-metrics.json
s3:
accessKeySecret:
key: accesskey
name: mlpipeline-minio-artifact
bucket: mlpipeline
endpoint: minio-service.kubeflow:9000
insecure: true
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-metrics.tgz
secretKeySecret:
key: secretkey
name: mlpipeline-minio-artifact
parameters:
- name: flip-2-output
valueFrom:
path: /tmp/output
- dag:
tasks:
- arguments:
parameters:
- name: flip-2-output
value: '{{tasks.flip-2.outputs.parameters.flip-2-output}}'
dependencies:
- flip-2
name: condition-2
template: condition-2
when: '{{tasks.flip-2.outputs.parameters.flip-2-output}} == heads'
- dependencies:
- print
name: flip-2
template: flip-2
- arguments:
parameters:
- name: flip-output
value: '{{inputs.parameters.flip-output}}'
name: print
template: print
inputs:
parameters:
- name: flip-output
name: graph-flip-component-1
outputs:
parameters:
- name: flip-2-output
valueFrom:
parameter: '{{tasks.flip-2.outputs.parameters.flip-2-output}}'
- dag:
tasks:
- name: flip
template: flip
- arguments:
parameters:
- name: flip-output
value: '{{tasks.flip.outputs.parameters.flip-output}}'
dependencies:
- flip
name: graph-flip-component-1
template: graph-flip-component-1
- arguments:
parameters:
- name: flip-2-output
value: '{{tasks.graph-flip-component-1.outputs.parameters.flip-2-output}}'
dependencies:
- graph-flip-component-1
name: print-2
template: print-2
name: pipeline-flip-coin
- container:
command:
- echo
- '{{inputs.parameters.flip-output}}'
image: alpine:3.6
inputs:
parameters:
- name: flip-output
name: print
outputs:
artifacts:
- name: mlpipeline-ui-metadata
path: /mlpipeline-ui-metadata.json
s3:
accessKeySecret:
key: accesskey
name: mlpipeline-minio-artifact
bucket: mlpipeline
endpoint: minio-service.kubeflow:9000
insecure: true
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-ui-metadata.tgz
secretKeySecret:
key: secretkey
name: mlpipeline-minio-artifact
- name: mlpipeline-metrics
path: /mlpipeline-metrics.json
s3:
accessKeySecret:
key: accesskey
name: mlpipeline-minio-artifact
bucket: mlpipeline
endpoint: minio-service.kubeflow:9000
insecure: true
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-metrics.tgz
secretKeySecret:
key: secretkey
name: mlpipeline-minio-artifact
- container:
command:
- echo
- cool, it is over. {{inputs.parameters.flip-2-output}}
image: alpine:3.6
inputs:
parameters:
- name: flip-2-output
name: print-2
outputs:
artifacts:
- name: mlpipeline-ui-metadata
path: /mlpipeline-ui-metadata.json
s3:
accessKeySecret:
key: accesskey
name: mlpipeline-minio-artifact
bucket: mlpipeline
endpoint: minio-service.kubeflow:9000
insecure: true
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-ui-metadata.tgz
secretKeySecret:
key: secretkey
name: mlpipeline-minio-artifact
- name: mlpipeline-metrics
path: /mlpipeline-metrics.json
s3:
accessKeySecret:
key: accesskey
name: mlpipeline-minio-artifact
bucket: mlpipeline
endpoint: minio-service.kubeflow:9000
insecure: true
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-metrics.tgz
secretKeySecret:
key: secretkey
name: mlpipeline-minio-artifact

View File

@ -13,10 +13,11 @@
# limitations under the License.
import kfp
from kfp.dsl import component
import kfp.dsl as dsl
from kfp.dsl import component, graph_component
from kfp.dsl._metadata import ComponentMeta, ParameterMeta, TypeMeta
from kfp.dsl.types import Integer, GCSPath, InconsistentTypeException
from kfp.dsl import ContainerOp, Pipeline
from kfp.dsl import ContainerOp, Pipeline, PipelineParam
import unittest
class TestPythonComponent(unittest.TestCase):
@ -422,4 +423,30 @@ class TestPythonComponent(unittest.TestCase):
a = a_op(field_l=12)
with self.assertRaises(InconsistentTypeException):
b = b_op(field_x=a.outputs['field_n'], field_y=a.outputs['field_o'], field_z=a.outputs['field_m'])
b = b_op(field_x=a.outputs['field_n'].ignore_type(), field_y=a.outputs['field_o'], field_z=a.outputs['field_m'])
b = b_op(field_x=a.outputs['field_n'].ignore_type(), field_y=a.outputs['field_o'], field_z=a.outputs['field_m'])
class TestGraphComponent(unittest.TestCase):
def test_graphcomponent_basic(self):
"""Test graph_component decorator metadata."""
@graph_component
def flip_component(flip_result):
with dsl.Condition(flip_result == 'heads'):
flip_component(flip_result)
return {'flip_result': flip_result}
with Pipeline('pipeline') as p:
param = PipelineParam(name='param')
flip_component(param)
self.assertEqual(1, len(p.groups))
self.assertEqual(1, len(p.groups[0].groups)) # pipeline
self.assertEqual(1, len(p.groups[0].groups[0].groups)) # flip_component
self.assertEqual(1, len(p.groups[0].groups[0].groups[0].groups)) # condition
self.assertEqual(0, len(p.groups[0].groups[0].groups[0].groups[0].groups)) # recursive flip_component
recursive_group = p.groups[0].groups[0].groups[0].groups[0]
self.assertTrue(recursive_group.recursive_ref is not None)
self.assertEqual(1, len(recursive_group.inputs))
self.assertEqual('param', recursive_group.inputs[0].name)
original_group = p.groups[0].groups[0]
self.assertTrue('flip_result' in original_group.outputs)
self.assertEqual('param', original_group.outputs['flip_result'])

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import kfp.dsl as dsl
from kfp.dsl import Pipeline, PipelineParam, ContainerOp, ExitHandler, OpsGroup
import unittest
@ -47,6 +47,40 @@ class TestOpsGroup(unittest.TestCase):
self.assertFalse(loop_group.groups)
self.assertCountEqual([x.name for x in loop_group.ops], ['op4'])
def test_basic_recursive_opsgroups(self):
"""Test recursive opsgroups."""
with Pipeline('somename') as p:
self.assertEqual(1, len(p.groups))
# When a graph opsgraph is called.
graph_ops_group_one = dsl._ops_group.Graph('hello')
graph_ops_group_one.__enter__()
self.assertFalse(graph_ops_group_one.recursive_ref)
self.assertEqual('graph-hello-1', graph_ops_group_one.name)
# Another graph opsgraph is called with the same name
# when the previous graph opsgraphs is not finished.
graph_ops_group_two = dsl._ops_group.Graph('hello')
graph_ops_group_two.__enter__()
self.assertTrue(graph_ops_group_two.recursive_ref)
self.assertEqual(graph_ops_group_one, graph_ops_group_two.recursive_ref)
def test_recursive_opsgroups_with_prefix_names(self):
"""Test recursive opsgroups."""
with Pipeline('somename') as p:
self.assertEqual(1, len(p.groups))
# When a graph opsgraph is called.
graph_ops_group_one = dsl._ops_group.Graph('foo_bar')
graph_ops_group_one.__enter__()
self.assertFalse(graph_ops_group_one.recursive_ref)
self.assertEqual('graph-foo-bar-1', graph_ops_group_one.name)
# Another graph opsgraph is called with the name as the prefix of the ops_group_one
# when the previous graph opsgraphs is not finished.
graph_ops_group_two = dsl._ops_group.Graph('foo')
graph_ops_group_two.__enter__()
self.assertFalse(graph_ops_group_two.recursive_ref)
class TestExitHandler(unittest.TestCase):