feat(sdk.v2): Support Exit handler in v2 compiler. (#5784)
This commit is contained in:
parent
ef33b9e77d
commit
6b87155a33
|
|
@ -647,10 +647,6 @@ class Compiler(object):
|
|||
raise NotImplementedError(
|
||||
'dsl.graph_component is not yet supported in KFP v2 compiler.')
|
||||
|
||||
if isinstance(subgroup, dsl.OpsGroup) and subgroup.type == 'exit_handler':
|
||||
raise NotImplementedError(
|
||||
'dsl.ExitHandler is not yet supported in KFP v2 compiler.')
|
||||
|
||||
if isinstance(subgroup, dsl.ContainerOp):
|
||||
if hasattr(subgroup, 'importer_spec'):
|
||||
importer_task_name = subgroup.task_spec.task_info.name
|
||||
|
|
@ -909,8 +905,60 @@ class Compiler(object):
|
|||
op_name_to_parent_groups,
|
||||
)
|
||||
|
||||
# Exit Handler
|
||||
if pipeline.groups[0].groups:
|
||||
first_group = pipeline.groups[0].groups[0]
|
||||
if first_group.type == 'exit_handler':
|
||||
exit_handler_op = first_group.exit_op
|
||||
|
||||
# Add exit op task spec
|
||||
task_name = exit_handler_op.task_spec.task_info.name
|
||||
exit_handler_op.task_spec.dependent_tasks.extend(
|
||||
pipeline_spec.root.dag.tasks.keys())
|
||||
exit_handler_op.task_spec.trigger_policy.strategy = (
|
||||
pipeline_spec_pb2.PipelineTaskSpec.TriggerPolicy.TriggerStrategy
|
||||
.ALL_UPSTREAM_TASKS_COMPLETED)
|
||||
pipeline_spec.root.dag.tasks[task_name].CopyFrom(
|
||||
exit_handler_op.task_spec)
|
||||
|
||||
# Add exit op component spec if it does not exist.
|
||||
component_name = exit_handler_op.task_spec.component_ref.name
|
||||
if component_name not in pipeline_spec.components:
|
||||
pipeline_spec.components[component_name].CopyFrom(
|
||||
exit_handler_op.component_spec)
|
||||
|
||||
# Add exit op executor spec if it does not exist.
|
||||
executor_label = exit_handler_op.component_spec.executor_label
|
||||
if executor_label not in deployment_config.executors:
|
||||
deployment_config.executors[executor_label].container.CopyFrom(
|
||||
exit_handler_op.container_spec)
|
||||
pipeline_spec.deployment_spec.update(
|
||||
json_format.MessageToDict(deployment_config))
|
||||
|
||||
return pipeline_spec
|
||||
|
||||
def _validate_exit_handler(self, pipeline):
|
||||
"""Makes sure there is only one global exit handler.
|
||||
|
||||
This is temporary to be compatible with KFP v1.
|
||||
"""
|
||||
|
||||
def _validate_exit_handler_helper(group, exiting_op_names, handler_exists):
|
||||
if group.type == 'exit_handler':
|
||||
if handler_exists or len(exiting_op_names) > 1:
|
||||
raise ValueError(
|
||||
'Only one global exit_handler is allowed and all ops need to be included.'
|
||||
)
|
||||
handler_exists = True
|
||||
|
||||
if group.ops:
|
||||
exiting_op_names.extend([x.name for x in group.ops])
|
||||
|
||||
for g in group.groups:
|
||||
_validate_exit_handler_helper(g, exiting_op_names, handler_exists)
|
||||
|
||||
return _validate_exit_handler_helper(pipeline.groups[0], [], False)
|
||||
|
||||
# TODO: Sanitizing beforehand, so that we don't need to sanitize here.
|
||||
def _sanitize_and_inject_artifact(self, pipeline: dsl.Pipeline) -> None:
|
||||
"""Sanitize operator/param names and inject pipeline artifact location. """
|
||||
|
|
@ -1006,6 +1054,7 @@ class Compiler(object):
|
|||
with dsl.Pipeline(pipeline_name) as dsl_pipeline:
|
||||
pipeline_func(*args_list)
|
||||
|
||||
self._validate_exit_handler(dsl_pipeline)
|
||||
self._sanitize_and_inject_artifact(dsl_pipeline)
|
||||
|
||||
# Fill in the default values.
|
||||
|
|
|
|||
|
|
@ -74,50 +74,6 @@ class CompilerTest(unittest.TestCase):
|
|||
finally:
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def test_compile_pipeline_with_dsl_exithandler_should_raise_error(self):
|
||||
|
||||
gcs_download_op = components.load_component_from_text("""
|
||||
name: GCS - Download
|
||||
inputs:
|
||||
- {name: url, type: String}
|
||||
outputs:
|
||||
- {name: result, type: String}
|
||||
implementation:
|
||||
container:
|
||||
image: gcr.io/my-project/my-image:tag
|
||||
args:
|
||||
- {inputValue: url}
|
||||
- {outputPath: result}
|
||||
""")
|
||||
|
||||
echo_op = components.load_component_from_text("""
|
||||
name: echo
|
||||
inputs:
|
||||
- {name: msg, type: String}
|
||||
implementation:
|
||||
container:
|
||||
image: gcr.io/my-project/my-image:tag
|
||||
args:
|
||||
- {inputValue: msg}
|
||||
""")
|
||||
|
||||
@dsl.pipeline(name='test-pipeline', pipeline_root='dummy_root')
|
||||
def download_and_print(
|
||||
url: str = 'gs://ml-pipeline/shakespeare/shakespeare1.txt'):
|
||||
"""A sample pipeline showing exit handler."""
|
||||
|
||||
exit_task = echo_op('exit!')
|
||||
|
||||
with dsl.ExitHandler(exit_task):
|
||||
download_task = gcs_download_op(url)
|
||||
echo_task = echo_op(download_task.outputs['result'])
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
'dsl.ExitHandler is not yet supported in KFP v2 compiler.'):
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=download_and_print, package_path='output.json')
|
||||
|
||||
def test_compile_pipeline_with_dsl_graph_component_should_raise_error(self):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
|
|
|
|||
|
|
@ -127,6 +127,9 @@ class CompilerCliTests(unittest.TestCase):
|
|||
def test_pipeline_with_metrics_outputs(self):
|
||||
self._test_compile_py_to_json('pipeline_with_metrics_outputs')
|
||||
|
||||
def test_pipeline_with_exit_handler(self):
|
||||
self._test_compile_py_to_json('pipeline_with_exit_handler')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2021 The Kubeflow Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Pipeline using ExitHandler."""
|
||||
|
||||
from kfp.v2 import dsl
|
||||
from kfp.v2 import compiler
|
||||
from kfp.v2.dsl import component
|
||||
|
||||
|
||||
@component
|
||||
def print_op(message: str):
|
||||
"""Prints a message."""
|
||||
print(message)
|
||||
|
||||
|
||||
@component
|
||||
def fail_op(message: str):
|
||||
"""Fails."""
|
||||
import sys
|
||||
print(message)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@dsl.pipeline(name='pipeline-with-exit-handler', pipeline_root='dummy_root')
|
||||
def my_pipeline():
|
||||
|
||||
exit_task = print_op('Exit handler has worked!')
|
||||
|
||||
with dsl.ExitHandler(exit_task):
|
||||
print_op('Hello World!')
|
||||
fail_op('Task failed.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=my_pipeline, package_path=__file__.replace('.py', '.json'))
|
||||
|
|
@ -29,6 +29,6 @@ click>=7.1.1,<8
|
|||
|
||||
# kfp.v2
|
||||
absl-py>=0.9,<=0.11
|
||||
kfp-pipeline-spec>=0.1.7,<0.2.0
|
||||
kfp-pipeline-spec>=0.1.8,<0.2.0
|
||||
fire>=0.3.1,<1
|
||||
google-api-python-client>=1.7.8,<2
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ idna==2.10
|
|||
# via requests
|
||||
jsonschema==3.2.0
|
||||
# via -r requirements.in
|
||||
kfp-pipeline-spec==0.1.7
|
||||
kfp-pipeline-spec==0.1.8
|
||||
# via -r requirements.in
|
||||
kfp-server-api==1.3.0
|
||||
# via -r requirements.in
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ REQUIRES = [
|
|||
'Deprecated>=1.2.7,<2',
|
||||
'strip-hints>=0.1.8,<1',
|
||||
'docstring-parser>=0.7.3,<1',
|
||||
'kfp-pipeline-spec>=0.1.7,<0.2.0',
|
||||
'kfp-pipeline-spec>=0.1.8,<0.2.0',
|
||||
'fire>=0.3.1,<1',
|
||||
'protobuf>=3.13.0,<4'
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in New Issue