trainer/pkg/runtime/core/trainingruntime.go

228 lines
8.4 KiB
Go

/*
Copyright 2024 The Kubeflow Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package core
import (
"context"
"errors"
"fmt"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/validation/field"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"
jobsetv1alpha2ac "sigs.k8s.io/jobset/client-go/applyconfiguration/jobset/v1alpha2"
trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
"github.com/kubeflow/trainer/pkg/apply"
"github.com/kubeflow/trainer/pkg/constants"
"github.com/kubeflow/trainer/pkg/runtime"
fwkcore "github.com/kubeflow/trainer/pkg/runtime/framework/core"
fwkplugins "github.com/kubeflow/trainer/pkg/runtime/framework/plugins"
idxer "github.com/kubeflow/trainer/pkg/runtime/indexer"
)
var (
errorNotFoundSpecifiedTrainingRuntime = errors.New("TrainingRuntime specified in TrainJob is not found")
)
type TrainingRuntime struct {
framework *fwkcore.Framework
client client.Client
}
var TrainingRuntimeGroupKind = schema.GroupKind{
Group: trainer.GroupVersion.Group,
Kind: trainer.TrainingRuntimeKind,
}.String()
var _ runtime.Runtime = (*TrainingRuntime)(nil)
var trainingRuntimeFactory *TrainingRuntime
func NewTrainingRuntime(ctx context.Context, c client.Client, indexer client.FieldIndexer) (runtime.Runtime, error) {
if err := indexer.IndexField(ctx, &trainer.TrainJob{}, idxer.TrainJobRuntimeRefKey, idxer.IndexTrainJobTrainingRuntime); err != nil {
return nil, fmt.Errorf("setting index on TrainingRuntime for TrainJob: %w", err)
}
if err := indexer.IndexField(ctx, &trainer.TrainJob{}, idxer.TrainJobClusterRuntimeRefKey, idxer.IndexTrainJobClusterTrainingRuntime); err != nil {
return nil, fmt.Errorf("setting index on ClusterTrainingRuntime for TrainJob: %w", err)
}
fwk, err := fwkcore.New(ctx, c, fwkplugins.NewRegistry(), indexer)
if err != nil {
return nil, err
}
trainingRuntimeFactory = &TrainingRuntime{
framework: fwk,
client: c,
}
return trainingRuntimeFactory, nil
}
func (r *TrainingRuntime) NewObjects(ctx context.Context, trainJob *trainer.TrainJob) ([]any, error) {
var trainingRuntime trainer.TrainingRuntime
err := r.client.Get(ctx, client.ObjectKey{Namespace: trainJob.Namespace, Name: trainJob.Spec.RuntimeRef.Name}, &trainingRuntime)
if err != nil {
return nil, fmt.Errorf("%w: %w", errorNotFoundSpecifiedTrainingRuntime, err)
}
return r.buildObjects(ctx, trainJob, trainingRuntime.Spec.Template, trainingRuntime.Spec.MLPolicy, trainingRuntime.Spec.PodGroupPolicy)
}
func (r *TrainingRuntime) buildObjects(
ctx context.Context, trainJob *trainer.TrainJob, jobSetTemplateSpec trainer.JobSetTemplateSpec, mlPolicy *trainer.MLPolicy, podGroupPolicy *trainer.PodGroupPolicy,
) ([]any, error) {
info, err := r.newRuntimeInfo(trainJob, jobSetTemplateSpec, mlPolicy, podGroupPolicy)
if err != nil {
return nil, err
}
if err = r.framework.RunEnforceMLPolicyPlugins(info, trainJob); err != nil {
return nil, err
}
if err = r.framework.RunEnforcePodGroupPolicyPlugins(info, trainJob); err != nil {
return nil, err
}
if err = r.framework.RunPodNetworkPlugins(info, trainJob); err != nil {
return nil, err
}
return r.framework.RunComponentBuilderPlugins(ctx, info, trainJob)
}
func (r *TrainingRuntime) newRuntimeInfo(
trainJob *trainer.TrainJob, jobSetTemplateSpec trainer.JobSetTemplateSpec, mlPolicy *trainer.MLPolicy, podGroupPolicy *trainer.PodGroupPolicy,
) (*runtime.Info, error) {
propagationLabels := jobSetTemplateSpec.Labels
if propagationLabels == nil && trainJob.Spec.Labels != nil {
propagationLabels = make(map[string]string, len(trainJob.Spec.Labels))
}
for k, v := range trainJob.Spec.Labels {
// The JobSetTemplateSpec labels are overridden by the TrainJob Labels (.spec.labels).
propagationLabels[k] = v
}
propagationAnnotations := jobSetTemplateSpec.Annotations
if propagationAnnotations == nil && trainJob.Spec.Annotations != nil {
propagationAnnotations = make(map[string]string, len(trainJob.Spec.Annotations))
}
for k, v := range trainJob.Spec.Annotations {
// The JobSetTemplateSpec annotations are overridden by the TrainJob Annotations (.spec.annotations).
propagationAnnotations[k] = v
}
jobSetSpecApply, err := apply.FromTypedObjWithFields[jobsetv1alpha2ac.JobSetSpecApplyConfiguration](&jobsetv1alpha2.JobSet{
TypeMeta: metav1.TypeMeta{
APIVersion: jobsetv1alpha2.GroupVersion.String(),
Kind: "JobSet",
},
Spec: jobSetTemplateSpec.Spec,
}, "spec")
if err != nil {
return nil, err
}
opts := []runtime.InfoOption{
runtime.WithLabels(propagationLabels),
runtime.WithAnnotations(propagationAnnotations),
runtime.WithMLPolicySource(mlPolicy),
runtime.WithPodGroupPolicy(podGroupPolicy),
runtime.WithTemplateSpecObjApply(jobSetSpecApply),
runtime.WithPodSetSyncer(syncPodSets),
}
for i, rJob := range jobSetSpecApply.ReplicatedJobs {
// TODO: Support multiple replicas ('.template.spec.replicatedJobs[*].replicas') for replicated Jobs.
// REF: https://github.com/kubeflow/trainer/issues/2318
count := ptr.Deref(rJob.Template.Spec.Parallelism, 1)
var ancestor *string
if metadata := rJob.Template.ObjectMetaApplyConfiguration; metadata != nil && metadata.Labels != nil {
if labelAncestor, ok := metadata.Labels[constants.LabelTrainJobAncestor]; ok {
if labelAncestor == constants.AncestorTrainer && mlPolicy != nil {
count = ptr.Deref(mlPolicy.NumNodes, 1)
}
ancestor = &labelAncestor
}
}
opts = append(opts, runtime.WithPodSet(
*rJob.Name,
ancestor,
count,
*jobSetTemplateSpec.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.DeepCopy(),
rJob.Template.Spec.Template.Spec),
)
}
return runtime.NewInfo(opts...), nil
}
func syncPodSets(info *runtime.Info) {
jsSpec, ok := runtime.TemplateSpecApply[jobsetv1alpha2ac.JobSetSpecApplyConfiguration](info)
if !ok {
return
}
for psIdx, ps := range info.TemplateSpec.PodSets {
if ps.Count != nil {
jsSpec.ReplicatedJobs[psIdx].Template.Spec.Parallelism = ps.Count
jsSpec.ReplicatedJobs[psIdx].Template.Spec.Completions = ps.Count
}
apply.UpsertVolumes(&jsSpec.ReplicatedJobs[psIdx].Template.Spec.Template.Spec.Volumes, ps.Volumes...)
for containerIdx, container := range ps.Containers {
apply.UpsertEnvVar(
&jsSpec.ReplicatedJobs[psIdx].Template.Spec.Template.Spec.Containers[containerIdx].Env,
container.Env...,
)
apply.UpsertPort(
&jsSpec.ReplicatedJobs[psIdx].Template.Spec.Template.Spec.Containers[containerIdx].Ports,
container.Ports...,
)
apply.UpsertVolumeMounts(
&jsSpec.ReplicatedJobs[psIdx].Template.Spec.Template.Spec.Containers[containerIdx].VolumeMounts,
container.VolumeMounts...,
)
}
}
}
func (r *TrainingRuntime) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) {
return r.framework.RunTerminalConditionPlugins(ctx, trainJob)
}
func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
var builders []runtime.ReconcilerBuilder
for _, ex := range r.framework.WatchExtensionPlugins() {
builders = append(builders, ex.ReconcilerBuilders()...)
}
return builders
}
func (r *TrainingRuntime) ValidateObjects(ctx context.Context, old, new *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
trainingRuntime := &trainer.TrainingRuntime{}
if err := r.client.Get(ctx, client.ObjectKey{
Namespace: new.Namespace,
Name: new.Spec.RuntimeRef.Name,
}, trainingRuntime); err != nil {
return nil, field.ErrorList{
field.Invalid(field.NewPath("spec", "runtimeRef"), new.Spec.RuntimeRef,
fmt.Sprintf("%v: specified trainingRuntime must be created before the TrainJob is created", err)),
}
}
info, _ := r.newRuntimeInfo(new, trainingRuntime.Spec.Template, trainingRuntime.Spec.MLPolicy, trainingRuntime.Spec.PodGroupPolicy) // ignoring the error here as the runtime configured should be valid
return r.framework.RunCustomValidationPlugins(info, old, new)
}