From 880818f741a04a6e3e9a75a213da32086041ae71 Mon Sep 17 00:00:00 2001 From: Piotr Stankiewicz Date: Wed, 30 Jul 2025 13:11:57 +0200 Subject: [PATCH] inference: Support memory estimation for remote models Signed-off-by: Piotr Stankiewicz --- pkg/inference/backend.go | 2 +- pkg/inference/backends/llamacpp/llamacpp.go | 97 +++++++++++++++++---- pkg/inference/backends/mlx/mlx.go | 2 +- pkg/inference/backends/vllm/vllm.go | 2 +- pkg/inference/models/manager.go | 33 +++++++ pkg/inference/scheduling/loader.go | 2 +- 6 files changed, 119 insertions(+), 19 deletions(-) diff --git a/pkg/inference/backend.go b/pkg/inference/backend.go index 3d857f8..0a6cae7 100644 --- a/pkg/inference/backend.go +++ b/pkg/inference/backend.go @@ -88,5 +88,5 @@ type Backend interface { GetDiskUsage() (int64, error) // GetRequiredMemoryForModel returns the required working memory for a given // model. - GetRequiredMemoryForModel(model string, config *BackendConfiguration) (*RequiredMemory, error) + GetRequiredMemoryForModel(ctx context.Context, model string, config *BackendConfiguration) (*RequiredMemory, error) } diff --git a/pkg/inference/backends/llamacpp/llamacpp.go b/pkg/inference/backends/llamacpp/llamacpp.go index f53b26a..272fa25 100644 --- a/pkg/inference/backends/llamacpp/llamacpp.go +++ b/pkg/inference/backends/llamacpp/llamacpp.go @@ -15,8 +15,10 @@ import ( "runtime" "strings" + v1 "github.com/google/go-containerregistry/pkg/v1" parser "github.com/gpustack/gguf-parser-go" + "github.com/docker/model-distribution/types" "github.com/docker/model-runner/pkg/diskusage" "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/inference/config" @@ -223,23 +225,23 @@ func (l *llamaCpp) GetDiskUsage() (int64, error) { return size, nil } -func (l *llamaCpp) GetRequiredMemoryForModel(model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) { - mdl, err := l.modelManager.GetModel(model) +func (l *llamaCpp) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) { + var mdlGguf *parser.GGUFFile + var mdlConfig types.Config + inStore, err := l.modelManager.IsModelInStore(model) if err != nil { - return nil, fmt.Errorf("getting model(%s): %w", model, err) + return nil, fmt.Errorf("checking if model is in local store: %w", err) } - mdlPath, err := mdl.GGUFPath() - if err != nil { - return nil, fmt.Errorf("getting gguf path for model(%s): %w", model, err) - } - mdlGguf, err := parser.ParseGGUFFile(mdlPath) - if err != nil { - l.log.Warnf("Failed to parse gguf(%s): %s", mdlPath, err) - return nil, inference.ErrGGUFParse - } - mdlConfig, err := mdl.Config() - if err != nil { - return nil, fmt.Errorf("accessing model(%s) config: %w", model, err) + if inStore { + mdlGguf, mdlConfig, err = l.parseLocalModel(model) + if err != nil { + return nil, fmt.Errorf("parsing local gguf: %w", err) + } + } else { + mdlGguf, mdlConfig, err = l.parseRemoteModel(ctx, model) + if err != nil { + return nil, fmt.Errorf("parsing remote model: %w", err) + } } contextSize := GetContextSize(&mdlConfig, config) @@ -277,6 +279,71 @@ func (l *llamaCpp) GetRequiredMemoryForModel(model string, config *inference.Bac }, nil } +func (l *llamaCpp) parseLocalModel(model string) (*parser.GGUFFile, types.Config, error) { + mdl, err := l.modelManager.GetModel(model) + if err != nil { + return nil, types.Config{}, fmt.Errorf("getting model(%s): %w", model, err) + } + mdlPath, err := mdl.GGUFPath() + if err != nil { + return nil, types.Config{}, fmt.Errorf("getting gguf path for model(%s): %w", model, err) + } + mdlGguf, err := parser.ParseGGUFFile(mdlPath) + if err != nil { + return nil, types.Config{}, fmt.Errorf("parsing gguf(%s): %w", mdlPath, err) + } + mdlConfig, err := mdl.Config() + if err != nil { + return nil, types.Config{}, fmt.Errorf("accessing model(%s) config: %w", model, err) + } + return mdlGguf, mdlConfig, nil +} + +func (l *llamaCpp) parseRemoteModel(ctx context.Context, model string) (*parser.GGUFFile, types.Config, error) { + mdl, err := l.modelManager.GetRemoteModel(ctx, model) + if err != nil { + return nil, types.Config{}, fmt.Errorf("getting remote model(%s): %w", model, err) + } + layers, err := mdl.Layers() + if err != nil { + return nil, types.Config{}, fmt.Errorf("getting layers of model(%s): %w", model, err) + } + var ggufDigest v1.Hash + for _, layer := range layers { + mt, err := layer.MediaType() + if err != nil { + return nil, types.Config{}, fmt.Errorf("getting media type of model(%s) layer: %w", model, err) + } + if mt == types.MediaTypeGGUF { + ggufDigest, err = layer.Digest() + if err != nil { + return nil, types.Config{}, fmt.Errorf("getting digest of GGUF layer for model(%s): %w", model, err) + } + break + } + } + if ggufDigest.String() == "" { + return nil, types.Config{}, fmt.Errorf("model(%s) has no GGUF layer", model) + } + blobURL, err := l.modelManager.GetRemoteModelBlobURL(model, ggufDigest) + if err != nil { + return nil, types.Config{}, fmt.Errorf("getting GGUF blob URL for model(%s): %w", model, err) + } + tok, err := l.modelManager.BearerTokenForModel(ctx, model) + if err != nil { + return nil, types.Config{}, fmt.Errorf("getting bearer token for model(%s): %w", model, err) + } + mdlGguf, err := parser.ParseGGUFFileRemote(ctx, blobURL, parser.UseBearerAuth(tok)) + if err != nil { + return nil, types.Config{}, fmt.Errorf("parsing GGUF for model(%s): %w", model, err) + } + config, err := mdl.Config() + if err != nil { + return nil, types.Config{}, fmt.Errorf("getting config for model(%s): %w", model, err) + } + return mdlGguf, config, nil +} + func (l *llamaCpp) checkGPUSupport(ctx context.Context) bool { binPath := l.vendoredServerStoragePath if l.updatedLlamaCpp { diff --git a/pkg/inference/backends/mlx/mlx.go b/pkg/inference/backends/mlx/mlx.go index 2bae367..267a9c8 100644 --- a/pkg/inference/backends/mlx/mlx.go +++ b/pkg/inference/backends/mlx/mlx.go @@ -63,6 +63,6 @@ func (m *mlx) GetDiskUsage() (int64, error) { return 0, nil } -func (m *mlx) GetRequiredMemoryForModel(model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) { +func (m *mlx) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) { return nil, errors.New("not implemented") } diff --git a/pkg/inference/backends/vllm/vllm.go b/pkg/inference/backends/vllm/vllm.go index 86334d4..8bafb8e 100644 --- a/pkg/inference/backends/vllm/vllm.go +++ b/pkg/inference/backends/vllm/vllm.go @@ -63,6 +63,6 @@ func (v *vLLM) GetDiskUsage() (int64, error) { return 0, nil } -func (v *vLLM) GetRequiredMemoryForModel(model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) { +func (v *vLLM) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) { return nil, errors.New("not implemented") } diff --git a/pkg/inference/models/manager.go b/pkg/inference/models/manager.go index b0cf68b..894cb81 100644 --- a/pkg/inference/models/manager.go +++ b/pkg/inference/models/manager.go @@ -18,6 +18,7 @@ import ( "github.com/docker/model-runner/pkg/diskusage" "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/logging" + v1 "github.com/google/go-containerregistry/pkg/v1" "github.com/sirupsen/logrus" ) @@ -562,6 +563,11 @@ func (m *Manager) ServeHTTP(w http.ResponseWriter, r *http.Request) { m.router.ServeHTTP(w, r) } +// IsModelInStore checks if a given model is in the local store. +func (m *Manager) IsModelInStore(ref string) (bool, error) { + return m.distributionClient.IsModelInStore(ref) +} + // GetModel returns a single model. func (m *Manager) GetModel(ref string) (types.Model, error) { model, err := m.distributionClient.GetModel(ref) @@ -571,6 +577,33 @@ func (m *Manager) GetModel(ref string) (types.Model, error) { return model, err } +// GetRemoteModel returns a single remote model. +func (m *Manager) GetRemoteModel(ctx context.Context, ref string) (types.ModelArtifact, error) { + model, err := m.registryClient.Model(ctx, ref) + if err != nil { + return nil, fmt.Errorf("error while getting remote model: %w", err) + } + return model, nil +} + +// GetRemoteModelBlobURL returns the URL of a given model blob. +func (m *Manager) GetRemoteModelBlobURL(ref string, digest v1.Hash) (string, error) { + blobURL, err := m.registryClient.BlobURL(ref, digest) + if err != nil { + return "", fmt.Errorf("error while getting remote model blob URL: %w", err) + } + return blobURL, nil +} + +// BearerTokenForModel returns the bearer token needed to pull a given model. +func (m *Manager) BearerTokenForModel(ctx context.Context, ref string) (string, error) { + tok, err := m.registryClient.BearerToken(ctx, ref) + if err != nil { + return "", fmt.Errorf("error while getting bearer token for model: %w", err) + } + return tok, nil +} + // GetModelPath returns the path to a model's files. func (m *Manager) GetModelPath(ref string) (string, error) { model, err := m.GetModel(ref) diff --git a/pkg/inference/scheduling/loader.go b/pkg/inference/scheduling/loader.go index ebbdd33..d5a82b1 100644 --- a/pkg/inference/scheduling/loader.go +++ b/pkg/inference/scheduling/loader.go @@ -420,7 +420,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string if rc, ok := l.runnerConfigs[runnerKey{backendName, modelID, mode}]; ok { runnerConfig = &rc } - memory, err := backend.GetRequiredMemoryForModel(modelID, runnerConfig) + memory, err := backend.GetRequiredMemoryForModel(ctx, modelID, runnerConfig) if errors.Is(err, inference.ErrGGUFParse) { // TODO(p1-0tr): For now override memory checks in case model can't be parsed // e.g. model is too new for gguf-parser-go to know. We should provide a cleaner