support tolerations for ContainerOps (#1269)

* add tolerations to ContainerOps

* add test

* add type for tolerations

* remove fix

* remove print
This commit is contained in:
Hamed 2019-05-10 00:37:59 +01:00 committed by Kubernetes Prow Robot
parent c4c2d166fe
commit ce6066136d
4 changed files with 118 additions and 1 deletions

View File

@ -237,6 +237,10 @@ def _op_to_template(op: BaseOp):
if processed_op.node_selector:
template['nodeSelector'] = processed_op.node_selector
# tolerations
if processed_op.tolerations:
template['tolerations'] = processed_op.tolerations
# metadata
if processed_op.pod_annotations or processed_op.pod_labels:
template['metadata'] = {}

View File

@ -15,6 +15,8 @@
import re
import warnings
from typing import Any, Dict, List, TypeVar, Union, Callable, Optional, Sequence
from kubernetes.client import V1Toleration
from kubernetes.client.models import (
V1Container, V1EnvVar, V1EnvFromSource, V1SecurityContext, V1Probe,
V1ResourceRequirements, V1VolumeDevice, V1VolumeMount, V1ContainerPort,
@ -644,7 +646,7 @@ class BaseOp(object):
# in the compilation process to generate the DAGs and task io parameters.
attrs_with_pipelineparams = [
'node_selector', 'volumes', 'pod_annotations', 'pod_labels',
'num_retries', 'sidecars'
'num_retries', 'sidecars', 'tolerations'
]
def __init__(self,
@ -680,6 +682,7 @@ class BaseOp(object):
# `io.argoproj.workflow.v1alpha1.Template` properties
self.node_selector = {}
self.volumes = []
self.tolerations = []
self.pod_annotations = {}
self.pod_labels = {}
self.num_retries = 0
@ -745,6 +748,17 @@ class BaseOp(object):
self.volumes.append(volume)
return self
def add_toleration(self, tolerations: V1Toleration):
"""Add K8s tolerations
Args:
volume: Kubernetes toleration
For detailed spec, check toleration definition
https://github.com/kubernetes-client/python/blob/master/kubernetes/client/models/v1_toleration.py
"""
self.tolerations.append(tolerations)
return self
def add_node_selector_constraint(self, label_name, value):
"""Add a constraint for nodeSelector. Each constraint is a key-value pair label. For the
container to be eligible to run on a node, the node must have each of the constraints appeared

View File

@ -27,6 +27,8 @@ import yaml
from kfp.dsl._component import component
from kfp.dsl import ContainerOp, pipeline
from kfp.dsl.types import Integer, InconsistentTypeException
from kubernetes.client import V1Toleration
class TestCompiler(unittest.TestCase):
@ -458,3 +460,31 @@ class TestCompiler(unittest.TestCase):
task2 = op().after(task1)
compiler.Compiler()._compile(pipeline)
def _test_op_to_template_yaml(self, ops, file_base_name):
test_data_dir = os.path.join(os.path.dirname(__file__), 'testdata')
target_yaml = os.path.join(test_data_dir, file_base_name + '.yaml')
with open(target_yaml, 'r') as f:
expected = yaml.safe_load(f)['spec']['templates'][0]
compiled_template = compiler.Compiler()._op_to_template(ops)
del compiled_template['name'], expected['name']
del compiled_template['outputs']['parameters'][0]['name'], expected['outputs']['parameters'][0]['name']
assert compiled_template == expected
def test_tolerations(self):
"""Test a pipeline with a tolerations."""
op1 = dsl.ContainerOp(
name='download',
image='busybox',
command=['sh', '-c'],
arguments=['sleep 10; wget localhost:5678 -O /tmp/results.txt'],
file_outputs={'downloaded': '/tmp/results.txt'}) \
.add_toleration(V1Toleration(
effect='NoSchedule',
key='gpu',
operator='Equal',
value='run'))
self._test_op_to_template_yaml(op1, file_base_name='tolerations')

View File

@ -0,0 +1,69 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
kind: Workflow
metadata:
generateName: tolerations-
apiVersion: argoproj.io/v1alpha1
spec:
arguments:
parameters: []
templates:
- outputs:
artifacts:
- name: mlpipeline-ui-metadata
path: "/mlpipeline-ui-metadata.json"
optional: true
s3:
endpoint: minio-service.kubeflow:9000
secretKeySecret:
name: mlpipeline-minio-artifact
key: secretkey
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-ui-metadata.tgz
bucket: mlpipeline
accessKeySecret:
name: mlpipeline-minio-artifact
key: accesskey
insecure: true
- name: mlpipeline-metrics
path: "/mlpipeline-metrics.json"
optional: true
s3:
endpoint: minio-service.kubeflow:9000
secretKeySecret:
name: mlpipeline-minio-artifact
key: secretkey
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-metrics.tgz
bucket: mlpipeline
accessKeySecret:
name: mlpipeline-minio-artifact
key: accesskey
insecure: true
parameters:
- name: download-downloaded
valueFrom:
path: "/tmp/results.txt"
name: download
container:
image: busybox
args:
- sleep 10; wget localhost:5678 -O /tmp/results.txt
command:
- sh
- "-c"
tolerations:
- effect: NoSchedule
key: gpu
operator: Equal
value: run
serviceAccountName: pipeline-runner