diff --git a/pkg/inference/backends/llamacpp/llamacpp.go b/pkg/inference/backends/llamacpp/llamacpp.go index db28aca..dd2a8d4 100644 --- a/pkg/inference/backends/llamacpp/llamacpp.go +++ b/pkg/inference/backends/llamacpp/llamacpp.go @@ -9,7 +9,7 @@ import ( "os/exec" "path/filepath" "runtime" - "slices" + "strconv" "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/inference/models" @@ -76,7 +76,7 @@ func (l *llamaCpp) Install(ctx context.Context, httpClient *http.Client) error { if runtime.GOOS == "linux" { return errors.New("not implemented") } else if (runtime.GOOS == "darwin" && runtime.GOARCH == "amd64") || - (runtime.GOOS == "windows" && !slices.Contains([]string{"amd64", "arm64"}, runtime.GOARCH)) { + (runtime.GOOS == "windows" && !(runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64")) { return errors.New("platform not supported") } @@ -115,6 +115,11 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference return fmt.Errorf("failed to get model path: %w", err) } + modelDesc, err := l.modelManager.GetModel(model) + if err != nil { + return fmt.Errorf("failed to get model: %w", err) + } + if err := os.RemoveAll(socket); err != nil { l.log.Warnln("failed to remove socket file %s: %w", socket, err) l.log.Warnln("llama.cpp may not be able to start") @@ -124,10 +129,28 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference if l.updatedLlamaCpp { binPath = l.updatedServerStoragePath } - llamaCppArgs := []string{"--model", modelPath, "--jinja", "-ngl", "100"} + llamaCppArgs := []string{"--model", modelPath, "--jinja"} if mode == inference.BackendModeEmbedding { llamaCppArgs = append(llamaCppArgs, "--embeddings") } + if runtime.GOOS == "windows" && runtime.GOARCH == "arm64" { + // Using a thread count equal to core count results in bad performance, and there seems to be little to no gain + // in going beyond core_count/2. + // TODO(p1-0tr): dig into why the defaults don't work well on windows/arm64 + nThreads := min(2, runtime.NumCPU()/2) + llamaCppArgs = append(llamaCppArgs, "--threads", strconv.Itoa(nThreads)) + + modelConfig, err := modelDesc.Config() + if err != nil { + return fmt.Errorf("failed to get model config: %w", err) + } + // The Adreno OpenCL implementation currently only supports Q4_0 + if modelConfig.Quantization == "Q4_0" { + llamaCppArgs = append(llamaCppArgs, "-ngl", "100") + } + } else { + llamaCppArgs = append(llamaCppArgs, "-ngl", "100") + } llamaCppProcess := exec.CommandContext( ctx, filepath.Join(binPath, "com.docker.llama-server"),