package scheduling import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "net/http" "strings" "sync" "time" "github.com/docker/model-distribution/distribution" "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/inference/memory" "github.com/docker/model-runner/pkg/inference/models" "github.com/docker/model-runner/pkg/logging" "github.com/docker/model-runner/pkg/metrics" "github.com/mattn/go-shellwords" "golang.org/x/sync/errgroup" ) // Scheduler is used to coordinate inference scheduling across multiple backends // and models. type Scheduler struct { // log is the associated logger. log logging.Logger // backends are the supported inference backends. backends map[string]inference.Backend // defaultBackend is the default inference backend. It may be nil. defaultBackend inference.Backend // modelManager is the shared model manager. modelManager *models.Manager // installer is the backend installer. installer *installer // loader is the backend loader. loader *loader // router is the HTTP request router. router *http.ServeMux // tracker is the metrics tracker. tracker *metrics.Tracker // openAIRecorder is used to record OpenAI API inference requests and responses. openAIRecorder *metrics.OpenAIRecorder // lock is used to synchronize access to the scheduler's router. lock sync.RWMutex } // NewScheduler creates a new inference scheduler. func NewScheduler( log logging.Logger, backends map[string]inference.Backend, defaultBackend inference.Backend, modelManager *models.Manager, httpClient *http.Client, allowedOrigins []string, tracker *metrics.Tracker, sysMemInfo memory.SystemMemoryInfo, ) *Scheduler { openAIRecorder := metrics.NewOpenAIRecorder(log.WithField("component", "openai-recorder"), modelManager) // Create the scheduler. s := &Scheduler{ log: log, backends: backends, defaultBackend: defaultBackend, modelManager: modelManager, installer: newInstaller(log, backends, httpClient), loader: newLoader(log, backends, modelManager, openAIRecorder, sysMemInfo), router: http.NewServeMux(), tracker: tracker, openAIRecorder: openAIRecorder, } // Register routes. s.router.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) { http.Error(w, "not found", http.StatusNotFound) }) for route, handler := range s.routeHandlers(allowedOrigins) { s.router.HandleFunc(route, handler) } // Scheduler successfully initialized. return s } func (s *Scheduler) RebuildRoutes(allowedOrigins []string) { s.lock.Lock() defer s.lock.Unlock() // Clear existing routes and re-register them. s.router = http.NewServeMux() // Register routes. s.router.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) { http.Error(w, "not found", http.StatusNotFound) }) for route, handler := range s.routeHandlers(allowedOrigins) { s.router.HandleFunc(route, handler) } } func (s *Scheduler) routeHandlers(allowedOrigins []string) map[string]http.HandlerFunc { openAIRoutes := []string{ "POST " + inference.InferencePrefix + "/{backend}/v1/chat/completions", "POST " + inference.InferencePrefix + "/{backend}/v1/completions", "POST " + inference.InferencePrefix + "/{backend}/v1/embeddings", "POST " + inference.InferencePrefix + "/v1/chat/completions", "POST " + inference.InferencePrefix + "/v1/completions", "POST " + inference.InferencePrefix + "/v1/embeddings", } m := make(map[string]http.HandlerFunc) for _, route := range openAIRoutes { m[route] = inference.CorsMiddleware(allowedOrigins, http.HandlerFunc(s.handleOpenAIInference)).ServeHTTP // Register OPTIONS for CORS preflight. optionsRoute := "OPTIONS " + route[strings.Index(route, " "):] m[optionsRoute] = inference.CorsMiddleware(allowedOrigins, http.HandlerFunc(s.handleOpenAIInference)).ServeHTTP } m["GET "+inference.InferencePrefix+"/status"] = s.GetBackendStatus m["GET "+inference.InferencePrefix+"/ps"] = s.GetRunningBackends m["GET "+inference.InferencePrefix+"/df"] = s.GetDiskUsage m["POST "+inference.InferencePrefix+"/unload"] = s.Unload m["POST "+inference.InferencePrefix+"/{backend}/_configure"] = s.Configure m["POST "+inference.InferencePrefix+"/_configure"] = s.Configure m["GET "+inference.InferencePrefix+"/requests"] = s.openAIRecorder.GetRecordsByModelHandler() return m } func (s *Scheduler) GetRoutes() []string { routeHandlers := s.routeHandlers(nil) routes := make([]string, 0, len(routeHandlers)) for route := range routeHandlers { routes = append(routes, route) } return routes } // Run is the scheduler's main run loop. By the time it returns, all inference // backends will have been unloaded from memory. func (s *Scheduler) Run(ctx context.Context) error { // Create an error group to track worker Goroutines. workers, workerCtx := errgroup.WithContext(ctx) // Start the installer. workers.Go(func() error { s.installer.run(workerCtx) return nil }) // Start the loader. workers.Go(func() error { s.loader.run(workerCtx) return nil }) // Wait for all workers to exit. return workers.Wait() } // handleOpenAIInference handles scheduling and responding to OpenAI inference // requests, including: // - POST /{backend}/v1/chat/completions // - POST /{backend}/v1/completions // - POST /{backend}/v1/embeddings func (s *Scheduler) handleOpenAIInference(w http.ResponseWriter, r *http.Request) { // Determine the requested backend and ensure that it's valid. var backend inference.Backend if b := r.PathValue("backend"); b == "" { backend = s.defaultBackend } else { backend = s.backends[b] } if backend == nil { http.Error(w, ErrBackendNotFound.Error(), http.StatusNotFound) return } // Read the entire request body. We put some basic size constraints in place // to avoid DoS attacks. We do this early to avoid client write timeouts. body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize)) if err != nil { if _, ok := err.(*http.MaxBytesError); ok { http.Error(w, "request too large", http.StatusBadRequest) } else { http.Error(w, "unknown error", http.StatusInternalServerError) } return } // Wait for the corresponding backend installation to complete or fail. We // don't allow any requests to be scheduled for a backend until it has // completed installation. if err := s.installer.wait(r.Context(), backend.Name()); err != nil { if errors.Is(err, ErrBackendNotFound) { http.Error(w, err.Error(), http.StatusNotFound) } else if errors.Is(err, errInstallerNotStarted) { http.Error(w, err.Error(), http.StatusServiceUnavailable) } else if errors.Is(err, context.Canceled) { // This could be due to the client aborting the request (in which // case this response will be ignored) or the inference service // shutting down (since that will also cancel the request context). // Either way, provide a response, even if it's ignored. http.Error(w, "service unavailable", http.StatusServiceUnavailable) } else { http.Error(w, fmt.Errorf("backend installation failed: %w", err).Error(), http.StatusServiceUnavailable) } return } // Determine the backend operation mode. backendMode, ok := backendModeForRequest(r.URL.Path) if !ok { http.Error(w, "unknown request path", http.StatusInternalServerError) return } // Decode the model specification portion of the request body. var request OpenAIInferenceRequest if err := json.Unmarshal(body, &request); err != nil { http.Error(w, "invalid request", http.StatusBadRequest) return } if request.Model == "" { http.Error(w, "model is required", http.StatusBadRequest) return } // Check if the shared model manager has the requested model available. if !backend.UsesExternalModelManagement() { model, err := s.modelManager.GetModel(request.Model) if err != nil { if errors.Is(err, distribution.ErrModelNotFound) { http.Error(w, err.Error(), http.StatusNotFound) } else { http.Error(w, "model unavailable", http.StatusInternalServerError) } return } // Non-blocking call to track the model usage. s.tracker.TrackModel(model, r.UserAgent()) } modelID := s.modelManager.ResolveModelID(request.Model) // Request a runner to execute the request and defer its release. runner, err := s.loader.load(r.Context(), backend.Name(), modelID, request.Model, backendMode) if err != nil { http.Error(w, fmt.Errorf("unable to load runner: %w", err).Error(), http.StatusInternalServerError) return } defer s.loader.release(runner) // Record the request in the OpenAI recorder. recordID := s.openAIRecorder.RecordRequest(request.Model, r, body) w = s.openAIRecorder.NewResponseRecorder(w) defer func() { // Record the response in the OpenAI recorder. s.openAIRecorder.RecordResponse(recordID, request.Model, w) }() // Create a request with the body replaced for forwarding upstream. upstreamRequest := r.Clone(r.Context()) upstreamRequest.Body = io.NopCloser(bytes.NewReader(body)) // Perform the request. runner.ServeHTTP(w, upstreamRequest) } func (s *Scheduler) GetBackendStatus(w http.ResponseWriter, r *http.Request) { status := make(map[string]string) for backendName, backend := range s.backends { status[backendName] = backend.Status() } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(status) } func (s *Scheduler) ResetInstaller(httpClient *http.Client) { s.installer = newInstaller(s.log, s.backends, httpClient) } // GetRunningBackends returns information about all running backends func (s *Scheduler) GetRunningBackends(w http.ResponseWriter, r *http.Request) { runningBackends := s.getLoaderStatus(r.Context()) w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(runningBackends); err != nil { http.Error(w, fmt.Sprintf("Failed to encode response: %v", err), http.StatusInternalServerError) return } } // getLoaderStatus returns information about all running backends managed by the loader func (s *Scheduler) getLoaderStatus(ctx context.Context) []BackendStatus { if !s.loader.lock(ctx) { return []BackendStatus{} } defer s.loader.unlock() result := make([]BackendStatus, 0, len(s.loader.runners)) for key, runnerInfo := range s.loader.runners { if s.loader.slots[runnerInfo.slot] != nil { status := BackendStatus{ BackendName: key.backend, ModelName: runnerInfo.modelRef, Mode: key.mode.String(), LastUsed: time.Time{}, } if s.loader.references[runnerInfo.slot] == 0 { status.LastUsed = s.loader.timestamps[runnerInfo.slot] } result = append(result, status) } } return result } func (s *Scheduler) GetDiskUsage(w http.ResponseWriter, _ *http.Request) { modelsDiskUsage, err, httpCode := s.modelManager.GetDiskUsage() if err != nil { http.Error(w, fmt.Sprintf("Failed to get models disk usage: %v", err), httpCode) return } // TODO: Get disk usage for each backend once the backends are implemented. defaultBackendDiskUsage, err := s.defaultBackend.GetDiskUsage() if err != nil { http.Error(w, fmt.Sprintf("Failed to get disk usage for %s: %v", s.defaultBackend.Name(), err), http.StatusInternalServerError) return } diskUsage := DiskUsage{modelsDiskUsage, defaultBackendDiskUsage} w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(diskUsage); err != nil { http.Error(w, fmt.Sprintf("Failed to encode response: %v", err), http.StatusInternalServerError) return } } // Unload unloads the specified runners (backend, model) from the backend. // Currently, this doesn't work for runners that are handling an OpenAI request. func (s *Scheduler) Unload(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize)) if err != nil { if _, ok := err.(*http.MaxBytesError); ok { http.Error(w, "request too large", http.StatusBadRequest) } else { http.Error(w, "unknown error", http.StatusInternalServerError) } return } var unloadRequest UnloadRequest if err := json.Unmarshal(body, &unloadRequest); err != nil { http.Error(w, "invalid request", http.StatusBadRequest) return } unloadedRunners := UnloadResponse{s.loader.Unload(r.Context(), unloadRequest)} w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(unloadedRunners); err != nil { http.Error(w, fmt.Sprintf("Failed to encode response: %v", err), http.StatusInternalServerError) return } } func (s *Scheduler) Configure(w http.ResponseWriter, r *http.Request) { // Determine the requested backend and ensure that it's valid. var backend inference.Backend if b := r.PathValue("backend"); b == "" { backend = s.defaultBackend } else { backend = s.backends[b] } if backend == nil { http.Error(w, ErrBackendNotFound.Error(), http.StatusNotFound) return } body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize)) if err != nil { if _, ok := err.(*http.MaxBytesError); ok { http.Error(w, "request too large", http.StatusBadRequest) } else { http.Error(w, "unknown error", http.StatusInternalServerError) } return } configureRequest := ConfigureRequest{ ContextSize: -1, } if err := json.Unmarshal(body, &configureRequest); err != nil { http.Error(w, "invalid request", http.StatusBadRequest) return } var runtimeFlags []string if len(configureRequest.RuntimeFlags) > 0 { runtimeFlags = configureRequest.RuntimeFlags } else { rawFlags, err := shellwords.Parse(configureRequest.RawRuntimeFlags) if err != nil { http.Error(w, "invalid request", http.StatusBadRequest) return } runtimeFlags = rawFlags } var runnerConfig inference.BackendConfiguration runnerConfig.ContextSize = configureRequest.ContextSize runnerConfig.RuntimeFlags = runtimeFlags if model, err := s.modelManager.GetModel(configureRequest.Model); err == nil { // Configure is called by compose for each model. s.tracker.TrackModel(model, r.UserAgent()) } modelID := s.modelManager.ResolveModelID(configureRequest.Model) if err := s.loader.setRunnerConfig(r.Context(), backend.Name(), modelID, inference.BackendModeCompletion, runnerConfig); err != nil { s.log.Warnf("Failed to configure %s runner for %s (%s): %s", backend.Name(), configureRequest.Model, modelID, err) if errors.Is(err, errRunnerAlreadyActive) { http.Error(w, err.Error(), http.StatusConflict) } else { http.Error(w, err.Error(), http.StatusInternalServerError) } return } w.WriteHeader(http.StatusAccepted) } // GetAllActiveRunners returns information about all active runners func (s *Scheduler) GetAllActiveRunners() []metrics.ActiveRunner { runningBackends := s.getLoaderStatus(context.Background()) var activeRunners []metrics.ActiveRunner if !s.loader.lock(context.Background()) { return activeRunners } defer s.loader.unlock() for _, backend := range runningBackends { // Find the runner slot for this backend/model combination key := runnerKey{ backend: backend.BackendName, modelID: backend.ModelName, mode: parseBackendMode(backend.Mode), } if runnerInfo, exists := s.loader.runners[key]; exists { socket, err := RunnerSocketPath(runnerInfo.slot) if err != nil { s.log.Warnf("Failed to get socket path for runner %s/%s (%s): %v", backend.BackendName, backend.ModelName, key.modelID, err) continue } activeRunners = append(activeRunners, metrics.ActiveRunner{ BackendName: backend.BackendName, ModelName: backend.ModelName, Mode: backend.Mode, Socket: socket, }) } } return activeRunners } // GetLlamaCppSocket returns the Unix socket path for an active llama.cpp runner func (s *Scheduler) GetLlamaCppSocket() (string, error) { runningBackends := s.getLoaderStatus(context.Background()) if !s.loader.lock(context.Background()) { return "", errors.New("failed to acquire loader lock") } defer s.loader.unlock() // Look for an active llama.cpp backend for _, backend := range runningBackends { if backend.BackendName == "llama.cpp" { // Find the runner slot for this backend/model combination key := runnerKey{ backend: backend.BackendName, modelID: backend.ModelName, mode: parseBackendMode(backend.Mode), } if runnerInfo, exists := s.loader.runners[key]; exists { // Use the RunnerSocketPath function to get the socket path return RunnerSocketPath(runnerInfo.slot) } } } return "", errors.New("no active llama.cpp backend found") } // parseBackendMode converts a string mode to BackendMode func parseBackendMode(mode string) inference.BackendMode { switch mode { case "completion": return inference.BackendModeCompletion case "embedding": return inference.BackendModeEmbedding default: return inference.BackendModeCompletion } } // ServeHTTP implements net/http.Handler.ServeHTTP. func (s *Scheduler) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.lock.RLock() defer s.lock.RUnlock() s.router.ServeHTTP(w, r) }