commit
5341c9fc29
2
go.mod
2
go.mod
|
|
@ -5,7 +5,7 @@ go 1.23.7
|
|||
require (
|
||||
github.com/containerd/containerd/v2 v2.0.4
|
||||
github.com/containerd/platforms v1.0.0-rc.1
|
||||
github.com/docker/model-distribution v0.0.0-20250822164750-dcd03ba922e7
|
||||
github.com/docker/model-distribution v0.0.0-20250822172258-8fe9daa4a4da
|
||||
github.com/elastic/go-sysinfo v1.15.3
|
||||
github.com/google/go-containerregistry v0.20.3
|
||||
github.com/gpustack/gguf-parser-go v0.14.1
|
||||
|
|
|
|||
4
go.sum
4
go.sum
|
|
@ -38,8 +38,8 @@ github.com/docker/distribution v2.8.3+incompatible h1:AtKxIZ36LoNK51+Z6RpzLpddBi
|
|||
github.com/docker/distribution v2.8.3+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w=
|
||||
github.com/docker/docker-credential-helpers v0.8.2 h1:bX3YxiGzFP5sOXWc3bTPEXdEaZSeVMrFgOr3T+zrFAo=
|
||||
github.com/docker/docker-credential-helpers v0.8.2/go.mod h1:P3ci7E3lwkZg6XiHdRKft1KckHiO9a2rNtyFbZ/ry9M=
|
||||
github.com/docker/model-distribution v0.0.0-20250822164750-dcd03ba922e7 h1:dOk1UTVMyDHNG4WFS8jnAtfKdPUE3QaMWNvrzRoK/dI=
|
||||
github.com/docker/model-distribution v0.0.0-20250822164750-dcd03ba922e7/go.mod h1:dThpO9JoG5Px3i+rTluAeZcqLGw8C0qepuEL4gL2o/c=
|
||||
github.com/docker/model-distribution v0.0.0-20250822172258-8fe9daa4a4da h1:ml99WBfcLnsy1frXQR4X+5WAC0DoGtwZyGoU/xBsDQM=
|
||||
github.com/docker/model-distribution v0.0.0-20250822172258-8fe9daa4a4da/go.mod h1:dThpO9JoG5Px3i+rTluAeZcqLGw8C0qepuEL4gL2o/c=
|
||||
github.com/elastic/go-sysinfo v1.15.3 h1:W+RnmhKFkqPTCRoFq2VCTmsT4p/fwpo+3gKNQsn1XU0=
|
||||
github.com/elastic/go-sysinfo v1.15.3/go.mod h1:K/cNrqYTDrSoMh2oDkYEMS2+a72GRxMvNP+GC+vRIlo=
|
||||
github.com/elastic/go-windows v1.0.2 h1:yoLLsAsV5cfg9FLhZ9EXZ2n2sQFKeDYrHenkcivY4vI=
|
||||
|
|
|
|||
|
|
@ -15,11 +15,10 @@ import (
|
|||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/docker/model-distribution/types"
|
||||
v1 "github.com/google/go-containerregistry/pkg/v1"
|
||||
parser "github.com/gpustack/gguf-parser-go"
|
||||
|
||||
"github.com/docker/model-distribution/types"
|
||||
|
||||
"github.com/docker/model-runner/pkg/diskusage"
|
||||
"github.com/docker/model-runner/pkg/inference"
|
||||
"github.com/docker/model-runner/pkg/inference/config"
|
||||
|
|
@ -133,7 +132,7 @@ func (l *llamaCpp) Install(ctx context.Context, httpClient *http.Client) error {
|
|||
|
||||
// Run implements inference.Backend.Run.
|
||||
func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference.BackendMode, config *inference.BackendConfiguration) error {
|
||||
mdl, err := l.modelManager.GetModel(model)
|
||||
bundle, err := l.modelManager.GetBundle(model)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get model: %w", err)
|
||||
}
|
||||
|
|
@ -148,7 +147,7 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference
|
|||
binPath = l.updatedServerStoragePath
|
||||
}
|
||||
|
||||
args, err := l.config.GetArgs(mdl, socket, mode, config)
|
||||
args, err := l.config.GetArgs(bundle, socket, mode, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get args for llama.cpp: %w", err)
|
||||
}
|
||||
|
|
@ -245,7 +244,7 @@ func (l *llamaCpp) GetRequiredMemoryForModel(ctx context.Context, model string,
|
|||
}
|
||||
}
|
||||
|
||||
contextSize := GetContextSize(&mdlConfig, config)
|
||||
contextSize := GetContextSize(mdlConfig, config)
|
||||
|
||||
ngl := uint64(0)
|
||||
if l.gpuSupported {
|
||||
|
|
@ -281,23 +280,15 @@ func (l *llamaCpp) GetRequiredMemoryForModel(ctx context.Context, model string,
|
|||
}
|
||||
|
||||
func (l *llamaCpp) parseLocalModel(model string) (*parser.GGUFFile, types.Config, error) {
|
||||
mdl, err := l.modelManager.GetModel(model)
|
||||
bundle, err := l.modelManager.GetBundle(model)
|
||||
if err != nil {
|
||||
return nil, types.Config{}, fmt.Errorf("getting model(%s): %w", model, err)
|
||||
}
|
||||
mdlPath, err := mdl.GGUFPath()
|
||||
modelGGUF, err := parser.ParseGGUFFile(bundle.GGUFPath())
|
||||
if err != nil {
|
||||
return nil, types.Config{}, fmt.Errorf("getting gguf path for model(%s): %w", model, err)
|
||||
return nil, types.Config{}, fmt.Errorf("parsing gguf(%s): %w", bundle.GGUFPath(), err)
|
||||
}
|
||||
mdlGguf, err := parser.ParseGGUFFile(mdlPath)
|
||||
if err != nil {
|
||||
return nil, types.Config{}, fmt.Errorf("parsing gguf(%s): %w", mdlPath, err)
|
||||
}
|
||||
mdlConfig, err := mdl.Config()
|
||||
if err != nil {
|
||||
return nil, types.Config{}, fmt.Errorf("accessing model(%s) config: %w", model, err)
|
||||
}
|
||||
return mdlGguf, mdlConfig, nil
|
||||
return modelGGUF, bundle.RuntimeConfig(), nil
|
||||
}
|
||||
|
||||
func (l *llamaCpp) parseRemoteModel(ctx context.Context, model string) (*parser.GGUFFile, types.Config, error) {
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import (
|
|||
"strconv"
|
||||
|
||||
"github.com/docker/model-distribution/types"
|
||||
|
||||
"github.com/docker/model-runner/pkg/inference"
|
||||
)
|
||||
|
||||
|
|
@ -35,18 +36,13 @@ func NewDefaultLlamaCppConfig() *Config {
|
|||
}
|
||||
|
||||
// GetArgs implements BackendConfig.GetArgs.
|
||||
func (c *Config) GetArgs(model types.Model, socket string, mode inference.BackendMode, config *inference.BackendConfiguration) ([]string, error) {
|
||||
func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference.BackendMode, config *inference.BackendConfiguration) ([]string, error) {
|
||||
// Start with the arguments from LlamaCppConfig
|
||||
args := append([]string{}, c.Args...)
|
||||
|
||||
modelPath, err := model.GGUFPath()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get gguf path: %w", err)
|
||||
}
|
||||
|
||||
modelCfg, err := model.Config()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get model config: %w", err)
|
||||
modelPath := bundle.GGUFPath()
|
||||
if modelPath == "" {
|
||||
return nil, fmt.Errorf("GGUF file required by llama.cpp backend")
|
||||
}
|
||||
|
||||
// Add model and socket arguments
|
||||
|
|
@ -57,7 +53,8 @@ func (c *Config) GetArgs(model types.Model, socket string, mode inference.Backen
|
|||
args = append(args, "--embeddings")
|
||||
}
|
||||
|
||||
args = append(args, "--ctx-size", strconv.FormatUint(GetContextSize(&modelCfg, config), 10))
|
||||
// Add context size from model config or backend config
|
||||
args = append(args, "--ctx-size", strconv.FormatUint(GetContextSize(bundle.RuntimeConfig(), config), 10))
|
||||
|
||||
// Add arguments from backend config
|
||||
if config != nil {
|
||||
|
|
@ -65,17 +62,16 @@ func (c *Config) GetArgs(model types.Model, socket string, mode inference.Backen
|
|||
}
|
||||
|
||||
// Add arguments for Multimodal projector
|
||||
path, err := model.MMPROJPath()
|
||||
if path != "" && err == nil {
|
||||
if path := bundle.MMPROJPath(); path != "" {
|
||||
args = append(args, "--mmproj", path)
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func GetContextSize(modelCfg *types.Config, backendCfg *inference.BackendConfiguration) uint64 {
|
||||
func GetContextSize(modelCfg types.Config, backendCfg *inference.BackendConfiguration) uint64 {
|
||||
// Model config takes precedence
|
||||
if modelCfg != nil && modelCfg.ContextSize != nil {
|
||||
if modelCfg.ContextSize != nil {
|
||||
return *modelCfg.ContextSize
|
||||
}
|
||||
// else use backend config
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
package llamacpp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/docker/model-distribution/types"
|
||||
|
||||
"github.com/docker/model-runner/pkg/inference"
|
||||
)
|
||||
|
||||
|
|
@ -74,7 +74,7 @@ func TestGetArgs(t *testing.T) {
|
|||
|
||||
tests := []struct {
|
||||
name string
|
||||
model types.Model
|
||||
bundle types.ModelBundle
|
||||
mode inference.BackendMode
|
||||
config *inference.BackendConfiguration
|
||||
expected []string
|
||||
|
|
@ -82,7 +82,7 @@ func TestGetArgs(t *testing.T) {
|
|||
{
|
||||
name: "completion mode",
|
||||
mode: inference.BackendModeCompletion,
|
||||
model: &fakeModel{
|
||||
bundle: &fakeBundle{
|
||||
ggufPath: modelPath,
|
||||
},
|
||||
expected: []string{
|
||||
|
|
@ -97,7 +97,7 @@ func TestGetArgs(t *testing.T) {
|
|||
{
|
||||
name: "embedding mode",
|
||||
mode: inference.BackendModeEmbedding,
|
||||
model: &fakeModel{
|
||||
bundle: &fakeBundle{
|
||||
ggufPath: modelPath,
|
||||
},
|
||||
expected: []string{
|
||||
|
|
@ -113,7 +113,7 @@ func TestGetArgs(t *testing.T) {
|
|||
{
|
||||
name: "context size from backend config",
|
||||
mode: inference.BackendModeEmbedding,
|
||||
model: &fakeModel{
|
||||
bundle: &fakeBundle{
|
||||
ggufPath: modelPath,
|
||||
},
|
||||
config: &inference.BackendConfiguration{
|
||||
|
|
@ -132,7 +132,7 @@ func TestGetArgs(t *testing.T) {
|
|||
{
|
||||
name: "context size from model config",
|
||||
mode: inference.BackendModeEmbedding,
|
||||
model: &fakeModel{
|
||||
bundle: &fakeBundle{
|
||||
ggufPath: modelPath,
|
||||
config: types.Config{
|
||||
ContextSize: uint64ptr(2096),
|
||||
|
|
@ -154,7 +154,7 @@ func TestGetArgs(t *testing.T) {
|
|||
{
|
||||
name: "raw flags from backend config",
|
||||
mode: inference.BackendModeEmbedding,
|
||||
model: &fakeModel{
|
||||
bundle: &fakeBundle{
|
||||
ggufPath: modelPath,
|
||||
},
|
||||
config: &inference.BackendConfiguration{
|
||||
|
|
@ -175,7 +175,7 @@ func TestGetArgs(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
args, err := config.GetArgs(tt.model, socket, tt.mode, tt.config)
|
||||
args, err := config.GetArgs(tt.bundle, socket, tt.mode, tt.config)
|
||||
if err != nil {
|
||||
t.Errorf("GetArgs() error = %v", err)
|
||||
}
|
||||
|
|
@ -248,35 +248,27 @@ func TestContainsArg(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
var _ types.Model = &fakeModel{}
|
||||
var _ types.ModelBundle = &fakeBundle{}
|
||||
|
||||
type fakeModel struct {
|
||||
type fakeBundle struct {
|
||||
ggufPath string
|
||||
config types.Config
|
||||
}
|
||||
|
||||
func (f *fakeModel) MMPROJPath() (string, error) {
|
||||
return "", errors.New("not found")
|
||||
}
|
||||
|
||||
func (f *fakeModel) ID() (string, error) {
|
||||
func (f *fakeBundle) RootDir() string {
|
||||
panic("shouldn't be called")
|
||||
}
|
||||
|
||||
func (f *fakeModel) GGUFPath() (string, error) {
|
||||
return f.ggufPath, nil
|
||||
func (f *fakeBundle) GGUFPath() string {
|
||||
return f.ggufPath
|
||||
}
|
||||
|
||||
func (f *fakeModel) Config() (types.Config, error) {
|
||||
return f.config, nil
|
||||
func (f *fakeBundle) MMPROJPath() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (f *fakeModel) Tags() []string {
|
||||
panic("shouldn't be called")
|
||||
}
|
||||
|
||||
func (f fakeModel) Descriptor() (types.Descriptor, error) {
|
||||
panic("shouldn't be called")
|
||||
func (f *fakeBundle) RuntimeConfig() types.Config {
|
||||
return f.config
|
||||
}
|
||||
|
||||
func uint64ptr(n uint64) *uint64 {
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package config
|
|||
|
||||
import (
|
||||
"github.com/docker/model-distribution/types"
|
||||
|
||||
"github.com/docker/model-runner/pkg/inference"
|
||||
)
|
||||
|
||||
|
|
@ -12,5 +13,5 @@ type BackendConfig interface {
|
|||
// GetArgs returns the command-line arguments for the backend.
|
||||
// It takes the model path, socket, and mode as input and returns
|
||||
// the appropriate arguments for the backend.
|
||||
GetArgs(model types.Model, socket string, mode inference.BackendMode, config *inference.BackendConfiguration) ([]string, error)
|
||||
GetArgs(bundle types.ModelBundle, socket string, mode inference.BackendMode, config *inference.BackendConfiguration) ([]string, error)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -622,17 +622,13 @@ func (m *Manager) BearerTokenForModel(ctx context.Context, ref string) (string,
|
|||
return tok, nil
|
||||
}
|
||||
|
||||
// GetModelPath returns the path to a model's files.
|
||||
func (m *Manager) GetModelPath(ref string) (string, error) {
|
||||
model, err := m.GetModel(ref)
|
||||
// GetBundle returns model bundle.
|
||||
func (m *Manager) GetBundle(ref string) (types.ModelBundle, error) {
|
||||
bundle, err := m.distributionClient.GetBundle(ref)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, fmt.Errorf("error while getting model bundle: %w", err)
|
||||
}
|
||||
path, err := model.GGUFPath()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error while getting model path: %w", err)
|
||||
}
|
||||
return path, nil
|
||||
return bundle, err
|
||||
}
|
||||
|
||||
// PullModel pulls a model to local storage. Any error it returns is suitable
|
||||
|
|
|
|||
Loading…
Reference in New Issue