support tolerations for ContainerOps (#1269)
* add tolerations to ContainerOps * add test * add type for tolerations * remove fix * remove print
This commit is contained in:
parent
c4c2d166fe
commit
ce6066136d
|
|
@ -237,6 +237,10 @@ def _op_to_template(op: BaseOp):
|
||||||
if processed_op.node_selector:
|
if processed_op.node_selector:
|
||||||
template['nodeSelector'] = processed_op.node_selector
|
template['nodeSelector'] = processed_op.node_selector
|
||||||
|
|
||||||
|
# tolerations
|
||||||
|
if processed_op.tolerations:
|
||||||
|
template['tolerations'] = processed_op.tolerations
|
||||||
|
|
||||||
# metadata
|
# metadata
|
||||||
if processed_op.pod_annotations or processed_op.pod_labels:
|
if processed_op.pod_annotations or processed_op.pod_labels:
|
||||||
template['metadata'] = {}
|
template['metadata'] = {}
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,8 @@
|
||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Dict, List, TypeVar, Union, Callable, Optional, Sequence
|
from typing import Any, Dict, List, TypeVar, Union, Callable, Optional, Sequence
|
||||||
|
|
||||||
|
from kubernetes.client import V1Toleration
|
||||||
from kubernetes.client.models import (
|
from kubernetes.client.models import (
|
||||||
V1Container, V1EnvVar, V1EnvFromSource, V1SecurityContext, V1Probe,
|
V1Container, V1EnvVar, V1EnvFromSource, V1SecurityContext, V1Probe,
|
||||||
V1ResourceRequirements, V1VolumeDevice, V1VolumeMount, V1ContainerPort,
|
V1ResourceRequirements, V1VolumeDevice, V1VolumeMount, V1ContainerPort,
|
||||||
|
|
@ -644,7 +646,7 @@ class BaseOp(object):
|
||||||
# in the compilation process to generate the DAGs and task io parameters.
|
# in the compilation process to generate the DAGs and task io parameters.
|
||||||
attrs_with_pipelineparams = [
|
attrs_with_pipelineparams = [
|
||||||
'node_selector', 'volumes', 'pod_annotations', 'pod_labels',
|
'node_selector', 'volumes', 'pod_annotations', 'pod_labels',
|
||||||
'num_retries', 'sidecars'
|
'num_retries', 'sidecars', 'tolerations'
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
|
@ -680,6 +682,7 @@ class BaseOp(object):
|
||||||
# `io.argoproj.workflow.v1alpha1.Template` properties
|
# `io.argoproj.workflow.v1alpha1.Template` properties
|
||||||
self.node_selector = {}
|
self.node_selector = {}
|
||||||
self.volumes = []
|
self.volumes = []
|
||||||
|
self.tolerations = []
|
||||||
self.pod_annotations = {}
|
self.pod_annotations = {}
|
||||||
self.pod_labels = {}
|
self.pod_labels = {}
|
||||||
self.num_retries = 0
|
self.num_retries = 0
|
||||||
|
|
@ -745,6 +748,17 @@ class BaseOp(object):
|
||||||
self.volumes.append(volume)
|
self.volumes.append(volume)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def add_toleration(self, tolerations: V1Toleration):
|
||||||
|
"""Add K8s tolerations
|
||||||
|
|
||||||
|
Args:
|
||||||
|
volume: Kubernetes toleration
|
||||||
|
For detailed spec, check toleration definition
|
||||||
|
https://github.com/kubernetes-client/python/blob/master/kubernetes/client/models/v1_toleration.py
|
||||||
|
"""
|
||||||
|
self.tolerations.append(tolerations)
|
||||||
|
return self
|
||||||
|
|
||||||
def add_node_selector_constraint(self, label_name, value):
|
def add_node_selector_constraint(self, label_name, value):
|
||||||
"""Add a constraint for nodeSelector. Each constraint is a key-value pair label. For the
|
"""Add a constraint for nodeSelector. Each constraint is a key-value pair label. For the
|
||||||
container to be eligible to run on a node, the node must have each of the constraints appeared
|
container to be eligible to run on a node, the node must have each of the constraints appeared
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,8 @@ import yaml
|
||||||
from kfp.dsl._component import component
|
from kfp.dsl._component import component
|
||||||
from kfp.dsl import ContainerOp, pipeline
|
from kfp.dsl import ContainerOp, pipeline
|
||||||
from kfp.dsl.types import Integer, InconsistentTypeException
|
from kfp.dsl.types import Integer, InconsistentTypeException
|
||||||
|
from kubernetes.client import V1Toleration
|
||||||
|
|
||||||
|
|
||||||
class TestCompiler(unittest.TestCase):
|
class TestCompiler(unittest.TestCase):
|
||||||
|
|
||||||
|
|
@ -458,3 +460,31 @@ class TestCompiler(unittest.TestCase):
|
||||||
task2 = op().after(task1)
|
task2 = op().after(task1)
|
||||||
|
|
||||||
compiler.Compiler()._compile(pipeline)
|
compiler.Compiler()._compile(pipeline)
|
||||||
|
|
||||||
|
def _test_op_to_template_yaml(self, ops, file_base_name):
|
||||||
|
test_data_dir = os.path.join(os.path.dirname(__file__), 'testdata')
|
||||||
|
target_yaml = os.path.join(test_data_dir, file_base_name + '.yaml')
|
||||||
|
with open(target_yaml, 'r') as f:
|
||||||
|
expected = yaml.safe_load(f)['spec']['templates'][0]
|
||||||
|
|
||||||
|
compiled_template = compiler.Compiler()._op_to_template(ops)
|
||||||
|
|
||||||
|
del compiled_template['name'], expected['name']
|
||||||
|
del compiled_template['outputs']['parameters'][0]['name'], expected['outputs']['parameters'][0]['name']
|
||||||
|
assert compiled_template == expected
|
||||||
|
|
||||||
|
def test_tolerations(self):
|
||||||
|
"""Test a pipeline with a tolerations."""
|
||||||
|
op1 = dsl.ContainerOp(
|
||||||
|
name='download',
|
||||||
|
image='busybox',
|
||||||
|
command=['sh', '-c'],
|
||||||
|
arguments=['sleep 10; wget localhost:5678 -O /tmp/results.txt'],
|
||||||
|
file_outputs={'downloaded': '/tmp/results.txt'}) \
|
||||||
|
.add_toleration(V1Toleration(
|
||||||
|
effect='NoSchedule',
|
||||||
|
key='gpu',
|
||||||
|
operator='Equal',
|
||||||
|
value='run'))
|
||||||
|
|
||||||
|
self._test_op_to_template_yaml(op1, file_base_name='tolerations')
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
# Copyright 2019 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
kind: Workflow
|
||||||
|
metadata:
|
||||||
|
generateName: tolerations-
|
||||||
|
apiVersion: argoproj.io/v1alpha1
|
||||||
|
spec:
|
||||||
|
arguments:
|
||||||
|
parameters: []
|
||||||
|
templates:
|
||||||
|
- outputs:
|
||||||
|
artifacts:
|
||||||
|
- name: mlpipeline-ui-metadata
|
||||||
|
path: "/mlpipeline-ui-metadata.json"
|
||||||
|
optional: true
|
||||||
|
s3:
|
||||||
|
endpoint: minio-service.kubeflow:9000
|
||||||
|
secretKeySecret:
|
||||||
|
name: mlpipeline-minio-artifact
|
||||||
|
key: secretkey
|
||||||
|
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-ui-metadata.tgz
|
||||||
|
bucket: mlpipeline
|
||||||
|
accessKeySecret:
|
||||||
|
name: mlpipeline-minio-artifact
|
||||||
|
key: accesskey
|
||||||
|
insecure: true
|
||||||
|
- name: mlpipeline-metrics
|
||||||
|
path: "/mlpipeline-metrics.json"
|
||||||
|
optional: true
|
||||||
|
s3:
|
||||||
|
endpoint: minio-service.kubeflow:9000
|
||||||
|
secretKeySecret:
|
||||||
|
name: mlpipeline-minio-artifact
|
||||||
|
key: secretkey
|
||||||
|
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-metrics.tgz
|
||||||
|
bucket: mlpipeline
|
||||||
|
accessKeySecret:
|
||||||
|
name: mlpipeline-minio-artifact
|
||||||
|
key: accesskey
|
||||||
|
insecure: true
|
||||||
|
parameters:
|
||||||
|
- name: download-downloaded
|
||||||
|
valueFrom:
|
||||||
|
path: "/tmp/results.txt"
|
||||||
|
name: download
|
||||||
|
container:
|
||||||
|
image: busybox
|
||||||
|
args:
|
||||||
|
- sleep 10; wget localhost:5678 -O /tmp/results.txt
|
||||||
|
command:
|
||||||
|
- sh
|
||||||
|
- "-c"
|
||||||
|
tolerations:
|
||||||
|
- effect: NoSchedule
|
||||||
|
key: gpu
|
||||||
|
operator: Equal
|
||||||
|
value: run
|
||||||
|
serviceAccountName: pipeline-runner
|
||||||
Loading…
Reference in New Issue