277 lines
6.2 KiB
Go
277 lines
6.2 KiB
Go
package llamacpp
|
|
|
|
import (
|
|
"runtime"
|
|
"strconv"
|
|
"testing"
|
|
|
|
"github.com/docker/model-distribution/types"
|
|
|
|
"github.com/docker/model-runner/pkg/inference"
|
|
)
|
|
|
|
func TestNewDefaultLlamaCppConfig(t *testing.T) {
|
|
config := NewDefaultLlamaCppConfig()
|
|
|
|
// Test default arguments
|
|
if !containsArg(config.Args, "--jinja") {
|
|
t.Error("Expected --jinja argument to be present")
|
|
}
|
|
|
|
// Test -ngl argument and its value
|
|
nglIndex := -1
|
|
for i, arg := range config.Args {
|
|
if arg == "-ngl" {
|
|
nglIndex = i
|
|
break
|
|
}
|
|
}
|
|
if nglIndex == -1 {
|
|
t.Error("Expected -ngl argument to be present")
|
|
}
|
|
if nglIndex+1 >= len(config.Args) {
|
|
t.Error("No value found after -ngl argument")
|
|
}
|
|
if config.Args[nglIndex+1] != "100" {
|
|
t.Errorf("Expected -ngl value to be 100, got %s", config.Args[nglIndex+1])
|
|
}
|
|
|
|
// Test Windows ARM64 specific case
|
|
if runtime.GOOS == "windows" && runtime.GOARCH == "arm64" {
|
|
if !containsArg(config.Args, "--threads") {
|
|
t.Error("Expected --threads argument to be present on Windows ARM64")
|
|
}
|
|
threadsIndex := -1
|
|
for i, arg := range config.Args {
|
|
if arg == "--threads" {
|
|
threadsIndex = i
|
|
break
|
|
}
|
|
}
|
|
if threadsIndex == -1 {
|
|
t.Error("Could not find --threads argument")
|
|
}
|
|
if threadsIndex+1 >= len(config.Args) {
|
|
t.Error("No value found after --threads argument")
|
|
}
|
|
threads, err := strconv.Atoi(config.Args[threadsIndex+1])
|
|
if err != nil {
|
|
t.Errorf("Failed to parse thread count: %v", err)
|
|
}
|
|
if threads > runtime.NumCPU()/2 {
|
|
t.Errorf("Thread count %d exceeds maximum allowed value of %d", threads, runtime.NumCPU()/2)
|
|
}
|
|
if threads < 1 {
|
|
t.Error("Thread count is less than 1")
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestGetArgs(t *testing.T) {
|
|
config := NewDefaultLlamaCppConfig()
|
|
modelPath := "/path/to/model"
|
|
socket := "unix:///tmp/socket"
|
|
|
|
tests := []struct {
|
|
name string
|
|
bundle types.ModelBundle
|
|
mode inference.BackendMode
|
|
config *inference.BackendConfiguration
|
|
expected []string
|
|
}{
|
|
{
|
|
name: "completion mode",
|
|
mode: inference.BackendModeCompletion,
|
|
bundle: &fakeBundle{
|
|
ggufPath: modelPath,
|
|
},
|
|
expected: []string{
|
|
"--jinja",
|
|
"-ngl", "100",
|
|
"--metrics",
|
|
"--model", modelPath,
|
|
"--host", socket,
|
|
"--ctx-size", "4096",
|
|
},
|
|
},
|
|
{
|
|
name: "embedding mode",
|
|
mode: inference.BackendModeEmbedding,
|
|
bundle: &fakeBundle{
|
|
ggufPath: modelPath,
|
|
},
|
|
expected: []string{
|
|
"--jinja",
|
|
"-ngl", "100",
|
|
"--metrics",
|
|
"--model", modelPath,
|
|
"--host", socket,
|
|
"--embeddings",
|
|
"--ctx-size", "4096",
|
|
},
|
|
},
|
|
{
|
|
name: "context size from backend config",
|
|
mode: inference.BackendModeEmbedding,
|
|
bundle: &fakeBundle{
|
|
ggufPath: modelPath,
|
|
},
|
|
config: &inference.BackendConfiguration{
|
|
ContextSize: 1234,
|
|
},
|
|
expected: []string{
|
|
"--jinja",
|
|
"-ngl", "100",
|
|
"--metrics",
|
|
"--model", modelPath,
|
|
"--host", socket,
|
|
"--embeddings",
|
|
"--ctx-size", "1234", // should add this flag
|
|
},
|
|
},
|
|
{
|
|
name: "context size from model config",
|
|
mode: inference.BackendModeEmbedding,
|
|
bundle: &fakeBundle{
|
|
ggufPath: modelPath,
|
|
config: types.Config{
|
|
ContextSize: uint64ptr(2096),
|
|
},
|
|
},
|
|
config: &inference.BackendConfiguration{
|
|
ContextSize: 1234,
|
|
},
|
|
expected: []string{
|
|
"--jinja",
|
|
"-ngl", "100",
|
|
"--metrics",
|
|
"--model", modelPath,
|
|
"--host", socket,
|
|
"--embeddings",
|
|
"--ctx-size", "2096", // model config takes precedence
|
|
},
|
|
},
|
|
{
|
|
name: "raw flags from backend config",
|
|
mode: inference.BackendModeEmbedding,
|
|
bundle: &fakeBundle{
|
|
ggufPath: modelPath,
|
|
},
|
|
config: &inference.BackendConfiguration{
|
|
RuntimeFlags: []string{"--some", "flag"},
|
|
},
|
|
expected: []string{
|
|
"--jinja",
|
|
"-ngl", "100",
|
|
"--metrics",
|
|
"--model", modelPath,
|
|
"--host", socket,
|
|
"--embeddings",
|
|
"--ctx-size", "4096",
|
|
"--some", "flag", // model config takes precedence
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
args, err := config.GetArgs(tt.bundle, socket, tt.mode, tt.config)
|
|
if err != nil {
|
|
t.Errorf("GetArgs() error = %v", err)
|
|
}
|
|
|
|
// Check that all expected arguments are present and in the correct order
|
|
expectedIndex := 0
|
|
for i := 0; i < len(args); i++ {
|
|
if expectedIndex >= len(tt.expected) {
|
|
t.Errorf("Unexpected extra argument: %s", args[i])
|
|
continue
|
|
}
|
|
|
|
if args[i] != tt.expected[expectedIndex] {
|
|
t.Errorf("Expected argument %s at position %d, got %s", tt.expected[expectedIndex], i, args[i])
|
|
continue
|
|
}
|
|
|
|
// If this is a flag that takes a value, check the next argument
|
|
if i+1 < len(args) && (args[i] == "-ngl" || args[i] == "--model" || args[i] == "--host") {
|
|
expectedIndex++
|
|
if args[i+1] != tt.expected[expectedIndex] {
|
|
t.Errorf("Expected value %s for flag %s, got %s", tt.expected[expectedIndex], args[i], args[i+1])
|
|
}
|
|
i++ // Skip the value in the next iteration
|
|
}
|
|
expectedIndex++
|
|
}
|
|
|
|
if expectedIndex != len(tt.expected) {
|
|
t.Errorf("Missing expected arguments. Got %d arguments, expected %d", expectedIndex, len(tt.expected))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestContainsArg(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
args []string
|
|
arg string
|
|
expected bool
|
|
}{
|
|
{
|
|
name: "argument exists",
|
|
args: []string{"--arg1", "--arg2", "--arg3"},
|
|
arg: "--arg2",
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "argument does not exist",
|
|
args: []string{"--arg1", "--arg2", "--arg3"},
|
|
arg: "--arg4",
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "empty args slice",
|
|
args: []string{},
|
|
arg: "--arg1",
|
|
expected: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := containsArg(tt.args, tt.arg)
|
|
if result != tt.expected {
|
|
t.Errorf("containsArg(%v, %s) = %v, want %v", tt.args, tt.arg, result, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
var _ types.ModelBundle = &fakeBundle{}
|
|
|
|
type fakeBundle struct {
|
|
ggufPath string
|
|
config types.Config
|
|
}
|
|
|
|
func (f *fakeBundle) RootDir() string {
|
|
panic("shouldn't be called")
|
|
}
|
|
|
|
func (f *fakeBundle) GGUFPath() string {
|
|
return f.ggufPath
|
|
}
|
|
|
|
func (f *fakeBundle) MMPROJPath() string {
|
|
return ""
|
|
}
|
|
|
|
func (f *fakeBundle) RuntimeConfig() types.Config {
|
|
return f.config
|
|
}
|
|
|
|
func uint64ptr(n uint64) *uint64 {
|
|
return &n
|
|
}
|