Basic param tuning on windows/arm64

Signed-off-by: Piotr Stankiewicz <piotr.stankiewicz@docker.com>
This commit is contained in:
Piotr Stankiewicz 2025-04-25 13:17:11 +02:00
parent c382c0b5c6
commit a47583dc39
1 changed files with 26 additions and 3 deletions

View File

@ -9,7 +9,7 @@ import (
"os/exec"
"path/filepath"
"runtime"
"slices"
"strconv"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/models"
@ -76,7 +76,7 @@ func (l *llamaCpp) Install(ctx context.Context, httpClient *http.Client) error {
if runtime.GOOS == "linux" {
return errors.New("not implemented")
} else if (runtime.GOOS == "darwin" && runtime.GOARCH == "amd64") ||
(runtime.GOOS == "windows" && !slices.Contains([]string{"amd64", "arm64"}, runtime.GOARCH)) {
(runtime.GOOS == "windows" && !(runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64")) {
return errors.New("platform not supported")
}
@ -115,6 +115,11 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference
return fmt.Errorf("failed to get model path: %w", err)
}
modelDesc, err := l.modelManager.GetModel(model)
if err != nil {
return fmt.Errorf("failed to get model: %w", err)
}
if err := os.RemoveAll(socket); err != nil {
l.log.Warnln("failed to remove socket file %s: %w", socket, err)
l.log.Warnln("llama.cpp may not be able to start")
@ -124,10 +129,28 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference
if l.updatedLlamaCpp {
binPath = l.updatedServerStoragePath
}
llamaCppArgs := []string{"--model", modelPath, "--jinja", "-ngl", "100"}
llamaCppArgs := []string{"--model", modelPath, "--jinja"}
if mode == inference.BackendModeEmbedding {
llamaCppArgs = append(llamaCppArgs, "--embeddings")
}
if runtime.GOOS == "windows" && runtime.GOARCH == "arm64" {
// Using a thread count equal to core count results in bad performance, and there seems to be little to no gain
// in going beyond core_count/2.
// TODO(p1-0tr): dig into why the defaults don't work well on windows/arm64
nThreads := min(2, runtime.NumCPU()/2)
llamaCppArgs = append(llamaCppArgs, "--threads", strconv.Itoa(nThreads))
modelConfig, err := modelDesc.Config()
if err != nil {
return fmt.Errorf("failed to get model config: %w", err)
}
// The Adreno OpenCL implementation currently only supports Q4_0
if modelConfig.Quantization == "Q4_0" {
llamaCppArgs = append(llamaCppArgs, "-ngl", "100")
}
} else {
llamaCppArgs = append(llamaCppArgs, "-ngl", "100")
}
llamaCppProcess := exec.CommandContext(
ctx,
filepath.Join(binPath, "com.docker.llama-server"),