diff --git a/go.mod b/go.mod index c2eefb4..0846a0c 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/containerd/containerd/v2 v2.0.4 github.com/containerd/platforms v1.0.0-rc.1 github.com/docker/model-distribution v0.0.0-20250411163353-b33595b4e207 + github.com/jaypipes/ghw v0.16.0 github.com/opencontainers/go-digest v1.0.0 github.com/opencontainers/image-spec v1.1.1 github.com/sirupsen/logrus v1.9.3 @@ -14,6 +15,7 @@ require ( require ( github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 // indirect + github.com/StackExchange/wmi v1.2.1 // indirect github.com/containerd/errdefs v1.0.0 // indirect github.com/containerd/log v0.1.0 // indirect github.com/containerd/stargz-snapshotter/estargz v0.16.3 // indirect @@ -24,9 +26,11 @@ require ( github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-ole/go-ole v1.2.6 // indirect github.com/google/go-containerregistry v0.20.3 // indirect github.com/gpustack/gguf-parser-go v0.13.20 // indirect github.com/henvic/httpretty v0.1.4 // indirect + github.com/jaypipes/pcidb v1.0.1 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.17.11 // indirect github.com/mitchellh/go-homedir v1.1.0 // indirect @@ -48,4 +52,6 @@ require ( golang.org/x/sys v0.31.0 // indirect golang.org/x/tools v0.29.0 // indirect gonum.org/v1/gonum v0.15.1 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect + howett.net/plist v1.0.0 // indirect ) diff --git a/go.sum b/go.sum index ade394a..5b54b9d 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/Microsoft/hcsshim v0.12.9 h1:2zJy5KA+l0loz1HzEGqyNnjd3fyZA31ZBCGKacp6lLg= github.com/Microsoft/hcsshim v0.12.9/go.mod h1:fJ0gkFAna6ukt0bLdKB8djt4XIJhF/vEPuoIWYVvZ8Y= +github.com/StackExchange/wmi v1.2.1 h1:VIkavFPXSjcnS+O8yTq7NI32k0R5Aj+v39y29VYDOSA= +github.com/StackExchange/wmi v1.2.1/go.mod h1:rcmrprowKIVzvc+NUiLncP2uuArMWLCbu9SBzvHz7e8= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/containerd/cgroups/v3 v3.0.3 h1:S5ByHZ/h9PMe5IOQoN7E+nMc2UcLEM/V48DGDJ9kip0= @@ -49,8 +51,6 @@ github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= -github.com/docker/model-distribution v0.0.0-20250410151231-bf9b59b512f7 h1:li7LReF/bddqOXbX7IfAnzRmYYAuMjG4C7xCMQLZPlI= -github.com/docker/model-distribution v0.0.0-20250410151231-bf9b59b512f7/go.mod h1:/JWSwYc3pihCpHqFzDUyoiRKegA1srfYESxRh/vJE10= github.com/docker/model-distribution v0.0.0-20250411163353-b33595b4e207 h1:BZFQGpeo7H4JLeX1Gn+T9P7vSPwbhtH10QeSKdkzKKs= github.com/docker/model-distribution v0.0.0-20250411163353-b33595b4e207/go.mod h1:/JWSwYc3pihCpHqFzDUyoiRKegA1srfYESxRh/vJE10= github.com/ebitengine/purego v0.8.2 h1:jPPGWs2sZ1UgOSgD2bClL0MJIqu58nOmIcBuXr62z1I= @@ -62,6 +62,7 @@ github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= @@ -79,10 +80,19 @@ github.com/gpustack/gguf-parser-go v0.13.20 h1:EmONF0H9WSUen16dCsKeqb0J+vAr3jbIv github.com/gpustack/gguf-parser-go v0.13.20/go.mod h1:GvHh1Kvvq5ojCOsJ5UpwiJJmIjFw3Qk5cW7R+CZ3IJo= github.com/henvic/httpretty v0.1.4 h1:Jo7uwIRWVFxkqOnErcoYfH90o3ddQyVrSANeS4cxYmU= github.com/henvic/httpretty v0.1.4/go.mod h1:Dn60sQTZfbt2dYsdUSNsCljyF4AfdqnuJFDLJA1I4AM= +github.com/jaypipes/ghw v0.16.0 h1:3HurCTS38VNpeQLo5fIdZsySuo/qAfpPSJ5t05QBFPM= +github.com/jaypipes/ghw v0.16.0/go.mod h1:In8SsaDqlb1oTyrbmTC14uy+fbBMvp+xdqX51MidlD8= +github.com/jaypipes/pcidb v1.0.1 h1:WB2zh27T3nwg8AE8ei81sNRb9yWBii3JGNJtT7K9Oic= +github.com/jaypipes/pcidb v1.0.1/go.mod h1:6xYUz/yYEyOkIkUt2t2J2folIuZ4Yg6uByCGFXMCeE4= +github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/magiconair/properties v1.8.9 h1:nWcCbLq1N2v/cpNsy5WvQ37Fb+YElfq20WJ/a8RkpQM= @@ -122,6 +132,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/rs/dnscache v0.0.0-20230804202142-fc85eb664529 h1:18kd+8ZUlt/ARXhljq+14TwAoKa61q6dX8jtwOf6DH8= github.com/rs/dnscache v0.0.0-20230804202142-fc85eb664529/go.mod h1:qe5TWALJ8/a1Lqznoc5BDHpYX/8HU60Hm2AwRmqzxqA= github.com/shirou/gopsutil/v4 v4.25.1 h1:QSWkTc+fu9LTAWfkZwZ6j8MSUk4A2LV7rbH0ZqmLjXs= @@ -168,6 +180,7 @@ golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= @@ -184,8 +197,13 @@ google.golang.org/grpc v1.68.1/go.mod h1:+q1XYFJjShcqn0QZHvCyeR4CXPA+llXIeUIfIe0 google.golang.org/protobuf v1.36.3 h1:82DV7MYdb8anAVi3qge1wSnMDrnKK7ebr+I0hHRN1BU= google.golang.org/protobuf v1.36.3/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0/go.mod h1:WDnlLJ4WF5VGsH/HVa3CI79GS0ol3YnhVnKP89i0kNg= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.0.3 h1:4AuOwCGf4lLR9u3YOe2awrHygurzhO/HeQ6laiA6Sx0= gotest.tools/v3 v3.0.3/go.mod h1:Z7Lb0S5l+klDB31fvDQX8ss/FlKDxtlFlw3Oa8Ymbl8= +howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM= +howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g= diff --git a/pkg/inference/backend.go b/pkg/inference/backend.go index c80e74f..b5eff25 100644 --- a/pkg/inference/backend.go +++ b/pkg/inference/backend.go @@ -67,4 +67,6 @@ type Backend interface { // instead load only the specified model. Backends should still respond to // OpenAI API requests for other models with a 421 error code. Run(ctx context.Context, socket, model string, mode BackendMode) error + // Status returns a description of the backend's state. + Status() string } diff --git a/pkg/inference/backends/llamacpp/download.go b/pkg/inference/backends/llamacpp/download.go index ae8fa8c..5cf2df7 100644 --- a/pkg/inference/backends/llamacpp/download.go +++ b/pkg/inference/backends/llamacpp/download.go @@ -3,6 +3,7 @@ package llamacpp import ( "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -12,6 +13,7 @@ import ( "regexp" "runtime" "strings" + "sync" "github.com/docker/model-runner/pkg/internal/dockerhub" "github.com/docker/model-runner/pkg/logging" @@ -22,7 +24,15 @@ const ( hubRepo = "docker-model-backend-llamacpp" ) -func ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client, llamaCppPath string) error { +var ( + ShouldUseGPUVariant bool + ShouldUseGPUVariantLock sync.Mutex + errLlamaCppUpToDate = errors.New("bundled llama.cpp version is up to date, no need to update") +) + +func (l *llamaCpp) downloadLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client, + llamaCppPath, vendoredServerStoragePath, desiredVersion, desiredVariant string, +) error { url := fmt.Sprintf("https://hub.docker.com/v2/namespaces/%s/repositories/%s/tags", hubNamespace, hubRepo) resp, err := httpClient.Get(url) if err != nil { @@ -47,25 +57,39 @@ func ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *h return fmt.Errorf("failed to unmarshal response body: %w", err) } + desiredTag := desiredVersion + "-" + desiredVariant var latest string for _, tag := range response.Results { - if tag.Name == "latest-update" { + if tag.Name == desiredTag { latest = tag.Digest break } } if latest == "" { - return fmt.Errorf("could not find any latest-update tag") + return fmt.Errorf("could not find the %s tag", desiredTag) } + bundledVersionFile := filepath.Join(vendoredServerStoragePath, "com.docker.llama-server.digest") currentVersionFile := filepath.Join(filepath.Dir(llamaCppPath), ".llamacpp_version") - data, err := os.ReadFile(currentVersionFile) + + data, err := os.ReadFile(bundledVersionFile) + if err != nil { + return fmt.Errorf("failed to read bundled llama.cpp version: %w", err) + } else if strings.TrimSpace(string(data)) == latest { + l.status = fmt.Sprintf("running llama.cpp %s (%s) version: %s", + desiredTag, latest, getLlamaCppVersion(log, filepath.Join(vendoredServerStoragePath, "com.docker.llama-server"))) + return errLlamaCppUpToDate + } + + data, err = os.ReadFile(currentVersionFile) if err != nil { log.Warnf("failed to read current llama.cpp version: %v", err) log.Warnf("proceeding to update llama.cpp binary") } else if strings.TrimSpace(string(data)) == latest { log.Infoln("current llama.cpp version is already up to date") if _, err := os.Stat(llamaCppPath); err == nil { + l.status = fmt.Sprintf("running llama.cpp %s (%s) version: %s", + desiredTag, latest, getLlamaCppVersion(log, llamaCppPath)) return nil } log.Infoln("llama.cpp binary must be updated, proceeding to update it") @@ -80,27 +104,44 @@ func ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *h } defer os.RemoveAll(downloadDir) + l.status = fmt.Sprintf("downloading %s (%s) variant of llama.cpp", desiredTag, latest) if err := extractFromImage(ctx, log, image, runtime.GOOS, runtime.GOARCH, downloadDir); err != nil { return fmt.Errorf("could not extract image: %w", err) } - if err := os.MkdirAll(filepath.Dir(llamaCppPath), 0o755); err != nil { - return fmt.Errorf("could not create directory for llama.cpp binary: %w", err) + if err := os.RemoveAll(filepath.Dir(llamaCppPath)); err != nil && !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("failed to clear inference binary dir: %w", err) } - rootDir := "com.docker.llama-server.native.darwin.metal.arm64" - if err := os.Rename(fmt.Sprintf("%s/%s/%s", downloadDir, rootDir, "bin/com.docker.llama-server"), llamaCppPath); err != nil { + if err := os.RemoveAll(filepath.Join(filepath.Dir(filepath.Dir(llamaCppPath)), "lib")); err != nil && !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("failed to clear inference library dir: %w", err) + } + + if err := os.MkdirAll(filepath.Dir(filepath.Dir(llamaCppPath)), 0o755); err != nil { + return fmt.Errorf("could not create directory for llama.cpp artifacts: %w", err) + } + + rootDir := fmt.Sprintf("com.docker.llama-server.native.%s.%s.%s", runtime.GOOS, desiredVariant, runtime.GOARCH) + if err := os.Rename(filepath.Join(downloadDir, rootDir, "bin"), filepath.Dir(llamaCppPath)); err != nil { return fmt.Errorf("could not move llama.cpp binary: %w", err) } if err := os.Chmod(llamaCppPath, 0o755); err != nil { return fmt.Errorf("could not chmod llama.cpp binary: %w", err) } - if err := os.Rename(fmt.Sprintf("%s/%s/%s", downloadDir, rootDir, "lib"), - filepath.Join(filepath.Dir(filepath.Dir(llamaCppPath)), "lib")); err != nil { - return fmt.Errorf("could not move llama.cpp libs: %w", err) + + libDir := filepath.Join(downloadDir, rootDir, "lib") + fi, err := os.Stat(libDir) + if err != nil && !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("failed to stat llama.cpp lib dir: %w", err) + } + if err == nil && fi.IsDir() { + if err := os.Rename(libDir, filepath.Join(filepath.Dir(filepath.Dir(llamaCppPath)), "lib")); err != nil { + return fmt.Errorf("could not move llama.cpp libs: %w", err) + } } log.Infoln("successfully updated llama.cpp binary") - log.Infoln("running llama.cpp version:", getLlamaCppVersion(log, llamaCppPath)) + l.status = fmt.Sprintf("running llama.cpp %s (%s) version: %s", desiredTag, latest, getLlamaCppVersion(log, llamaCppPath)) + log.Infoln(l.status) if err := os.WriteFile(currentVersionFile, []byte(latest), 0o644); err != nil { log.Warnf("failed to save llama.cpp version: %v", err) diff --git a/pkg/inference/backends/llamacpp/download_darwin.go b/pkg/inference/backends/llamacpp/download_darwin.go new file mode 100644 index 0000000..9925dee --- /dev/null +++ b/pkg/inference/backends/llamacpp/download_darwin.go @@ -0,0 +1,17 @@ +package llamacpp + +import ( + "context" + "net/http" + + "github.com/docker/model-runner/pkg/logging" +) + +func (l *llamaCpp) ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client, + llamaCppPath, vendoredServerStoragePath string, +) error { + desiredVersion := "latest" + desiredVariant := "metal" + return l.downloadLatestLlamaCpp(ctx, log, httpClient, llamaCppPath, vendoredServerStoragePath, desiredVersion, + desiredVariant) +} diff --git a/pkg/inference/backends/llamacpp/download_linux.go b/pkg/inference/backends/llamacpp/download_linux.go new file mode 100644 index 0000000..7d659ca --- /dev/null +++ b/pkg/inference/backends/llamacpp/download_linux.go @@ -0,0 +1,15 @@ +package llamacpp + +import ( + "context" + "errors" + "net/http" + + "github.com/docker/model-runner/pkg/logging" +) + +func (l *llamaCpp) ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client, + llamaCppPath, vendoredServerStoragePath string, +) error { + return errors.New("platform is not supported") +} diff --git a/pkg/inference/backends/llamacpp/download_windows.go b/pkg/inference/backends/llamacpp/download_windows.go new file mode 100644 index 0000000..1a34206 --- /dev/null +++ b/pkg/inference/backends/llamacpp/download_windows.go @@ -0,0 +1,35 @@ +package llamacpp + +import ( + "context" + "fmt" + "net/http" + "path/filepath" + + "github.com/docker/model-runner/pkg/logging" +) + +func (l *llamaCpp) ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client, + llamaCppPath, vendoredServerStoragePath string, +) error { + nvGPUInfoBin := filepath.Join(vendoredServerStoragePath, "com.docker.nv-gpu-info.exe") + var canUseCUDA11 bool + var err error + ShouldUseGPUVariantLock.Lock() + defer ShouldUseGPUVariantLock.Unlock() + if ShouldUseGPUVariant { + canUseCUDA11, err = hasCUDA11CapableGPU(ctx, nvGPUInfoBin) + if err != nil { + l.status = fmt.Sprintf("failed to check CUDA 11 capability: %v", err) + return fmt.Errorf("failed to check CUDA 11 capability: %w", err) + } + } + desiredVersion := "latest" + desiredVariant := "cpu" + if canUseCUDA11 { + desiredVariant = "cuda" + } + l.status = fmt.Sprintf("looking for updates for %s variant", desiredVariant) + return l.downloadLatestLlamaCpp(ctx, log, httpClient, llamaCppPath, vendoredServerStoragePath, desiredVersion, + desiredVariant) +} diff --git a/pkg/inference/backends/llamacpp/gpuinfo_notwindows.go b/pkg/inference/backends/llamacpp/gpuinfo_notwindows.go new file mode 100644 index 0000000..d20748e --- /dev/null +++ b/pkg/inference/backends/llamacpp/gpuinfo_notwindows.go @@ -0,0 +1,7 @@ +//go:build !windows + +package llamacpp + +import "context" + +func CanUseGPU(context.Context, string) (bool, error) { return false, nil } diff --git a/pkg/inference/backends/llamacpp/gpuinfo_windows.go b/pkg/inference/backends/llamacpp/gpuinfo_windows.go new file mode 100644 index 0000000..b19fc41 --- /dev/null +++ b/pkg/inference/backends/llamacpp/gpuinfo_windows.go @@ -0,0 +1,61 @@ +package llamacpp + +import ( + "bufio" + "context" + "fmt" + "os/exec" + "strconv" + "strings" + + "github.com/jaypipes/ghw" +) + +func hasNVIDIAGPU() (bool, error) { + gpus, err := ghw.GPU() + if err != nil { + return false, err + } + for _, gpu := range gpus.GraphicsCards { + if strings.ToLower(gpu.DeviceInfo.Vendor.Name) == "nvidia" { + return true, nil + } + } + return false, nil +} + +func hasCUDA11CapableGPU(ctx context.Context, nvGPUInfoBin string) (bool, error) { + nvGPU, err := hasNVIDIAGPU() + if !nvGPU || err != nil { + return false, err + } + cmd := exec.CommandContext(ctx, nvGPUInfoBin) + out, err := cmd.CombinedOutput() + if err != nil { + return false, err + } + sc := bufio.NewScanner(strings.NewReader(string(out))) + for sc.Scan() { + version, found := strings.CutPrefix(sc.Text(), "driver version:") + if found { + version = strings.TrimSpace(version) + if len(version) != 5 { + return false, fmt.Errorf("unexpected NVIDIA driver version format: %s", version) + } + major, err := strconv.Atoi(version[:3]) + if err != nil { + return false, fmt.Errorf("unexpected NVIDIA driver version format: %s", version) + } + minor, err := strconv.Atoi(version[3:5]) + if err != nil { + return false, fmt.Errorf("unexpected NVIDIA driver version format: %s", version) + } + return major > 452 || (major == 452 && minor >= 39), nil + } + } + return false, nil +} + +func CanUseGPU(ctx context.Context, nvGPUInfoBin string) (bool, error) { + return hasCUDA11CapableGPU(ctx, nvGPUInfoBin) +} diff --git a/pkg/inference/backends/llamacpp/llamacpp.go b/pkg/inference/backends/llamacpp/llamacpp.go index dd7fa25..18999c3 100644 --- a/pkg/inference/backends/llamacpp/llamacpp.go +++ b/pkg/inference/backends/llamacpp/llamacpp.go @@ -34,6 +34,8 @@ type llamaCpp struct { // updatedServerStoragePath is the parent path of the updated version of com.docker.llama-server. // It is also where updates will be stored when downloaded. updatedServerStoragePath string + // status is the state in which the llama.cpp backend is in. + status string } // New creates a new llama.cpp-based backend. @@ -66,21 +68,33 @@ func (l *llamaCpp) UsesExternalModelManagement() bool { // Install implements inference.Backend.Install. func (l *llamaCpp) Install(ctx context.Context, httpClient *http.Client) error { + l.updatedLlamaCpp = false + // We don't currently support this backend on Windows or Linux. We'll likely // never support it on Intel Macs. - if runtime.GOOS == "windows" || runtime.GOOS == "linux" { + if runtime.GOOS == "linux" { return errors.New("not implemented") - } else if runtime.GOOS == "darwin" && runtime.GOARCH == "amd64" { + } else if (runtime.GOOS == "darwin" && runtime.GOARCH == "amd64") || (runtime.GOOS == "windows" && runtime.GOARCH == "arm64") { return errors.New("platform not supported") } + llamaServerBin := "com.docker.llama-server" + if runtime.GOOS == "windows" { + llamaServerBin = "com.docker.llama-server.exe" + } + + l.status = "installing" + // Temporary workaround for dynamically downloading llama.cpp from Docker Hub. - // Internet access and an available docker/docker-model-backend-llamacpp:latest-update on Docker Hub are required. - // Even if docker/docker-model-backend-llamacpp:latest-update has been downloaded before, we still require its + // Internet access and an available docker/docker-model-backend-llamacpp:latest on Docker Hub are required. + // Even if docker/docker-model-backend-llamacpp:latest has been downloaded before, we still require its // digest to be equal to the one on Docker Hub. - llamaCppPath := filepath.Join(l.updatedServerStoragePath, "com.docker.llama-server") - if err := ensureLatestLlamaCpp(ctx, l.log, httpClient, llamaCppPath); err != nil { + llamaCppPath := filepath.Join(l.updatedServerStoragePath, llamaServerBin) + if err := l.ensureLatestLlamaCpp(ctx, l.log, httpClient, llamaCppPath, l.vendoredServerStoragePath); err != nil { l.log.Infof("failed to ensure latest llama.cpp: %v\n", err) + if !errors.Is(err, errLlamaCppUpToDate) { + l.status = fmt.Sprintf("failed to install llama.cpp: %v", err) + } if errors.Is(err, context.Canceled) { return err } @@ -108,7 +122,7 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference if l.updatedLlamaCpp { binPath = l.updatedServerStoragePath } - llamaCppArgs := []string{"--model", modelPath, "--jinja"} + llamaCppArgs := []string{"--model", modelPath, "--jinja", "-ngl", "100"} if mode == inference.BackendModeEmbedding { llamaCppArgs = append(llamaCppArgs, "--embeddings") } @@ -121,8 +135,9 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference "DD_INF_UDS="+socket, ) llamaCppProcess.Cancel = func() error { - // TODO: Figure out the correct process to send on Windows if/when we - // port this backend there. + if runtime.GOOS == "windows" { + return llamaCppProcess.Process.Kill() + } return llamaCppProcess.Process.Signal(os.Interrupt) } serverLogStream := l.serverLog.Writer() @@ -159,3 +174,7 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference return fmt.Errorf("llama.cpp terminated unexpectedly: %w", llamaCppErr) } } + +func (l *llamaCpp) Status() string { + return l.status +} diff --git a/pkg/inference/backends/mlx/mlx.go b/pkg/inference/backends/mlx/mlx.go index 5338c00..7f30251 100644 --- a/pkg/inference/backends/mlx/mlx.go +++ b/pkg/inference/backends/mlx/mlx.go @@ -54,3 +54,7 @@ func (m *mlx) Run(ctx context.Context, socket, model string, mode inference.Back m.log.Warn("MLX backend is not yet supported") return errors.New("not implemented") } + +func (m *mlx) Status() string { + return "not running" +} diff --git a/pkg/inference/backends/vllm/vllm.go b/pkg/inference/backends/vllm/vllm.go index a52fbe5..ccf86df 100644 --- a/pkg/inference/backends/vllm/vllm.go +++ b/pkg/inference/backends/vllm/vllm.go @@ -54,3 +54,7 @@ func (v *vLLM) Run(ctx context.Context, socket, model string, mode inference.Bac v.log.Warn("vLLM backend is not yet supported") return errors.New("not implemented") } + +func (v *vLLM) Status() string { + return "not running" +} diff --git a/pkg/inference/scheduling/scheduler.go b/pkg/inference/scheduling/scheduler.go index 2e11f11..f0ae129 100644 --- a/pkg/inference/scheduling/scheduler.go +++ b/pkg/inference/scheduling/scheduler.go @@ -59,16 +59,16 @@ func NewScheduler( http.Error(w, "not found", http.StatusNotFound) }) - for _, route := range s.GetRoutes() { - s.router.HandleFunc(route, s.handleOpenAIInference) + for route, handler := range s.routeHandlers() { + s.router.HandleFunc(route, handler) } // Scheduler successfully initialized. return s } -func (s *Scheduler) GetRoutes() []string { - return []string{ +func (s *Scheduler) routeHandlers() map[string]http.HandlerFunc { + openAIRoutes := []string{ "POST " + inference.InferencePrefix + "/{backend}/v1/chat/completions", "POST " + inference.InferencePrefix + "/{backend}/v1/completions", "POST " + inference.InferencePrefix + "/{backend}/v1/embeddings", @@ -76,6 +76,21 @@ func (s *Scheduler) GetRoutes() []string { "POST " + inference.InferencePrefix + "/v1/completions", "POST " + inference.InferencePrefix + "/v1/embeddings", } + m := make(map[string]http.HandlerFunc) + for _, route := range openAIRoutes { + m[route] = s.handleOpenAIInference + } + m["GET "+inference.InferencePrefix+"/status"] = s.GetBackendStatus + return m +} + +func (s *Scheduler) GetRoutes() []string { + routeHandlers := s.routeHandlers() + routes := make([]string, 0, len(routeHandlers)) + for route := range routeHandlers { + routes = append(routes, route) + } + return routes } // Run is the scheduler's main run loop. By the time it returns, all inference @@ -196,6 +211,19 @@ func (s *Scheduler) handleOpenAIInference(w http.ResponseWriter, r *http.Request runner.ServeHTTP(w, upstreamRequest) } +func (s *Scheduler) GetBackendStatus(w http.ResponseWriter, r *http.Request) { + status := make(map[string]string) + for backendName, backend := range s.backends { + status[backendName] = backend.Status() + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(status) +} + +func (s *Scheduler) ResetInstaller(httpClient *http.Client) { + s.installer = newInstaller(s.log, s.backends, httpClient) +} + // ServeHTTP implements net/http.Handler.ServeHTTP. func (s *Scheduler) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.router.ServeHTTP(w, r)