model-runner/pkg/inference/models/manager.go

704 lines
21 KiB
Go

package models
import (
"context"
"encoding/json"
"errors"
"fmt"
"html"
"net/http"
"path"
"strconv"
"strings"
"sync"
"github.com/docker/model-distribution/distribution"
"github.com/docker/model-distribution/registry"
"github.com/docker/model-distribution/types"
"github.com/sirupsen/logrus"
"github.com/docker/model-runner/pkg/diskusage"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/logging"
)
const (
// maximumConcurrentModelPulls is the maximum number of concurrent model
// pulls that a model manager will allow.
maximumConcurrentModelPulls = 2
)
// Manager manages inference model pulls and storage.
type Manager struct {
// log is the associated logger.
log logging.Logger
// pullTokens is a semaphore used to restrict the maximum number of
// concurrent pull requests.
pullTokens chan struct{}
// router is the HTTP request router.
router *http.ServeMux
// distributionClient is the client for model distribution.
distributionClient *distribution.Client
// registryClient is the client for model registry.
registryClient *registry.Client
// lock is used to synchronize access to the models manager's router.
lock sync.RWMutex
}
type ClientConfig struct {
// StoreRootPath is the root path for the model store.
StoreRootPath string
// Logger is the logger to use.
Logger *logrus.Entry
// Transport is the HTTP transport to use.
Transport http.RoundTripper
// UserAgent is the user agent to use.
UserAgent string
}
// NewManager creates a new model's manager.
func NewManager(log logging.Logger, c ClientConfig, allowedOrigins []string) *Manager {
// Create the model distribution client.
distributionClient, err := distribution.NewClient(
distribution.WithStoreRootPath(c.StoreRootPath),
distribution.WithLogger(c.Logger),
distribution.WithTransport(c.Transport),
distribution.WithUserAgent(c.UserAgent),
)
if err != nil {
log.Errorf("Failed to create distribution client: %v", err)
// Continue without distribution client. The model manager will still
// respond to requests, but may return errors if the client is required.
}
// Create the model registry client.
registryClient := registry.NewClient(
registry.WithTransport(c.Transport),
registry.WithUserAgent(c.UserAgent),
)
// Create the manager.
m := &Manager{
log: log,
pullTokens: make(chan struct{}, maximumConcurrentModelPulls),
router: http.NewServeMux(),
distributionClient: distributionClient,
registryClient: registryClient,
}
// Register routes.
m.router.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "not found", http.StatusNotFound)
})
for route, handler := range m.routeHandlers(allowedOrigins) {
m.router.HandleFunc(route, handler)
}
// Populate the pull concurrency semaphore.
for i := 0; i < maximumConcurrentModelPulls; i++ {
m.pullTokens <- struct{}{}
}
// Manager successfully initialized.
return m
}
func (m *Manager) RebuildRoutes(allowedOrigins []string) {
m.lock.Lock()
defer m.lock.Unlock()
// Clear existing routes and re-register them.
m.router = http.NewServeMux()
// Register routes.
m.router.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "not found", http.StatusNotFound)
})
for route, handler := range m.routeHandlers(allowedOrigins) {
m.router.HandleFunc(route, handler)
}
}
func (m *Manager) routeHandlers(allowedOrigins []string) map[string]http.HandlerFunc {
handlers := map[string]http.HandlerFunc{
"POST " + inference.ModelsPrefix + "/create": m.handleCreateModel,
"POST " + inference.ModelsPrefix + "/load": m.handleLoadModel,
"GET " + inference.ModelsPrefix: m.handleGetModels,
"GET " + inference.ModelsPrefix + "/{name...}": m.handleGetModel,
"DELETE " + inference.ModelsPrefix + "/{name...}": m.handleDeleteModel,
"POST " + inference.ModelsPrefix + "/{nameAndAction...}": m.handleModelAction,
"GET " + inference.InferencePrefix + "/{backend}/v1/models": m.handleOpenAIGetModels,
"GET " + inference.InferencePrefix + "/{backend}/v1/models/{name...}": m.handleOpenAIGetModel,
"GET " + inference.InferencePrefix + "/v1/models": m.handleOpenAIGetModels,
"GET " + inference.InferencePrefix + "/v1/models/{name...}": m.handleOpenAIGetModel,
}
for route, handler := range handlers {
if strings.HasPrefix(route, "GET ") {
handlers[route] = inference.CorsMiddleware(allowedOrigins, handler).ServeHTTP
}
}
return handlers
}
func (m *Manager) GetRoutes() []string {
routeHandlers := m.routeHandlers(nil)
routes := make([]string, 0, len(routeHandlers))
for route := range routeHandlers {
routes = append(routes, route)
}
return routes
}
// handleCreateModel handles POST <inference-prefix>/models/create requests.
func (m *Manager) handleCreateModel(w http.ResponseWriter, r *http.Request) {
if m.distributionClient == nil {
http.Error(w, "model distribution service unavailable", http.StatusServiceUnavailable)
return
}
// Decode the request.
var request ModelCreateRequest
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
http.Error(w, "invalid request body", http.StatusBadRequest)
return
}
// Pull the model. In the future, we may support additional operations here
// besides pulling (such as model building).
if err := m.PullModel(request.From, r, w); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
m.log.Infof("Request canceled/timed out while pulling model %q", request.From)
return
}
if errors.Is(err, registry.ErrInvalidReference) {
m.log.Warnf("Invalid model reference %q: %v", request.From, err)
http.Error(w, "Invalid model reference", http.StatusBadRequest)
return
}
if errors.Is(err, registry.ErrUnauthorized) {
m.log.Warnf("Unauthorized to pull model %q: %v", request.From, err)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
if errors.Is(err, registry.ErrModelNotFound) {
m.log.Warnf("Failed to pull model %q: %v", request.From, err)
http.Error(w, "Model not found", http.StatusNotFound)
return
}
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
// handleLoadModel handles POST <inference-prefix>/models/load requests.
func (m *Manager) handleLoadModel(w http.ResponseWriter, r *http.Request) {
if m.distributionClient == nil {
http.Error(w, "model distribution service unavailable", http.StatusServiceUnavailable)
return
}
if _, err := m.distributionClient.LoadModel(r.Body, w); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
return
}
// handleGetModels handles GET <inference-prefix>/models requests.
func (m *Manager) handleGetModels(w http.ResponseWriter, r *http.Request) {
if m.distributionClient == nil {
http.Error(w, "model distribution service unavailable", http.StatusServiceUnavailable)
return
}
// Query models.
models, err := m.distributionClient.ListModels()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
apiModels := make([]*Model, len(models))
for i, model := range models {
apiModels[i], err = ToModel(model)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
// Write the response.
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(apiModels); err != nil {
m.log.Warnln("Error while encoding model listing response:", err)
}
}
// handleGetModel handles GET <inference-prefix>/models/{name} requests.
func (m *Manager) handleGetModel(w http.ResponseWriter, r *http.Request) {
// Parse remote query parameter
remote := false
if r.URL.Query().Has("remote") {
if val, err := strconv.ParseBool(r.URL.Query().Get("remote")); err != nil {
m.log.Warnln("Error while parsing remote query parameter:", err)
} else {
remote = val
}
}
if remote && m.registryClient == nil {
http.Error(w, "registry client unavailable", http.StatusServiceUnavailable)
return
}
var apiModel *Model
var err error
if remote {
apiModel, err = getRemoteModel(r.Context(), m, r.PathValue("name"))
} else {
apiModel, err = getLocalModel(m, r.PathValue("name"))
}
if err != nil {
if errors.Is(err, distribution.ErrModelNotFound) || errors.Is(err, registry.ErrModelNotFound) {
http.Error(w, err.Error(), http.StatusNotFound)
return
}
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Write the response.
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(apiModel); err != nil {
m.log.Warnln("Error while encoding model response:", err)
}
}
// ResolveModelID resolves a model reference to a model ID. If resolution fails, it returns the original ref.
func (m *Manager) ResolveModelID(modelRef string) string {
// Sanitize modelRef to prevent log forgery
sanitizedModelRef := strings.ReplaceAll(modelRef, "\n", "")
sanitizedModelRef = strings.ReplaceAll(sanitizedModelRef, "\r", "")
model, err := m.GetModel(sanitizedModelRef)
if err != nil {
m.log.Warnf("Failed to resolve model ref %s to ID: %v", sanitizedModelRef, err)
return sanitizedModelRef
}
modelID, err := model.ID()
if err != nil {
m.log.Warnf("Failed to get model ID for ref %s: %v", sanitizedModelRef, err)
return sanitizedModelRef
}
return modelID
}
func getLocalModel(m *Manager, name string) (*Model, error) {
if m.distributionClient == nil {
return nil, errors.New("model distribution service unavailable")
}
// Query the model.
model, err := m.GetModel(name)
if err != nil {
return nil, err
}
return ToModel(model)
}
func getRemoteModel(ctx context.Context, m *Manager, name string) (*Model, error) {
if m.registryClient == nil {
return nil, errors.New("registry client unavailable")
}
m.log.Infoln("Getting remote model:", name)
model, err := m.registryClient.Model(ctx, name)
if err != nil {
return nil, err
}
id, err := model.ID()
if err != nil {
return nil, err
}
descriptor, err := model.Descriptor()
if err != nil {
return nil, err
}
config, err := model.Config()
if err != nil {
return nil, err
}
apiModel := &Model{
ID: id,
Tags: nil,
Created: descriptor.Created.Unix(),
Config: config,
}
return apiModel, nil
}
// handleDeleteModel handles DELETE <inference-prefix>/models/{name} requests.
// query params:
// - force: if true, delete the model even if it has multiple tags
func (m *Manager) handleDeleteModel(w http.ResponseWriter, r *http.Request) {
if m.distributionClient == nil {
http.Error(w, "model distribution service unavailable", http.StatusServiceUnavailable)
return
}
// TODO: We probably want the manager to have a lock / unlock mechanism for
// models so that active runners can retain / release a model, analogous to
// a container blocking the release of an image. However, unlike containers,
// runners are only evicted when idle or when memory is needed, so users
// won't be able to release the images manually. Perhaps we can unlink the
// corresponding GGUF files from disk and allow the OS to clean them up once
// the runner process exits (though this won't work for Windows, where we
// might need some separate cleanup process).
var force bool
if r.URL.Query().Has("force") {
if val, err := strconv.ParseBool(r.URL.Query().Get("force")); err != nil {
m.log.Warnln("Error while parsing force query parameter:", err)
} else {
force = val
}
}
resp, err := m.distributionClient.DeleteModel(r.PathValue("name"), force)
if err != nil {
if errors.Is(err, distribution.ErrModelNotFound) {
http.Error(w, err.Error(), http.StatusNotFound)
return
}
if errors.Is(err, distribution.ErrConflict) {
http.Error(w, err.Error(), http.StatusConflict)
return
}
m.log.Warnln("Error while deleting model:", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
http.Error(w, fmt.Sprintf("error writing response: %v", err), http.StatusInternalServerError)
}
}
// handleOpenAIGetModels handles GET <inference-prefix>/<backend>/v1/models and
// GET /<inference-prefix>/v1/models requests.
func (m *Manager) handleOpenAIGetModels(w http.ResponseWriter, r *http.Request) {
if m.distributionClient == nil {
http.Error(w, "model distribution service unavailable", http.StatusServiceUnavailable)
return
}
// Query models.
available, err := m.distributionClient.ListModels()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
models, err := ToOpenAIList(available)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Write the response.
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(models); err != nil {
m.log.Warnln("Error while encoding OpenAI model listing response:", err)
}
}
// handleOpenAIGetModel handles GET <inference-prefix>/<backend>/v1/models/{name}
// and GET <inference-prefix>/v1/models/{name} requests.
func (m *Manager) handleOpenAIGetModel(w http.ResponseWriter, r *http.Request) {
if m.distributionClient == nil {
http.Error(w, "model distribution service unavailable", http.StatusServiceUnavailable)
return
}
// Query the model.
model, err := m.GetModel(r.PathValue("name"))
if err != nil {
if errors.Is(err, distribution.ErrModelNotFound) {
http.Error(w, err.Error(), http.StatusNotFound)
} else {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
}
// Write the response.
w.Header().Set("Content-Type", "application/json")
openaiModel, err := ToOpenAI(model)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if err := json.NewEncoder(w).Encode(openaiModel); err != nil {
m.log.Warnln("Error while encoding OpenAI model response:", err)
}
}
// handleTagModel handles POST <inference-prefix>/models/{nameAndAction} requests.
// Action is one of:
// - tag: tag the model with a repository and tag (e.g. POST <inference-prefix>/models/my-org/my-repo:latest/tag})
// - push: pushes a tagged model to the registry
func (m *Manager) handleModelAction(w http.ResponseWriter, r *http.Request) {
model, action := path.Split(r.PathValue("nameAndAction"))
model = strings.TrimRight(model, "/")
switch action {
case "tag":
m.handleTagModel(w, r, model)
case "push":
m.handlePushModel(w, r, model)
default:
http.Error(w, fmt.Sprintf("unknown action %q", action), http.StatusNotFound)
}
}
// handleTagModel handles POST <inference-prefix>/models/{name}/tag requests.
// The query parameters are:
// - repo: the repository to tag the model with (required)
// - tag: the tag to apply to the model (required)
func (m *Manager) handleTagModel(w http.ResponseWriter, r *http.Request, model string) {
if m.distributionClient == nil {
http.Error(w, "model distribution service unavailable", http.StatusServiceUnavailable)
return
}
// Extract query parameters.
repo := r.URL.Query().Get("repo")
tag := r.URL.Query().Get("tag")
// Validate query parameters.
if repo == "" || tag == "" {
http.Error(w, "missing repo or tag query parameter", http.StatusBadRequest)
return
}
// Construct the target string.
target := fmt.Sprintf("%s:%s", repo, tag)
// Call the Tag method on the distribution client with source and modelName.
if err := m.distributionClient.Tag(model, target); err != nil {
m.log.Warnf("Failed to apply tag %q to model %q: %v", target, model, err)
if errors.Is(err, distribution.ErrModelNotFound) {
http.Error(w, err.Error(), http.StatusNotFound)
return
}
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Respond with success.
w.WriteHeader(http.StatusCreated)
w.Write([]byte(fmt.Sprintf("Model %q tagged successfully with %q", model, target)))
}
// handlePushModel handles POST <inference-prefix>/models/{name}/push requests.
func (m *Manager) handlePushModel(w http.ResponseWriter, r *http.Request, model string) {
if m.distributionClient == nil {
http.Error(w, "model distribution service unavailable", http.StatusServiceUnavailable)
return
}
// Call the PushModel method on the distribution client.
if err := m.PushModel(model, r, w); err != nil {
if errors.Is(err, distribution.ErrInvalidReference) {
m.log.Warnf("Invalid model reference %q: %v", model, err)
http.Error(w, "Invalid model reference", http.StatusBadRequest)
return
}
if errors.Is(err, distribution.ErrModelNotFound) {
m.log.Warnf("Failed to push model %q: %v", model, err)
http.Error(w, "Model not found", http.StatusNotFound)
return
}
if errors.Is(err, registry.ErrUnauthorized) {
m.log.Warnf("Unauthorized to push model %q: %v", model, err)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
// GetDiskUsage returns the disk usage of the model store.
func (m *Manager) GetDiskUsage() (int64, error, int) {
if m.distributionClient == nil {
return 0, errors.New("model distribution service unavailable"), http.StatusServiceUnavailable
}
storePath := m.distributionClient.GetStorePath()
size, err := diskusage.Size(storePath)
if err != nil {
return 0, fmt.Errorf("error while getting store size: %v", err), http.StatusInternalServerError
}
return size, nil, http.StatusOK
}
// ServeHTTP implement net/http.Handler.ServeHTTP.
func (m *Manager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
m.lock.RLock()
defer m.lock.RUnlock()
m.router.ServeHTTP(w, r)
}
// GetModel returns a single model.
func (m *Manager) GetModel(ref string) (types.Model, error) {
model, err := m.distributionClient.GetModel(ref)
if err != nil {
return nil, fmt.Errorf("error while getting model: %w", err)
}
return model, err
}
// GetBundle returns model bundle.
func (m *Manager) GetBundle(ref string) (types.ModelBundle, error) {
bundle, err := m.distributionClient.GetBundle(ref)
if err != nil {
return nil, fmt.Errorf("error while getting model bundle: %w", err)
}
return bundle, err
}
// PullModel pulls a model to local storage. Any error it returns is suitable
// for writing back to the client.
func (m *Manager) PullModel(model string, r *http.Request, w http.ResponseWriter) error {
// Restrict model pull concurrency.
select {
case <-m.pullTokens:
case <-r.Context().Done():
return context.Canceled
}
defer func() {
m.pullTokens <- struct{}{}
}()
// Set up response headers for streaming
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Transfer-Encoding", "chunked")
// Check Accept header to determine content type
acceptHeader := r.Header.Get("Accept")
isJSON := acceptHeader == "application/json"
if isJSON {
w.Header().Set("Content-Type", "application/json")
} else {
// Defaults to text/plain
w.Header().Set("Content-Type", "text/plain")
}
// Create a flusher to ensure chunks are sent immediately
flusher, ok := w.(http.Flusher)
if !ok {
return fmt.Errorf("streaming not supported")
}
// Create a progress writer that writes to the response
progressWriter := &progressResponseWriter{
writer: w,
flusher: flusher,
isJSON: isJSON,
}
// Pull the model using the Docker model distribution client
m.log.Infoln("Pulling model:", model)
err := m.distributionClient.PullModel(r.Context(), model, progressWriter)
if err != nil {
return fmt.Errorf("error while pulling model: %w", err)
}
return nil
}
// PushModel pushes a model from the store to the registry.
func (m *Manager) PushModel(model string, r *http.Request, w http.ResponseWriter) error {
// Set up response headers for streaming
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Transfer-Encoding", "chunked")
// Check Accept header to determine content type
acceptHeader := r.Header.Get("Accept")
isJSON := acceptHeader == "application/json"
if isJSON {
w.Header().Set("Content-Type", "application/json")
} else {
w.Header().Set("Content-Type", "text/plain")
}
// Create a flusher to ensure chunks are sent immediately
flusher, ok := w.(http.Flusher)
if !ok {
return errors.New("streaming not supported")
}
// Create a progress writer that writes to the response
progressWriter := &progressResponseWriter{
writer: w,
flusher: flusher,
isJSON: isJSON,
}
// Pull the model using the Docker model distribution client
m.log.Infoln("Pushing model:", model)
err := m.distributionClient.PushModel(r.Context(), model, progressWriter)
if err != nil {
return fmt.Errorf("error while pushing model: %w", err)
}
return nil
}
// progressResponseWriter implements io.Writer to write progress updates to the HTTP response
type progressResponseWriter struct {
writer http.ResponseWriter
flusher http.Flusher
isJSON bool
}
func (w *progressResponseWriter) Write(p []byte) (n int, err error) {
var data []byte
if w.isJSON {
// For JSON, write the raw bytes without escaping
data = p
} else {
// For plain text, escape HTML
escapedData := html.EscapeString(string(p))
data = []byte(escapedData)
}
n, err = w.writer.Write(data)
if err != nil {
return 0, err
}
// Flush the response to ensure the chunk is sent immediately
if w.flusher != nil {
w.flusher.Flush()
}
return n, nil
}