model-runner/pkg/inference/models/manager.go

425 lines
13 KiB
Go

package models
import (
"context"
"encoding/json"
"errors"
"fmt"
"html"
"net/http"
"github.com/docker/model-distribution/distribution"
"github.com/docker/model-distribution/types"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/logging"
"github.com/sirupsen/logrus"
)
const (
// maximumConcurrentModelPulls is the maximum number of concurrent model
// pulls that a model manager will allow.
maximumConcurrentModelPulls = 2
)
// Manager manages inference model pulls and storage.
type Manager struct {
// log is the associated logger.
log logging.Logger
// pullTokens is a semaphore used to restrict the maximum number of
// concurrent pull requests.
pullTokens chan struct{}
// router is the HTTP request router.
router *http.ServeMux
// distributionClient is the client for model distribution.
distributionClient *distribution.Client
}
type ClientConfig struct {
// StoreRootPath is the root path for the model store.
StoreRootPath string
// Logger is the logger to use.
Logger *logrus.Entry
// Transport is the HTTP transport to use.
Transport http.RoundTripper
// UserAgent is the user agent to use.
UserAgent string
}
// NewManager creates a new model's manager.
func NewManager(log logging.Logger, c ClientConfig) *Manager {
// Create the model distribution client.
distributionClient, err := distribution.NewClient(
distribution.WithStoreRootPath(c.StoreRootPath),
distribution.WithLogger(c.Logger),
distribution.WithTransport(c.Transport),
distribution.WithUserAgent(c.UserAgent),
)
if err != nil {
log.Errorf("Failed to create distribution client: %v", err)
// Continue without distribution client. The model manager will still
// respond to requests, but may return errors if the client is required.
}
// Create the manager.
m := &Manager{
log: log,
pullTokens: make(chan struct{}, maximumConcurrentModelPulls),
router: http.NewServeMux(),
distributionClient: distributionClient,
}
// Register routes.
m.router.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "not found", http.StatusNotFound)
})
for route, handler := range m.routeHandlers() {
m.router.HandleFunc(route, handler)
}
// Populate the pull concurrency semaphore.
for i := 0; i < maximumConcurrentModelPulls; i++ {
m.pullTokens <- struct{}{}
}
// Manager successfully initialized.
return m
}
func (m *Manager) routeHandlers() map[string]http.HandlerFunc {
return map[string]http.HandlerFunc{
"POST " + inference.ModelsPrefix + "/create": m.handleCreateModel,
"GET " + inference.ModelsPrefix: m.handleGetModels,
"GET " + inference.ModelsPrefix + "/{name...}": m.handleGetModel,
"DELETE " + inference.ModelsPrefix + "/{name...}": m.handleDeleteModel,
"POST " + inference.ModelsPrefix + "/{name}/tag": m.handleTagModel,
"GET " + inference.InferencePrefix + "/{backend}/v1/models": m.handleOpenAIGetModels,
"GET " + inference.InferencePrefix + "/{backend}/v1/models/{name...}": m.handleOpenAIGetModel,
"GET " + inference.InferencePrefix + "/v1/models": m.handleOpenAIGetModels,
"GET " + inference.InferencePrefix + "/v1/models/{name...}": m.handleOpenAIGetModel,
}
}
func (m *Manager) GetRoutes() []string {
routeHandlers := m.routeHandlers()
routes := make([]string, 0, len(routeHandlers))
for route := range routeHandlers {
routes = append(routes, route)
}
return routes
}
// handleCreateModel handles POST <inference-prefix>/models/create requests.
func (m *Manager) handleCreateModel(w http.ResponseWriter, r *http.Request) {
if m.distributionClient == nil {
http.Error(w, "model distribution service unavailable", http.StatusServiceUnavailable)
return
}
// Decode the request.
var request ModelCreateRequest
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
http.Error(w, "invalid request body", http.StatusBadRequest)
return
}
// Pull the model. In the future, we may support additional operations here
// besides pulling (such as model building).
if err := m.PullModel(r.Context(), request.From, w); err != nil {
if errors.Is(err, distribution.ErrInvalidReference) {
m.log.Warnf("Invalid model reference %q: %v", request.From, err)
http.Error(w, "Invalid model reference", http.StatusBadRequest)
return
}
if errors.Is(err, distribution.ErrUnauthorized) || errors.Is(err, distribution.ErrModelNotFound) {
m.log.Warnf("Failed to pull model %q: %v", request.From, err)
http.Error(w, "Model not found", http.StatusNotFound)
return
}
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
// handleGetModels handles GET <inference-prefix>/models requests.
func (m *Manager) handleGetModels(w http.ResponseWriter, r *http.Request) {
if m.distributionClient == nil {
http.Error(w, "model distribution service unavailable", http.StatusServiceUnavailable)
return
}
// Query models.
models, err := m.distributionClient.ListModels()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
apiModels := make([]*Model, len(models))
for i, model := range models {
apiModels[i], err = ToModel(model)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
// Write the response.
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(apiModels); err != nil {
m.log.Warnln("Error while encoding model listing response:", err)
}
}
// handleGetModel handles GET <inference-prefix>/models/{name} requests.
func (m *Manager) handleGetModel(w http.ResponseWriter, r *http.Request) {
if m.distributionClient == nil {
http.Error(w, "model distribution service unavailable", http.StatusServiceUnavailable)
return
}
// Query the model.
model, err := m.GetModel(r.PathValue("name"))
if err != nil {
if errors.Is(err, distribution.ErrModelNotFound) {
http.Error(w, err.Error(), http.StatusNotFound)
} else {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
}
apiModel, err := ToModel(model)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Write the response.
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(apiModel); err != nil {
m.log.Warnln("Error while encoding model response:", err)
}
}
// handleDeleteModel handles DELETE <inference-prefix>/models/{name} requests.
func (m *Manager) handleDeleteModel(w http.ResponseWriter, r *http.Request) {
if m.distributionClient == nil {
http.Error(w, "model distribution service unavailable", http.StatusServiceUnavailable)
return
}
// TODO: We probably want the manager to have a lock / unlock mechanism for
// models so that active runners can retain / release a model, analogous to
// a container blocking the release of an image. However, unlike containers,
// runners are only evicted when idle or when memory is needed, so users
// won't be able to release the images manually. Perhaps we can unlink the
// corresponding GGUF files from disk and allow the OS to clean them up once
// the runner process exits (though this won't work for Windows, where we
// might need some separate cleanup process).
err := m.distributionClient.DeleteModel(r.PathValue("name"))
if err != nil {
if errors.Is(err, distribution.ErrModelNotFound) {
http.Error(w, err.Error(), http.StatusNotFound)
return
}
m.log.Warnln("Error while deleting model:", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
// handleOpenAIGetModels handles GET <inference-prefix>/<backend>/v1/models and
// GET /<inference-prefix>/v1/models requests.
func (m *Manager) handleOpenAIGetModels(w http.ResponseWriter, r *http.Request) {
if m.distributionClient == nil {
http.Error(w, "model distribution service unavailable", http.StatusServiceUnavailable)
return
}
// Query models.
available, err := m.distributionClient.ListModels()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
models, err := ToOpenAIList(available)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Write the response.
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(models); err != nil {
m.log.Warnln("Error while encoding OpenAI model listing response:", err)
}
}
// handleOpenAIGetModel handles GET <inference-prefix>/<backend>/v1/models/{name}
// and GET <inference-prefix>/v1/models/{name} requests.
func (m *Manager) handleOpenAIGetModel(w http.ResponseWriter, r *http.Request) {
if m.distributionClient == nil {
http.Error(w, "model distribution service unavailable", http.StatusServiceUnavailable)
return
}
// Query the model.
model, err := m.GetModel(r.PathValue("name"))
if err != nil {
if errors.Is(err, distribution.ErrModelNotFound) {
http.Error(w, err.Error(), http.StatusNotFound)
} else {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
}
// Write the response.
w.Header().Set("Content-Type", "application/json")
openaiModel, err := ToOpenAI(model)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if err := json.NewEncoder(w).Encode(openaiModel); err != nil {
m.log.Warnln("Error while encoding OpenAI model response:", err)
}
}
// handleTagModel handles POST <inference-prefix>/models/{name}/tag requests.
// The query parameters are:
// - repo: the repository to tag the model with (required)
// - tag: the tag to tag the model with (optional, defaults to "latest")
func (m *Manager) handleTagModel(w http.ResponseWriter, r *http.Request) {
if m.distributionClient == nil {
http.Error(w, "model distribution service unavailable", http.StatusServiceUnavailable)
return
}
// Extract the model name from the request path.
model := r.PathValue("name")
// Extract query parameters.
repo := r.URL.Query().Get("repo")
tag := r.URL.Query().Get("tag")
// Validate query parameters.
if repo == "" {
http.Error(w, "missing repo or tag query parameter", http.StatusBadRequest)
return
}
if tag == "" {
tag = "latest"
}
// Construct the target string.
target := fmt.Sprintf("%s:%s", repo, tag)
// Call the Tag method on the distribution client with source and modelName.
if err := m.distributionClient.Tag(model, target); err != nil {
m.log.Warnf("Failed to tag model %q: %v", model, err)
if errors.Is(err, distribution.ErrModelNotFound) {
http.Error(w, err.Error(), http.StatusNotFound)
return
}
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Respond with success.
w.WriteHeader(http.StatusCreated)
w.Write([]byte(fmt.Sprintf("Model %q tagged successfully with source %q", modelName, model)))
}
// ServeHTTP implement net/http.Handler.ServeHTTP.
func (m *Manager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
m.router.ServeHTTP(w, r)
}
// GetModel returns a single model.
func (m *Manager) GetModel(ref string) (types.Model, error) {
model, err := m.distributionClient.GetModel(ref)
if err != nil {
return nil, fmt.Errorf("error while getting model: %w", err)
}
return model, err
}
// GetModelPath returns the path to a model's files.
func (m *Manager) GetModelPath(ref string) (string, error) {
model, err := m.GetModel(ref)
if err != nil {
return "", err
}
path, err := model.GGUFPath()
if err != nil {
return "", fmt.Errorf("error while getting model path: %w", err)
}
return path, nil
}
// PullModel pulls a model to local storage. Any error it returns is suitable
// for writing back to the client.
func (m *Manager) PullModel(ctx context.Context, model string, w http.ResponseWriter) error {
// Restrict model pull concurrency.
select {
case <-m.pullTokens:
case <-ctx.Done():
return context.Canceled
}
defer func() {
m.pullTokens <- struct{}{}
}()
// Set up response headers for streaming
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Transfer-Encoding", "chunked")
// Create a flusher to ensure chunks are sent immediately
flusher, ok := w.(http.Flusher)
if !ok {
return fmt.Errorf("streaming not supported")
}
// Create a progress writer that writes to the response
progressWriter := &progressResponseWriter{
writer: w,
flusher: flusher,
}
// Pull the model using the Docker model distribution client
m.log.Infoln("Pulling model:", model)
err := m.distributionClient.PullModel(ctx, model, progressWriter)
if err != nil {
return fmt.Errorf("error while pulling model: %w", err)
}
return nil
}
// progressResponseWriter implements io.Writer to write progress updates to the HTTP response
type progressResponseWriter struct {
writer http.ResponseWriter
flusher http.Flusher
}
func (w *progressResponseWriter) Write(p []byte) (n int, err error) {
escapedData := html.EscapeString(string(p))
n, err = w.writer.Write([]byte(escapedData))
if err != nil {
return 0, err
}
// Flush the response to ensure the chunk is sent immediately
if w.flusher != nil {
w.flusher.Flush()
}
return n, nil
}