model-runner/pkg/inference/backends/llamacpp/llamacpp_config_test.go

277 lines
6.2 KiB
Go

package llamacpp
import (
"runtime"
"strconv"
"testing"
"github.com/docker/model-distribution/types"
"github.com/docker/model-runner/pkg/inference"
)
func TestNewDefaultLlamaCppConfig(t *testing.T) {
config := NewDefaultLlamaCppConfig()
// Test default arguments
if !containsArg(config.Args, "--jinja") {
t.Error("Expected --jinja argument to be present")
}
// Test -ngl argument and its value
nglIndex := -1
for i, arg := range config.Args {
if arg == "-ngl" {
nglIndex = i
break
}
}
if nglIndex == -1 {
t.Error("Expected -ngl argument to be present")
}
if nglIndex+1 >= len(config.Args) {
t.Error("No value found after -ngl argument")
}
if config.Args[nglIndex+1] != "100" {
t.Errorf("Expected -ngl value to be 100, got %s", config.Args[nglIndex+1])
}
// Test Windows ARM64 specific case
if runtime.GOOS == "windows" && runtime.GOARCH == "arm64" {
if !containsArg(config.Args, "--threads") {
t.Error("Expected --threads argument to be present on Windows ARM64")
}
threadsIndex := -1
for i, arg := range config.Args {
if arg == "--threads" {
threadsIndex = i
break
}
}
if threadsIndex == -1 {
t.Error("Could not find --threads argument")
}
if threadsIndex+1 >= len(config.Args) {
t.Error("No value found after --threads argument")
}
threads, err := strconv.Atoi(config.Args[threadsIndex+1])
if err != nil {
t.Errorf("Failed to parse thread count: %v", err)
}
if threads > runtime.NumCPU()/2 {
t.Errorf("Thread count %d exceeds maximum allowed value of %d", threads, runtime.NumCPU()/2)
}
if threads < 1 {
t.Error("Thread count is less than 1")
}
}
}
func TestGetArgs(t *testing.T) {
config := NewDefaultLlamaCppConfig()
modelPath := "/path/to/model"
socket := "unix:///tmp/socket"
tests := []struct {
name string
bundle types.ModelBundle
mode inference.BackendMode
config *inference.BackendConfiguration
expected []string
}{
{
name: "completion mode",
mode: inference.BackendModeCompletion,
bundle: &fakeBundle{
ggufPath: modelPath,
},
expected: []string{
"--jinja",
"-ngl", "100",
"--metrics",
"--model", modelPath,
"--host", socket,
"--ctx-size", "4096",
},
},
{
name: "embedding mode",
mode: inference.BackendModeEmbedding,
bundle: &fakeBundle{
ggufPath: modelPath,
},
expected: []string{
"--jinja",
"-ngl", "100",
"--metrics",
"--model", modelPath,
"--host", socket,
"--embeddings",
"--ctx-size", "4096",
},
},
{
name: "context size from backend config",
mode: inference.BackendModeEmbedding,
bundle: &fakeBundle{
ggufPath: modelPath,
},
config: &inference.BackendConfiguration{
ContextSize: 1234,
},
expected: []string{
"--jinja",
"-ngl", "100",
"--metrics",
"--model", modelPath,
"--host", socket,
"--embeddings",
"--ctx-size", "1234", // should add this flag
},
},
{
name: "context size from model config",
mode: inference.BackendModeEmbedding,
bundle: &fakeBundle{
ggufPath: modelPath,
config: types.Config{
ContextSize: uint64ptr(2096),
},
},
config: &inference.BackendConfiguration{
ContextSize: 1234,
},
expected: []string{
"--jinja",
"-ngl", "100",
"--metrics",
"--model", modelPath,
"--host", socket,
"--embeddings",
"--ctx-size", "2096", // model config takes precedence
},
},
{
name: "raw flags from backend config",
mode: inference.BackendModeEmbedding,
bundle: &fakeBundle{
ggufPath: modelPath,
},
config: &inference.BackendConfiguration{
RuntimeFlags: []string{"--some", "flag"},
},
expected: []string{
"--jinja",
"-ngl", "100",
"--metrics",
"--model", modelPath,
"--host", socket,
"--embeddings",
"--ctx-size", "4096",
"--some", "flag", // model config takes precedence
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
args, err := config.GetArgs(tt.bundle, socket, tt.mode, tt.config)
if err != nil {
t.Errorf("GetArgs() error = %v", err)
}
// Check that all expected arguments are present and in the correct order
expectedIndex := 0
for i := 0; i < len(args); i++ {
if expectedIndex >= len(tt.expected) {
t.Errorf("Unexpected extra argument: %s", args[i])
continue
}
if args[i] != tt.expected[expectedIndex] {
t.Errorf("Expected argument %s at position %d, got %s", tt.expected[expectedIndex], i, args[i])
continue
}
// If this is a flag that takes a value, check the next argument
if i+1 < len(args) && (args[i] == "-ngl" || args[i] == "--model" || args[i] == "--host") {
expectedIndex++
if args[i+1] != tt.expected[expectedIndex] {
t.Errorf("Expected value %s for flag %s, got %s", tt.expected[expectedIndex], args[i], args[i+1])
}
i++ // Skip the value in the next iteration
}
expectedIndex++
}
if expectedIndex != len(tt.expected) {
t.Errorf("Missing expected arguments. Got %d arguments, expected %d", expectedIndex, len(tt.expected))
}
})
}
}
func TestContainsArg(t *testing.T) {
tests := []struct {
name string
args []string
arg string
expected bool
}{
{
name: "argument exists",
args: []string{"--arg1", "--arg2", "--arg3"},
arg: "--arg2",
expected: true,
},
{
name: "argument does not exist",
args: []string{"--arg1", "--arg2", "--arg3"},
arg: "--arg4",
expected: false,
},
{
name: "empty args slice",
args: []string{},
arg: "--arg1",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := containsArg(tt.args, tt.arg)
if result != tt.expected {
t.Errorf("containsArg(%v, %s) = %v, want %v", tt.args, tt.arg, result, tt.expected)
}
})
}
}
var _ types.ModelBundle = &fakeBundle{}
type fakeBundle struct {
ggufPath string
config types.Config
}
func (f *fakeBundle) RootDir() string {
panic("shouldn't be called")
}
func (f *fakeBundle) GGUFPath() string {
return f.ggufPath
}
func (f *fakeBundle) MMPROJPath() string {
return ""
}
func (f *fakeBundle) RuntimeConfig() types.Config {
return f.config
}
func uint64ptr(n uint64) *uint64 {
return &n
}