Merge pull request #133 from docker/shards

Support GGUF shards
This commit is contained in:
Emily Casey 2025-08-22 11:37:38 -06:00 committed by GitHub
commit 5341c9fc29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 45 additions and 69 deletions

2
go.mod
View File

@ -5,7 +5,7 @@ go 1.23.7
require (
github.com/containerd/containerd/v2 v2.0.4
github.com/containerd/platforms v1.0.0-rc.1
github.com/docker/model-distribution v0.0.0-20250822164750-dcd03ba922e7
github.com/docker/model-distribution v0.0.0-20250822172258-8fe9daa4a4da
github.com/elastic/go-sysinfo v1.15.3
github.com/google/go-containerregistry v0.20.3
github.com/gpustack/gguf-parser-go v0.14.1

4
go.sum
View File

@ -38,8 +38,8 @@ github.com/docker/distribution v2.8.3+incompatible h1:AtKxIZ36LoNK51+Z6RpzLpddBi
github.com/docker/distribution v2.8.3+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w=
github.com/docker/docker-credential-helpers v0.8.2 h1:bX3YxiGzFP5sOXWc3bTPEXdEaZSeVMrFgOr3T+zrFAo=
github.com/docker/docker-credential-helpers v0.8.2/go.mod h1:P3ci7E3lwkZg6XiHdRKft1KckHiO9a2rNtyFbZ/ry9M=
github.com/docker/model-distribution v0.0.0-20250822164750-dcd03ba922e7 h1:dOk1UTVMyDHNG4WFS8jnAtfKdPUE3QaMWNvrzRoK/dI=
github.com/docker/model-distribution v0.0.0-20250822164750-dcd03ba922e7/go.mod h1:dThpO9JoG5Px3i+rTluAeZcqLGw8C0qepuEL4gL2o/c=
github.com/docker/model-distribution v0.0.0-20250822172258-8fe9daa4a4da h1:ml99WBfcLnsy1frXQR4X+5WAC0DoGtwZyGoU/xBsDQM=
github.com/docker/model-distribution v0.0.0-20250822172258-8fe9daa4a4da/go.mod h1:dThpO9JoG5Px3i+rTluAeZcqLGw8C0qepuEL4gL2o/c=
github.com/elastic/go-sysinfo v1.15.3 h1:W+RnmhKFkqPTCRoFq2VCTmsT4p/fwpo+3gKNQsn1XU0=
github.com/elastic/go-sysinfo v1.15.3/go.mod h1:K/cNrqYTDrSoMh2oDkYEMS2+a72GRxMvNP+GC+vRIlo=
github.com/elastic/go-windows v1.0.2 h1:yoLLsAsV5cfg9FLhZ9EXZ2n2sQFKeDYrHenkcivY4vI=

View File

@ -15,11 +15,10 @@ import (
"runtime"
"strings"
"github.com/docker/model-distribution/types"
v1 "github.com/google/go-containerregistry/pkg/v1"
parser "github.com/gpustack/gguf-parser-go"
"github.com/docker/model-distribution/types"
"github.com/docker/model-runner/pkg/diskusage"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/config"
@ -133,7 +132,7 @@ func (l *llamaCpp) Install(ctx context.Context, httpClient *http.Client) error {
// Run implements inference.Backend.Run.
func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference.BackendMode, config *inference.BackendConfiguration) error {
mdl, err := l.modelManager.GetModel(model)
bundle, err := l.modelManager.GetBundle(model)
if err != nil {
return fmt.Errorf("failed to get model: %w", err)
}
@ -148,7 +147,7 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference
binPath = l.updatedServerStoragePath
}
args, err := l.config.GetArgs(mdl, socket, mode, config)
args, err := l.config.GetArgs(bundle, socket, mode, config)
if err != nil {
return fmt.Errorf("failed to get args for llama.cpp: %w", err)
}
@ -245,7 +244,7 @@ func (l *llamaCpp) GetRequiredMemoryForModel(ctx context.Context, model string,
}
}
contextSize := GetContextSize(&mdlConfig, config)
contextSize := GetContextSize(mdlConfig, config)
ngl := uint64(0)
if l.gpuSupported {
@ -281,23 +280,15 @@ func (l *llamaCpp) GetRequiredMemoryForModel(ctx context.Context, model string,
}
func (l *llamaCpp) parseLocalModel(model string) (*parser.GGUFFile, types.Config, error) {
mdl, err := l.modelManager.GetModel(model)
bundle, err := l.modelManager.GetBundle(model)
if err != nil {
return nil, types.Config{}, fmt.Errorf("getting model(%s): %w", model, err)
}
mdlPath, err := mdl.GGUFPath()
modelGGUF, err := parser.ParseGGUFFile(bundle.GGUFPath())
if err != nil {
return nil, types.Config{}, fmt.Errorf("getting gguf path for model(%s): %w", model, err)
return nil, types.Config{}, fmt.Errorf("parsing gguf(%s): %w", bundle.GGUFPath(), err)
}
mdlGguf, err := parser.ParseGGUFFile(mdlPath)
if err != nil {
return nil, types.Config{}, fmt.Errorf("parsing gguf(%s): %w", mdlPath, err)
}
mdlConfig, err := mdl.Config()
if err != nil {
return nil, types.Config{}, fmt.Errorf("accessing model(%s) config: %w", model, err)
}
return mdlGguf, mdlConfig, nil
return modelGGUF, bundle.RuntimeConfig(), nil
}
func (l *llamaCpp) parseRemoteModel(ctx context.Context, model string) (*parser.GGUFFile, types.Config, error) {

View File

@ -6,6 +6,7 @@ import (
"strconv"
"github.com/docker/model-distribution/types"
"github.com/docker/model-runner/pkg/inference"
)
@ -35,18 +36,13 @@ func NewDefaultLlamaCppConfig() *Config {
}
// GetArgs implements BackendConfig.GetArgs.
func (c *Config) GetArgs(model types.Model, socket string, mode inference.BackendMode, config *inference.BackendConfiguration) ([]string, error) {
func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference.BackendMode, config *inference.BackendConfiguration) ([]string, error) {
// Start with the arguments from LlamaCppConfig
args := append([]string{}, c.Args...)
modelPath, err := model.GGUFPath()
if err != nil {
return nil, fmt.Errorf("get gguf path: %w", err)
}
modelCfg, err := model.Config()
if err != nil {
return nil, fmt.Errorf("get model config: %w", err)
modelPath := bundle.GGUFPath()
if modelPath == "" {
return nil, fmt.Errorf("GGUF file required by llama.cpp backend")
}
// Add model and socket arguments
@ -57,7 +53,8 @@ func (c *Config) GetArgs(model types.Model, socket string, mode inference.Backen
args = append(args, "--embeddings")
}
args = append(args, "--ctx-size", strconv.FormatUint(GetContextSize(&modelCfg, config), 10))
// Add context size from model config or backend config
args = append(args, "--ctx-size", strconv.FormatUint(GetContextSize(bundle.RuntimeConfig(), config), 10))
// Add arguments from backend config
if config != nil {
@ -65,17 +62,16 @@ func (c *Config) GetArgs(model types.Model, socket string, mode inference.Backen
}
// Add arguments for Multimodal projector
path, err := model.MMPROJPath()
if path != "" && err == nil {
if path := bundle.MMPROJPath(); path != "" {
args = append(args, "--mmproj", path)
}
return args, nil
}
func GetContextSize(modelCfg *types.Config, backendCfg *inference.BackendConfiguration) uint64 {
func GetContextSize(modelCfg types.Config, backendCfg *inference.BackendConfiguration) uint64 {
// Model config takes precedence
if modelCfg != nil && modelCfg.ContextSize != nil {
if modelCfg.ContextSize != nil {
return *modelCfg.ContextSize
}
// else use backend config

View File

@ -1,12 +1,12 @@
package llamacpp
import (
"errors"
"runtime"
"strconv"
"testing"
"github.com/docker/model-distribution/types"
"github.com/docker/model-runner/pkg/inference"
)
@ -74,7 +74,7 @@ func TestGetArgs(t *testing.T) {
tests := []struct {
name string
model types.Model
bundle types.ModelBundle
mode inference.BackendMode
config *inference.BackendConfiguration
expected []string
@ -82,7 +82,7 @@ func TestGetArgs(t *testing.T) {
{
name: "completion mode",
mode: inference.BackendModeCompletion,
model: &fakeModel{
bundle: &fakeBundle{
ggufPath: modelPath,
},
expected: []string{
@ -97,7 +97,7 @@ func TestGetArgs(t *testing.T) {
{
name: "embedding mode",
mode: inference.BackendModeEmbedding,
model: &fakeModel{
bundle: &fakeBundle{
ggufPath: modelPath,
},
expected: []string{
@ -113,7 +113,7 @@ func TestGetArgs(t *testing.T) {
{
name: "context size from backend config",
mode: inference.BackendModeEmbedding,
model: &fakeModel{
bundle: &fakeBundle{
ggufPath: modelPath,
},
config: &inference.BackendConfiguration{
@ -132,7 +132,7 @@ func TestGetArgs(t *testing.T) {
{
name: "context size from model config",
mode: inference.BackendModeEmbedding,
model: &fakeModel{
bundle: &fakeBundle{
ggufPath: modelPath,
config: types.Config{
ContextSize: uint64ptr(2096),
@ -154,7 +154,7 @@ func TestGetArgs(t *testing.T) {
{
name: "raw flags from backend config",
mode: inference.BackendModeEmbedding,
model: &fakeModel{
bundle: &fakeBundle{
ggufPath: modelPath,
},
config: &inference.BackendConfiguration{
@ -175,7 +175,7 @@ func TestGetArgs(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
args, err := config.GetArgs(tt.model, socket, tt.mode, tt.config)
args, err := config.GetArgs(tt.bundle, socket, tt.mode, tt.config)
if err != nil {
t.Errorf("GetArgs() error = %v", err)
}
@ -248,35 +248,27 @@ func TestContainsArg(t *testing.T) {
}
}
var _ types.Model = &fakeModel{}
var _ types.ModelBundle = &fakeBundle{}
type fakeModel struct {
type fakeBundle struct {
ggufPath string
config types.Config
}
func (f *fakeModel) MMPROJPath() (string, error) {
return "", errors.New("not found")
}
func (f *fakeModel) ID() (string, error) {
func (f *fakeBundle) RootDir() string {
panic("shouldn't be called")
}
func (f *fakeModel) GGUFPath() (string, error) {
return f.ggufPath, nil
func (f *fakeBundle) GGUFPath() string {
return f.ggufPath
}
func (f *fakeModel) Config() (types.Config, error) {
return f.config, nil
func (f *fakeBundle) MMPROJPath() string {
return ""
}
func (f *fakeModel) Tags() []string {
panic("shouldn't be called")
}
func (f fakeModel) Descriptor() (types.Descriptor, error) {
panic("shouldn't be called")
func (f *fakeBundle) RuntimeConfig() types.Config {
return f.config
}
func uint64ptr(n uint64) *uint64 {

View File

@ -2,6 +2,7 @@ package config
import (
"github.com/docker/model-distribution/types"
"github.com/docker/model-runner/pkg/inference"
)
@ -12,5 +13,5 @@ type BackendConfig interface {
// GetArgs returns the command-line arguments for the backend.
// It takes the model path, socket, and mode as input and returns
// the appropriate arguments for the backend.
GetArgs(model types.Model, socket string, mode inference.BackendMode, config *inference.BackendConfiguration) ([]string, error)
GetArgs(bundle types.ModelBundle, socket string, mode inference.BackendMode, config *inference.BackendConfiguration) ([]string, error)
}

View File

@ -622,17 +622,13 @@ func (m *Manager) BearerTokenForModel(ctx context.Context, ref string) (string,
return tok, nil
}
// GetModelPath returns the path to a model's files.
func (m *Manager) GetModelPath(ref string) (string, error) {
model, err := m.GetModel(ref)
// GetBundle returns model bundle.
func (m *Manager) GetBundle(ref string) (types.ModelBundle, error) {
bundle, err := m.distributionClient.GetBundle(ref)
if err != nil {
return "", err
return nil, fmt.Errorf("error while getting model bundle: %w", err)
}
path, err := model.GGUFPath()
if err != nil {
return "", fmt.Errorf("error while getting model path: %w", err)
}
return path, nil
return bundle, err
}
// PullModel pulls a model to local storage. Any error it returns is suitable