allow specifying gpus explicitly (#16)

* allow specifying gpu resources explicitly; also no longer allocate gpus to the launcher

* small fixes

* address review comments
This commit is contained in:
Rong Ou 2018-06-14 09:49:28 -07:00 committed by k8s-ci-robot
parent 0e3119bae6
commit a3487b2208
6 changed files with 176 additions and 69 deletions

View File

@ -0,0 +1,16 @@
# This file shows how to run multi-node training benchmarks using an MPIJob,
# specifying GPUs explicitly per replica.
apiVersion: kubeflow.org/v1alpha1
kind: MPIJob
metadata:
name: tensorflow-benchmarks-16-custom
spec:
replicas: 4
template:
spec:
containers:
- image: mpioperator/tensorflow-benchmarks:latest
name: tensorflow-benchmarks
resources:
limits:
nvidia.com/gpu: 4

View File

@ -1,4 +1,15 @@
# This file shows how to run multi-node training benchmarks using an MPIJob. # This file shows how to run multi-node training benchmarks using an MPIJob,
# letting the operator decide on how best to allocate GPUs.
#
# In this mode, the operator assumes all nodes have the same number of GPUs.
# If `gpus` is bigger than the number of GPUs per node, then only whole nodes
# can be allocated.
#
# For example, if each node has 8 GPUs, the valid `gpus` values are:
# 1, 2, 4, 8, 16, 24, 32, ...or any multiple of 8.
#
# If you need more flexibility in allocating GPUs, you can use the alternative
# mode to specify `replicas` and GPU resource limit explicitly.
apiVersion: kubeflow.org/v1alpha1 apiVersion: kubeflow.org/v1alpha1
kind: MPIJob kind: MPIJob
metadata: metadata:

View File

@ -39,9 +39,16 @@ type MPIJobList struct {
type MPIJobSpec struct { type MPIJobSpec struct {
// Specifies the desired number of GPUs the MPIJob should run on. // Specifies the desired number of GPUs the MPIJob should run on.
// Mutually exclusive with the `Replicas` field.
// +optional // +optional
GPUs *int32 `json:"gpus,omitempty"` GPUs *int32 `json:"gpus,omitempty"`
// Specifies the desired number of replicas the MPIJob should run on.
// The `PodSpec` should specify the number of GPUs.
// Mutually exclusive with the `GPUs` field.
// +optional
Replicas *int32 `json:"replicas,omitempty"`
// Describes the pod that will be created when executing an MPIJob. // Describes the pod that will be created when executing an MPIJob.
Template corev1.PodTemplateSpec `json:"template,omitempty"` Template corev1.PodTemplateSpec `json:"template,omitempty"`
} }

View File

@ -95,6 +95,15 @@ func (in *MPIJobSpec) DeepCopyInto(out *MPIJobSpec) {
**out = **in **out = **in
} }
} }
if in.Replicas != nil {
in, out := &in.Replicas, &out.Replicas
if *in == nil {
*out = nil
} else {
*out = new(int32)
**out = **in
}
}
in.Template.DeepCopyInto(&out.Template) in.Template.DeepCopyInto(&out.Template)
return return
} }

View File

