support tolerations for ContainerOps (#1269)
* add tolerations to ContainerOps * add test * add type for tolerations * remove fix * remove print
This commit is contained in:
parent
c4c2d166fe
commit
ce6066136d
|
|
@ -237,6 +237,10 @@ def _op_to_template(op: BaseOp):
|
|||
if processed_op.node_selector:
|
||||
template['nodeSelector'] = processed_op.node_selector
|
||||
|
||||
# tolerations
|
||||
if processed_op.tolerations:
|
||||
template['tolerations'] = processed_op.tolerations
|
||||
|
||||
# metadata
|
||||
if processed_op.pod_annotations or processed_op.pod_labels:
|
||||
template['metadata'] = {}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@
|
|||
import re
|
||||
import warnings
|
||||
from typing import Any, Dict, List, TypeVar, Union, Callable, Optional, Sequence
|
||||
|
||||
from kubernetes.client import V1Toleration
|
||||
from kubernetes.client.models import (
|
||||
V1Container, V1EnvVar, V1EnvFromSource, V1SecurityContext, V1Probe,
|
||||
V1ResourceRequirements, V1VolumeDevice, V1VolumeMount, V1ContainerPort,
|
||||
|
|
@ -644,7 +646,7 @@ class BaseOp(object):
|
|||
# in the compilation process to generate the DAGs and task io parameters.
|
||||
attrs_with_pipelineparams = [
|
||||
'node_selector', 'volumes', 'pod_annotations', 'pod_labels',
|
||||
'num_retries', 'sidecars'
|
||||
'num_retries', 'sidecars', 'tolerations'
|
||||
]
|
||||
|
||||
def __init__(self,
|
||||
|
|
@ -680,6 +682,7 @@ class BaseOp(object):
|
|||
# `io.argoproj.workflow.v1alpha1.Template` properties
|
||||
self.node_selector = {}
|
||||
self.volumes = []
|
||||
self.tolerations = []
|
||||
self.pod_annotations = {}
|
||||
self.pod_labels = {}
|
||||
self.num_retries = 0
|
||||
|
|
@ -745,6 +748,17 @@ class BaseOp(object):
|
|||
self.volumes.append(volume)
|
||||
return self
|
||||
|
||||
def add_toleration(self, tolerations: V1Toleration):
|
||||
"""Add K8s tolerations
|
||||
|
||||
Args:
|
||||
volume: Kubernetes toleration
|
||||
For detailed spec, check toleration definition
|
||||
https://github.com/kubernetes-client/python/blob/master/kubernetes/client/models/v1_toleration.py
|
||||
"""
|
||||
self.tolerations.append(tolerations)
|
||||
return self
|
||||
|
||||
def add_node_selector_constraint(self, label_name, value):
|
||||
"""Add a constraint for nodeSelector. Each constraint is a key-value pair label. For the
|
||||
container to be eligible to run on a node, the node must have each of the constraints appeared
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ import yaml
|
|||
from kfp.dsl._component import component
|
||||
from kfp.dsl import ContainerOp, pipeline
|
||||
from kfp.dsl.types import Integer, InconsistentTypeException
|
||||
from kubernetes.client import V1Toleration
|
||||
|
||||
|
||||
class TestCompiler(unittest.TestCase):
|
||||
|
||||
|
|
@ -458,3 +460,31 @@ class TestCompiler(unittest.TestCase):
|
|||
task2 = op().after(task1)
|
||||
|
||||
compiler.Compiler()._compile(pipeline)
|
||||
|
||||
def _test_op_to_template_yaml(self, ops, file_base_name):
|
||||
test_data_dir = os.path.join(os.path.dirname(__file__), 'testdata')
|
||||
target_yaml = os.path.join(test_data_dir, file_base_name + '.yaml')
|
||||
with open(target_yaml, 'r') as f:
|
||||
expected = yaml.safe_load(f)['spec']['templates'][0]
|
||||
|
||||
compiled_template = compiler.Compiler()._op_to_template(ops)
|
||||
|
||||
del compiled_template['name'], expected['name']
|
||||
del compiled_template['outputs']['parameters'][0]['name'], expected['outputs']['parameters'][0]['name']
|
||||
assert compiled_template == expected
|
||||
|
||||
def test_tolerations(self):
|
||||
"""Test a pipeline with a tolerations."""
|
||||
op1 = dsl.ContainerOp(
|
||||
name='download',
|
||||
image='busybox',
|
||||
command=['sh', '-c'],
|
||||
arguments=['sleep 10; wget localhost:5678 -O /tmp/results.txt'],
|
||||
file_outputs={'downloaded': '/tmp/results.txt'}) \
|
||||
.add_toleration(V1Toleration(
|
||||
effect='NoSchedule',
|
||||
key='gpu',
|
||||
operator='Equal',
|
||||
value='run'))
|
||||
|
||||
self._test_op_to_template_yaml(op1, file_base_name='tolerations')
|
||||
|
|
|
|||
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
kind: Workflow
|
||||
metadata:
|
||||
generateName: tolerations-
|
||||
apiVersion: argoproj.io/v1alpha1
|
||||
spec:
|
||||
arguments:
|
||||
parameters: []
|
||||
templates:
|
||||
- outputs:
|
||||
artifacts:
|
||||
- name: mlpipeline-ui-metadata
|
||||
path: "/mlpipeline-ui-metadata.json"
|
||||
optional: true
|
||||
s3:
|
||||
endpoint: minio-service.kubeflow:9000
|
||||
secretKeySecret:
|
||||
name: mlpipeline-minio-artifact
|
||||
key: secretkey
|
||||
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-ui-metadata.tgz
|
||||
bucket: mlpipeline
|
||||
accessKeySecret:
|
||||
name: mlpipeline-minio-artifact
|
||||
key: accesskey
|
||||
insecure: true
|
||||
- name: mlpipeline-metrics
|
||||
path: "/mlpipeline-metrics.json"
|
||||
optional: true
|
||||
s3:
|
||||
endpoint: minio-service.kubeflow:9000
|
||||
secretKeySecret:
|
||||
name: mlpipeline-minio-artifact
|
||||
key: secretkey
|
||||
key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-metrics.tgz
|
||||
bucket: mlpipeline
|
||||
accessKeySecret:
|
||||
name: mlpipeline-minio-artifact
|
||||
key: accesskey
|
||||
insecure: true
|
||||
parameters:
|
||||
- name: download-downloaded
|
||||
valueFrom:
|
||||
path: "/tmp/results.txt"
|
||||
name: download
|
||||
container:
|
||||
image: busybox
|
||||
args:
|
||||
- sleep 10; wget localhost:5678 -O /tmp/results.txt
|
||||
command:
|
||||
- sh
|
||||
- "-c"
|
||||
tolerations:
|
||||
- effect: NoSchedule
|
||||
key: gpu
|
||||
operator: Equal
|
||||
value: run
|
||||
serviceAccountName: pipeline-runner
|
||||
Loading…
Reference in New Issue