Merge pull request #13 from docker/ps-win-support

Windows support
This commit is contained in:
Jacob Howard 2025-04-17 14:08:17 -06:00 committed by GitHub
commit 3b88d22cbd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 284 additions and 27 deletions

6
go.mod
View File

@ -6,6 +6,7 @@ require (
github.com/containerd/containerd/v2 v2.0.4
github.com/containerd/platforms v1.0.0-rc.1
github.com/docker/model-distribution v0.0.0-20250411163353-b33595b4e207
github.com/jaypipes/ghw v0.16.0
github.com/opencontainers/go-digest v1.0.0
github.com/opencontainers/image-spec v1.1.1
github.com/sirupsen/logrus v1.9.3
@ -14,6 +15,7 @@ require (
require (
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 // indirect
github.com/StackExchange/wmi v1.2.1 // indirect
github.com/containerd/errdefs v1.0.0 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/containerd/stargz-snapshotter/estargz v0.16.3 // indirect
@ -24,9 +26,11 @@ require (
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/google/go-containerregistry v0.20.3 // indirect
github.com/gpustack/gguf-parser-go v0.13.20 // indirect
github.com/henvic/httpretty v0.1.4 // indirect
github.com/jaypipes/pcidb v1.0.1 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.17.11 // indirect
github.com/mitchellh/go-homedir v1.1.0 // indirect
@ -48,4 +52,6 @@ require (
golang.org/x/sys v0.31.0 // indirect
golang.org/x/tools v0.29.0 // indirect
gonum.org/v1/gonum v0.15.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
howett.net/plist v1.0.0 // indirect
)

22
go.sum
View File

@ -8,6 +8,8 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/Microsoft/hcsshim v0.12.9 h1:2zJy5KA+l0loz1HzEGqyNnjd3fyZA31ZBCGKacp6lLg=
github.com/Microsoft/hcsshim v0.12.9/go.mod h1:fJ0gkFAna6ukt0bLdKB8djt4XIJhF/vEPuoIWYVvZ8Y=
github.com/StackExchange/wmi v1.2.1 h1:VIkavFPXSjcnS+O8yTq7NI32k0R5Aj+v39y29VYDOSA=
github.com/StackExchange/wmi v1.2.1/go.mod h1:rcmrprowKIVzvc+NUiLncP2uuArMWLCbu9SBzvHz7e8=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/containerd/cgroups/v3 v3.0.3 h1:S5ByHZ/h9PMe5IOQoN7E+nMc2UcLEM/V48DGDJ9kip0=
@ -49,8 +51,6 @@ github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/docker/model-distribution v0.0.0-20250410151231-bf9b59b512f7 h1:li7LReF/bddqOXbX7IfAnzRmYYAuMjG4C7xCMQLZPlI=
github.com/docker/model-distribution v0.0.0-20250410151231-bf9b59b512f7/go.mod h1:/JWSwYc3pihCpHqFzDUyoiRKegA1srfYESxRh/vJE10=
github.com/docker/model-distribution v0.0.0-20250411163353-b33595b4e207 h1:BZFQGpeo7H4JLeX1Gn+T9P7vSPwbhtH10QeSKdkzKKs=
github.com/docker/model-distribution v0.0.0-20250411163353-b33595b4e207/go.mod h1:/JWSwYc3pihCpHqFzDUyoiRKegA1srfYESxRh/vJE10=
github.com/ebitengine/purego v0.8.2 h1:jPPGWs2sZ1UgOSgD2bClL0MJIqu58nOmIcBuXr62z1I=
@ -62,6 +62,7 @@ github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
@ -79,10 +80,19 @@ github.com/gpustack/gguf-parser-go v0.13.20 h1:EmONF0H9WSUen16dCsKeqb0J+vAr3jbIv
github.com/gpustack/gguf-parser-go v0.13.20/go.mod h1:GvHh1Kvvq5ojCOsJ5UpwiJJmIjFw3Qk5cW7R+CZ3IJo=
github.com/henvic/httpretty v0.1.4 h1:Jo7uwIRWVFxkqOnErcoYfH90o3ddQyVrSANeS4cxYmU=
github.com/henvic/httpretty v0.1.4/go.mod h1:Dn60sQTZfbt2dYsdUSNsCljyF4AfdqnuJFDLJA1I4AM=
github.com/jaypipes/ghw v0.16.0 h1:3HurCTS38VNpeQLo5fIdZsySuo/qAfpPSJ5t05QBFPM=
github.com/jaypipes/ghw v0.16.0/go.mod h1:In8SsaDqlb1oTyrbmTC14uy+fbBMvp+xdqX51MidlD8=
github.com/jaypipes/pcidb v1.0.1 h1:WB2zh27T3nwg8AE8ei81sNRb9yWBii3JGNJtT7K9Oic=
github.com/jaypipes/pcidb v1.0.1/go.mod h1:6xYUz/yYEyOkIkUt2t2J2folIuZ4Yg6uByCGFXMCeE4=
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/magiconair/properties v1.8.9 h1:nWcCbLq1N2v/cpNsy5WvQ37Fb+YElfq20WJ/a8RkpQM=
@ -122,6 +132,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/rs/dnscache v0.0.0-20230804202142-fc85eb664529 h1:18kd+8ZUlt/ARXhljq+14TwAoKa61q6dX8jtwOf6DH8=
github.com/rs/dnscache v0.0.0-20230804202142-fc85eb664529/go.mod h1:qe5TWALJ8/a1Lqznoc5BDHpYX/8HU60Hm2AwRmqzxqA=
github.com/shirou/gopsutil/v4 v4.25.1 h1:QSWkTc+fu9LTAWfkZwZ6j8MSUk4A2LV7rbH0ZqmLjXs=
@ -168,6 +180,7 @@ golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
@ -184,8 +197,13 @@ google.golang.org/grpc v1.68.1/go.mod h1:+q1XYFJjShcqn0QZHvCyeR4CXPA+llXIeUIfIe0
google.golang.org/protobuf v1.36.3 h1:82DV7MYdb8anAVi3qge1wSnMDrnKK7ebr+I0hHRN1BU=
google.golang.org/protobuf v1.36.3/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0/go.mod h1:WDnlLJ4WF5VGsH/HVa3CI79GS0ol3YnhVnKP89i0kNg=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gotest.tools/v3 v3.0.3 h1:4AuOwCGf4lLR9u3YOe2awrHygurzhO/HeQ6laiA6Sx0=
gotest.tools/v3 v3.0.3/go.mod h1:Z7Lb0S5l+klDB31fvDQX8ss/FlKDxtlFlw3Oa8Ymbl8=
howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM=
howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=

View File

@ -67,4 +67,6 @@ type Backend interface {
// instead load only the specified model. Backends should still respond to
// OpenAI API requests for other models with a 421 error code.
Run(ctx context.Context, socket, model string, mode BackendMode) error
// Status returns a description of the backend's state.
Status() string
}

View File

@ -3,6 +3,7 @@ package llamacpp
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
@ -12,6 +13,7 @@ import (
"regexp"
"runtime"
"strings"
"sync"
"github.com/docker/model-runner/pkg/internal/dockerhub"
"github.com/docker/model-runner/pkg/logging"
@ -22,7 +24,15 @@ const (
hubRepo = "docker-model-backend-llamacpp"
)
func ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client, llamaCppPath string) error {
var (
ShouldUseGPUVariant bool
ShouldUseGPUVariantLock sync.Mutex
errLlamaCppUpToDate = errors.New("bundled llama.cpp version is up to date, no need to update")
)
func (l *llamaCpp) downloadLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
llamaCppPath, vendoredServerStoragePath, desiredVersion, desiredVariant string,
) error {
url := fmt.Sprintf("https://hub.docker.com/v2/namespaces/%s/repositories/%s/tags", hubNamespace, hubRepo)
resp, err := httpClient.Get(url)
if err != nil {
@ -47,25 +57,39 @@ func ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *h
return fmt.Errorf("failed to unmarshal response body: %w", err)
}
desiredTag := desiredVersion + "-" + desiredVariant
var latest string
for _, tag := range response.Results {
if tag.Name == "latest-update" {
if tag.Name == desiredTag {
latest = tag.Digest
break
}
}
if latest == "" {
return fmt.Errorf("could not find any latest-update tag")
return fmt.Errorf("could not find the %s tag", desiredTag)
}
bundledVersionFile := filepath.Join(vendoredServerStoragePath, "com.docker.llama-server.digest")
currentVersionFile := filepath.Join(filepath.Dir(llamaCppPath), ".llamacpp_version")
data, err := os.ReadFile(currentVersionFile)
data, err := os.ReadFile(bundledVersionFile)
if err != nil {
return fmt.Errorf("failed to read bundled llama.cpp version: %w", err)
} else if strings.TrimSpace(string(data)) == latest {
l.status = fmt.Sprintf("running llama.cpp %s (%s) version: %s",
desiredTag, latest, getLlamaCppVersion(log, filepath.Join(vendoredServerStoragePath, "com.docker.llama-server")))
return errLlamaCppUpToDate
}
data, err = os.ReadFile(currentVersionFile)
if err != nil {
log.Warnf("failed to read current llama.cpp version: %v", err)
log.Warnf("proceeding to update llama.cpp binary")
} else if strings.TrimSpace(string(data)) == latest {
log.Infoln("current llama.cpp version is already up to date")
if _, err := os.Stat(llamaCppPath); err == nil {
l.status = fmt.Sprintf("running llama.cpp %s (%s) version: %s",
desiredTag, latest, getLlamaCppVersion(log, llamaCppPath))
return nil
}
log.Infoln("llama.cpp binary must be updated, proceeding to update it")
@ -80,27 +104,44 @@ func ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *h
}
defer os.RemoveAll(downloadDir)
l.status = fmt.Sprintf("downloading %s (%s) variant of llama.cpp", desiredTag, latest)
if err := extractFromImage(ctx, log, image, runtime.GOOS, runtime.GOARCH, downloadDir); err != nil {
return fmt.Errorf("could not extract image: %w", err)
}
if err := os.MkdirAll(filepath.Dir(llamaCppPath), 0o755); err != nil {
return fmt.Errorf("could not create directory for llama.cpp binary: %w", err)
if err := os.RemoveAll(filepath.Dir(llamaCppPath)); err != nil && !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("failed to clear inference binary dir: %w", err)
}
rootDir := "com.docker.llama-server.native.darwin.metal.arm64"
if err := os.Rename(fmt.Sprintf("%s/%s/%s", downloadDir, rootDir, "bin/com.docker.llama-server"), llamaCppPath); err != nil {
if err := os.RemoveAll(filepath.Join(filepath.Dir(filepath.Dir(llamaCppPath)), "lib")); err != nil && !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("failed to clear inference library dir: %w", err)
}
if err := os.MkdirAll(filepath.Dir(filepath.Dir(llamaCppPath)), 0o755); err != nil {
return fmt.Errorf("could not create directory for llama.cpp artifacts: %w", err)
}
rootDir := fmt.Sprintf("com.docker.llama-server.native.%s.%s.%s", runtime.GOOS, desiredVariant, runtime.GOARCH)
if err := os.Rename(filepath.Join(downloadDir, rootDir, "bin"), filepath.Dir(llamaCppPath)); err != nil {
return fmt.Errorf("could not move llama.cpp binary: %w", err)
}
if err := os.Chmod(llamaCppPath, 0o755); err != nil {
return fmt.Errorf("could not chmod llama.cpp binary: %w", err)
}
if err := os.Rename(fmt.Sprintf("%s/%s/%s", downloadDir, rootDir, "lib"),
filepath.Join(filepath.Dir(filepath.Dir(llamaCppPath)), "lib")); err != nil {
return fmt.Errorf("could not move llama.cpp libs: %w", err)
libDir := filepath.Join(downloadDir, rootDir, "lib")
fi, err := os.Stat(libDir)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("failed to stat llama.cpp lib dir: %w", err)
}
if err == nil && fi.IsDir() {
if err := os.Rename(libDir, filepath.Join(filepath.Dir(filepath.Dir(llamaCppPath)), "lib")); err != nil {
return fmt.Errorf("could not move llama.cpp libs: %w", err)
}
}
log.Infoln("successfully updated llama.cpp binary")
log.Infoln("running llama.cpp version:", getLlamaCppVersion(log, llamaCppPath))
l.status = fmt.Sprintf("running llama.cpp %s (%s) version: %s", desiredTag, latest, getLlamaCppVersion(log, llamaCppPath))
log.Infoln(l.status)
if err := os.WriteFile(currentVersionFile, []byte(latest), 0o644); err != nil {
log.Warnf("failed to save llama.cpp version: %v", err)

View File

@ -0,0 +1,17 @@
package llamacpp
import (
"context"
"net/http"
"github.com/docker/model-runner/pkg/logging"
)
func (l *llamaCpp) ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
llamaCppPath, vendoredServerStoragePath string,
) error {
desiredVersion := "latest"
desiredVariant := "metal"
return l.downloadLatestLlamaCpp(ctx, log, httpClient, llamaCppPath, vendoredServerStoragePath, desiredVersion,
desiredVariant)
}

View File

@ -0,0 +1,15 @@
package llamacpp
import (
"context"
"errors"
"net/http"
"github.com/docker/model-runner/pkg/logging"
)
func (l *llamaCpp) ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
llamaCppPath, vendoredServerStoragePath string,
) error {
return errors.New("platform is not supported")
}

View File

@ -0,0 +1,35 @@
package llamacpp
import (
"context"
"fmt"
"net/http"
"path/filepath"
"github.com/docker/model-runner/pkg/logging"
)
func (l *llamaCpp) ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
llamaCppPath, vendoredServerStoragePath string,
) error {
nvGPUInfoBin := filepath.Join(vendoredServerStoragePath, "com.docker.nv-gpu-info.exe")
var canUseCUDA11 bool
var err error
ShouldUseGPUVariantLock.Lock()
defer ShouldUseGPUVariantLock.Unlock()
if ShouldUseGPUVariant {
canUseCUDA11, err = hasCUDA11CapableGPU(ctx, nvGPUInfoBin)
if err != nil {
l.status = fmt.Sprintf("failed to check CUDA 11 capability: %v", err)
return fmt.Errorf("failed to check CUDA 11 capability: %w", err)
}
}
desiredVersion := "latest"
desiredVariant := "cpu"
if canUseCUDA11 {
desiredVariant = "cuda"
}
l.status = fmt.Sprintf("looking for updates for %s variant", desiredVariant)
return l.downloadLatestLlamaCpp(ctx, log, httpClient, llamaCppPath, vendoredServerStoragePath, desiredVersion,
desiredVariant)
}

View File

@ -0,0 +1,7 @@
//go:build !windows
package llamacpp
import "context"
func CanUseGPU(context.Context, string) (bool, error) { return false, nil }

View File

@ -0,0 +1,61 @@
package llamacpp
import (
"bufio"
"context"
"fmt"
"os/exec"
"strconv"
"strings"
"github.com/jaypipes/ghw"
)
func hasNVIDIAGPU() (bool, error) {
gpus, err := ghw.GPU()
if err != nil {
return false, err
}
for _, gpu := range gpus.GraphicsCards {
if strings.ToLower(gpu.DeviceInfo.Vendor.Name) == "nvidia" {
return true, nil
}
}
return false, nil
}
func hasCUDA11CapableGPU(ctx context.Context, nvGPUInfoBin string) (bool, error) {
nvGPU, err := hasNVIDIAGPU()
if !nvGPU || err != nil {
return false, err
}
cmd := exec.CommandContext(ctx, nvGPUInfoBin)
out, err := cmd.CombinedOutput()
if err != nil {
return false, err
}
sc := bufio.NewScanner(strings.NewReader(string(out)))
for sc.Scan() {
version, found := strings.CutPrefix(sc.Text(), "driver version:")
if found {
version = strings.TrimSpace(version)
if len(version) != 5 {
return false, fmt.Errorf("unexpected NVIDIA driver version format: %s", version)
}
major, err := strconv.Atoi(version[:3])
if err != nil {
return false, fmt.Errorf("unexpected NVIDIA driver version format: %s", version)
}
minor, err := strconv.Atoi(version[3:5])
if err != nil {
return false, fmt.Errorf("unexpected NVIDIA driver version format: %s", version)
}
return major > 452 || (major == 452 && minor >= 39), nil
}
}
return false, nil
}
func CanUseGPU(ctx context.Context, nvGPUInfoBin string) (bool, error) {
return hasCUDA11CapableGPU(ctx, nvGPUInfoBin)
}

View File

@ -34,6 +34,8 @@ type llamaCpp struct {
// updatedServerStoragePath is the parent path of the updated version of com.docker.llama-server.
// It is also where updates will be stored when downloaded.
updatedServerStoragePath string
// status is the state in which the llama.cpp backend is in.
status string
}
// New creates a new llama.cpp-based backend.
@ -66,21 +68,33 @@ func (l *llamaCpp) UsesExternalModelManagement() bool {
// Install implements inference.Backend.Install.
func (l *llamaCpp) Install(ctx context.Context, httpClient *http.Client) error {
l.updatedLlamaCpp = false
// We don't currently support this backend on Windows or Linux. We'll likely
// never support it on Intel Macs.
if runtime.GOOS == "windows" || runtime.GOOS == "linux" {
if runtime.GOOS == "linux" {
return errors.New("not implemented")
} else if runtime.GOOS == "darwin" && runtime.GOARCH == "amd64" {
} else if (runtime.GOOS == "darwin" && runtime.GOARCH == "amd64") || (runtime.GOOS == "windows" && runtime.GOARCH == "arm64") {
return errors.New("platform not supported")
}
llamaServerBin := "com.docker.llama-server"
if runtime.GOOS == "windows" {
llamaServerBin = "com.docker.llama-server.exe"
}
l.status = "installing"
// Temporary workaround for dynamically downloading llama.cpp from Docker Hub.
// Internet access and an available docker/docker-model-backend-llamacpp:latest-update on Docker Hub are required.
// Even if docker/docker-model-backend-llamacpp:latest-update has been downloaded before, we still require its
// Internet access and an available docker/docker-model-backend-llamacpp:latest on Docker Hub are required.
// Even if docker/docker-model-backend-llamacpp:latest has been downloaded before, we still require its
// digest to be equal to the one on Docker Hub.
llamaCppPath := filepath.Join(l.updatedServerStoragePath, "com.docker.llama-server")
if err := ensureLatestLlamaCpp(ctx, l.log, httpClient, llamaCppPath); err != nil {
llamaCppPath := filepath.Join(l.updatedServerStoragePath, llamaServerBin)
if err := l.ensureLatestLlamaCpp(ctx, l.log, httpClient, llamaCppPath, l.vendoredServerStoragePath); err != nil {
l.log.Infof("failed to ensure latest llama.cpp: %v\n", err)
if !errors.Is(err, errLlamaCppUpToDate) {
l.status = fmt.Sprintf("failed to install llama.cpp: %v", err)
}
if errors.Is(err, context.Canceled) {
return err
}
@ -108,7 +122,7 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference
if l.updatedLlamaCpp {
binPath = l.updatedServerStoragePath
}
llamaCppArgs := []string{"--model", modelPath, "--jinja"}
llamaCppArgs := []string{"--model", modelPath, "--jinja", "-ngl", "100"}
if mode == inference.BackendModeEmbedding {
llamaCppArgs = append(llamaCppArgs, "--embeddings")
}
@ -121,8 +135,9 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference
"DD_INF_UDS="+socket,
)
llamaCppProcess.Cancel = func() error {
// TODO: Figure out the correct process to send on Windows if/when we
// port this backend there.
if runtime.GOOS == "windows" {
return llamaCppProcess.Process.Kill()
}
return llamaCppProcess.Process.Signal(os.Interrupt)
}
serverLogStream := l.serverLog.Writer()
@ -159,3 +174,7 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference
return fmt.Errorf("llama.cpp terminated unexpectedly: %w", llamaCppErr)
}
}
func (l *llamaCpp) Status() string {
return l.status
}

View File

@ -54,3 +54,7 @@ func (m *mlx) Run(ctx context.Context, socket, model string, mode inference.Back
m.log.Warn("MLX backend is not yet supported")
return errors.New("not implemented")
}
func (m *mlx) Status() string {
return "not running"
}

View File

@ -54,3 +54,7 @@ func (v *vLLM) Run(ctx context.Context, socket, model string, mode inference.Bac
v.log.Warn("vLLM backend is not yet supported")
return errors.New("not implemented")
}
func (v *vLLM) Status() string {
return "not running"
}

View File

@ -59,16 +59,16 @@ func NewScheduler(
http.Error(w, "not found", http.StatusNotFound)
})
for _, route := range s.GetRoutes() {
s.router.HandleFunc(route, s.handleOpenAIInference)
for route, handler := range s.routeHandlers() {
s.router.HandleFunc(route, handler)
}
// Scheduler successfully initialized.
return s
}
func (s *Scheduler) GetRoutes() []string {
return []string{
func (s *Scheduler) routeHandlers() map[string]http.HandlerFunc {
openAIRoutes := []string{
"POST " + inference.InferencePrefix + "/{backend}/v1/chat/completions",
"POST " + inference.InferencePrefix + "/{backend}/v1/completions",
"POST " + inference.InferencePrefix + "/{backend}/v1/embeddings",
@ -76,6 +76,21 @@ func (s *Scheduler) GetRoutes() []string {
"POST " + inference.InferencePrefix + "/v1/completions",
"POST " + inference.InferencePrefix + "/v1/embeddings",
}
m := make(map[string]http.HandlerFunc)
for _, route := range openAIRoutes {
m[route] = s.handleOpenAIInference
}
m["GET "+inference.InferencePrefix+"/status"] = s.GetBackendStatus
return m
}
func (s *Scheduler) GetRoutes() []string {
routeHandlers := s.routeHandlers()
routes := make([]string, 0, len(routeHandlers))
for route := range routeHandlers {
routes = append(routes, route)
}
return routes
}
// Run is the scheduler's main run loop. By the time it returns, all inference
@ -196,6 +211,19 @@ func (s *Scheduler) handleOpenAIInference(w http.ResponseWriter, r *http.Request
runner.ServeHTTP(w, upstreamRequest)
}
func (s *Scheduler) GetBackendStatus(w http.ResponseWriter, r *http.Request) {
status := make(map[string]string)
for backendName, backend := range s.backends {
status[backendName] = backend.Status()
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(status)
}
func (s *Scheduler) ResetInstaller(httpClient *http.Client) {
s.installer = newInstaller(s.log, s.backends, httpClient)
}
// ServeHTTP implements net/http.Handler.ServeHTTP.
func (s *Scheduler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.router.ServeHTTP(w, r)