inference: Support memory estimation for remote models

Signed-off-by: Piotr Stankiewicz <piotr.stankiewicz@docker.com>
This commit is contained in:
Piotr Stankiewicz 2025-07-30 13:11:57 +02:00 committed by Piotr
parent 59da65a365
commit 880818f741
6 changed files with 119 additions and 19 deletions

View File

@ -88,5 +88,5 @@ type Backend interface {
GetDiskUsage() (int64, error)
// GetRequiredMemoryForModel returns the required working memory for a given
// model.
GetRequiredMemoryForModel(model string, config *BackendConfiguration) (*RequiredMemory, error)
GetRequiredMemoryForModel(ctx context.Context, model string, config *BackendConfiguration) (*RequiredMemory, error)
}

View File

@ -15,8 +15,10 @@ import (
"runtime"
"strings"
v1 "github.com/google/go-containerregistry/pkg/v1"
parser "github.com/gpustack/gguf-parser-go"
"github.com/docker/model-distribution/types"
"github.com/docker/model-runner/pkg/diskusage"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/config"
@ -223,23 +225,23 @@ func (l *llamaCpp) GetDiskUsage() (int64, error) {
return size, nil
}
func (l *llamaCpp) GetRequiredMemoryForModel(model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
mdl, err := l.modelManager.GetModel(model)
func (l *llamaCpp) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
var mdlGguf *parser.GGUFFile
var mdlConfig types.Config
inStore, err := l.modelManager.IsModelInStore(model)
if err != nil {
return nil, fmt.Errorf("getting model(%s): %w", model, err)
return nil, fmt.Errorf("checking if model is in local store: %w", err)
}
mdlPath, err := mdl.GGUFPath()
if err != nil {
return nil, fmt.Errorf("getting gguf path for model(%s): %w", model, err)
}
mdlGguf, err := parser.ParseGGUFFile(mdlPath)
if err != nil {
l.log.Warnf("Failed to parse gguf(%s): %s", mdlPath, err)
return nil, inference.ErrGGUFParse
}
mdlConfig, err := mdl.Config()
if err != nil {
return nil, fmt.Errorf("accessing model(%s) config: %w", model, err)
if inStore {
mdlGguf, mdlConfig, err = l.parseLocalModel(model)
if err != nil {
return nil, fmt.Errorf("parsing local gguf: %w", err)
}
} else {
mdlGguf, mdlConfig, err = l.parseRemoteModel(ctx, model)
if err != nil {
return nil, fmt.Errorf("parsing remote model: %w", err)
}
}
contextSize := GetContextSize(&mdlConfig, config)
@ -277,6 +279,71 @@ func (l *llamaCpp) GetRequiredMemoryForModel(model string, config *inference.Bac
}, nil
}
func (l *llamaCpp) parseLocalModel(model string) (*parser.GGUFFile, types.Config, error) {
mdl, err := l.modelManager.GetModel(model)
if err != nil {
return nil, types.Config{}, fmt.Errorf("getting model(%s): %w", model, err)
}
mdlPath, err := mdl.GGUFPath()
if err != nil {
return nil, types.Config{}, fmt.Errorf("getting gguf path for model(%s): %w", model, err)
}
mdlGguf, err := parser.ParseGGUFFile(mdlPath)
if err != nil {
return nil, types.Config{}, fmt.Errorf("parsing gguf(%s): %w", mdlPath, err)
}
mdlConfig, err := mdl.Config()
if err != nil {
return nil, types.Config{}, fmt.Errorf("accessing model(%s) config: %w", model, err)
}
return mdlGguf, mdlConfig, nil
}
func (l *llamaCpp) parseRemoteModel(ctx context.Context, model string) (*parser.GGUFFile, types.Config, error) {
mdl, err := l.modelManager.GetRemoteModel(ctx, model)
if err != nil {
return nil, types.Config{}, fmt.Errorf("getting remote model(%s): %w", model, err)
}
layers, err := mdl.Layers()
if err != nil {
return nil, types.Config{}, fmt.Errorf("getting layers of model(%s): %w", model, err)
}
var ggufDigest v1.Hash
for _, layer := range layers {
mt, err := layer.MediaType()
if err != nil {
return nil, types.Config{}, fmt.Errorf("getting media type of model(%s) layer: %w", model, err)
}
if mt == types.MediaTypeGGUF {
ggufDigest, err = layer.Digest()
if err != nil {
return nil, types.Config{}, fmt.Errorf("getting digest of GGUF layer for model(%s): %w", model, err)
}
break
}
}
if ggufDigest.String() == "" {
return nil, types.Config{}, fmt.Errorf("model(%s) has no GGUF layer", model)
}
blobURL, err := l.modelManager.GetRemoteModelBlobURL(model, ggufDigest)
if err != nil {
return nil, types.Config{}, fmt.Errorf("getting GGUF blob URL for model(%s): %w", model, err)
}
tok, err := l.modelManager.BearerTokenForModel(ctx, model)
if err != nil {
return nil, types.Config{}, fmt.Errorf("getting bearer token for model(%s): %w", model, err)
}
mdlGguf, err := parser.ParseGGUFFileRemote(ctx, blobURL, parser.UseBearerAuth(tok))
if err != nil {
return nil, types.Config{}, fmt.Errorf("parsing GGUF for model(%s): %w", model, err)
}
config, err := mdl.Config()
if err != nil {
return nil, types.Config{}, fmt.Errorf("getting config for model(%s): %w", model, err)
}
return mdlGguf, config, nil
}
func (l *llamaCpp) checkGPUSupport(ctx context.Context) bool {
binPath := l.vendoredServerStoragePath
if l.updatedLlamaCpp {

View File

@ -63,6 +63,6 @@ func (m *mlx) GetDiskUsage() (int64, error) {
return 0, nil
}
func (m *mlx) GetRequiredMemoryForModel(model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
func (m *mlx) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
return nil, errors.New("not implemented")
}

View File

@ -63,6 +63,6 @@ func (v *vLLM) GetDiskUsage() (int64, error) {
return 0, nil
}
func (v *vLLM) GetRequiredMemoryForModel(model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
func (v *vLLM) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
return nil, errors.New("not implemented")
}

View File

@ -18,6 +18,7 @@ import (
"github.com/docker/model-runner/pkg/diskusage"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/logging"
v1 "github.com/google/go-containerregistry/pkg/v1"
"github.com/sirupsen/logrus"
)
@ -562,6 +563,11 @@ func (m *Manager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
m.router.ServeHTTP(w, r)
}
// IsModelInStore checks if a given model is in the local store.
func (m *Manager) IsModelInStore(ref string) (bool, error) {
return m.distributionClient.IsModelInStore(ref)
}
// GetModel returns a single model.
func (m *Manager) GetModel(ref string) (types.Model, error) {
model, err := m.distributionClient.GetModel(ref)
@ -571,6 +577,33 @@ func (m *Manager) GetModel(ref string) (types.Model, error) {
return model, err
}
// GetRemoteModel returns a single remote model.
func (m *Manager) GetRemoteModel(ctx context.Context, ref string) (types.ModelArtifact, error) {
model, err := m.registryClient.Model(ctx, ref)
if err != nil {
return nil, fmt.Errorf("error while getting remote model: %w", err)
}
return model, nil
}
// GetRemoteModelBlobURL returns the URL of a given model blob.
func (m *Manager) GetRemoteModelBlobURL(ref string, digest v1.Hash) (string, error) {
blobURL, err := m.registryClient.BlobURL(ref, digest)
if err != nil {
return "", fmt.Errorf("error while getting remote model blob URL: %w", err)
}
return blobURL, nil
}
// BearerTokenForModel returns the bearer token needed to pull a given model.
func (m *Manager) BearerTokenForModel(ctx context.Context, ref string) (string, error) {
tok, err := m.registryClient.BearerToken(ctx, ref)
if err != nil {
return "", fmt.Errorf("error while getting bearer token for model: %w", err)
}
return tok, nil
}
// GetModelPath returns the path to a model's files.
func (m *Manager) GetModelPath(ref string) (string, error) {
model, err := m.GetModel(ref)

View File

@ -420,7 +420,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string
if rc, ok := l.runnerConfigs[runnerKey{backendName, modelID, mode}]; ok {
runnerConfig = &rc
}
memory, err := backend.GetRequiredMemoryForModel(modelID, runnerConfig)
memory, err := backend.GetRequiredMemoryForModel(ctx, modelID, runnerConfig)
if errors.Is(err, inference.ErrGGUFParse) {
// TODO(p1-0tr): For now override memory checks in case model can't be parsed
// e.g. model is too new for gguf-parser-go to know. We should provide a cleaner