211 lines
6.3 KiB
Go
211 lines
6.3 KiB
Go
package llamacpp
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io/fs"
|
|
"net/http"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"runtime"
|
|
"strconv"
|
|
|
|
"github.com/docker/model-runner/pkg/diskusage"
|
|
"github.com/docker/model-runner/pkg/inference"
|
|
"github.com/docker/model-runner/pkg/inference/models"
|
|
"github.com/docker/model-runner/pkg/logging"
|
|
)
|
|
|
|
const (
|
|
// Name is the backend name.
|
|
Name = "llama.cpp"
|
|
)
|
|
|
|
// llamaCpp is the llama.cpp-based backend implementation.
|
|
type llamaCpp struct {
|
|
// log is the associated logger.
|
|
log logging.Logger
|
|
// modelManager is the shared model manager.
|
|
modelManager *models.Manager
|
|
// serverLog is the logger to use for the llama.cpp server process.
|
|
serverLog logging.Logger
|
|
updatedLlamaCpp bool
|
|
// vendoredServerStoragePath is the parent path of the vendored version of com.docker.llama-server.
|
|
vendoredServerStoragePath string
|
|
// updatedServerStoragePath is the parent path of the updated version of com.docker.llama-server.
|
|
// It is also where updates will be stored when downloaded.
|
|
updatedServerStoragePath string
|
|
// status is the state in which the llama.cpp backend is in.
|
|
status string
|
|
}
|
|
|
|
// New creates a new llama.cpp-based backend.
|
|
func New(
|
|
log logging.Logger,
|
|
modelManager *models.Manager,
|
|
serverLog logging.Logger,
|
|
vendoredServerStoragePath string,
|
|
updatedServerStoragePath string,
|
|
) (inference.Backend, error) {
|
|
return &llamaCpp{
|
|
log: log,
|
|
modelManager: modelManager,
|
|
serverLog: serverLog,
|
|
vendoredServerStoragePath: vendoredServerStoragePath,
|
|
updatedServerStoragePath: updatedServerStoragePath,
|
|
}, nil
|
|
}
|
|
|
|
// Name implements inference.Backend.Name.
|
|
func (l *llamaCpp) Name() string {
|
|
return Name
|
|
}
|
|
|
|
// UsesExternalModelManagement implements
|
|
// inference.Backend.UsesExternalModelManagement.
|
|
func (l *llamaCpp) UsesExternalModelManagement() bool {
|
|
return false
|
|
}
|
|
|
|
// Install implements inference.Backend.Install.
|
|
func (l *llamaCpp) Install(ctx context.Context, httpClient *http.Client) error {
|
|
l.updatedLlamaCpp = false
|
|
|
|
// We don't currently support this backend on Windows. We'll likely
|
|
// never support it on Intel Macs.
|
|
if (runtime.GOOS == "darwin" && runtime.GOARCH == "amd64") ||
|
|
(runtime.GOOS == "windows" && !(runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64")) {
|
|
return errors.New("platform not supported")
|
|
}
|
|
|
|
llamaServerBin := "com.docker.llama-server"
|
|
if runtime.GOOS == "windows" {
|
|
llamaServerBin = "com.docker.llama-server.exe"
|
|
}
|
|
|
|
l.status = "installing"
|
|
|
|
// Temporary workaround for dynamically downloading llama.cpp from Docker Hub.
|
|
// Internet access and an available docker/docker-model-backend-llamacpp:latest on Docker Hub are required.
|
|
// Even if docker/docker-model-backend-llamacpp:latest has been downloaded before, we still require its
|
|
// digest to be equal to the one on Docker Hub.
|
|
llamaCppPath := filepath.Join(l.updatedServerStoragePath, llamaServerBin)
|
|
if err := l.ensureLatestLlamaCpp(ctx, l.log, httpClient, llamaCppPath, l.vendoredServerStoragePath); err != nil {
|
|
l.log.Infof("failed to ensure latest llama.cpp: %v\n", err)
|
|
if !errors.Is(err, errLlamaCppUpToDate) {
|
|
l.status = fmt.Sprintf("failed to install llama.cpp: %v", err)
|
|
}
|
|
if errors.Is(err, context.Canceled) {
|
|
return err
|
|
}
|
|
} else {
|
|
l.updatedLlamaCpp = true
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Run implements inference.Backend.Run.
|
|
func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference.BackendMode) error {
|
|
modelPath, err := l.modelManager.GetModelPath(model)
|
|
l.log.Infof("Model path: %s", modelPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get model path: %w", err)
|
|
}
|
|
|
|
modelDesc, err := l.modelManager.GetModel(model)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get model: %w", err)
|
|
}
|
|
|
|
if err := os.RemoveAll(socket); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
|
l.log.Warnf("failed to remove socket file %s: %w\n", socket, err)
|
|
l.log.Warnln("llama.cpp may not be able to start")
|
|
}
|
|
|
|
binPath := l.vendoredServerStoragePath
|
|
if l.updatedLlamaCpp {
|
|
binPath = l.updatedServerStoragePath
|
|
}
|
|
llamaCppArgs := []string{"--model", modelPath, "--jinja", "--host", socket}
|
|
if mode == inference.BackendModeEmbedding {
|
|
llamaCppArgs = append(llamaCppArgs, "--embeddings")
|
|
}
|
|
if runtime.GOOS == "windows" && runtime.GOARCH == "arm64" {
|
|
// Using a thread count equal to core count results in bad performance, and there seems to be little to no gain
|
|
// in going beyond core_count/2.
|
|
// TODO(p1-0tr): dig into why the defaults don't work well on windows/arm64
|
|
nThreads := min(2, runtime.NumCPU()/2)
|
|
llamaCppArgs = append(llamaCppArgs, "--threads", strconv.Itoa(nThreads))
|
|
|
|
modelConfig, err := modelDesc.Config()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get model config: %w", err)
|
|
}
|
|
// The Adreno OpenCL implementation currently only supports Q4_0
|
|
if modelConfig.Quantization == "Q4_0" {
|
|
llamaCppArgs = append(llamaCppArgs, "-ngl", "100")
|
|
}
|
|
} else {
|
|
llamaCppArgs = append(llamaCppArgs, "-ngl", "100")
|
|
}
|
|
llamaCppProcess := exec.CommandContext(
|
|
ctx,
|
|
filepath.Join(binPath, "com.docker.llama-server"),
|
|
llamaCppArgs...,
|
|
)
|
|
llamaCppProcess.Cancel = func() error {
|
|
if runtime.GOOS == "windows" {
|
|
return llamaCppProcess.Process.Kill()
|
|
}
|
|
return llamaCppProcess.Process.Signal(os.Interrupt)
|
|
}
|
|
serverLogStream := l.serverLog.Writer()
|
|
llamaCppProcess.Stdout = serverLogStream
|
|
llamaCppProcess.Stderr = serverLogStream
|
|
|
|
if err := llamaCppProcess.Start(); err != nil {
|
|
return fmt.Errorf("unable to start llama.cpp: %w", err)
|
|
}
|
|
|
|
llamaCppErrors := make(chan error, 1)
|
|
go func() {
|
|
llamaCppErr := llamaCppProcess.Wait()
|
|
serverLogStream.Close()
|
|
llamaCppErrors <- llamaCppErr
|
|
close(llamaCppErrors)
|
|
if err := os.Remove(socket); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
|
l.log.Warnf("failed to remove socket file %s on exit: %w\n", socket, err)
|
|
}
|
|
}()
|
|
defer func() {
|
|
<-llamaCppErrors
|
|
}()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
case llamaCppErr := <-llamaCppErrors:
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
default:
|
|
}
|
|
return fmt.Errorf("llama.cpp terminated unexpectedly: %w", llamaCppErr)
|
|
}
|
|
}
|
|
|
|
func (l *llamaCpp) Status() string {
|
|
return l.status
|
|
}
|
|
|
|
func (l *llamaCpp) GetDiskUsage() (float64, error) {
|
|
size, err := diskusage.Size(l.updatedServerStoragePath)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("error while getting store size: %v", err)
|
|
}
|
|
return size, nil
|
|
}
|