feat(sdk.v2): Support Exit handler in v2 compiler. (#5784)

This commit is contained in:
Chen Sun 2021-06-03 12:58:36 -07:00 committed by GitHub
parent ef33b9e77d
commit 6b87155a33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 294 additions and 51 deletions

View File

@ -647,10 +647,6 @@ class Compiler(object):
raise NotImplementedError(
'dsl.graph_component is not yet supported in KFP v2 compiler.')
if isinstance(subgroup, dsl.OpsGroup) and subgroup.type == 'exit_handler':
raise NotImplementedError(
'dsl.ExitHandler is not yet supported in KFP v2 compiler.')
if isinstance(subgroup, dsl.ContainerOp):
if hasattr(subgroup, 'importer_spec'):
importer_task_name = subgroup.task_spec.task_info.name
@ -909,8 +905,60 @@ class Compiler(object):
op_name_to_parent_groups,
)
# Exit Handler
if pipeline.groups[0].groups:
first_group = pipeline.groups[0].groups[0]
if first_group.type == 'exit_handler':
exit_handler_op = first_group.exit_op
# Add exit op task spec
task_name = exit_handler_op.task_spec.task_info.name
exit_handler_op.task_spec.dependent_tasks.extend(
pipeline_spec.root.dag.tasks.keys())
exit_handler_op.task_spec.trigger_policy.strategy = (
pipeline_spec_pb2.PipelineTaskSpec.TriggerPolicy.TriggerStrategy
.ALL_UPSTREAM_TASKS_COMPLETED)
pipeline_spec.root.dag.tasks[task_name].CopyFrom(
exit_handler_op.task_spec)
# Add exit op component spec if it does not exist.
component_name = exit_handler_op.task_spec.component_ref.name
if component_name not in pipeline_spec.components:
pipeline_spec.components[component_name].CopyFrom(
exit_handler_op.component_spec)
# Add exit op executor spec if it does not exist.
executor_label = exit_handler_op.component_spec.executor_label
if executor_label not in deployment_config.executors:
deployment_config.executors[executor_label].container.CopyFrom(
exit_handler_op.container_spec)
pipeline_spec.deployment_spec.update(
json_format.MessageToDict(deployment_config))
return pipeline_spec
def _validate_exit_handler(self, pipeline):
"""Makes sure there is only one global exit handler.
This is temporary to be compatible with KFP v1.
"""
def _validate_exit_handler_helper(group, exiting_op_names, handler_exists):
if group.type == 'exit_handler':
if handler_exists or len(exiting_op_names) > 1:
raise ValueError(
'Only one global exit_handler is allowed and all ops need to be included.'
)
handler_exists = True
if group.ops:
exiting_op_names.extend([x.name for x in group.ops])
for g in group.groups:
_validate_exit_handler_helper(g, exiting_op_names, handler_exists)
return _validate_exit_handler_helper(pipeline.groups[0], [], False)
# TODO: Sanitizing beforehand, so that we don't need to sanitize here.
def _sanitize_and_inject_artifact(self, pipeline: dsl.Pipeline) -> None:
"""Sanitize operator/param names and inject pipeline artifact location. """
@ -1006,6 +1054,7 @@ class Compiler(object):
with dsl.Pipeline(pipeline_name) as dsl_pipeline:
pipeline_func(*args_list)
self._validate_exit_handler(dsl_pipeline)
self._sanitize_and_inject_artifact(dsl_pipeline)
# Fill in the default values.

View File

@ -74,50 +74,6 @@ class CompilerTest(unittest.TestCase):
finally:
shutil.rmtree(tmpdir)
def test_compile_pipeline_with_dsl_exithandler_should_raise_error(self):
gcs_download_op = components.load_component_from_text("""
name: GCS - Download
inputs:
- {name: url, type: String}
outputs:
- {name: result, type: String}
implementation:
container:
image: gcr.io/my-project/my-image:tag
args:
- {inputValue: url}
- {outputPath: result}
""")
echo_op = components.load_component_from_text("""
name: echo
inputs:
- {name: msg, type: String}
implementation:
container:
image: gcr.io/my-project/my-image:tag
args:
- {inputValue: msg}
""")
@dsl.pipeline(name='test-pipeline', pipeline_root='dummy_root')
def download_and_print(
url: str = 'gs://ml-pipeline/shakespeare/shakespeare1.txt'):
"""A sample pipeline showing exit handler."""
exit_task = echo_op('exit!')
with dsl.ExitHandler(exit_task):
download_task = gcs_download_op(url)
echo_task = echo_op(download_task.outputs['result'])
with self.assertRaisesRegex(
NotImplementedError,
'dsl.ExitHandler is not yet supported in KFP v2 compiler.'):
compiler.Compiler().compile(
pipeline_func=download_and_print, package_path='output.json')
def test_compile_pipeline_with_dsl_graph_component_should_raise_error(self):
with self.assertRaisesRegex(

View File

@ -127,6 +127,9 @@ class CompilerCliTests(unittest.TestCase):
def test_pipeline_with_metrics_outputs(self):
self._test_compile_py_to_json('pipeline_with_metrics_outputs')
def test_pipeline_with_exit_handler(self):
self._test_compile_py_to_json('pipeline_with_exit_handler')
if __name__ == '__main__':
unittest.main()

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,47 @@
# Copyright 2021 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pipeline using ExitHandler."""
from kfp.v2 import dsl
from kfp.v2 import compiler
from kfp.v2.dsl import component
@component
def print_op(message: str):
"""Prints a message."""
print(message)
@component
def fail_op(message: str):
"""Fails."""
import sys
print(message)
sys.exit(1)
@dsl.pipeline(name='pipeline-with-exit-handler', pipeline_root='dummy_root')
def my_pipeline():
exit_task = print_op('Exit handler has worked!')
with dsl.ExitHandler(exit_task):
print_op('Hello World!')
fail_op('Task failed.')
if __name__ == '__main__':
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=__file__.replace('.py', '.json'))

View File

@ -29,6 +29,6 @@ click>=7.1.1,<8
# kfp.v2
absl-py>=0.9,<=0.11
kfp-pipeline-spec>=0.1.7,<0.2.0
kfp-pipeline-spec>=0.1.8,<0.2.0
fire>=0.3.1,<1
google-api-python-client>=1.7.8,<2

View File

@ -64,7 +64,7 @@ idna==2.10
# via requests
jsonschema==3.2.0
# via -r requirements.in
kfp-pipeline-spec==0.1.7
kfp-pipeline-spec==0.1.8
# via -r requirements.in
kfp-server-api==1.3.0
# via -r requirements.in

View File

@ -46,7 +46,7 @@ REQUIRES = [
'Deprecated>=1.2.7,<2',
'strip-hints>=0.1.8,<1',
'docstring-parser>=0.7.3,<1',
'kfp-pipeline-spec>=0.1.7,<0.2.0',
'kfp-pipeline-spec>=0.1.8,<0.2.0',
'fire>=0.3.1,<1',
'protobuf>=3.13.0,<4'
]