support tolerations for ContainerOps (#1269)

* add tolerations to ContainerOps

* add test

* add type for tolerations

* remove fix

* remove print
This commit is contained in:
Hamed 2019-05-10 00:37:59 +01:00 committed by Kubernetes Prow Robot
parent c4c2d166fe
commit ce6066136d
4 changed files with 118 additions and 1 deletions

View File

@ -237,6 +237,10 @@ def _op_to_template(op: BaseOp):
if processed_op.node_selector: if processed_op.node_selector:
template['nodeSelector'] = processed_op.node_selector template['nodeSelector'] = processed_op.node_selector
# tolerations
if processed_op.tolerations:
template['tolerations'] = processed_op.tolerations
# metadata # metadata
if processed_op.pod_annotations or processed_op.pod_labels: if processed_op.pod_annotations or processed_op.pod_labels:
template['metadata'] = {} template['metadata'] = {}

View File

@ -15,6 +15,8 @@
import re import re
import warnings import warnings
from typing import Any, Dict, List, TypeVar, Union, Callable, Optional, Sequence from typing import Any, Dict, List, TypeVar, Union, Callable, Optional, Sequence
from kubernetes.client import V1Toleration
from kubernetes.client.models import ( from kubernetes.client.models import (
V1Container, V1EnvVar, V1EnvFromSource, V1SecurityContext, V1Probe, V1Container, V1EnvVar, V1EnvFromSource, V1SecurityContext, V1Probe,
V1ResourceRequirements, V1VolumeDevice, V1VolumeMount, V1ContainerPort, V1ResourceRequirements, V1VolumeDevice, V1VolumeMount, V1ContainerPort,
@ -644,7 +646,7 @@ class BaseOp(object):
# in the compilation process to generate the DAGs and task io parameters. # in the compilation process to generate the DAGs and task io parameters.
attrs_with_pipelineparams = [ attrs_with_pipelineparams = [
'node_selector', 'volumes', 'pod_annotations', 'pod_labels', 'node_selector', 'volumes', 'pod_annotations', 'pod_labels',
'num_retries', 'sidecars' 'num_retries', 'sidecars', 'tolerations'
] ]
def __init__(self, def __init__(self,
@ -680,6 +682,7 @@ class BaseOp(object):
# `io.argoproj.workflow.v1alpha1.Template` properties # `io.argoproj.workflow.v1alpha1.Template` properties
self.node_selector = {} self.node_selector = {}
self.volumes = [] self.volumes = []
self.tolerations = []
self.pod_annotations = {} self.pod_annotations = {}
self.pod_labels = {} self.pod_labels = {}
self.num_retries = 0 self.num_retries = 0
@ -745,6 +748,17 @@ class BaseOp(object):
self.volumes.append(volume) self.volumes.append(volume)
return self return self
def add_toleration(self, tolerations: V1Toleration):
"""Add K8s tolerations
Args:
volume: Kubernetes toleration
For detailed spec, check toleration definition
https://github.com/kubernetes-client/python/blob/master/kubernetes/client/models/v1_toleration.py
"""
self.tolerations.append(tolerations)
return self
def add_node_selector_constraint(self, label_name, value): def add_node_selector_constraint(self, label_name, value):
"""Add a constraint for nodeSelector. Each constraint is a key-value pair label. For the """Add a constraint for nodeSelector. Each constraint is a key-value pair label. For the
container to be eligible to run on a node, the node must have each of the constraints appeared container to be eligible to run on a node, the node must have each of the constraints appeared

View File

@ -27,6 +27,8 @@ import yaml
from kfp.dsl._component import component from kfp.dsl._component import component
from kfp.dsl import ContainerOp, pipeline from kfp.dsl import ContainerOp, pipeline
from kfp.dsl.types import Integer, InconsistentTypeException from kfp.dsl.types import Integer, InconsistentTypeException
from kubernetes.client import V1Toleration
class TestCompiler(unittest.TestCase): class TestCompiler(unittest.TestCase):
@ -458,3 +460,31 @@ class TestCompiler(unittest.TestCase):
task2 = op().after(task1) task2 = op().after(task1)
compiler.Compiler()._compile(pipeline) compiler.Compiler()._compile(pipeline)
def _test_op_to_template_yaml(self, ops, file_base_name):
test_data_dir = os.path.join(os.path.dirname(__file__), 'testdata')
target_yaml = os.path.join(test_data_dir, file_base_name + '.yaml')
with open(target_yaml, 'r') as f:
expected = yaml.safe_load(f)['spec']['templates'][0]
compiled_template = compiler.Compiler()._op_to_template(ops)
del compiled_template['name'], expected['name']
del compiled_template['outputs']['parameters'][0]['name'], expected['outputs']['parameters'][0]['name']
assert compiled_template == expected
def test_tolerations(self):
"""Test a pipeline with a tolerations."""
op1 = dsl.ContainerOp(
name='download',
image='busybox',
command=['sh', '-c'],
arguments=['sleep 10; wget localhost:5678 -O /tmp/results.txt'],
file_outputs={'downloaded': '/tmp/results.txt'}) \
.add_toleration(V1Toleration(
effect='NoSchedule',
key='gpu',
operator='Equal',
value='run'))
self._test_op_to_template_yaml(op1, file_base_name='tolerations')

View File

@ -0,0 +1,69 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
kind: Workflow
metadata:
generateName: tolerations-
apiVersion: argoproj.io/v1alpha1
spec:
arguments:
parameters: []
templates:
- outputs:
artifacts:
- name: mlpipeline-ui-metadata
path: "/mlpipeline-ui-metadata.json"
optional: true
s3:
endpoint: minio-service.kubeflow:9000
secretKeySecret:
name: mlpipeline-minio-artifact
key: secretkey
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-ui-metadata.tgz
bucket: mlpipeline
accessKeySecret:
name: mlpipeline-minio-artifact
key: accesskey
insecure: true
- name: mlpipeline-metrics
path: "/mlpipeline-metrics.json"
optional: true
s3:
endpoint: minio-service.kubeflow:9000
secretKeySecret:
name: mlpipeline-minio-artifact
key: secretkey
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-metrics.tgz
bucket: mlpipeline
accessKeySecret:
name: mlpipeline-minio-artifact
key: accesskey
insecure: true
parameters:
- name: download-downloaded
valueFrom:
path: "/tmp/results.txt"
name: download
container:
image: busybox
args:
- sleep 10; wget localhost:5678 -O /tmp/results.txt
command:
- sh
- "-c"
tolerations:
- effect: NoSchedule
key: gpu
operator: Equal
value: run
serviceAccountName: pipeline-runner