Using model distribution client

This commit is contained in:
Ignasi 2025-03-04 22:00:57 +01:00 committed by Jacob Howard
parent f1c25bd9a2
commit b69d84f8aa
No known key found for this signature in database
GPG Key ID: 3E8B8F7FEB46FC66
2 changed files with 27 additions and 13 deletions

View File

@ -71,6 +71,7 @@ func (l *llamaCpp) Install(_ context.Context, _ *http.Client) error {
// Run implements inference.Backend.Run.
func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference.BackendMode) error {
modelPath, err := l.modelManager.GetModelPath(model)
log.Infof("Model path: %s", modelPath)
if err != nil {
return fmt.Errorf("failed to get model path: %w", err)
}

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net/http"
"strings"
"github.com/docker/model-distribution/pkg/distribution"
"github.com/docker/model-runner/pkg/logger"
@ -128,7 +129,13 @@ func (m *Manager) handleDeleteModel(w http.ResponseWriter, r *http.Request) {
// corresponding GGUF files from disk and allow the OS to clean them up once
// the runner process exits (though this won't work for Windows, where we
// might need some separate cleanup process).
http.Error(w, "not implemented", http.StatusNotImplemented)
err := m.distributionClient.DeleteModel(r.PathValue("namespace") + "/" + r.PathValue("name"))
if err != nil {
m.log.Warnln("Error while deleting model:", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
// handleOpenAIGetModels handles GET /ml/{backend}/v1/models and
@ -174,17 +181,17 @@ func (m *Manager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
m.router.ServeHTTP(w, r)
}
// getModels returns a list of all models or (if model is a non-empty string) a
// getModels returns a list of all models or (if ref is a non-empty string) a
// list containing only a specific model. If no models exist or the specific
// model can't be found, then an empty (but non-nil) list is returned. Any error
// it returns is suitable for writing back to the client.
func (m *Manager) getModels(model string) (ModelList, error) {
func (m *Manager) getModels(ref string) (ModelList, error) {
// Initialize the model list. We always want to return a non-nil list (even
// if it's empty) so that it can be encoded directly to JSON.
models := make(ModelList, 0)
if model != "" {
model, err := m.distributionClient.GetModel(model)
if ref != "" {
model, err := m.distributionClient.GetModel(ref)
if err != nil {
return nil, fmt.Errorf("error while getting model: %w", err)
}
@ -218,22 +225,28 @@ func (m *Manager) getModels(model string) (ModelList, error) {
// GetModel looks up and returns a single model. It returns ErrModelNotFound if
// the model could not be located.
func (m *Manager) GetModel(model string) (*Model, error) {
models, err := m.getModels(model)
func (m *Manager) GetModel(ref string) (*Model, error) {
models, err := m.getModels(ref)
if err != nil {
return nil, err
} else if len(models) == 0 {
}
if len(models) == 0 {
return nil, ErrModelNotFound
}
return models[0], nil
}
func (m *Manager) GetModelPath(model string) (string, error) {
models, err := m.getModels(model)
func (m *Manager) GetModelPath(ref string) (string, error) {
model, err := m.GetModel(ref)
if err != nil {
return "", err
} else if len(models) == 0 {
return "", ErrModelNotFound
}
return models[0].Files[0], nil
// TODO: Handle multiple files
// Convert <algorithm>:<digest> to <algorithm>/<digest>
blobName := model.Files[0]
parts := strings.Split(blobName, ":")
if len(parts) != 2 {
return "", fmt.Errorf("invalid blob format: %s", blobName)
}
return fmt.Sprintf("%s/blobs/%s/%s", m.distributionClient.GetStorePath(), parts[0], parts[1]), nil
}