model-runner/pkg/inference/scheduling/scheduler.go

519 lines
17 KiB
Go

package scheduling
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/docker/model-distribution/distribution"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/memory"
"github.com/docker/model-runner/pkg/inference/models"
"github.com/docker/model-runner/pkg/logging"
"github.com/docker/model-runner/pkg/metrics"
"github.com/mattn/go-shellwords"
"golang.org/x/sync/errgroup"
)
// Scheduler is used to coordinate inference scheduling across multiple backends
// and models.
type Scheduler struct {
// log is the associated logger.
log logging.Logger
// backends are the supported inference backends.
backends map[string]inference.Backend
// defaultBackend is the default inference backend. It may be nil.
defaultBackend inference.Backend
// modelManager is the shared model manager.
modelManager *models.Manager
// installer is the backend installer.
installer *installer
// loader is the backend loader.
loader *loader
// router is the HTTP request router.
router *http.ServeMux
// tracker is the metrics tracker.
tracker *metrics.Tracker
// openAIRecorder is used to record OpenAI API inference requests and responses.
openAIRecorder *metrics.OpenAIRecorder
// lock is used to synchronize access to the scheduler's router.
lock sync.RWMutex
}
// NewScheduler creates a new inference scheduler.
func NewScheduler(
log logging.Logger,
backends map[string]inference.Backend,
defaultBackend inference.Backend,
modelManager *models.Manager,
httpClient *http.Client,
allowedOrigins []string,
tracker *metrics.Tracker,
sysMemInfo memory.SystemMemoryInfo,
) *Scheduler {
openAIRecorder := metrics.NewOpenAIRecorder(log.WithField("component", "openai-recorder"), modelManager)
// Create the scheduler.
s := &Scheduler{
log: log,
backends: backends,
defaultBackend: defaultBackend,
modelManager: modelManager,
installer: newInstaller(log, backends, httpClient),
loader: newLoader(log, backends, modelManager, openAIRecorder, sysMemInfo),
router: http.NewServeMux(),
tracker: tracker,
openAIRecorder: openAIRecorder,
}
// Register routes.
s.router.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "not found", http.StatusNotFound)
})
for route, handler := range s.routeHandlers(allowedOrigins) {
s.router.HandleFunc(route, handler)
}
// Scheduler successfully initialized.
return s
}
func (s *Scheduler) RebuildRoutes(allowedOrigins []string) {
s.lock.Lock()
defer s.lock.Unlock()
// Clear existing routes and re-register them.
s.router = http.NewServeMux()
// Register routes.
s.router.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "not found", http.StatusNotFound)
})
for route, handler := range s.routeHandlers(allowedOrigins) {
s.router.HandleFunc(route, handler)
}
}
func (s *Scheduler) routeHandlers(allowedOrigins []string) map[string]http.HandlerFunc {
openAIRoutes := []string{
"POST " + inference.InferencePrefix + "/{backend}/v1/chat/completions",
"POST " + inference.InferencePrefix + "/{backend}/v1/completions",
"POST " + inference.InferencePrefix + "/{backend}/v1/embeddings",
"POST " + inference.InferencePrefix + "/v1/chat/completions",
"POST " + inference.InferencePrefix + "/v1/completions",
"POST " + inference.InferencePrefix + "/v1/embeddings",
}
m := make(map[string]http.HandlerFunc)
for _, route := range openAIRoutes {
m[route] = inference.CorsMiddleware(allowedOrigins, http.HandlerFunc(s.handleOpenAIInference)).ServeHTTP
// Register OPTIONS for CORS preflight.
optionsRoute := "OPTIONS " + route[strings.Index(route, " "):]
m[optionsRoute] = inference.CorsMiddleware(allowedOrigins, http.HandlerFunc(s.handleOpenAIInference)).ServeHTTP
}
m["GET "+inference.InferencePrefix+"/status"] = s.GetBackendStatus
m["GET "+inference.InferencePrefix+"/ps"] = s.GetRunningBackends
m["GET "+inference.InferencePrefix+"/df"] = s.GetDiskUsage
m["POST "+inference.InferencePrefix+"/unload"] = s.Unload
m["POST "+inference.InferencePrefix+"/{backend}/_configure"] = s.Configure
m["POST "+inference.InferencePrefix+"/_configure"] = s.Configure
m["GET "+inference.InferencePrefix+"/requests"] = s.openAIRecorder.GetRecordsByModelHandler()
return m
}
func (s *Scheduler) GetRoutes() []string {
routeHandlers := s.routeHandlers(nil)
routes := make([]string, 0, len(routeHandlers))
for route := range routeHandlers {
routes = append(routes, route)
}
return routes
}
// Run is the scheduler's main run loop. By the time it returns, all inference
// backends will have been unloaded from memory.
func (s *Scheduler) Run(ctx context.Context) error {
// Create an error group to track worker Goroutines.
workers, workerCtx := errgroup.WithContext(ctx)
// Start the installer.
workers.Go(func() error {
s.installer.run(workerCtx)
return nil
})
// Start the loader.
workers.Go(func() error {
s.loader.run(workerCtx)
return nil
})
// Wait for all workers to exit.
return workers.Wait()
}
// handleOpenAIInference handles scheduling and responding to OpenAI inference
// requests, including:
// - POST <inference-prefix>/{backend}/v1/chat/completions
// - POST <inference-prefix>/{backend}/v1/completions
// - POST <inference-prefix>/{backend}/v1/embeddings
func (s *Scheduler) handleOpenAIInference(w http.ResponseWriter, r *http.Request) {
// Determine the requested backend and ensure that it's valid.
var backend inference.Backend
if b := r.PathValue("backend"); b == "" {
backend = s.defaultBackend
} else {
backend = s.backends[b]
}
if backend == nil {
http.Error(w, ErrBackendNotFound.Error(), http.StatusNotFound)
return
}
// Read the entire request body. We put some basic size constraints in place
// to avoid DoS attacks. We do this early to avoid client write timeouts.
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize))
if err != nil {
if _, ok := err.(*http.MaxBytesError); ok {
http.Error(w, "request too large", http.StatusBadRequest)
} else {
http.Error(w, "unknown error", http.StatusInternalServerError)
}
return
}
// Wait for the corresponding backend installation to complete or fail. We
// don't allow any requests to be scheduled for a backend until it has
// completed installation.
if err := s.installer.wait(r.Context(), backend.Name()); err != nil {
if errors.Is(err, ErrBackendNotFound) {
http.Error(w, err.Error(), http.StatusNotFound)
} else if errors.Is(err, errInstallerNotStarted) {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
} else if errors.Is(err, context.Canceled) {
// This could be due to the client aborting the request (in which
// case this response will be ignored) or the inference service
// shutting down (since that will also cancel the request context).
// Either way, provide a response, even if it's ignored.
http.Error(w, "service unavailable", http.StatusServiceUnavailable)
} else {
http.Error(w, fmt.Errorf("backend installation failed: %w", err).Error(), http.StatusServiceUnavailable)
}
return
}
// Determine the backend operation mode.
backendMode, ok := backendModeForRequest(r.URL.Path)
if !ok {
http.Error(w, "unknown request path", http.StatusInternalServerError)
return
}
// Decode the model specification portion of the request body.
var request OpenAIInferenceRequest
if err := json.Unmarshal(body, &request); err != nil {
http.Error(w, "invalid request", http.StatusBadRequest)
return
}
if request.Model == "" {
http.Error(w, "model is required", http.StatusBadRequest)
return
}
// Check if the shared model manager has the requested model available.
if !backend.UsesExternalModelManagement() {
model, err := s.modelManager.GetModel(request.Model)
if err != nil {
if errors.Is(err, distribution.ErrModelNotFound) {
http.Error(w, err.Error(), http.StatusNotFound)
} else {
http.Error(w, "model unavailable", http.StatusInternalServerError)
}
return
}
// Non-blocking call to track the model usage.
s.tracker.TrackModel(model, r.UserAgent())
}
modelID := s.modelManager.ResolveModelID(request.Model)
// Request a runner to execute the request and defer its release.
runner, err := s.loader.load(r.Context(), backend.Name(), modelID, request.Model, backendMode)
if err != nil {
http.Error(w, fmt.Errorf("unable to load runner: %w", err).Error(), http.StatusInternalServerError)
return
}
defer s.loader.release(runner)
// Record the request in the OpenAI recorder.
recordID := s.openAIRecorder.RecordRequest(request.Model, r, body)
w = s.openAIRecorder.NewResponseRecorder(w)
defer func() {
// Record the response in the OpenAI recorder.
s.openAIRecorder.RecordResponse(recordID, request.Model, w)
}()
// Create a request with the body replaced for forwarding upstream.
upstreamRequest := r.Clone(r.Context())
upstreamRequest.Body = io.NopCloser(bytes.NewReader(body))
// Perform the request.
runner.ServeHTTP(w, upstreamRequest)
}
func (s *Scheduler) GetBackendStatus(w http.ResponseWriter, r *http.Request) {
status := make(map[string]string)
for backendName, backend := range s.backends {
status[backendName] = backend.Status()
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(status)
}
func (s *Scheduler) ResetInstaller(httpClient *http.Client) {
s.installer = newInstaller(s.log, s.backends, httpClient)
}
// GetRunningBackends returns information about all running backends
func (s *Scheduler) GetRunningBackends(w http.ResponseWriter, r *http.Request) {
runningBackends := s.getLoaderStatus(r.Context())
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(runningBackends); err != nil {
http.Error(w, fmt.Sprintf("Failed to encode response: %v", err), http.StatusInternalServerError)
return
}
}
// getLoaderStatus returns information about all running backends managed by the loader
func (s *Scheduler) getLoaderStatus(ctx context.Context) []BackendStatus {
if !s.loader.lock(ctx) {
return []BackendStatus{}
}
defer s.loader.unlock()
result := make([]BackendStatus, 0, len(s.loader.runners))
for key, runnerInfo := range s.loader.runners {
if s.loader.slots[runnerInfo.slot] != nil {
status := BackendStatus{
BackendName: key.backend,
ModelName: runnerInfo.modelRef,
Mode: key.mode.String(),
LastUsed: time.Time{},
}
if s.loader.references[runnerInfo.slot] == 0 {
status.LastUsed = s.loader.timestamps[runnerInfo.slot]
}
result = append(result, status)
}
}
return result
}
func (s *Scheduler) GetDiskUsage(w http.ResponseWriter, _ *http.Request) {
modelsDiskUsage, err, httpCode := s.modelManager.GetDiskUsage()
if err != nil {
http.Error(w, fmt.Sprintf("Failed to get models disk usage: %v", err), httpCode)
return
}
// TODO: Get disk usage for each backend once the backends are implemented.
defaultBackendDiskUsage, err := s.defaultBackend.GetDiskUsage()
if err != nil {
http.Error(w, fmt.Sprintf("Failed to get disk usage for %s: %v", s.defaultBackend.Name(), err), http.StatusInternalServerError)
return
}
diskUsage := DiskUsage{modelsDiskUsage, defaultBackendDiskUsage}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(diskUsage); err != nil {
http.Error(w, fmt.Sprintf("Failed to encode response: %v", err), http.StatusInternalServerError)
return
}
}
// Unload unloads the specified runners (backend, model) from the backend.
// Currently, this doesn't work for runners that are handling an OpenAI request.
func (s *Scheduler) Unload(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize))
if err != nil {
if _, ok := err.(*http.MaxBytesError); ok {
http.Error(w, "request too large", http.StatusBadRequest)
} else {
http.Error(w, "unknown error", http.StatusInternalServerError)
}
return
}
var unloadRequest UnloadRequest
if err := json.Unmarshal(body, &unloadRequest); err != nil {
http.Error(w, "invalid request", http.StatusBadRequest)
return
}
unloadedRunners := UnloadResponse{s.loader.Unload(r.Context(), unloadRequest)}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(unloadedRunners); err != nil {
http.Error(w, fmt.Sprintf("Failed to encode response: %v", err), http.StatusInternalServerError)
return
}
}
func (s *Scheduler) Configure(w http.ResponseWriter, r *http.Request) {
// Determine the requested backend and ensure that it's valid.
var backend inference.Backend
if b := r.PathValue("backend"); b == "" {
backend = s.defaultBackend
} else {
backend = s.backends[b]
}
if backend == nil {
http.Error(w, ErrBackendNotFound.Error(), http.StatusNotFound)
return
}
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize))
if err != nil {
if _, ok := err.(*http.MaxBytesError); ok {
http.Error(w, "request too large", http.StatusBadRequest)
} else {
http.Error(w, "unknown error", http.StatusInternalServerError)
}
return
}
configureRequest := ConfigureRequest{
ContextSize: -1,
}
if err := json.Unmarshal(body, &configureRequest); err != nil {
http.Error(w, "invalid request", http.StatusBadRequest)
return
}
var runtimeFlags []string
if len(configureRequest.RuntimeFlags) > 0 {
runtimeFlags = configureRequest.RuntimeFlags
} else {
rawFlags, err := shellwords.Parse(configureRequest.RawRuntimeFlags)
if err != nil {
http.Error(w, "invalid request", http.StatusBadRequest)
return
}
runtimeFlags = rawFlags
}
var runnerConfig inference.BackendConfiguration
runnerConfig.ContextSize = configureRequest.ContextSize
runnerConfig.RuntimeFlags = runtimeFlags
if model, err := s.modelManager.GetModel(configureRequest.Model); err == nil {
// Configure is called by compose for each model.
s.tracker.TrackModel(model, r.UserAgent())
}
modelID := s.modelManager.ResolveModelID(configureRequest.Model)
if err := s.loader.setRunnerConfig(r.Context(), backend.Name(), modelID, inference.BackendModeCompletion, runnerConfig); err != nil {
s.log.Warnf("Failed to configure %s runner for %s (%s): %s", backend.Name(), configureRequest.Model, modelID, err)
if errors.Is(err, errRunnerAlreadyActive) {
http.Error(w, err.Error(), http.StatusConflict)
} else {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
}
w.WriteHeader(http.StatusAccepted)
}
// GetAllActiveRunners returns information about all active runners
func (s *Scheduler) GetAllActiveRunners() []metrics.ActiveRunner {
runningBackends := s.getLoaderStatus(context.Background())
var activeRunners []metrics.ActiveRunner
if !s.loader.lock(context.Background()) {
return activeRunners
}
defer s.loader.unlock()
for _, backend := range runningBackends {
// Find the runner slot for this backend/model combination
key := runnerKey{
backend: backend.BackendName,
modelID: backend.ModelName,
mode: parseBackendMode(backend.Mode),
}
if runnerInfo, exists := s.loader.runners[key]; exists {
socket, err := RunnerSocketPath(runnerInfo.slot)
if err != nil {
s.log.Warnf("Failed to get socket path for runner %s/%s (%s): %v", backend.BackendName, backend.ModelName, key.modelID, err)
continue
}
activeRunners = append(activeRunners, metrics.ActiveRunner{
BackendName: backend.BackendName,
ModelName: backend.ModelName,
Mode: backend.Mode,
Socket: socket,
})
}
}
return activeRunners
}
// GetLlamaCppSocket returns the Unix socket path for an active llama.cpp runner
func (s *Scheduler) GetLlamaCppSocket() (string, error) {
runningBackends := s.getLoaderStatus(context.Background())
if !s.loader.lock(context.Background()) {
return "", errors.New("failed to acquire loader lock")
}
defer s.loader.unlock()
// Look for an active llama.cpp backend
for _, backend := range runningBackends {
if backend.BackendName == "llama.cpp" {
// Find the runner slot for this backend/model combination
key := runnerKey{
backend: backend.BackendName,
modelID: backend.ModelName,
mode: parseBackendMode(backend.Mode),
}
if runnerInfo, exists := s.loader.runners[key]; exists {
// Use the RunnerSocketPath function to get the socket path
return RunnerSocketPath(runnerInfo.slot)
}
}
}
return "", errors.New("no active llama.cpp backend found")
}
// parseBackendMode converts a string mode to BackendMode
func parseBackendMode(mode string) inference.BackendMode {
switch mode {
case "completion":
return inference.BackendModeCompletion
case "embedding":
return inference.BackendModeEmbedding
default:
return inference.BackendModeCompletion
}
}
// ServeHTTP implements net/http.Handler.ServeHTTP.
func (s *Scheduler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.lock.RLock()
defer s.lock.RUnlock()
s.router.ServeHTTP(w, r)
}