178 lines
6.1 KiB
Go
178 lines
6.1 KiB
Go
package llamacpp
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"regexp"
|
|
"runtime"
|
|
"strings"
|
|
|
|
"github.com/docker/model-runner/pkg/internal/dockerhub"
|
|
"github.com/docker/model-runner/pkg/logging"
|
|
)
|
|
|
|
const (
|
|
hubNamespace = "docker"
|
|
hubRepo = "docker-model-backend-llamacpp"
|
|
)
|
|
|
|
var (
|
|
ShouldUseGPUVariant bool
|
|
errLlamaCppUpToDate = errors.New("bundled llama.cpp version is up to date, no need to update")
|
|
)
|
|
|
|
func (l *llamaCpp) downloadLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
|
|
llamaCppPath, vendoredServerStoragePath, desiredVersion, desiredVariant string,
|
|
) error {
|
|
url := fmt.Sprintf("https://hub.docker.com/v2/namespaces/%s/repositories/%s/tags", hubNamespace, hubRepo)
|
|
resp, err := httpClient.Get(url)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read response body: %w", err)
|
|
}
|
|
|
|
// https://docs.docker.com/reference/api/hub/latest/#tag/repositories/paths/~1v2~1namespaces~1%7Bnamespace%7D~1repositories~1%7Brepository%7D~1tags/get
|
|
var response struct {
|
|
Results []struct {
|
|
Name string `json:"name"`
|
|
Digest string `json:"digest"`
|
|
}
|
|
}
|
|
|
|
if err := json.Unmarshal(body, &response); err != nil {
|
|
return fmt.Errorf("failed to unmarshal response body: %w", err)
|
|
}
|
|
|
|
desiredTag := desiredVersion + "-" + desiredVariant
|
|
var latest string
|
|
for _, tag := range response.Results {
|
|
if tag.Name == desiredTag {
|
|
latest = tag.Digest
|
|
break
|
|
}
|
|
}
|
|
if latest == "" {
|
|
return fmt.Errorf("could not find the %s tag", desiredTag)
|
|
}
|
|
|
|
bundledVersionFile := filepath.Join(vendoredServerStoragePath, "com.docker.llama-server.digest")
|
|
currentVersionFile := filepath.Join(filepath.Dir(llamaCppPath), ".llamacpp_version")
|
|
|
|
data, err := os.ReadFile(bundledVersionFile)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read bundled llama.cpp version: %w", err)
|
|
} else if strings.TrimSpace(string(data)) == latest {
|
|
l.status = fmt.Sprintf("running llama.cpp %s (%s) version: %s",
|
|
desiredTag, latest, getLlamaCppVersion(log, filepath.Join(vendoredServerStoragePath, "com.docker.llama-server")))
|
|
return errLlamaCppUpToDate
|
|
}
|
|
|
|
data, err = os.ReadFile(currentVersionFile)
|
|
if err != nil {
|
|
log.Warnf("failed to read current llama.cpp version: %v", err)
|
|
log.Warnf("proceeding to update llama.cpp binary")
|
|
} else if strings.TrimSpace(string(data)) == latest {
|
|
log.Infoln("current llama.cpp version is already up to date")
|
|
if _, err := os.Stat(llamaCppPath); err == nil {
|
|
l.status = fmt.Sprintf("running llama.cpp %s (%s) version: %s",
|
|
desiredTag, latest, getLlamaCppVersion(log, llamaCppPath))
|
|
return nil
|
|
}
|
|
log.Infoln("llama.cpp binary must be updated, proceeding to update it")
|
|
} else {
|
|
log.Infof("current llama.cpp version is outdated: %s vs %s, proceeding to update it", strings.TrimSpace(string(data)), latest)
|
|
}
|
|
|
|
image := fmt.Sprintf("registry-1.docker.io/%s/%s@%s", hubNamespace, hubRepo, latest)
|
|
downloadDir, err := os.MkdirTemp("", "llamacpp-install")
|
|
if err != nil {
|
|
return fmt.Errorf("could not create temporary directory: %w", err)
|
|
}
|
|
defer os.RemoveAll(downloadDir)
|
|
|
|
l.status = fmt.Sprintf("downloading %s (%s) variant of llama.cpp", desiredTag, latest)
|
|
if err := extractFromImage(ctx, log, image, runtime.GOOS, runtime.GOARCH, downloadDir); err != nil {
|
|
return fmt.Errorf("could not extract image: %w", err)
|
|
}
|
|
|
|
if err := os.RemoveAll(filepath.Dir(llamaCppPath)); err != nil && !errors.Is(err, os.ErrNotExist) {
|
|
return fmt.Errorf("failed to clear inference binary dir: %w", err)
|
|
}
|
|
if err := os.RemoveAll(filepath.Join(filepath.Dir(filepath.Dir(llamaCppPath)), "lib")); err != nil && !errors.Is(err, os.ErrNotExist) {
|
|
return fmt.Errorf("failed to clear inference library dir: %w", err)
|
|
}
|
|
|
|
if err := os.MkdirAll(filepath.Dir(filepath.Dir(llamaCppPath)), 0o755); err != nil {
|
|
return fmt.Errorf("could not create directory for llama.cpp artifacts: %w", err)
|
|
}
|
|
|
|
rootDir := fmt.Sprintf("com.docker.llama-server.native.%s.%s.%s", runtime.GOOS, desiredVariant, runtime.GOARCH)
|
|
if err := os.Rename(fmt.Sprintf("%s/%s/%s", downloadDir, rootDir, "bin"), filepath.Dir(llamaCppPath)); err != nil {
|
|
return fmt.Errorf("could not move llama.cpp binary: %w", err)
|
|
}
|
|
if err := os.Chmod(llamaCppPath, 0o755); err != nil {
|
|
return fmt.Errorf("could not chmod llama.cpp binary: %w", err)
|
|
}
|
|
|
|
libDir := fmt.Sprintf("%s/%s/%s", downloadDir, rootDir, "lib")
|
|
fi, err := os.Stat(libDir)
|
|
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
|
return fmt.Errorf("failed to stat llama.cpp lib dir: %w", err)
|
|
}
|
|
if err == nil && fi.IsDir() {
|
|
if err := os.Rename(libDir, filepath.Join(filepath.Dir(filepath.Dir(llamaCppPath)), "lib")); err != nil {
|
|
return fmt.Errorf("could not move llama.cpp libs: %w", err)
|
|
}
|
|
}
|
|
|
|
log.Infoln("successfully updated llama.cpp binary")
|
|
l.status = fmt.Sprintf("running llama.cpp %s (%s) version: %s", desiredTag, latest, getLlamaCppVersion(log, llamaCppPath))
|
|
log.Infoln(l.status)
|
|
|
|
if err := os.WriteFile(currentVersionFile, []byte(latest), 0o644); err != nil {
|
|
log.Warnf("failed to save llama.cpp version: %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func extractFromImage(ctx context.Context, log logging.Logger, image, requiredOs, requiredArch, destination string) error {
|
|
log.Infof("Extracting image %q to %q", image, destination)
|
|
tmpDir, err := os.MkdirTemp("", "docker-tar-extract")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
imageTar := filepath.Join(tmpDir, "save.tar")
|
|
if err := dockerhub.PullPlatform(ctx, image, imageTar, requiredOs, requiredArch); err != nil {
|
|
return err
|
|
}
|
|
return dockerhub.Extract(imageTar, requiredArch, requiredOs, destination)
|
|
}
|
|
|
|
func getLlamaCppVersion(log logging.Logger, llamaCpp string) string {
|
|
output, err := exec.Command(llamaCpp, "--version").CombinedOutput()
|
|
if err != nil {
|
|
log.Warnf("could not get llama.cpp version: %v", err)
|
|
return "unknown"
|
|
}
|
|
re := regexp.MustCompile(`version: \d+ \((\w+)\)`)
|
|
matches := re.FindStringSubmatch(string(output))
|
|
if len(matches) == 2 {
|
|
return matches[1]
|
|
}
|
|
log.Warnf("failed to parse llama.cpp version from output:\n%s", strings.TrimSpace(string(output)))
|
|
return "unknown"
|
|
}
|