@ -398,11 +398,10 @@ func (c *MPIJobController) syncHandler(key string) error {
// We're done if the launcher either succeeded or failed. // We're done if the launcher either succeeded or failed.
done := launcher != nil && (launcher.Status.Succeeded == 1 || launcher.Status.Failed == 1) done := launcher != nil && (launcher.Status.Succeeded == 1 || launcher.Status.Failed == 1)
totalGPUs := getTotalGPUs(mpiJob) workerReplicas, gpusPerWorker, err := allocateGPUs(mpiJob, c.gpusPerNode, done)
workerReplicas := c.getWorkerReplicas(totalGPUs, done) if err != nil {
gpusPerWorker := totalGPUs runtime.HandleError(err)
if totalGPUs > c.gpusPerNode { return nil
gpusPerWorker = c.gpusPerNode
} }
if !done { if !done {
@ -427,7 +426,7 @@ func (c *MPIJobController) syncHandler(key string) error {
} }
} }
worker, err := c.getOrCreateWorkerStatefulSet(mpiJob, workerReplicas) worker, err := c.getOrCreateWorkerStatefulSet(mpiJob, workerReplicas, gpusPerWorker)
if err != nil { if err != nil {
return err return err
} }
@ -435,11 +434,7 @@ func (c *MPIJobController) syncHandler(key string) error {
// If the worker is ready, start the launcher. // If the worker is ready, start the launcher.
workerReady := workerReplicas == 0 || int(worker.Status.ReadyReplicas) == workerReplicas workerReady := workerReplicas == 0 || int(worker.Status.ReadyReplicas) == workerReplicas
if workerReady && launcher == nil { if workerReady && launcher == nil {
launcherGPUs := totalGPUs launcher, err = c.kubeClient.BatchV1().Jobs(namespace).Create(newLauncher(mpiJob, c.kubectlDeliveryImage))
if launcherGPUs > c.gpusPerNode {
launcherGPUs = c.gpusPerNode
}
launcher, err = c.kubeClient.BatchV1().Jobs(namespace).Create(newLauncher(mpiJob, launcherGPUs, c.kubectlDeliveryImage))
if err != nil { if err != nil {
return err return err
} }
@ -480,13 +475,36 @@ func (c *MPIJobController) getLauncherJob(mpiJob *kubeflow.MPIJob) (*batchv1.Job
return launcher, nil return launcher, nil
} }
// getTotalGPUs gets the total number of desired GPUs. Defaults to 1 if not specified. // allocateGPUs allocates the worker replicas and GPUs per worker.
func getTotalGPUs(mpiJob *kubeflow.MPIJob) int { func allocateGPUs(mpiJob *kubeflow.MPIJob, gpusPerNode int, done bool) (workerReplicas int, gpusPerWorker int, err error) {
totalGPUs := 1 workerReplicas = 0
gpusPerWorker = 0
err = nil
if mpiJob.Spec.GPUs != nil { if mpiJob.Spec.GPUs != nil {
totalGPUs = int(*mpiJob.Spec.GPUs) totalGPUs := int(*mpiJob.Spec.GPUs)
if totalGPUs < gpusPerNode {
workerReplicas = 1
gpusPerWorker = totalGPUs
} else if totalGPUs % gpusPerNode == 0 {
workerReplicas = totalGPUs / gpusPerNode
gpusPerWorker = gpusPerNode
} else {
err = fmt.Errorf("specified #GPUs is not a multiple of GPUs per node (%d)", gpusPerNode)
}
} else if mpiJob.Spec.Replicas != nil {
workerReplicas = int(*mpiJob.Spec.Replicas)
container := mpiJob.Spec.Template.Spec.Containers[0]
if container.Resources.Limits != nil {
if val, ok := container.Resources.Limits[gpuResourceName]; ok {
gpus, _ := val.AsInt64()
gpusPerWorker = int(gpus)
}
}
} }
return totalGPUs if done {
workerReplicas = 0
}
return workerReplicas, gpusPerWorker, err
} }
// getWorkerReplicas gets the desired number of worker replicas. // getWorkerReplicas gets the desired number of worker replicas.
@ -603,11 +621,11 @@ func (c *MPIJobController) getLauncherRoleBinding(mpiJob *kubeflow.MPIJob) (*rba
// getOrCreateWorkerStatefulSet gets the worker StatefulSet controlled by this // getOrCreateWorkerStatefulSet gets the worker StatefulSet controlled by this
// MPIJob, or creates one if it doesn't exist. // MPIJob, or creates one if it doesn't exist.
func (c *MPIJobController) getOrCreateWorkerStatefulSet(mpiJob *kubeflow.MPIJob, workerReplicas int) (*appsv1.StatefulSet, error) { func (c *MPIJobController) getOrCreateWorkerStatefulSet(mpiJob *kubeflow.MPIJob, workerReplicas int, gpusPerWorker int) (*appsv1.StatefulSet, error) {
worker, err := c.statefulSetLister.StatefulSets(mpiJob.Namespace).Get(mpiJob.Name + workerSuffix) worker, err := c.statefulSetLister.StatefulSets(mpiJob.Namespace).Get(mpiJob.Name + workerSuffix)
// If the StatefulSet doesn't exist, we'll create it. // If the StatefulSet doesn't exist, we'll create it.
if errors.IsNotFound(err) && workerReplicas > 0 { if errors.IsNotFound(err) && workerReplicas > 0 {
worker, err = c.kubeClient.AppsV1().StatefulSets(mpiJob.Namespace).Create(newWorker(mpiJob, int32(workerReplicas), c.gpusPerNode)) worker, err = c.kubeClient.AppsV1().StatefulSets(mpiJob.Namespace).Create(newWorker(mpiJob, int32(workerReplicas), gpusPerWorker))
} }
// If an error occurs during Get/Create, we'll requeue the item so we // If an error occurs during Get/Create, we'll requeue the item so we
// can attempt processing again later. This could have been caused by a // can attempt processing again later. This could have been caused by a
@ -626,7 +644,7 @@ func (c *MPIJobController) getOrCreateWorkerStatefulSet(mpiJob *kubeflow.MPIJob,
// If the worker is out of date, update the worker. // If the worker is out of date, update the worker.
if worker != nil && int(*worker.Spec.Replicas) != workerReplicas { if worker != nil && int(*worker.Spec.Replicas) != workerReplicas {
worker, err = c.kubeClient.AppsV1().StatefulSets(mpiJob.Namespace).Update(newWorker(mpiJob, int32(workerReplicas), c.gpusPerNode)) worker, err = c.kubeClient.AppsV1().StatefulSets(mpiJob.Namespace).Update(newWorker(mpiJob, int32(workerReplicas), gpusPerWorker))
// If an error occurs during Update, we'll requeue the item so we can // If an error occurs during Update, we'll requeue the item so we can
// attempt processing again later. This could have been caused by a // attempt processing again later. This could have been caused by a
// temporary network failure, or any other transient reason. // temporary network failure, or any other transient reason.
@ -727,10 +745,14 @@ shift
%s/kubectl exec ${POD_NAME} -- /bin/sh -c "$*" %s/kubectl exec ${POD_NAME} -- /bin/sh -c "$*"
`, kubectlMountPath) `, kubectlMountPath)
// If no GPU is specified, default to 1 slot.
slots := 1
if gpusPerWorker > 0 {
slots = gpusPerWorker
}
var buffer bytes.Buffer var buffer bytes.Buffer
buffer.WriteString(fmt.Sprintf("localhost slots=%d max_slots=%d\n", gpusPerWorker, gpusPerWorker))
for i := 0; i < workerReplicas; i++ { for i := 0; i < workerReplicas; i++ {
buffer.WriteString(fmt.Sprintf("%s%s-%d slots=%d max_slots=%d\n", mpiJob.Name, workerSuffix, i, gpusPerWorker, gpusPerWorker)) buffer.WriteString(fmt.Sprintf("%s%s-%d slots=%d max_slots=%d\n", mpiJob.Name, workerSuffix, i, slots, slots))
} }
return &corev1.ConfigMap{ return &corev1.ConfigMap{
@ -903,7 +925,7 @@ func newWorker(mpiJob *kubeflow.MPIJob, desiredReplicas int32, gpus int) *appsv1
// newLauncher creates a new launcher Job for an MPIJob resource. It also sets // newLauncher creates a new launcher Job for an MPIJob resource. It also sets
// the appropriate OwnerReferences on the resource so handleObject can discover // the appropriate OwnerReferences on the resource so handleObject can discover
// the MPIJob resource that 'owns' it. // the MPIJob resource that 'owns' it.
func newLauncher(mpiJob *kubeflow.MPIJob, gpus int, kubectlDeliveryImage string) *batchv1.Job { func newLauncher(mpiJob *kubeflow.MPIJob, kubectlDeliveryImage string) *batchv1.Job {
launcherName := mpiJob.Name + launcherSuffix launcherName := mpiJob.Name + launcherSuffix
labels := map[string]string{ labels := map[string]string{
"app": launcherName, "app": launcherName,
@ -938,10 +960,9 @@ func newLauncher(mpiJob *kubeflow.MPIJob, gpus int, kubectlDeliveryImage string)
Name: "OMPI_MCA_orte_default_hostfile", Name: "OMPI_MCA_orte_default_hostfile",
Value: fmt.Sprintf("%s/%s", configMountPath, hostfileName), Value: fmt.Sprintf("%s/%s", configMountPath, hostfileName),
}) })
if container.Resources.Limits == nil { if container.Resources.Limits != nil {
container.Resources.Limits = make(corev1.ResourceList) delete(container.Resources.Limits, gpuResourceName)
} }
container.Resources.Limits[gpuResourceName] = *resource.NewQuantity(int64(gpus), resource.DecimalExponent)
container.VolumeMounts = append(container.VolumeMounts, container.VolumeMounts = append(container.VolumeMounts,
corev1.VolumeMount{ corev1.VolumeMount{
Name: kubectlVolumeName, Name: kubectlVolumeName,

View File

@ -36,6 +36,7 @@ import (
kubeflow "github.com/kubeflow/mpi-operator/pkg/apis/kubeflow/v1alpha1" kubeflow "github.com/kubeflow/mpi-operator/pkg/apis/kubeflow/v1alpha1"
"github.com/kubeflow/mpi-operator/pkg/client/clientset/versioned/fake" "github.com/kubeflow/mpi-operator/pkg/client/clientset/versioned/fake"
informers "github.com/kubeflow/mpi-operator/pkg/client/informers/externalversions" informers "github.com/kubeflow/mpi-operator/pkg/client/informers/externalversions"
"k8s.io/apimachinery/pkg/api/resource"
) )
var ( var (
@ -98,6 +99,34 @@ func newMPIJob(name string, gpus *int32) *kubeflow.MPIJob {
} }
} }
func newMPIJobWithCustomResources(name string, replicas *int32, gpusPerReplica int64) *kubeflow.MPIJob {
return &kubeflow.MPIJob{
TypeMeta: metav1.TypeMeta{APIVersion: kubeflow.SchemeGroupVersion.String()},
ObjectMeta: metav1.ObjectMeta{
Name: name,
Namespace: metav1.NamespaceDefault,
},
Spec: kubeflow.MPIJobSpec{
Replicas: replicas,
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{
Name: "foo",
Image: "bar",
Resources: corev1.ResourceRequirements{
Limits: corev1.ResourceList{
"nvidia.com/gpu": *resource.NewQuantity(gpusPerReplica, resource.DecimalExponent),
},
},
},
},
},
},
},
}
}
func (f *fixture) newController() (*MPIJobController, informers.SharedInformerFactory, kubeinformers.SharedInformerFactory) { func (f *fixture) newController() (*MPIJobController, informers.SharedInformerFactory, kubeinformers.SharedInformerFactory) {
f.client = fake.NewSimpleClientset(f.objects...) f.client = fake.NewSimpleClientset(f.objects...)
f.kubeClient = k8sfake.NewSimpleClientset(f.kubeObjects...) f.kubeClient = k8sfake.NewSimpleClientset(f.kubeObjects...)
@ -415,7 +444,7 @@ func TestLauncherNotControlledByUs(t *testing.T) {
mpiJob := newMPIJob("test", int32Ptr(64)) mpiJob := newMPIJob("test", int32Ptr(64))
f.setUpMPIJob(mpiJob) f.setUpMPIJob(mpiJob)
launcher := newLauncher(mpiJob, 64, "kubectl-delivery") launcher := newLauncher(mpiJob, "kubectl-delivery")
launcher.OwnerReferences = nil launcher.OwnerReferences = nil
f.setUpLauncher(launcher) f.setUpLauncher(launcher)
@ -428,7 +457,7 @@ func TestLauncherSucceeded(t *testing.T) {
mpiJob := newMPIJob("test", int32Ptr(64)) mpiJob := newMPIJob("test", int32Ptr(64))
f.setUpMPIJob(mpiJob) f.setUpMPIJob(mpiJob)
launcher := newLauncher(mpiJob, 64, "kubectl-delivery") launcher := newLauncher(mpiJob, "kubectl-delivery")
launcher.Status.Succeeded = 1 launcher.Status.Succeeded = 1
f.setUpLauncher(launcher) f.setUpLauncher(launcher)
@ -445,7 +474,7 @@ func TestLauncherFailed(t *testing.T) {
mpiJob := newMPIJob("test", int32Ptr(64)) mpiJob := newMPIJob("test", int32Ptr(64))
f.setUpMPIJob(mpiJob) f.setUpMPIJob(mpiJob)
launcher := newLauncher(mpiJob, 64, "kubectl-delivery") launcher := newLauncher(mpiJob, "kubectl-delivery")
launcher.Status.Failed = 1 launcher.Status.Failed = 1
f.setUpLauncher(launcher) f.setUpLauncher(launcher)
@ -462,19 +491,47 @@ func TestLauncherDoesNotExist(t *testing.T) {
mpiJob := newMPIJob("test", int32Ptr(64)) mpiJob := newMPIJob("test", int32Ptr(64))
f.setUpMPIJob(mpiJob) f.setUpMPIJob(mpiJob)
expConfigMap := newConfigMap(mpiJob, 7, 8) expConfigMap := newConfigMap(mpiJob, 8, 8)
f.expectCreateConfigMapAction(expConfigMap) f.expectCreateConfigMapAction(expConfigMap)
expServiceAccount := newLauncherServiceAccount(mpiJob) expServiceAccount := newLauncherServiceAccount(mpiJob)
f.expectCreateServiceAccountAction(expServiceAccount) f.expectCreateServiceAccountAction(expServiceAccount)
expRole := newLauncherRole(mpiJob, 7) expRole := newLauncherRole(mpiJob, 8)
f.expectCreateRoleAction(expRole) f.expectCreateRoleAction(expRole)
expRoleBinding := newLauncherRoleBinding(mpiJob) expRoleBinding := newLauncherRoleBinding(mpiJob)
f.expectCreateRoleBindingAction(expRoleBinding) f.expectCreateRoleBindingAction(expRoleBinding)
expWorker := newWorker(mpiJob, 7, 8) expWorker := newWorker(mpiJob, 8, 8)
f.expectCreateStatefulSetAction(expWorker)
mpiJobCopy := mpiJob.DeepCopy()
mpiJobCopy.Status.WorkerReplicas = 0
f.expectUpdateMPIJobStatusAction(mpiJobCopy)
f.run(getKey(mpiJob, t))
}
func TestLauncherDoesNotExistWithCustomResources(t *testing.T) {
f := newFixture(t)
mpiJob := newMPIJobWithCustomResources("test", int32Ptr(4), 4)
f.setUpMPIJob(mpiJob)
expConfigMap := newConfigMap(mpiJob, 4, 4)
f.expectCreateConfigMapAction(expConfigMap)
expServiceAccount := newLauncherServiceAccount(mpiJob)
f.expectCreateServiceAccountAction(expServiceAccount)
expRole := newLauncherRole(mpiJob, 4)
f.expectCreateRoleAction(expRole)
expRoleBinding := newLauncherRoleBinding(mpiJob)
f.expectCreateRoleBindingAction(expRoleBinding)
expWorker := newWorker(mpiJob, 4, 4)
f.expectCreateStatefulSetAction(expWorker) f.expectCreateStatefulSetAction(expWorker)
mpiJobCopy := mpiJob.DeepCopy() mpiJobCopy := mpiJob.DeepCopy()
@ -490,7 +547,7 @@ func TestConfigMapNotControlledByUs(t *testing.T) {
mpiJob := newMPIJob("test", int32Ptr(64)) mpiJob := newMPIJob("test", int32Ptr(64))
f.setUpMPIJob(mpiJob) f.setUpMPIJob(mpiJob)
configMap := newConfigMap(mpiJob, 7, 8) configMap := newConfigMap(mpiJob, 8, 8)
configMap.OwnerReferences = nil configMap.OwnerReferences = nil
f.setUpConfigMap(configMap) f.setUpConfigMap(configMap)
@ -503,7 +560,7 @@ func TestServiceAccountNotControlledByUs(t *testing.T) {
mpiJob := newMPIJob("test", int32Ptr(64)) mpiJob := newMPIJob("test", int32Ptr(64))
f.setUpMPIJob(mpiJob) f.setUpMPIJob(mpiJob)
f.setUpConfigMap(newConfigMap(mpiJob, 7, 8)) f.setUpConfigMap(newConfigMap(mpiJob, 8, 8))
serviceAccount := newLauncherServiceAccount(mpiJob) serviceAccount := newLauncherServiceAccount(mpiJob)
serviceAccount.OwnerReferences = nil serviceAccount.OwnerReferences = nil
@ -518,10 +575,10 @@ func TestRoleNotControlledByUs(t *testing.T) {
mpiJob := newMPIJob("test", int32Ptr(64)) mpiJob := newMPIJob("test", int32Ptr(64))
f.setUpMPIJob(mpiJob) f.setUpMPIJob(mpiJob)
f.setUpConfigMap(newConfigMap(mpiJob, 7, 8)) f.setUpConfigMap(newConfigMap(mpiJob, 8, 8))
f.setUpServiceAccount(newLauncherServiceAccount(mpiJob)) f.setUpServiceAccount(newLauncherServiceAccount(mpiJob))
role := newLauncherRole(mpiJob, 7) role := newLauncherRole(mpiJob, 8)
role.OwnerReferences = nil role.OwnerReferences = nil
f.setUpRole(role) f.setUpRole(role)
@ -534,9 +591,9 @@ func TestRoleBindingNotControlledByUs(t *testing.T) {
mpiJob := newMPIJob("test", int32Ptr(64)) mpiJob := newMPIJob("test", int32Ptr(64))
f.setUpMPIJob(mpiJob) f.setUpMPIJob(mpiJob)
f.setUpConfigMap(newConfigMap(mpiJob, 7, 8)) f.setUpConfigMap(newConfigMap(mpiJob, 8, 8))
f.setUpServiceAccount(newLauncherServiceAccount(mpiJob)) f.setUpServiceAccount(newLauncherServiceAccount(mpiJob))
f.setUpRole(newLauncherRole(mpiJob, 7)) f.setUpRole(newLauncherRole(mpiJob, 8))
roleBinding := newLauncherRoleBinding(mpiJob) roleBinding := newLauncherRoleBinding(mpiJob)
roleBinding.OwnerReferences = nil roleBinding.OwnerReferences = nil
@ -551,11 +608,11 @@ func TestShutdownWorker(t *testing.T) {
mpiJob := newMPIJob("test", int32Ptr(64)) mpiJob := newMPIJob("test", int32Ptr(64))
f.setUpMPIJob(mpiJob) f.setUpMPIJob(mpiJob)
launcher := newLauncher(mpiJob, 64, "kubectl-delivery") launcher := newLauncher(mpiJob, "kubectl-delivery")
launcher.Status.Succeeded = 1 launcher.Status.Succeeded = 1
f.setUpLauncher(launcher) f.setUpLauncher(launcher)
worker := newWorker(mpiJob, 7, 8) worker := newWorker(mpiJob, 8, 8)
f.setUpWorker(worker) f.setUpWorker(worker)
expWorker := newWorker(mpiJob, 0, 8) expWorker := newWorker(mpiJob, 0, 8)
@ -575,46 +632,32 @@ func TestWorkerNotControlledByUs(t *testing.T) {
mpiJob := newMPIJob("test", int32Ptr(64)) mpiJob := newMPIJob("test", int32Ptr(64))
f.setUpMPIJob(mpiJob) f.setUpMPIJob(mpiJob)
f.setUpConfigMap(newConfigMap(mpiJob, 7, 8)) f.setUpConfigMap(newConfigMap(mpiJob, 8, 8))
f.setUpRbac(mpiJob, 7) f.setUpRbac(mpiJob, 8)
worker := newWorker(mpiJob, 7, 8) worker := newWorker(mpiJob, 8, 8)
worker.OwnerReferences = nil worker.OwnerReferences = nil
f.setUpWorker(worker) f.setUpWorker(worker)
f.runExpectError(getKey(mpiJob, t)) f.runExpectError(getKey(mpiJob, t))
} }
func TestWorkerNotNeeded(t *testing.T) {
f := newFixture(t)
mpiJob := newMPIJob("test", int32Ptr(8))
f.setUpMPIJob(mpiJob)
f.setUpConfigMap(newConfigMap(mpiJob, 0, 8))
f.setUpRbac(mpiJob, 0)
expLauncher := newLauncher(mpiJob, 8, "kubectl-delivery")
f.expectCreateJobAction(expLauncher)
f.expectUpdateMPIJobStatusAction(mpiJob)
f.run(getKey(mpiJob, t))
}
func TestLauncherActive(t *testing.T) { func TestLauncherActive(t *testing.T) {
f := newFixture(t) f := newFixture(t)
mpiJob := newMPIJob("test", int32Ptr(8)) mpiJob := newMPIJob("test", int32Ptr(8))
f.setUpMPIJob(mpiJob) f.setUpMPIJob(mpiJob)
f.setUpConfigMap(newConfigMap(mpiJob, 0, 8)) f.setUpConfigMap(newConfigMap(mpiJob, 1, 8))
f.setUpRbac(mpiJob, 0) f.setUpRbac(mpiJob, 1)
launcher := newLauncher(mpiJob, 64, "kubectl-delivery") launcher := newLauncher(mpiJob, "kubectl-delivery")
launcher.Status.Active = 1 launcher.Status.Active = 1
f.setUpLauncher(launcher) f.setUpLauncher(launcher)
worker := newWorker(mpiJob, 1, 8)
f.setUpWorker(worker)
mpiJobCopy := mpiJob.DeepCopy() mpiJobCopy := mpiJob.DeepCopy()
mpiJobCopy.Status.LauncherStatus = kubeflow.LauncherActive mpiJobCopy.Status.LauncherStatus = kubeflow.LauncherActive
f.expectUpdateMPIJobStatusAction(mpiJobCopy) f.expectUpdateMPIJobStatusAction(mpiJobCopy)
@ -628,18 +671,18 @@ func TestWorkerReady(t *testing.T) {
mpiJob := newMPIJob("test", int32Ptr(16)) mpiJob := newMPIJob("test", int32Ptr(16))
f.setUpMPIJob(mpiJob) f.setUpMPIJob(mpiJob)
f.setUpConfigMap(newConfigMap(mpiJob, 1, 8)) f.setUpConfigMap(newConfigMap(mpiJob, 2, 8))
f.setUpRbac(mpiJob, 1) f.setUpRbac(mpiJob, 2)
worker := newWorker(mpiJob, 1, 8) worker := newWorker(mpiJob, 2, 8)
worker.Status.ReadyReplicas = 1 worker.Status.ReadyReplicas = 2
f.setUpWorker(worker) f.setUpWorker(worker)
expLauncher := newLauncher(mpiJob, 8, "kubectl-delivery") expLauncher := newLauncher(mpiJob, "kubectl-delivery")
f.expectCreateJobAction(expLauncher) f.expectCreateJobAction(expLauncher)
mpiJobCopy := mpiJob.DeepCopy() mpiJobCopy := mpiJob.DeepCopy()
mpiJobCopy.Status.WorkerReplicas = 1 mpiJobCopy.Status.WorkerReplicas = 2
f.expectUpdateMPIJobStatusAction(mpiJobCopy) f.expectUpdateMPIJobStatusAction(mpiJobCopy)
f.run(getKey(mpiJob, t)) f.run(getKey(mpiJob, t))