/* Copyright 2024 The Kubeflow Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package core import ( "context" "errors" "fmt" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/util/validation/field" "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" jobsetv1alpha2ac "sigs.k8s.io/jobset/client-go/applyconfiguration/jobset/v1alpha2" trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1" "github.com/kubeflow/trainer/pkg/apply" "github.com/kubeflow/trainer/pkg/constants" "github.com/kubeflow/trainer/pkg/runtime" fwkcore "github.com/kubeflow/trainer/pkg/runtime/framework/core" fwkplugins "github.com/kubeflow/trainer/pkg/runtime/framework/plugins" idxer "github.com/kubeflow/trainer/pkg/runtime/indexer" ) var ( errorNotFoundSpecifiedTrainingRuntime = errors.New("TrainingRuntime specified in TrainJob is not found") ) type TrainingRuntime struct { framework *fwkcore.Framework client client.Client } var TrainingRuntimeGroupKind = schema.GroupKind{ Group: trainer.GroupVersion.Group, Kind: trainer.TrainingRuntimeKind, }.String() var _ runtime.Runtime = (*TrainingRuntime)(nil) var trainingRuntimeFactory *TrainingRuntime func NewTrainingRuntime(ctx context.Context, c client.Client, indexer client.FieldIndexer) (runtime.Runtime, error) { if err := indexer.IndexField(ctx, &trainer.TrainJob{}, idxer.TrainJobRuntimeRefKey, idxer.IndexTrainJobTrainingRuntime); err != nil { return nil, fmt.Errorf("setting index on TrainingRuntime for TrainJob: %w", err) } if err := indexer.IndexField(ctx, &trainer.TrainJob{}, idxer.TrainJobClusterRuntimeRefKey, idxer.IndexTrainJobClusterTrainingRuntime); err != nil { return nil, fmt.Errorf("setting index on ClusterTrainingRuntime for TrainJob: %w", err) } fwk, err := fwkcore.New(ctx, c, fwkplugins.NewRegistry(), indexer) if err != nil { return nil, err } trainingRuntimeFactory = &TrainingRuntime{ framework: fwk, client: c, } return trainingRuntimeFactory, nil } func (r *TrainingRuntime) NewObjects(ctx context.Context, trainJob *trainer.TrainJob) ([]any, error) { var trainingRuntime trainer.TrainingRuntime err := r.client.Get(ctx, client.ObjectKey{Namespace: trainJob.Namespace, Name: trainJob.Spec.RuntimeRef.Name}, &trainingRuntime) if err != nil { return nil, fmt.Errorf("%w: %w", errorNotFoundSpecifiedTrainingRuntime, err) } return r.buildObjects(ctx, trainJob, trainingRuntime.Spec.Template, trainingRuntime.Spec.MLPolicy, trainingRuntime.Spec.PodGroupPolicy) } func (r *TrainingRuntime) buildObjects( ctx context.Context, trainJob *trainer.TrainJob, jobSetTemplateSpec trainer.JobSetTemplateSpec, mlPolicy *trainer.MLPolicy, podGroupPolicy *trainer.PodGroupPolicy, ) ([]any, error) { info, err := r.newRuntimeInfo(trainJob, jobSetTemplateSpec, mlPolicy, podGroupPolicy) if err != nil { return nil, err } if err = r.framework.RunEnforceMLPolicyPlugins(info, trainJob); err != nil { return nil, err } if err = r.framework.RunEnforcePodGroupPolicyPlugins(info, trainJob); err != nil { return nil, err } if err = r.framework.RunPodNetworkPlugins(info, trainJob); err != nil { return nil, err } return r.framework.RunComponentBuilderPlugins(ctx, info, trainJob) } func (r *TrainingRuntime) newRuntimeInfo( trainJob *trainer.TrainJob, jobSetTemplateSpec trainer.JobSetTemplateSpec, mlPolicy *trainer.MLPolicy, podGroupPolicy *trainer.PodGroupPolicy, ) (*runtime.Info, error) { propagationLabels := jobSetTemplateSpec.Labels if propagationLabels == nil && trainJob.Spec.Labels != nil { propagationLabels = make(map[string]string, len(trainJob.Spec.Labels)) } for k, v := range trainJob.Spec.Labels { // The JobSetTemplateSpec labels are overridden by the TrainJob Labels (.spec.labels). propagationLabels[k] = v } propagationAnnotations := jobSetTemplateSpec.Annotations if propagationAnnotations == nil && trainJob.Spec.Annotations != nil { propagationAnnotations = make(map[string]string, len(trainJob.Spec.Annotations)) } for k, v := range trainJob.Spec.Annotations { // The JobSetTemplateSpec annotations are overridden by the TrainJob Annotations (.spec.annotations). propagationAnnotations[k] = v } jobSetSpecApply, err := apply.FromTypedObjWithFields[jobsetv1alpha2ac.JobSetSpecApplyConfiguration](&jobsetv1alpha2.JobSet{ TypeMeta: metav1.TypeMeta{ APIVersion: jobsetv1alpha2.GroupVersion.String(), Kind: "JobSet", }, Spec: jobSetTemplateSpec.Spec, }, "spec") if err != nil { return nil, err } opts := []runtime.InfoOption{ runtime.WithLabels(propagationLabels), runtime.WithAnnotations(propagationAnnotations), runtime.WithMLPolicySource(mlPolicy), runtime.WithPodGroupPolicy(podGroupPolicy), runtime.WithTemplateSpecObjApply(jobSetSpecApply), runtime.WithPodSetSyncer(syncPodSets), } for i, rJob := range jobSetSpecApply.ReplicatedJobs { // TODO: Support multiple replicas ('.template.spec.replicatedJobs[*].replicas') for replicated Jobs. // REF: https://github.com/kubeflow/trainer/issues/2318 count := ptr.Deref(rJob.Template.Spec.Parallelism, 1) var ancestor *string if metadata := rJob.Template.ObjectMetaApplyConfiguration; metadata != nil && metadata.Labels != nil { if labelAncestor, ok := metadata.Labels[constants.LabelTrainJobAncestor]; ok { if labelAncestor == constants.AncestorTrainer && mlPolicy != nil { count = ptr.Deref(mlPolicy.NumNodes, 1) } ancestor = &labelAncestor } } opts = append(opts, runtime.WithPodSet( *rJob.Name, ancestor, count, *jobSetTemplateSpec.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.DeepCopy(), rJob.Template.Spec.Template.Spec), ) } return runtime.NewInfo(opts...), nil } func syncPodSets(info *runtime.Info) { jsSpec, ok := runtime.TemplateSpecApply[jobsetv1alpha2ac.JobSetSpecApplyConfiguration](info) if !ok { return } for psIdx, ps := range info.TemplateSpec.PodSets { if ps.Count != nil { jsSpec.ReplicatedJobs[psIdx].Template.Spec.Parallelism = ps.Count jsSpec.ReplicatedJobs[psIdx].Template.Spec.Completions = ps.Count } apply.UpsertVolumes(&jsSpec.ReplicatedJobs[psIdx].Template.Spec.Template.Spec.Volumes, ps.Volumes...) for containerIdx, container := range ps.Containers { apply.UpsertEnvVar( &jsSpec.ReplicatedJobs[psIdx].Template.Spec.Template.Spec.Containers[containerIdx].Env, container.Env..., ) apply.UpsertPort( &jsSpec.ReplicatedJobs[psIdx].Template.Spec.Template.Spec.Containers[containerIdx].Ports, container.Ports..., ) apply.UpsertVolumeMounts( &jsSpec.ReplicatedJobs[psIdx].Template.Spec.Template.Spec.Containers[containerIdx].VolumeMounts, container.VolumeMounts..., ) } } } func (r *TrainingRuntime) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) { return r.framework.RunTerminalConditionPlugins(ctx, trainJob) } func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder { var builders []runtime.ReconcilerBuilder for _, ex := range r.framework.WatchExtensionPlugins() { builders = append(builders, ex.ReconcilerBuilders()...) } return builders } func (r *TrainingRuntime) ValidateObjects(ctx context.Context, old, new *trainer.TrainJob) (admission.Warnings, field.ErrorList) { trainingRuntime := &trainer.TrainingRuntime{} if err := r.client.Get(ctx, client.ObjectKey{ Namespace: new.Namespace, Name: new.Spec.RuntimeRef.Name, }, trainingRuntime); err != nil { return nil, field.ErrorList{ field.Invalid(field.NewPath("spec", "runtimeRef"), new.Spec.RuntimeRef, fmt.Sprintf("%v: specified trainingRuntime must be created before the TrainJob is created", err)), } } info, _ := r.newRuntimeInfo(new, trainingRuntime.Spec.Template, trainingRuntime.Spec.MLPolicy, trainingRuntime.Spec.PodGroupPolicy) // ignoring the error here as the runtime configured should be valid return r.framework.RunCustomValidationPlugins(info, old, new) }