From 156686cc6ff83d8e4247caf4a0d88c1a16c9a25b Mon Sep 17 00:00:00 2001 From: Emily Casey Date: Thu, 21 Aug 2025 22:10:17 -0600 Subject: [PATCH] Run from bundle Signed-off-by: Emily Casey --- go.mod | 2 +- go.sum | 4 +- pkg/inference/backends/llamacpp/llamacpp.go | 25 ++++------- .../backends/llamacpp/llamacpp_config.go | 25 +++++------ .../backends/llamacpp/llamacpp_config_test.go | 42 ++++++++----------- pkg/inference/config/config.go | 3 +- pkg/inference/models/manager.go | 17 ++++---- 7 files changed, 49 insertions(+), 69 deletions(-) diff --git a/go.mod b/go.mod index 09aeeb0..d2bb0ed 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.23.7 require ( github.com/containerd/containerd/v2 v2.0.4 github.com/containerd/platforms v1.0.0-rc.1 - github.com/docker/model-distribution v0.0.0-20250724114133-a11d745e582c + github.com/docker/model-distribution v0.0.0-20250822151640-fca29728e7be github.com/elastic/go-sysinfo v1.15.3 github.com/google/go-containerregistry v0.20.3 github.com/gpustack/gguf-parser-go v0.14.1 diff --git a/go.sum b/go.sum index 68f041d..c1ffe2c 100644 --- a/go.sum +++ b/go.sum @@ -38,8 +38,8 @@ github.com/docker/distribution v2.8.3+incompatible h1:AtKxIZ36LoNK51+Z6RpzLpddBi github.com/docker/distribution v2.8.3+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= github.com/docker/docker-credential-helpers v0.8.2 h1:bX3YxiGzFP5sOXWc3bTPEXdEaZSeVMrFgOr3T+zrFAo= github.com/docker/docker-credential-helpers v0.8.2/go.mod h1:P3ci7E3lwkZg6XiHdRKft1KckHiO9a2rNtyFbZ/ry9M= -github.com/docker/model-distribution v0.0.0-20250724114133-a11d745e582c h1:w9MekYamXmWLe9ZWXWgNXJ7BLDDemXwB8WcF7wzHF5Q= -github.com/docker/model-distribution v0.0.0-20250724114133-a11d745e582c/go.mod h1:dThpO9JoG5Px3i+rTluAeZcqLGw8C0qepuEL4gL2o/c= +github.com/docker/model-distribution v0.0.0-20250822151640-fca29728e7be h1:S6v82p2JPC4HwaZcnM5TmOjMQktIqu7HCvJMbkDIS+U= +github.com/docker/model-distribution v0.0.0-20250822151640-fca29728e7be/go.mod h1:dThpO9JoG5Px3i+rTluAeZcqLGw8C0qepuEL4gL2o/c= github.com/elastic/go-sysinfo v1.15.3 h1:W+RnmhKFkqPTCRoFq2VCTmsT4p/fwpo+3gKNQsn1XU0= github.com/elastic/go-sysinfo v1.15.3/go.mod h1:K/cNrqYTDrSoMh2oDkYEMS2+a72GRxMvNP+GC+vRIlo= github.com/elastic/go-windows v1.0.2 h1:yoLLsAsV5cfg9FLhZ9EXZ2n2sQFKeDYrHenkcivY4vI= diff --git a/pkg/inference/backends/llamacpp/llamacpp.go b/pkg/inference/backends/llamacpp/llamacpp.go index f53b26a..15c6625 100644 --- a/pkg/inference/backends/llamacpp/llamacpp.go +++ b/pkg/inference/backends/llamacpp/llamacpp.go @@ -130,7 +130,7 @@ func (l *llamaCpp) Install(ctx context.Context, httpClient *http.Client) error { // Run implements inference.Backend.Run. func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference.BackendMode, config *inference.BackendConfiguration) error { - mdl, err := l.modelManager.GetModel(model) + bundle, err := l.modelManager.GetBundle(model) if err != nil { return fmt.Errorf("failed to get model: %w", err) } @@ -145,7 +145,7 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference binPath = l.updatedServerStoragePath } - args, err := l.config.GetArgs(mdl, socket, mode, config) + args, err := l.config.GetArgs(bundle, socket, mode, config) if err != nil { return fmt.Errorf("failed to get args for llama.cpp: %w", err) } @@ -224,29 +224,22 @@ func (l *llamaCpp) GetDiskUsage() (int64, error) { } func (l *llamaCpp) GetRequiredMemoryForModel(model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) { - mdl, err := l.modelManager.GetModel(model) + bundle, err := l.modelManager.GetBundle(model) if err != nil { return nil, fmt.Errorf("getting model(%s): %w", model, err) } - mdlPath, err := mdl.GGUFPath() + + mdlGGUF, err := parser.ParseGGUFFile(bundle.GGUFPath()) if err != nil { - return nil, fmt.Errorf("getting gguf path for model(%s): %w", model, err) - } - mdlGguf, err := parser.ParseGGUFFile(mdlPath) - if err != nil { - l.log.Warnf("Failed to parse gguf(%s): %s", mdlPath, err) + l.log.Warnf("Failed to parse gguf(%s): %s", bundle.GGUFPath(), err) return nil, inference.ErrGGUFParse } - mdlConfig, err := mdl.Config() - if err != nil { - return nil, fmt.Errorf("accessing model(%s) config: %w", model, err) - } - contextSize := GetContextSize(&mdlConfig, config) + contextSize := GetContextSize(bundle.RuntimeConfig(), config) ngl := uint64(0) if l.gpuSupported { - if runtime.GOOS == "windows" && runtime.GOARCH == "arm64" && mdlConfig.Quantization != "Q4_0" { + if runtime.GOOS == "windows" && runtime.GOARCH == "arm64" && strings.TrimSpace(mdlGGUF.Metadata().FileType.String()) != "Q4_0" { ngl = 0 // only Q4_0 models can be accelerated on Adreno } ngl = 100 @@ -255,7 +248,7 @@ func (l *llamaCpp) GetRequiredMemoryForModel(model string, config *inference.Bac // TODO(p1-0tr): for now assume we are running on GPU (single one) - Devices[1]; // sum up weights + kv cache + context for an estimate of total GPU memory needed // while running inference with the given model - estimate := mdlGguf.EstimateLLaMACppRun(parser.WithLLaMACppContextSize(int32(contextSize)), + estimate := mdlGGUF.EstimateLLaMACppRun(parser.WithLLaMACppContextSize(int32(contextSize)), // TODO(p1-0tr): add logic for resolving other param values, instead of hardcoding them parser.WithLLaMACppLogicalBatchSize(2048), parser.WithLLaMACppOffloadLayers(ngl)) diff --git a/pkg/inference/backends/llamacpp/llamacpp_config.go b/pkg/inference/backends/llamacpp/llamacpp_config.go index becc3a1..d35858c 100644 --- a/pkg/inference/backends/llamacpp/llamacpp_config.go +++ b/pkg/inference/backends/llamacpp/llamacpp_config.go @@ -6,6 +6,7 @@ import ( "strconv" "github.com/docker/model-distribution/types" + "github.com/docker/model-runner/pkg/inference" ) @@ -35,18 +36,13 @@ func NewDefaultLlamaCppConfig() *Config { } // GetArgs implements BackendConfig.GetArgs. -func (c *Config) GetArgs(model types.Model, socket string, mode inference.BackendMode, config *inference.BackendConfiguration) ([]string, error) { +func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference.BackendMode, config *inference.BackendConfiguration) ([]string, error) { // Start with the arguments from LlamaCppConfig args := append([]string{}, c.Args...) - modelPath, err := model.GGUFPath() - if err != nil { - return nil, fmt.Errorf("get gguf path: %w", err) - } - - modelCfg, err := model.Config() - if err != nil { - return nil, fmt.Errorf("get model config: %w", err) + modelPath := bundle.GGUFPath() + if modelPath == "" { + return nil, fmt.Errorf("GGUF file required by llama.cpp backend") } // Add model and socket arguments @@ -57,7 +53,8 @@ func (c *Config) GetArgs(model types.Model, socket string, mode inference.Backen args = append(args, "--embeddings") } - args = append(args, "--ctx-size", strconv.FormatUint(GetContextSize(&modelCfg, config), 10)) + // Add context size from model config or backend config + args = append(args, "--ctx-size", strconv.FormatUint(GetContextSize(bundle.RuntimeConfig(), config), 10)) // Add arguments from backend config if config != nil { @@ -65,17 +62,17 @@ func (c *Config) GetArgs(model types.Model, socket string, mode inference.Backen } // Add arguments for Multimodal projector - path, err := model.MMPROJPath() - if path != "" && err == nil { + mmprojPath := bundle.MMPROJPath() + if path := mmprojPath; path != "" { args = append(args, "--mmproj", path) } return args, nil } -func GetContextSize(modelCfg *types.Config, backendCfg *inference.BackendConfiguration) uint64 { +func GetContextSize(modelCfg types.Config, backendCfg *inference.BackendConfiguration) uint64 { // Model config takes precedence - if modelCfg != nil && modelCfg.ContextSize != nil { + if modelCfg.ContextSize != nil { return *modelCfg.ContextSize } // else use backend config diff --git a/pkg/inference/backends/llamacpp/llamacpp_config_test.go b/pkg/inference/backends/llamacpp/llamacpp_config_test.go index f013009..f6e029a 100644 --- a/pkg/inference/backends/llamacpp/llamacpp_config_test.go +++ b/pkg/inference/backends/llamacpp/llamacpp_config_test.go @@ -1,12 +1,12 @@ package llamacpp import ( - "errors" "runtime" "strconv" "testing" "github.com/docker/model-distribution/types" + "github.com/docker/model-runner/pkg/inference" ) @@ -74,7 +74,7 @@ func TestGetArgs(t *testing.T) { tests := []struct { name string - model types.Model + bundle types.ModelBundle mode inference.BackendMode config *inference.BackendConfiguration expected []string @@ -82,7 +82,7 @@ func TestGetArgs(t *testing.T) { { name: "completion mode", mode: inference.BackendModeCompletion, - model: &fakeModel{ + bundle: &fakeBundle{ ggufPath: modelPath, }, expected: []string{ @@ -97,7 +97,7 @@ func TestGetArgs(t *testing.T) { { name: "embedding mode", mode: inference.BackendModeEmbedding, - model: &fakeModel{ + bundle: &fakeBundle{ ggufPath: modelPath, }, expected: []string{ @@ -113,7 +113,7 @@ func TestGetArgs(t *testing.T) { { name: "context size from backend config", mode: inference.BackendModeEmbedding, - model: &fakeModel{ + bundle: &fakeBundle{ ggufPath: modelPath, }, config: &inference.BackendConfiguration{ @@ -132,7 +132,7 @@ func TestGetArgs(t *testing.T) { { name: "context size from model config", mode: inference.BackendModeEmbedding, - model: &fakeModel{ + bundle: &fakeBundle{ ggufPath: modelPath, config: types.Config{ ContextSize: uint64ptr(2096), @@ -154,7 +154,7 @@ func TestGetArgs(t *testing.T) { { name: "raw flags from backend config", mode: inference.BackendModeEmbedding, - model: &fakeModel{ + bundle: &fakeBundle{ ggufPath: modelPath, }, config: &inference.BackendConfiguration{ @@ -175,7 +175,7 @@ func TestGetArgs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - args, err := config.GetArgs(tt.model, socket, tt.mode, tt.config) + args, err := config.GetArgs(tt.bundle, socket, tt.mode, tt.config) if err != nil { t.Errorf("GetArgs() error = %v", err) } @@ -248,35 +248,27 @@ func TestContainsArg(t *testing.T) { } } -var _ types.Model = &fakeModel{} +var _ types.ModelBundle = &fakeBundle{} -type fakeModel struct { +type fakeBundle struct { ggufPath string config types.Config } -func (f *fakeModel) MMPROJPath() (string, error) { - return "", errors.New("not found") -} - -func (f *fakeModel) ID() (string, error) { +func (f *fakeBundle) RootDir() string { panic("shouldn't be called") } -func (f *fakeModel) GGUFPath() (string, error) { - return f.ggufPath, nil +func (f *fakeBundle) GGUFPath() string { + return f.ggufPath } -func (f *fakeModel) Config() (types.Config, error) { - return f.config, nil +func (f *fakeBundle) MMPROJPath() string { + return "" } -func (f *fakeModel) Tags() []string { - panic("shouldn't be called") -} - -func (f fakeModel) Descriptor() (types.Descriptor, error) { - panic("shouldn't be called") +func (f *fakeBundle) RuntimeConfig() types.Config { + return f.config } func uint64ptr(n uint64) *uint64 { diff --git a/pkg/inference/config/config.go b/pkg/inference/config/config.go index 8163759..72a22be 100644 --- a/pkg/inference/config/config.go +++ b/pkg/inference/config/config.go @@ -2,6 +2,7 @@ package config import ( "github.com/docker/model-distribution/types" + "github.com/docker/model-runner/pkg/inference" ) @@ -12,5 +13,5 @@ type BackendConfig interface { // GetArgs returns the command-line arguments for the backend. // It takes the model path, socket, and mode as input and returns // the appropriate arguments for the backend. - GetArgs(model types.Model, socket string, mode inference.BackendMode, config *inference.BackendConfiguration) ([]string, error) + GetArgs(bundle types.ModelBundle, socket string, mode inference.BackendMode, config *inference.BackendConfiguration) ([]string, error) } diff --git a/pkg/inference/models/manager.go b/pkg/inference/models/manager.go index b0cf68b..9ea8011 100644 --- a/pkg/inference/models/manager.go +++ b/pkg/inference/models/manager.go @@ -15,10 +15,11 @@ import ( "github.com/docker/model-distribution/distribution" "github.com/docker/model-distribution/registry" "github.com/docker/model-distribution/types" + "github.com/sirupsen/logrus" + "github.com/docker/model-runner/pkg/diskusage" "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/logging" - "github.com/sirupsen/logrus" ) const ( @@ -571,17 +572,13 @@ func (m *Manager) GetModel(ref string) (types.Model, error) { return model, err } -// GetModelPath returns the path to a model's files. -func (m *Manager) GetModelPath(ref string) (string, error) { - model, err := m.GetModel(ref) +// GetBundle returns model bundle. +func (m *Manager) GetBundle(ref string) (types.ModelBundle, error) { + bundle, err := m.distributionClient.GetBundle(ref) if err != nil { - return "", err + return nil, fmt.Errorf("error while getting model bundle: %w", err) } - path, err := model.GGUFPath() - if err != nil { - return "", fmt.Errorf("error while getting model path: %w", err) - } - return path, nil + return bundle, err } // PullModel pulls a model to local storage. Any error it returns is suitable