SDK/Compiler - Invoke the op_transformers as early as possible (#1464)
* Add reproducible test case * Invoke the op_transformers as early as possible
This commit is contained in:
parent
f8b06387b8
commit
381083a7c3
|
|
@ -453,7 +453,7 @@ class Compiler(object):
|
|||
|
||||
def _create_templates(self, pipeline, op_transformers=None, op_to_templates_handler=None):
|
||||
"""Create all groups and ops templates in the pipeline.
|
||||
|
||||
|
||||
Args:
|
||||
pipeline: Pipeline context object to get all the pipeline data from.
|
||||
op_transformers: A list of functions that are applied to all ContainerOp instances that are being processed.
|
||||
|
|
@ -463,6 +463,13 @@ class Compiler(object):
|
|||
op_to_templates_handler = op_to_templates_handler or (lambda op : [_op_to_template(op)])
|
||||
new_root_group = pipeline.groups[0]
|
||||
|
||||
# Call the transformation functions before determining the inputs/outputs, otherwise
|
||||
# the user would not be able to use pipeline parameters in the container definition
|
||||
# (for example as pod labels) - the generated template is invalid.
|
||||
for op in pipeline.ops.values():
|
||||
for transformer in op_transformers or []:
|
||||
transformer(op)
|
||||
|
||||
# Generate core data structures to prepare for argo yaml generation
|
||||
# op_groups: op name -> list of ancestor groups including the current op
|
||||
# opsgroups: a dictionary of ospgroup.name -> opsgroup
|
||||
|
|
@ -486,8 +493,6 @@ class Compiler(object):
|
|||
templates.append(template)
|
||||
|
||||
for op in pipeline.ops.values():
|
||||
for transformer in op_transformers or []:
|
||||
op = transformer(op) or op
|
||||
templates.extend(op_to_templates_handler(op))
|
||||
return templates
|
||||
|
||||
|
|
|
|||
|
|
@ -366,6 +366,10 @@ class TestCompiler(unittest.TestCase):
|
|||
"""Test pipeline param_substitutions."""
|
||||
self._test_py_compile_yaml('param_substitutions')
|
||||
|
||||
def test_py_param_op_transform(self):
|
||||
"""Test pipeline param_op_transform."""
|
||||
self._test_py_compile_yaml('param_op_transform')
|
||||
|
||||
def test_type_checking_with_consistent_types(self):
|
||||
"""Test type check pipeline parameters against component metadata."""
|
||||
@component
|
||||
|
|
@ -471,7 +475,7 @@ class TestCompiler(unittest.TestCase):
|
|||
def pipeline():
|
||||
task1 = op()
|
||||
task2 = op().after(task1)
|
||||
|
||||
|
||||
compiler.Compiler()._compile(pipeline)
|
||||
|
||||
def _test_op_to_template_yaml(self, ops, file_base_name):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,28 @@
|
|||
from typing import Callable
|
||||
|
||||
import kfp.dsl as dsl
|
||||
|
||||
def add_common_labels(param):
|
||||
|
||||
def _add_common_labels(op: dsl.ContainerOp) -> dsl.ContainerOp:
|
||||
return op.add_pod_label('param', param)
|
||||
|
||||
return _add_common_labels
|
||||
|
||||
@dsl.pipeline(
|
||||
name="Parameters in Op transformation functions",
|
||||
description="Test that parameters used in Op transformation functions as pod labels "
|
||||
"would be correcly identified and set as arguments in he generated yaml"
|
||||
)
|
||||
def param_substitutions(param = dsl.PipelineParam(name='param')):
|
||||
dsl.get_pipeline_conf().op_transformers.append(add_common_labels(param))
|
||||
|
||||
op = dsl.ContainerOp(
|
||||
name="cop",
|
||||
image="image",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import kfp.compiler as compiler
|
||||
compiler.Compiler().compile(param_substitutions, __file__ + '.yaml')
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
apiVersion: argoproj.io/v1alpha1
|
||||
kind: Workflow
|
||||
metadata:
|
||||
generateName: parameters-in-op-transformation-functions-
|
||||
spec:
|
||||
arguments:
|
||||
parameters:
|
||||
- name: param
|
||||
entrypoint: parameters-in-op-transformation-functions
|
||||
serviceAccountName: pipeline-runner
|
||||
templates:
|
||||
- container:
|
||||
image: image
|
||||
inputs:
|
||||
parameters:
|
||||
- name: param
|
||||
metadata:
|
||||
labels:
|
||||
param: '{{inputs.parameters.param}}'
|
||||
name: cop
|
||||
outputs:
|
||||
artifacts:
|
||||
- name: mlpipeline-ui-metadata
|
||||
optional: true
|
||||
path: /mlpipeline-ui-metadata.json
|
||||
- name: mlpipeline-metrics
|
||||
optional: true
|
||||
path: /mlpipeline-metrics.json
|
||||
- dag:
|
||||
tasks:
|
||||
- arguments:
|
||||
parameters:
|
||||
- name: param
|
||||
value: '{{inputs.parameters.param}}'
|
||||
name: cop
|
||||
template: cop
|
||||
inputs:
|
||||
parameters:
|
||||
- name: param
|
||||
name: parameters-in-op-transformation-functions
|
||||
Loading…
Reference in New Issue