Revert "Revert "Revert "Revert "configure backend args""" (#54)" (#55)

This reverts commit 1898347e1a.
This commit is contained in:
Ignasi 2025-05-30 16:51:34 +02:00 committed by GitHub
parent 423e1d6d1b
commit 23896c491b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 438 additions and 28 deletions

View File

@ -7,6 +7,7 @@ BASE_IMAGE := ubuntu:24.04
DOCKER_IMAGE := docker/model-runner:latest
PORT := 8080
MODELS_PATH := $(shell pwd)/models-store
LLAMA_ARGS ?=
# Main targets
.PHONY: build run clean test docker-build docker-run help
@ -20,6 +21,7 @@ build:
# Run the application locally
run: build
LLAMA_ARGS="$(LLAMA_ARGS)" \
./$(APP_NAME)
# Clean build artifacts
@ -55,6 +57,7 @@ docker-run: docker-build
-e MODEL_RUNNER_PORT=$(PORT) \
-e LLAMA_SERVER_PATH=/app/bin \
-e MODELS_PATH=/models \
-e LLAMA_ARGS="$(LLAMA_ARGS)" \
$(DOCKER_IMAGE)
# Show help
@ -67,3 +70,10 @@ help:
@echo " docker-build - Build Docker image"
@echo " docker-run - Run in Docker container with TCP port access and mounted model storage"
@echo " help - Show this help message"
@echo ""
@echo "Backend configuration options:"
@echo " LLAMA_ARGS - Arguments for llama.cpp (e.g., \"--verbose --jinja -ngl 100 --ctx-size 2048\")"
@echo ""
@echo "Example usage:"
@echo " make run LLAMA_ARGS=\"--verbose --jinja -ngl 100 --ctx-size 2048\""
@echo " make docker-run LLAMA_ARGS=\"--verbose --jinja -ngl 100 --threads 4 --ctx-size 2048\""

62
main.go
View File

@ -7,10 +7,12 @@ import (
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
"github.com/docker/model-runner/pkg/inference/config"
"github.com/docker/model-runner/pkg/inference/models"
"github.com/docker/model-runner/pkg/inference/scheduling"
"github.com/docker/model-runner/pkg/routing"
@ -57,6 +59,9 @@ func main() {
log.Infof("LLAMA_SERVER_PATH: %s", llamaServerPath)
// Create llama.cpp configuration from environment variables
llamaCppConfig := createLlamaCppConfigFromEnv()
llamaCppBackend, err := llamacpp.New(
log,
modelManager,
@ -68,6 +73,7 @@ func main() {
_ = os.MkdirAll(d, 0o755)
return d
}(),
llamaCppConfig,
)
if err != nil {
log.Fatalf("unable to initialize %s backend: %v", llamacpp.Name, err)
@ -141,3 +147,59 @@ func main() {
}
log.Infoln("Docker Model Runner stopped")
}
// createLlamaCppConfigFromEnv creates a LlamaCppConfig from environment variables
func createLlamaCppConfigFromEnv() config.BackendConfig {
// Check if any configuration environment variables are set
argsStr := os.Getenv("LLAMA_ARGS")
// If no environment variables are set, use default configuration
if argsStr == "" {
return nil // nil will cause the backend to use its default configuration
}
// Split the string by spaces, respecting quoted arguments
args := splitArgs(argsStr)
// Check for disallowed arguments
disallowedArgs := []string{"--model", "--host", "--embeddings", "--mmproj"}
for _, arg := range args {
for _, disallowed := range disallowedArgs {
if arg == disallowed {
log.Fatalf("LLAMA_ARGS cannot override the %s argument as it is controlled by the model runner", disallowed)
}
}
}
log.Infof("Using custom arguments: %v", args)
return &llamacpp.Config{
Args: args,
}
}
// splitArgs splits a string into arguments, respecting quoted arguments
func splitArgs(s string) []string {
var args []string
var currentArg strings.Builder
inQuotes := false
for _, r := range s {
switch {
case r == '"' || r == '\'':
inQuotes = !inQuotes
case r == ' ' && !inQuotes:
if currentArg.Len() > 0 {
args = append(args, currentArg.String())
currentArg.Reset()
}
default:
currentArg.WriteRune(r)
}
}
if currentArg.Len() > 0 {
args = append(args, currentArg.String())
}
return args
}

108
main_test.go Normal file
View File

@ -0,0 +1,108 @@
package main
import (
"os"
"testing"
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
"github.com/sirupsen/logrus"
)
func TestCreateLlamaCppConfigFromEnv(t *testing.T) {
tests := []struct {
name string
llamaArgs string
wantErr bool
}{
{
name: "empty args",
llamaArgs: "",
wantErr: false,
},
{
name: "valid args",
llamaArgs: "--threads 4 --ctx-size 2048",
wantErr: false,
},
{
name: "disallowed model arg",
llamaArgs: "--model test.gguf",
wantErr: true,
},
{
name: "disallowed host arg",
llamaArgs: "--host localhost:8080",
wantErr: true,
},
{
name: "disallowed embeddings arg",
llamaArgs: "--embeddings",
wantErr: true,
},
{
name: "disallowed mmproj arg",
llamaArgs: "--mmproj test.mmproj",
wantErr: true,
},
{
name: "multiple disallowed args",
llamaArgs: "--model test.gguf --host localhost:8080",
wantErr: true,
},
{
name: "quoted args",
llamaArgs: "--prompt \"Hello, world!\" --threads 4",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set up environment
if tt.llamaArgs != "" {
os.Setenv("LLAMA_ARGS", tt.llamaArgs)
defer os.Unsetenv("LLAMA_ARGS")
}
// Create a test logger that captures fatal errors
originalLog := log
defer func() { log = originalLog }()
// Create a new logger that will exit with a special exit code
testLog := logrus.New()
var exitCode int
testLog.ExitFunc = func(code int) {
exitCode = code
}
log = testLog
config := createLlamaCppConfigFromEnv()
if tt.wantErr {
if exitCode != 1 {
t.Errorf("Expected exit code 1, got %d", exitCode)
}
} else {
if exitCode != 0 {
t.Errorf("Expected exit code 0, got %d", exitCode)
}
if tt.llamaArgs == "" {
if config != nil {
t.Error("Expected nil config for empty args")
}
} else {
llamaConfig, ok := config.(*llamacpp.Config)
if !ok {
t.Errorf("Expected *llamacpp.Config, got %T", config)
}
if llamaConfig == nil {
t.Error("Expected non-nil config")
}
if len(llamaConfig.Args) == 0 {
t.Error("Expected non-empty args")
}
}
}
})
}
}

View File

@ -10,10 +10,10 @@ import (
"os/exec"
"path/filepath"
"runtime"
"strconv"
"github.com/docker/model-runner/pkg/diskusage"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/config"
"github.com/docker/model-runner/pkg/inference/models"
"github.com/docker/model-runner/pkg/logging"
)
@ -39,6 +39,8 @@ type llamaCpp struct {
updatedServerStoragePath string
// status is the state in which the llama.cpp backend is in.
status string
// config is the configuration for the llama.cpp backend.
config config.BackendConfig
}
// New creates a new llama.cpp-based backend.
@ -48,13 +50,20 @@ func New(
serverLog logging.Logger,
vendoredServerStoragePath string,
updatedServerStoragePath string,
conf config.BackendConfig,
) (inference.Backend, error) {
// If no config is provided, use the default configuration
if conf == nil {
conf = NewDefaultLlamaCppConfig()
}
return &llamaCpp{
log: log,
modelManager: modelManager,
serverLog: serverLog,
vendoredServerStoragePath: vendoredServerStoragePath,
updatedServerStoragePath: updatedServerStoragePath,
config: conf,
}, nil
}
@ -115,11 +124,6 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference
return fmt.Errorf("failed to get model path: %w", err)
}
modelDesc, err := l.modelManager.GetModel(model)
if err != nil {
return fmt.Errorf("failed to get model: %w", err)
}
if err := os.RemoveAll(socket); err != nil && !errors.Is(err, fs.ErrNotExist) {
l.log.Warnf("failed to remove socket file %s: %w\n", socket, err)
l.log.Warnln("llama.cpp may not be able to start")
@ -129,32 +133,13 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference
if l.updatedLlamaCpp {
binPath = l.updatedServerStoragePath
}
llamaCppArgs := []string{"--model", modelPath, "--jinja", "--host", socket}
if mode == inference.BackendModeEmbedding {
llamaCppArgs = append(llamaCppArgs, "--embeddings")
}
if runtime.GOOS == "windows" && runtime.GOARCH == "arm64" {
// Using a thread count equal to core count results in bad performance, and there seems to be little to no gain
// in going beyond core_count/2.
// TODO(p1-0tr): dig into why the defaults don't work well on windows/arm64
nThreads := min(2, runtime.NumCPU()/2)
llamaCppArgs = append(llamaCppArgs, "--threads", strconv.Itoa(nThreads))
modelConfig, err := modelDesc.Config()
if err != nil {
return fmt.Errorf("failed to get model config: %w", err)
}
// The Adreno OpenCL implementation currently only supports Q4_0
if modelConfig.Quantization == "Q4_0" {
llamaCppArgs = append(llamaCppArgs, "-ngl", "100")
}
} else {
llamaCppArgs = append(llamaCppArgs, "-ngl", "100")
}
args := l.config.GetArgs(modelPath, socket, mode)
l.log.Infof("llamaCppArgs: %v", args)
llamaCppProcess := exec.CommandContext(
ctx,
filepath.Join(binPath, "com.docker.llama-server"),
llamaCppArgs...,
args...,
)
llamaCppProcess.Cancel = func() error {
if runtime.GOOS == "windows" {

View File

@ -0,0 +1,59 @@
package llamacpp
import (
"runtime"
"strconv"
"github.com/docker/model-runner/pkg/inference"
)
// Config is the configuration for the llama.cpp backend.
type Config struct {
// Args are the base arguments that are always included.
Args []string
}
// NewDefaultLlamaCppConfig creates a new LlamaCppConfig with default values.
func NewDefaultLlamaCppConfig() *Config {
args := append([]string{"--jinja", "-ngl", "100"})
// Special case for Windows ARM64
if runtime.GOOS == "windows" && runtime.GOARCH == "arm64" {
// Using a thread count equal to core count results in bad performance, and there seems to be little to no gain
// in going beyond core_count/2.
if !containsArg(args, "--threads") {
nThreads := min(2, runtime.NumCPU()/2)
args = append(args, "--threads", strconv.Itoa(nThreads))
}
}
return &Config{
Args: args,
}
}
// GetArgs implements BackendConfig.GetArgs.
func (c *Config) GetArgs(modelPath, socket string, mode inference.BackendMode) []string {
// Start with the arguments from LlamaCppConfig
args := append([]string{}, c.Args...)
// Add model and socket arguments
args = append(args, "--model", modelPath, "--host", socket)
// Add mode-specific arguments
if mode == inference.BackendModeEmbedding {
args = append(args, "--embeddings")
}
return args
}
// containsArg checks if the given argument is already in the args slice.
func containsArg(args []string, arg string) bool {
for _, a := range args {
if a == arg {
return true
}
}
return false
}

View File

@ -0,0 +1,171 @@
package llamacpp
import (
"runtime"
"strconv"
"testing"
"github.com/docker/model-runner/pkg/inference"
)
func TestNewDefaultLlamaCppConfig(t *testing.T) {
config := NewDefaultLlamaCppConfig()
// Test default arguments
if !containsArg(config.Args, "--jinja") {
t.Error("Expected --jinja argument to be present")
}
// Test -ngl argument and its value
nglIndex := -1
for i, arg := range config.Args {
if arg == "-ngl" {
nglIndex = i
break
}
}
if nglIndex == -1 {
t.Error("Expected -ngl argument to be present")
}
if nglIndex+1 >= len(config.Args) {
t.Error("No value found after -ngl argument")
}
if config.Args[nglIndex+1] != "100" {
t.Errorf("Expected -ngl value to be 100, got %s", config.Args[nglIndex+1])
}
// Test Windows ARM64 specific case
if runtime.GOOS == "windows" && runtime.GOARCH == "arm64" {
if !containsArg(config.Args, "--threads") {
t.Error("Expected --threads argument to be present on Windows ARM64")
}
threadsIndex := -1
for i, arg := range config.Args {
if arg == "--threads" {
threadsIndex = i
break
}
}
if threadsIndex == -1 {
t.Error("Could not find --threads argument")
}
if threadsIndex+1 >= len(config.Args) {
t.Error("No value found after --threads argument")
}
threads, err := strconv.Atoi(config.Args[threadsIndex+1])
if err != nil {
t.Errorf("Failed to parse thread count: %v", err)
}
if threads > runtime.NumCPU()/2 {
t.Errorf("Thread count %d exceeds maximum allowed value of %d", threads, runtime.NumCPU()/2)
}
if threads < 1 {
t.Error("Thread count is less than 1")
}
}
}
func TestGetArgs(t *testing.T) {
config := NewDefaultLlamaCppConfig()
modelPath := "/path/to/model"
socket := "unix:///tmp/socket"
tests := []struct {
name string
mode inference.BackendMode
expected []string
}{
{
name: "completion mode",
mode: inference.BackendModeCompletion,
expected: []string{
"--jinja",
"-ngl", "100",
"--model", modelPath,
"--host", socket,
},
},
{
name: "embedding mode",
mode: inference.BackendModeEmbedding,
expected: []string{
"--jinja",
"-ngl", "100",
"--model", modelPath,
"--host", socket,
"--embeddings",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
args := config.GetArgs(modelPath, socket, tt.mode)
// Check that all expected arguments are present and in the correct order
expectedIndex := 0
for i := 0; i < len(args); i++ {
if expectedIndex >= len(tt.expected) {
t.Errorf("Unexpected extra argument: %s", args[i])
continue
}
if args[i] != tt.expected[expectedIndex] {
t.Errorf("Expected argument %s at position %d, got %s", tt.expected[expectedIndex], i, args[i])
continue
}
// If this is a flag that takes a value, check the next argument
if i+1 < len(args) && (args[i] == "-ngl" || args[i] == "--model" || args[i] == "--host") {
expectedIndex++
if args[i+1] != tt.expected[expectedIndex] {
t.Errorf("Expected value %s for flag %s, got %s", tt.expected[expectedIndex], args[i], args[i+1])
}
i++ // Skip the value in the next iteration
}
expectedIndex++
}
if expectedIndex != len(tt.expected) {
t.Errorf("Missing expected arguments. Got %d arguments, expected %d", expectedIndex, len(tt.expected))
}
})
}
}
func TestContainsArg(t *testing.T) {
tests := []struct {
name string
args []string
arg string
expected bool
}{
{
name: "argument exists",
args: []string{"--arg1", "--arg2", "--arg3"},
arg: "--arg2",
expected: true,
},
{
name: "argument does not exist",
args: []string{"--arg1", "--arg2", "--arg3"},
arg: "--arg4",
expected: false,
},
{
name: "empty args slice",
args: []string{},
arg: "--arg1",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := containsArg(tt.args, tt.arg)
if result != tt.expected {
t.Errorf("containsArg(%v, %s) = %v, want %v", tt.args, tt.arg, result, tt.expected)
}
})
}
}

View File

@ -0,0 +1,15 @@
package config
import (
"github.com/docker/model-runner/pkg/inference"
)
// BackendConfig is the interface implemented by backend configurations.
// It provides methods to get command-line arguments for a backend based on
// the model path, socket, and mode.
type BackendConfig interface {
// GetArgs returns the command-line arguments for the backend.
// It takes the model path, socket, and mode as input and returns
// the appropriate arguments for the backend.
GetArgs(modelPath, socket string, mode inference.BackendMode) []string
}