model-runner/pkg/inference/scheduling/scheduler.go

231 lines
7.2 KiB
Go

package scheduling
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"github.com/docker/model-distribution/distribution"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/models"
"github.com/docker/model-runner/pkg/logging"
"golang.org/x/sync/errgroup"
)
// Scheduler is used to coordinate inference scheduling across multiple backends
// and models.
type Scheduler struct {
// log is the associated logger.
log logging.Logger
// backends are the supported inference backends.
backends map[string]inference.Backend
// defaultBackend is the default inference backend. It may be nil.
defaultBackend inference.Backend
// modelManager is the shared model manager.
modelManager *models.Manager
// installer is the backend installer.
installer *installer
// loader is the backend loader.
loader *loader
// router is the HTTP request router.
router *http.ServeMux
}
// NewScheduler creates a new inference scheduler.
func NewScheduler(
log logging.Logger,
backends map[string]inference.Backend,
defaultBackend inference.Backend,
modelManager *models.Manager,
httpClient *http.Client,
) *Scheduler {
// Create the scheduler.
s := &Scheduler{
log: log,
backends: backends,
defaultBackend: defaultBackend,
modelManager: modelManager,
installer: newInstaller(log, backends, httpClient),
loader: newLoader(log, backends, modelManager),
router: http.NewServeMux(),
}
// Register routes.
s.router.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "not found", http.StatusNotFound)
})
for route, handler := range s.routeHandlers() {
s.router.HandleFunc(route, handler)
}
// Scheduler successfully initialized.
return s
}
func (s *Scheduler) routeHandlers() map[string]http.HandlerFunc {
openAIRoutes := []string{
"POST " + inference.InferencePrefix + "/{backend}/v1/chat/completions",
"POST " + inference.InferencePrefix + "/{backend}/v1/completions",
"POST " + inference.InferencePrefix + "/{backend}/v1/embeddings",
"POST " + inference.InferencePrefix + "/v1/chat/completions",
"POST " + inference.InferencePrefix + "/v1/completions",
"POST " + inference.InferencePrefix + "/v1/embeddings",
}
m := make(map[string]http.HandlerFunc)
for _, route := range openAIRoutes {
m[route] = s.handleOpenAIInference
}
m["GET "+inference.InferencePrefix+"/status"] = s.GetBackendStatus
return m
}
func (s *Scheduler) GetRoutes() []string {
routeHandlers := s.routeHandlers()
routes := make([]string, 0, len(routeHandlers))
for route := range routeHandlers {
routes = append(routes, route)
}
return routes
}
// Run is the scheduler's main run loop. By the time it returns, all inference
// backends will have been unloaded from memory.
func (s *Scheduler) Run(ctx context.Context) error {
// Create an error group to track worker Goroutines.
workers, workerCtx := errgroup.WithContext(ctx)
// Start the installer.
workers.Go(func() error {
s.installer.run(workerCtx)
return nil
})
// Start the loader.
workers.Go(func() error {
s.loader.run(workerCtx)
return nil
})
// Wait for all workers to exit.
return workers.Wait()
}
// handleOpenAIInference handles scheduling and responding to OpenAI inference
// requests, including:
// - POST <inference-prefix>/{backend}/v1/chat/completions
// - POST <inference-prefix>/{backend}/v1/completions
// - POST <inference-prefix>/{backend}/v1/embeddings
func (s *Scheduler) handleOpenAIInference(w http.ResponseWriter, r *http.Request) {
// Determine the requested backend and ensure that it's valid.
var backend inference.Backend
if b := r.PathValue("backend"); b == "" {
backend = s.defaultBackend
} else {
backend = s.backends[b]
}
if backend == nil {
http.Error(w, ErrBackendNotFound.Error(), http.StatusNotFound)
return
}
// Read the entire request body. We put some basic size constraints in place
// to avoid DoS attacks. We do this early to avoid client write timeouts.
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize))
if err != nil {
if _, ok := err.(*http.MaxBytesError); ok {
http.Error(w, "request too large", http.StatusBadRequest)
} else {
http.Error(w, "unknown error", http.StatusInternalServerError)
}
return
}
// Wait for the corresponding backend installation to complete or fail. We
// don't allow any requests to be scheduled for a backend until it has
// completed installation.
if err := s.installer.wait(r.Context(), backend.Name()); err != nil {
if errors.Is(err, ErrBackendNotFound) {
http.Error(w, err.Error(), http.StatusNotFound)
} else if errors.Is(err, errInstallerNotStarted) {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
} else if errors.Is(err, context.Canceled) {
// This could be due to the client aborting the request (in which
// case this response will be ignored) or the inference service
// shutting down (since that will also cancel the request context).
// Either way, provide a response, even if it's ignored.
http.Error(w, "service unavailable", http.StatusServiceUnavailable)
} else {
http.Error(w, fmt.Errorf("backend installation failed: %w", err).Error(), http.StatusServiceUnavailable)
}
return
}
// Determine the backend operation mode.
backendMode, ok := backendModeForRequest(r.URL.Path)
if !ok {
http.Error(w, "unknown request path", http.StatusInternalServerError)
return
}
// Decode the model specification portion of the request body.
var request OpenAIInferenceRequest
if err := json.Unmarshal(body, &request); err != nil {
http.Error(w, "invalid request", http.StatusBadRequest)
return
}
if request.Model == "" {
http.Error(w, "model is required", http.StatusBadRequest)
return
}
// Check if the shared model manager has the requested model available.
if !backend.UsesExternalModelManagement() {
if _, err := s.modelManager.GetModel(request.Model); err != nil {
if errors.Is(err, distribution.ErrModelNotFound) {
http.Error(w, err.Error(), http.StatusNotFound)
} else {
http.Error(w, "model unavailable", http.StatusInternalServerError)
}
return
}
}
// Request a runner to execute the request and defer its release.
runner, err := s.loader.load(r.Context(), backend.Name(), request.Model, backendMode)
if err != nil {
http.Error(w, fmt.Errorf("unable to load runner: %w", err).Error(), http.StatusInternalServerError)
return
}
defer s.loader.release(runner)
// Create a request with the body replaced for forwarding upstream.
upstreamRequest := r.Clone(r.Context())
upstreamRequest.Body = io.NopCloser(bytes.NewReader(body))
// Perform the request.
runner.ServeHTTP(w, upstreamRequest)
}
func (s *Scheduler) GetBackendStatus(w http.ResponseWriter, r *http.Request) {
status := make(map[string]string)
for backendName, backend := range s.backends {
status[backendName] = backend.Status()
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(status)
}
func (s *Scheduler) ResetInstaller(httpClient *http.Client) {
s.installer = newInstaller(s.log, s.backends, httpClient)
}
// ServeHTTP implements net/http.Handler.ServeHTTP.
func (s *Scheduler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.router.ServeHTTP(w, r)
}