This reverts commit 1898347e1a.
This commit is contained in:
parent
423e1d6d1b
commit
23896c491b
10
Makefile
10
Makefile
|
|
@ -7,6 +7,7 @@ BASE_IMAGE := ubuntu:24.04
|
|||
DOCKER_IMAGE := docker/model-runner:latest
|
||||
PORT := 8080
|
||||
MODELS_PATH := $(shell pwd)/models-store
|
||||
LLAMA_ARGS ?=
|
||||
|
||||
# Main targets
|
||||
.PHONY: build run clean test docker-build docker-run help
|
||||
|
|
@ -20,6 +21,7 @@ build:
|
|||
|
||||
# Run the application locally
|
||||
run: build
|
||||
LLAMA_ARGS="$(LLAMA_ARGS)" \
|
||||
./$(APP_NAME)
|
||||
|
||||
# Clean build artifacts
|
||||
|
|
@ -55,6 +57,7 @@ docker-run: docker-build
|
|||
-e MODEL_RUNNER_PORT=$(PORT) \
|
||||
-e LLAMA_SERVER_PATH=/app/bin \
|
||||
-e MODELS_PATH=/models \
|
||||
-e LLAMA_ARGS="$(LLAMA_ARGS)" \
|
||||
$(DOCKER_IMAGE)
|
||||
|
||||
# Show help
|
||||
|
|
@ -67,3 +70,10 @@ help:
|
|||
@echo " docker-build - Build Docker image"
|
||||
@echo " docker-run - Run in Docker container with TCP port access and mounted model storage"
|
||||
@echo " help - Show this help message"
|
||||
@echo ""
|
||||
@echo "Backend configuration options:"
|
||||
@echo " LLAMA_ARGS - Arguments for llama.cpp (e.g., \"--verbose --jinja -ngl 100 --ctx-size 2048\")"
|
||||
@echo ""
|
||||
@echo "Example usage:"
|
||||
@echo " make run LLAMA_ARGS=\"--verbose --jinja -ngl 100 --ctx-size 2048\""
|
||||
@echo " make docker-run LLAMA_ARGS=\"--verbose --jinja -ngl 100 --threads 4 --ctx-size 2048\""
|
||||
|
|
|
|||
62
main.go
62
main.go
|
|
@ -7,10 +7,12 @@ import (
|
|||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/docker/model-runner/pkg/inference"
|
||||
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
|
||||
"github.com/docker/model-runner/pkg/inference/config"
|
||||
"github.com/docker/model-runner/pkg/inference/models"
|
||||
"github.com/docker/model-runner/pkg/inference/scheduling"
|
||||
"github.com/docker/model-runner/pkg/routing"
|
||||
|
|
@ -57,6 +59,9 @@ func main() {
|
|||
|
||||
log.Infof("LLAMA_SERVER_PATH: %s", llamaServerPath)
|
||||
|
||||
// Create llama.cpp configuration from environment variables
|
||||
llamaCppConfig := createLlamaCppConfigFromEnv()
|
||||
|
||||
llamaCppBackend, err := llamacpp.New(
|
||||
log,
|
||||
modelManager,
|
||||
|
|
@ -68,6 +73,7 @@ func main() {
|
|||
_ = os.MkdirAll(d, 0o755)
|
||||
return d
|
||||
}(),
|
||||
llamaCppConfig,
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatalf("unable to initialize %s backend: %v", llamacpp.Name, err)
|
||||
|
|
@ -141,3 +147,59 @@ func main() {
|
|||
}
|
||||
log.Infoln("Docker Model Runner stopped")
|
||||
}
|
||||
|
||||
// createLlamaCppConfigFromEnv creates a LlamaCppConfig from environment variables
|
||||
func createLlamaCppConfigFromEnv() config.BackendConfig {
|
||||
// Check if any configuration environment variables are set
|
||||
argsStr := os.Getenv("LLAMA_ARGS")
|
||||
|
||||
// If no environment variables are set, use default configuration
|
||||
if argsStr == "" {
|
||||
return nil // nil will cause the backend to use its default configuration
|
||||
}
|
||||
|
||||
// Split the string by spaces, respecting quoted arguments
|
||||
args := splitArgs(argsStr)
|
||||
|
||||
// Check for disallowed arguments
|
||||
disallowedArgs := []string{"--model", "--host", "--embeddings", "--mmproj"}
|
||||
for _, arg := range args {
|
||||
for _, disallowed := range disallowedArgs {
|
||||
if arg == disallowed {
|
||||
log.Fatalf("LLAMA_ARGS cannot override the %s argument as it is controlled by the model runner", disallowed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("Using custom arguments: %v", args)
|
||||
return &llamacpp.Config{
|
||||
Args: args,
|
||||
}
|
||||
}
|
||||
|
||||
// splitArgs splits a string into arguments, respecting quoted arguments
|
||||
func splitArgs(s string) []string {
|
||||
var args []string
|
||||
var currentArg strings.Builder
|
||||
inQuotes := false
|
||||
|
||||
for _, r := range s {
|
||||
switch {
|
||||
case r == '"' || r == '\'':
|
||||
inQuotes = !inQuotes
|
||||
case r == ' ' && !inQuotes:
|
||||
if currentArg.Len() > 0 {
|
||||
args = append(args, currentArg.String())
|
||||
currentArg.Reset()
|
||||
}
|
||||
default:
|
||||
currentArg.WriteRune(r)
|
||||
}
|
||||
}
|
||||
|
||||
if currentArg.Len() > 0 {
|
||||
args = append(args, currentArg.String())
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,108 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func TestCreateLlamaCppConfigFromEnv(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
llamaArgs string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty args",
|
||||
llamaArgs: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid args",
|
||||
llamaArgs: "--threads 4 --ctx-size 2048",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "disallowed model arg",
|
||||
llamaArgs: "--model test.gguf",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "disallowed host arg",
|
||||
llamaArgs: "--host localhost:8080",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "disallowed embeddings arg",
|
||||
llamaArgs: "--embeddings",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "disallowed mmproj arg",
|
||||
llamaArgs: "--mmproj test.mmproj",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "multiple disallowed args",
|
||||
llamaArgs: "--model test.gguf --host localhost:8080",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "quoted args",
|
||||
llamaArgs: "--prompt \"Hello, world!\" --threads 4",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Set up environment
|
||||
if tt.llamaArgs != "" {
|
||||
os.Setenv("LLAMA_ARGS", tt.llamaArgs)
|
||||
defer os.Unsetenv("LLAMA_ARGS")
|
||||
}
|
||||
|
||||
// Create a test logger that captures fatal errors
|
||||
originalLog := log
|
||||
defer func() { log = originalLog }()
|
||||
|
||||
// Create a new logger that will exit with a special exit code
|
||||
testLog := logrus.New()
|
||||
var exitCode int
|
||||
testLog.ExitFunc = func(code int) {
|
||||
exitCode = code
|
||||
}
|
||||
log = testLog
|
||||
|
||||
config := createLlamaCppConfigFromEnv()
|
||||
|
||||
if tt.wantErr {
|
||||
if exitCode != 1 {
|
||||
t.Errorf("Expected exit code 1, got %d", exitCode)
|
||||
}
|
||||
} else {
|
||||
if exitCode != 0 {
|
||||
t.Errorf("Expected exit code 0, got %d", exitCode)
|
||||
}
|
||||
if tt.llamaArgs == "" {
|
||||
if config != nil {
|
||||
t.Error("Expected nil config for empty args")
|
||||
}
|
||||
} else {
|
||||
llamaConfig, ok := config.(*llamacpp.Config)
|
||||
if !ok {
|
||||
t.Errorf("Expected *llamacpp.Config, got %T", config)
|
||||
}
|
||||
if llamaConfig == nil {
|
||||
t.Error("Expected non-nil config")
|
||||
}
|
||||
if len(llamaConfig.Args) == 0 {
|
||||
t.Error("Expected non-empty args")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -10,10 +10,10 @@ import (
|
|||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
|
||||
"github.com/docker/model-runner/pkg/diskusage"
|
||||
"github.com/docker/model-runner/pkg/inference"
|
||||
"github.com/docker/model-runner/pkg/inference/config"
|
||||
"github.com/docker/model-runner/pkg/inference/models"
|
||||
"github.com/docker/model-runner/pkg/logging"
|
||||
)
|
||||
|
|
@ -39,6 +39,8 @@ type llamaCpp struct {
|
|||
updatedServerStoragePath string
|
||||
// status is the state in which the llama.cpp backend is in.
|
||||
status string
|
||||
// config is the configuration for the llama.cpp backend.
|
||||
config config.BackendConfig
|
||||
}
|
||||
|
||||
// New creates a new llama.cpp-based backend.
|
||||
|
|
@ -48,13 +50,20 @@ func New(
|
|||
serverLog logging.Logger,
|
||||
vendoredServerStoragePath string,
|
||||
updatedServerStoragePath string,
|
||||
conf config.BackendConfig,
|
||||
) (inference.Backend, error) {
|
||||
// If no config is provided, use the default configuration
|
||||
if conf == nil {
|
||||
conf = NewDefaultLlamaCppConfig()
|
||||
}
|
||||
|
||||
return &llamaCpp{
|
||||
log: log,
|
||||
modelManager: modelManager,
|
||||
serverLog: serverLog,
|
||||
vendoredServerStoragePath: vendoredServerStoragePath,
|
||||
updatedServerStoragePath: updatedServerStoragePath,
|
||||
config: conf,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -115,11 +124,6 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference
|
|||
return fmt.Errorf("failed to get model path: %w", err)
|
||||
}
|
||||
|
||||
modelDesc, err := l.modelManager.GetModel(model)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get model: %w", err)
|
||||
}
|
||||
|
||||
if err := os.RemoveAll(socket); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
l.log.Warnf("failed to remove socket file %s: %w\n", socket, err)
|
||||
l.log.Warnln("llama.cpp may not be able to start")
|
||||
|
|
@ -129,32 +133,13 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference
|
|||
if l.updatedLlamaCpp {
|
||||
binPath = l.updatedServerStoragePath
|
||||
}
|
||||
llamaCppArgs := []string{"--model", modelPath, "--jinja", "--host", socket}
|
||||
if mode == inference.BackendModeEmbedding {
|
||||
llamaCppArgs = append(llamaCppArgs, "--embeddings")
|
||||
}
|
||||
if runtime.GOOS == "windows" && runtime.GOARCH == "arm64" {
|
||||
// Using a thread count equal to core count results in bad performance, and there seems to be little to no gain
|
||||
// in going beyond core_count/2.
|
||||
// TODO(p1-0tr): dig into why the defaults don't work well on windows/arm64
|
||||
nThreads := min(2, runtime.NumCPU()/2)
|
||||
llamaCppArgs = append(llamaCppArgs, "--threads", strconv.Itoa(nThreads))
|
||||
|
||||
modelConfig, err := modelDesc.Config()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get model config: %w", err)
|
||||
}
|
||||
// The Adreno OpenCL implementation currently only supports Q4_0
|
||||
if modelConfig.Quantization == "Q4_0" {
|
||||
llamaCppArgs = append(llamaCppArgs, "-ngl", "100")
|
||||
}
|
||||
} else {
|
||||
llamaCppArgs = append(llamaCppArgs, "-ngl", "100")
|
||||
}
|
||||
args := l.config.GetArgs(modelPath, socket, mode)
|
||||
l.log.Infof("llamaCppArgs: %v", args)
|
||||
llamaCppProcess := exec.CommandContext(
|
||||
ctx,
|
||||
filepath.Join(binPath, "com.docker.llama-server"),
|
||||
llamaCppArgs...,
|
||||
args...,
|
||||
)
|
||||
llamaCppProcess.Cancel = func() error {
|
||||
if runtime.GOOS == "windows" {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,59 @@
|
|||
package llamacpp
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"strconv"
|
||||
|
||||
"github.com/docker/model-runner/pkg/inference"
|
||||
)
|
||||
|
||||
// Config is the configuration for the llama.cpp backend.
|
||||
type Config struct {
|
||||
// Args are the base arguments that are always included.
|
||||
Args []string
|
||||
}
|
||||
|
||||
// NewDefaultLlamaCppConfig creates a new LlamaCppConfig with default values.
|
||||
func NewDefaultLlamaCppConfig() *Config {
|
||||
args := append([]string{"--jinja", "-ngl", "100"})
|
||||
|
||||
// Special case for Windows ARM64
|
||||
if runtime.GOOS == "windows" && runtime.GOARCH == "arm64" {
|
||||
// Using a thread count equal to core count results in bad performance, and there seems to be little to no gain
|
||||
// in going beyond core_count/2.
|
||||
if !containsArg(args, "--threads") {
|
||||
nThreads := min(2, runtime.NumCPU()/2)
|
||||
args = append(args, "--threads", strconv.Itoa(nThreads))
|
||||
}
|
||||
}
|
||||
|
||||
return &Config{
|
||||
Args: args,
|
||||
}
|
||||
}
|
||||
|
||||
// GetArgs implements BackendConfig.GetArgs.
|
||||
func (c *Config) GetArgs(modelPath, socket string, mode inference.BackendMode) []string {
|
||||
// Start with the arguments from LlamaCppConfig
|
||||
args := append([]string{}, c.Args...)
|
||||
|
||||
// Add model and socket arguments
|
||||
args = append(args, "--model", modelPath, "--host", socket)
|
||||
|
||||
// Add mode-specific arguments
|
||||
if mode == inference.BackendModeEmbedding {
|
||||
args = append(args, "--embeddings")
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
// containsArg checks if the given argument is already in the args slice.
|
||||
func containsArg(args []string, arg string) bool {
|
||||
for _, a := range args {
|
||||
if a == arg {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
@ -0,0 +1,171 @@
|
|||
package llamacpp
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/docker/model-runner/pkg/inference"
|
||||
)
|
||||
|
||||
func TestNewDefaultLlamaCppConfig(t *testing.T) {
|
||||
config := NewDefaultLlamaCppConfig()
|
||||
|
||||
// Test default arguments
|
||||
if !containsArg(config.Args, "--jinja") {
|
||||
t.Error("Expected --jinja argument to be present")
|
||||
}
|
||||
|
||||
// Test -ngl argument and its value
|
||||
nglIndex := -1
|
||||
for i, arg := range config.Args {
|
||||
if arg == "-ngl" {
|
||||
nglIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if nglIndex == -1 {
|
||||
t.Error("Expected -ngl argument to be present")
|
||||
}
|
||||
if nglIndex+1 >= len(config.Args) {
|
||||
t.Error("No value found after -ngl argument")
|
||||
}
|
||||
if config.Args[nglIndex+1] != "100" {
|
||||
t.Errorf("Expected -ngl value to be 100, got %s", config.Args[nglIndex+1])
|
||||
}
|
||||
|
||||
// Test Windows ARM64 specific case
|
||||
if runtime.GOOS == "windows" && runtime.GOARCH == "arm64" {
|
||||
if !containsArg(config.Args, "--threads") {
|
||||
t.Error("Expected --threads argument to be present on Windows ARM64")
|
||||
}
|
||||
threadsIndex := -1
|
||||
for i, arg := range config.Args {
|
||||
if arg == "--threads" {
|
||||
threadsIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if threadsIndex == -1 {
|
||||
t.Error("Could not find --threads argument")
|
||||
}
|
||||
if threadsIndex+1 >= len(config.Args) {
|
||||
t.Error("No value found after --threads argument")
|
||||
}
|
||||
threads, err := strconv.Atoi(config.Args[threadsIndex+1])
|
||||
if err != nil {
|
||||
t.Errorf("Failed to parse thread count: %v", err)
|
||||
}
|
||||
if threads > runtime.NumCPU()/2 {
|
||||
t.Errorf("Thread count %d exceeds maximum allowed value of %d", threads, runtime.NumCPU()/2)
|
||||
}
|
||||
if threads < 1 {
|
||||
t.Error("Thread count is less than 1")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetArgs(t *testing.T) {
|
||||
config := NewDefaultLlamaCppConfig()
|
||||
modelPath := "/path/to/model"
|
||||
socket := "unix:///tmp/socket"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mode inference.BackendMode
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "completion mode",
|
||||
mode: inference.BackendModeCompletion,
|
||||
expected: []string{
|
||||
"--jinja",
|
||||
"-ngl", "100",
|
||||
"--model", modelPath,
|
||||
"--host", socket,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "embedding mode",
|
||||
mode: inference.BackendModeEmbedding,
|
||||
expected: []string{
|
||||
"--jinja",
|
||||
"-ngl", "100",
|
||||
"--model", modelPath,
|
||||
"--host", socket,
|
||||
"--embeddings",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
args := config.GetArgs(modelPath, socket, tt.mode)
|
||||
|
||||
// Check that all expected arguments are present and in the correct order
|
||||
expectedIndex := 0
|
||||
for i := 0; i < len(args); i++ {
|
||||
if expectedIndex >= len(tt.expected) {
|
||||
t.Errorf("Unexpected extra argument: %s", args[i])
|
||||
continue
|
||||
}
|
||||
|
||||
if args[i] != tt.expected[expectedIndex] {
|
||||
t.Errorf("Expected argument %s at position %d, got %s", tt.expected[expectedIndex], i, args[i])
|
||||
continue
|
||||
}
|
||||
|
||||
// If this is a flag that takes a value, check the next argument
|
||||
if i+1 < len(args) && (args[i] == "-ngl" || args[i] == "--model" || args[i] == "--host") {
|
||||
expectedIndex++
|
||||
if args[i+1] != tt.expected[expectedIndex] {
|
||||
t.Errorf("Expected value %s for flag %s, got %s", tt.expected[expectedIndex], args[i], args[i+1])
|
||||
}
|
||||
i++ // Skip the value in the next iteration
|
||||
}
|
||||
expectedIndex++
|
||||
}
|
||||
|
||||
if expectedIndex != len(tt.expected) {
|
||||
t.Errorf("Missing expected arguments. Got %d arguments, expected %d", expectedIndex, len(tt.expected))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsArg(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
arg string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "argument exists",
|
||||
args: []string{"--arg1", "--arg2", "--arg3"},
|
||||
arg: "--arg2",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "argument does not exist",
|
||||
args: []string{"--arg1", "--arg2", "--arg3"},
|
||||
arg: "--arg4",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty args slice",
|
||||
args: []string{},
|
||||
arg: "--arg1",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := containsArg(tt.args, tt.arg)
|
||||
if result != tt.expected {
|
||||
t.Errorf("containsArg(%v, %s) = %v, want %v", tt.args, tt.arg, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"github.com/docker/model-runner/pkg/inference"
|
||||
)
|
||||
|
||||
// BackendConfig is the interface implemented by backend configurations.
|
||||
// It provides methods to get command-line arguments for a backend based on
|
||||
// the model path, socket, and mode.
|
||||
type BackendConfig interface {
|
||||
// GetArgs returns the command-line arguments for the backend.
|
||||
// It takes the model path, socket, and mode as input and returns
|
||||
// the appropriate arguments for the backend.
|
||||
GetArgs(modelPath, socket string, mode inference.BackendMode) []string
|
||||
}
|
||||
Loading…
Reference in New Issue