* Add parallelism limits to pipeline in kfp sdk * fix lint error
This commit is contained in:
parent
5ff7a65a0c
commit
9167da1b4e
|
|
@ -0,0 +1,40 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import kfp
|
||||
from kfp import dsl
|
||||
|
||||
|
||||
def print_op(msg):
|
||||
"""Print a message."""
|
||||
return dsl.ContainerOp(
|
||||
name='Print',
|
||||
image='alpine:3.6',
|
||||
command=['echo', msg],
|
||||
)
|
||||
|
||||
|
||||
@dsl.pipeline(
|
||||
name='Pipeline service account',
|
||||
description='The pipeline shows how to set the max number of parallel pods in a pipeline.'
|
||||
)
|
||||
def pipeline_parallelism():
|
||||
op1 = print_op('hey, what are you up to?')
|
||||
op2 = print_op('train my model.')
|
||||
dsl.get_pipeline_conf().set_parallelism(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
kfp.compiler.Compiler().compile(pipeline_parallelism, __file__ + '.yaml')
|
||||
|
|
@ -670,9 +670,13 @@ class Compiler(object):
|
|||
'entrypoint': pipeline_template_name,
|
||||
'templates': templates,
|
||||
'arguments': {'parameters': input_params},
|
||||
'serviceAccountName': 'pipeline-runner'
|
||||
'serviceAccountName': 'pipeline-runner',
|
||||
}
|
||||
}
|
||||
# set parallelism limits at pipeline level
|
||||
if pipeline_conf.parallelism:
|
||||
workflow['spec']['parallelism'] = pipeline_conf.parallelism
|
||||
|
||||
# set ttl after workflow finishes
|
||||
if pipeline_conf.ttl_seconds_after_finished >= 0:
|
||||
workflow['spec']['ttlSecondsAfterFinished'] = pipeline_conf.ttl_seconds_after_finished
|
||||
|
|
|
|||
|
|
@ -62,6 +62,7 @@ class PipelineConf():
|
|||
self.ttl_seconds_after_finished = -1
|
||||
self.op_transformers = []
|
||||
self.image_pull_policy = None
|
||||
self.parallelism = None
|
||||
|
||||
def set_image_pull_secrets(self, image_pull_secrets):
|
||||
"""Configures the pipeline level imagepullsecret
|
||||
|
|
@ -83,6 +84,15 @@ class PipelineConf():
|
|||
self.timeout = seconds
|
||||
return self
|
||||
|
||||
def set_parallelism(self, max_num_pods: int):
|
||||
"""Configures the max number of total parallel pods that can execute at the same time in a workflow.
|
||||
|
||||
Args:
|
||||
max_num_pods (int): max number of total parallel pods.
|
||||
"""
|
||||
self.parallelism = max_num_pods
|
||||
return self
|
||||
|
||||
def set_ttl_seconds_after_finished(self, seconds: int):
|
||||
"""Configures the ttl after the pipeline has finished.
|
||||
|
||||
|
|
@ -96,7 +106,7 @@ class PipelineConf():
|
|||
"""Configures the default image pull policy
|
||||
|
||||
Args:
|
||||
policy: the pull policy, has to be one of: Always, Never, IfNotPresent.
|
||||
policy: the pull policy, has to be one of: Always, Never, IfNotPresent.
|
||||
For more info: https://github.com/kubernetes-client/python/blob/10a7f95435c0b94a6d949ba98375f8cc85a70e5a/kubernetes/docs/V1Container.md
|
||||
"""
|
||||
self.image_pull_policy = policy
|
||||
|
|
|
|||
|
|
@ -152,7 +152,7 @@ class TestCompiler(unittest.TestCase):
|
|||
self.maxDiff = None
|
||||
self.assertEqual(golden_output, compiler._op_to_template._op_to_template(op))
|
||||
self.assertEqual(res_output, compiler._op_to_template._op_to_template(res))
|
||||
|
||||
|
||||
kfp.compiler.Compiler()._compile(my_pipeline)
|
||||
|
||||
def _get_yaml_from_zip(self, zip_file):
|
||||
|
|
@ -325,7 +325,7 @@ class TestCompiler(unittest.TestCase):
|
|||
|
||||
with open(os.path.join(test_data_dir, target_yaml), 'r') as f:
|
||||
compiled = yaml.safe_load(f)
|
||||
|
||||
|
||||
for workflow in golden, compiled:
|
||||
del workflow['metadata']
|
||||
|
||||
|
|
@ -640,6 +640,25 @@ implementation:
|
|||
template = workflow_dict['spec']['templates'][0]
|
||||
self.assertEqual(template['metadata']['annotations']['pipelines.kubeflow.org/task_display_name'], 'Custom name')
|
||||
|
||||
def test_set_parallelism(self):
|
||||
"""Test a pipeline with parallelism limits."""
|
||||
def some_op():
|
||||
return dsl.ContainerOp(
|
||||
name='sleep',
|
||||
image='busybox',
|
||||
command=['sleep 1'],
|
||||
)
|
||||
|
||||
@dsl.pipeline()
|
||||
def some_pipeline():
|
||||
some_op()
|
||||
some_op()
|
||||
some_op()
|
||||
dsl.get_pipeline_conf().set_parallelism(1)
|
||||
|
||||
workflow_dict = kfp.compiler.Compiler()._compile(some_pipeline)
|
||||
self.assertEqual(workflow_dict['spec']['parallelism'], 1)
|
||||
|
||||
def test_set_ttl_seconds_after_finished(self):
|
||||
"""Test a pipeline with ttl after finished."""
|
||||
def some_op():
|
||||
|
|
@ -817,11 +836,11 @@ implementation:
|
|||
name="foo-bar-cm",
|
||||
namespace="default"
|
||||
)
|
||||
)
|
||||
)
|
||||
# delete the config map in k8s
|
||||
dsl.ResourceOp(
|
||||
name="delete-config-map",
|
||||
action="delete",
|
||||
name="delete-config-map",
|
||||
action="delete",
|
||||
k8s_resource=config_map
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue