Add Status to Backend interface
Signed-off-by: Dorin Geman <dorin.geman@docker.com>
This commit is contained in:
parent
5e4719501a
commit
e5d5ccf2dd
|
|
@ -67,4 +67,6 @@ type Backend interface {
|
||||||
// instead load only the specified model. Backends should still respond to
|
// instead load only the specified model. Backends should still respond to
|
||||||
// OpenAI API requests for other models with a 421 error code.
|
// OpenAI API requests for other models with a 421 error code.
|
||||||
Run(ctx context.Context, socket, model string, mode BackendMode) error
|
Run(ctx context.Context, socket, model string, mode BackendMode) error
|
||||||
|
// Status returns in which state the backend is in.
|
||||||
|
Status() string
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -23,9 +23,12 @@ const (
|
||||||
hubRepo = "docker-model-backend-llamacpp"
|
hubRepo = "docker-model-backend-llamacpp"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ShouldUseGPUVariant bool
|
var (
|
||||||
|
ShouldUseGPUVariant bool
|
||||||
|
errLlamaCppUpToDate = errors.New("bundled llama.cpp version is up to date, no need to update")
|
||||||
|
)
|
||||||
|
|
||||||
func downloadLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
|
func (l *llamaCpp) downloadLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
|
||||||
llamaCppPath, vendoredServerStoragePath, desiredVersion, desiredVariant string,
|
llamaCppPath, vendoredServerStoragePath, desiredVersion, desiredVariant string,
|
||||||
) error {
|
) error {
|
||||||
url := fmt.Sprintf("https://hub.docker.com/v2/namespaces/%s/repositories/%s/tags", hubNamespace, hubRepo)
|
url := fmt.Sprintf("https://hub.docker.com/v2/namespaces/%s/repositories/%s/tags", hubNamespace, hubRepo)
|
||||||
|
|
@ -71,7 +74,9 @@ func downloadLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to read bundled llama.cpp version: %w", err)
|
return fmt.Errorf("failed to read bundled llama.cpp version: %w", err)
|
||||||
} else if strings.TrimSpace(string(data)) == latest {
|
} else if strings.TrimSpace(string(data)) == latest {
|
||||||
return errors.New("bundled llama.cpp version is up to date, no need to update")
|
l.status = fmt.Sprintf("running llama.cpp %s (%s) version: %s",
|
||||||
|
desiredTag, latest, getLlamaCppVersion(log, filepath.Join(vendoredServerStoragePath, "com.docker.llama-server")))
|
||||||
|
return errLlamaCppUpToDate
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err = os.ReadFile(currentVersionFile)
|
data, err = os.ReadFile(currentVersionFile)
|
||||||
|
|
@ -81,6 +86,8 @@ func downloadLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient
|
||||||
} else if strings.TrimSpace(string(data)) == latest {
|
} else if strings.TrimSpace(string(data)) == latest {
|
||||||
log.Infoln("current llama.cpp version is already up to date")
|
log.Infoln("current llama.cpp version is already up to date")
|
||||||
if _, err := os.Stat(llamaCppPath); err == nil {
|
if _, err := os.Stat(llamaCppPath); err == nil {
|
||||||
|
l.status = fmt.Sprintf("running llama.cpp %s (%s) version: %s",
|
||||||
|
desiredTag, latest, getLlamaCppVersion(log, llamaCppPath))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
log.Infoln("llama.cpp binary must be updated, proceeding to update it")
|
log.Infoln("llama.cpp binary must be updated, proceeding to update it")
|
||||||
|
|
@ -95,6 +102,7 @@ func downloadLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(downloadDir)
|
defer os.RemoveAll(downloadDir)
|
||||||
|
|
||||||
|
l.status = fmt.Sprintf("downloading %s (%s) variant of llama.cpp", desiredTag, latest)
|
||||||
if err := extractFromImage(ctx, log, image, runtime.GOOS, runtime.GOARCH, downloadDir); err != nil {
|
if err := extractFromImage(ctx, log, image, runtime.GOOS, runtime.GOARCH, downloadDir); err != nil {
|
||||||
return fmt.Errorf("could not extract image: %w", err)
|
return fmt.Errorf("could not extract image: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -130,7 +138,8 @@ func downloadLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infoln("successfully updated llama.cpp binary")
|
log.Infoln("successfully updated llama.cpp binary")
|
||||||
log.Infoln("running llama.cpp version:", getLlamaCppVersion(log, llamaCppPath))
|
l.status = fmt.Sprintf("running llama.cpp %s (%s) version: %s", desiredTag, latest, getLlamaCppVersion(log, llamaCppPath))
|
||||||
|
log.Infoln(l.status)
|
||||||
|
|
||||||
if err := os.WriteFile(currentVersionFile, []byte(latest), 0o644); err != nil {
|
if err := os.WriteFile(currentVersionFile, []byte(latest), 0o644); err != nil {
|
||||||
log.Warnf("failed to save llama.cpp version: %v", err)
|
log.Warnf("failed to save llama.cpp version: %v", err)
|
||||||
|
|
|
||||||
|
|
@ -7,11 +7,11 @@ import (
|
||||||
"github.com/docker/model-runner/pkg/logging"
|
"github.com/docker/model-runner/pkg/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
|
func (l *llamaCpp) ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
|
||||||
llamaCppPath, vendoredServerStoragePath string,
|
llamaCppPath, vendoredServerStoragePath string,
|
||||||
) error {
|
) error {
|
||||||
desiredVersion := "latest"
|
desiredVersion := "latest"
|
||||||
desiredVariant := "metal"
|
desiredVariant := "metal"
|
||||||
return downloadLatestLlamaCpp(ctx, log, httpClient, llamaCppPath, vendoredServerStoragePath, desiredVersion,
|
return l.downloadLatestLlamaCpp(ctx, log, httpClient, llamaCppPath, vendoredServerStoragePath, desiredVersion,
|
||||||
desiredVariant)
|
desiredVariant)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ import (
|
||||||
"github.com/docker/model-runner/pkg/logging"
|
"github.com/docker/model-runner/pkg/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
|
func (l *llamaCpp) ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
|
||||||
llamaCppPath, vendoredServerStoragePath string,
|
llamaCppPath, vendoredServerStoragePath string,
|
||||||
) error {
|
) error {
|
||||||
return errors.New("platform is not supported")
|
return errors.New("platform is not supported")
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ import (
|
||||||
"github.com/docker/model-runner/pkg/logging"
|
"github.com/docker/model-runner/pkg/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
|
func (l *llamaCpp) ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
|
||||||
llamaCppPath, vendoredServerStoragePath string,
|
llamaCppPath, vendoredServerStoragePath string,
|
||||||
) error {
|
) error {
|
||||||
nvGPUInfoBin := filepath.Join(vendoredServerStoragePath, "com.docker.nv-gpu-info.exe")
|
nvGPUInfoBin := filepath.Join(vendoredServerStoragePath, "com.docker.nv-gpu-info.exe")
|
||||||
|
|
@ -18,6 +18,7 @@ func ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *h
|
||||||
if ShouldUseGPUVariant {
|
if ShouldUseGPUVariant {
|
||||||
canUseCUDA11, err = hasCUDA11CapableGPU(ctx, nvGPUInfoBin)
|
canUseCUDA11, err = hasCUDA11CapableGPU(ctx, nvGPUInfoBin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
l.status = fmt.Sprintf("failed to check CUDA 11 capability: %v", err)
|
||||||
return fmt.Errorf("failed to check CUDA 11 capability: %w", err)
|
return fmt.Errorf("failed to check CUDA 11 capability: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -26,6 +27,7 @@ func ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *h
|
||||||
if canUseCUDA11 {
|
if canUseCUDA11 {
|
||||||
desiredVariant = "cuda"
|
desiredVariant = "cuda"
|
||||||
}
|
}
|
||||||
return downloadLatestLlamaCpp(ctx, log, httpClient, llamaCppPath, vendoredServerStoragePath, desiredVersion,
|
l.status = fmt.Sprintf("looking for updates for %s variant", desiredVariant)
|
||||||
|
return l.downloadLatestLlamaCpp(ctx, log, httpClient, llamaCppPath, vendoredServerStoragePath, desiredVersion,
|
||||||
desiredVariant)
|
desiredVariant)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,8 @@ type llamaCpp struct {
|
||||||
// updatedServerStoragePath is the parent path of the updated version of com.docker.llama-server.
|
// updatedServerStoragePath is the parent path of the updated version of com.docker.llama-server.
|
||||||
// It is also where updates will be stored when downloaded.
|
// It is also where updates will be stored when downloaded.
|
||||||
updatedServerStoragePath string
|
updatedServerStoragePath string
|
||||||
|
// status is the state in which the llama.cpp backend is in.
|
||||||
|
status string
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new llama.cpp-based backend.
|
// New creates a new llama.cpp-based backend.
|
||||||
|
|
@ -81,13 +83,18 @@ func (l *llamaCpp) Install(ctx context.Context, httpClient *http.Client) error {
|
||||||
llamaServerBin = "com.docker.llama-server.exe"
|
llamaServerBin = "com.docker.llama-server.exe"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
l.status = "installing"
|
||||||
|
|
||||||
// Temporary workaround for dynamically downloading llama.cpp from Docker Hub.
|
// Temporary workaround for dynamically downloading llama.cpp from Docker Hub.
|
||||||
// Internet access and an available docker/docker-model-backend-llamacpp:latest on Docker Hub are required.
|
// Internet access and an available docker/docker-model-backend-llamacpp:latest on Docker Hub are required.
|
||||||
// Even if docker/docker-model-backend-llamacpp:latest has been downloaded before, we still require its
|
// Even if docker/docker-model-backend-llamacpp:latest has been downloaded before, we still require its
|
||||||
// digest to be equal to the one on Docker Hub.
|
// digest to be equal to the one on Docker Hub.
|
||||||
llamaCppPath := filepath.Join(l.updatedServerStoragePath, llamaServerBin)
|
llamaCppPath := filepath.Join(l.updatedServerStoragePath, llamaServerBin)
|
||||||
if err := ensureLatestLlamaCpp(ctx, l.log, httpClient, llamaCppPath, l.vendoredServerStoragePath); err != nil {
|
if err := l.ensureLatestLlamaCpp(ctx, l.log, httpClient, llamaCppPath, l.vendoredServerStoragePath); err != nil {
|
||||||
l.log.Infof("failed to ensure latest llama.cpp: %v\n", err)
|
l.log.Infof("failed to ensure latest llama.cpp: %v\n", err)
|
||||||
|
if !errors.Is(err, errLlamaCppUpToDate) {
|
||||||
|
l.status = fmt.Sprintf("failed to install llama.cpp: %v", err)
|
||||||
|
}
|
||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -167,3 +174,7 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference
|
||||||
return fmt.Errorf("llama.cpp terminated unexpectedly: %w", llamaCppErr)
|
return fmt.Errorf("llama.cpp terminated unexpectedly: %w", llamaCppErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *llamaCpp) Status() string {
|
||||||
|
return l.status
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -54,3 +54,7 @@ func (m *mlx) Run(ctx context.Context, socket, model string, mode inference.Back
|
||||||
m.log.Warn("MLX backend is not yet supported")
|
m.log.Warn("MLX backend is not yet supported")
|
||||||
return errors.New("not implemented")
|
return errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mlx) Status() string {
|
||||||
|
return "not running"
|
||||||
|
}
|
||||||
|
|
@ -54,3 +54,7 @@ func (v *vLLM) Run(ctx context.Context, socket, model string, mode inference.Bac
|
||||||
v.log.Warn("vLLM backend is not yet supported")
|
v.log.Warn("vLLM backend is not yet supported")
|
||||||
return errors.New("not implemented")
|
return errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (v *vLLM) Status() string {
|
||||||
|
return "not running"
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -59,16 +59,16 @@ func NewScheduler(
|
||||||
http.Error(w, "not found", http.StatusNotFound)
|
http.Error(w, "not found", http.StatusNotFound)
|
||||||
})
|
})
|
||||||
|
|
||||||
for _, route := range s.GetRoutes() {
|
for route, handler := range s.routeHandlers() {
|
||||||
s.router.HandleFunc(route, s.handleOpenAIInference)
|
s.router.HandleFunc(route, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scheduler successfully initialized.
|
// Scheduler successfully initialized.
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Scheduler) GetRoutes() []string {
|
func (s *Scheduler) routeHandlers() map[string]http.HandlerFunc {
|
||||||
return []string{
|
openAIRoutes := []string{
|
||||||
"POST " + inference.InferencePrefix + "/{backend}/v1/chat/completions",
|
"POST " + inference.InferencePrefix + "/{backend}/v1/chat/completions",
|
||||||
"POST " + inference.InferencePrefix + "/{backend}/v1/completions",
|
"POST " + inference.InferencePrefix + "/{backend}/v1/completions",
|
||||||
"POST " + inference.InferencePrefix + "/{backend}/v1/embeddings",
|
"POST " + inference.InferencePrefix + "/{backend}/v1/embeddings",
|
||||||
|
|
@ -76,6 +76,21 @@ func (s *Scheduler) GetRoutes() []string {
|
||||||
"POST " + inference.InferencePrefix + "/v1/completions",
|
"POST " + inference.InferencePrefix + "/v1/completions",
|
||||||
"POST " + inference.InferencePrefix + "/v1/embeddings",
|
"POST " + inference.InferencePrefix + "/v1/embeddings",
|
||||||
}
|
}
|
||||||
|
m := make(map[string]http.HandlerFunc)
|
||||||
|
for _, route := range openAIRoutes {
|
||||||
|
m[route] = s.handleOpenAIInference
|
||||||
|
}
|
||||||
|
m["GET "+inference.InferencePrefix+"/status"] = s.GetBackendStatus
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Scheduler) GetRoutes() []string {
|
||||||
|
routeHandlers := s.routeHandlers()
|
||||||
|
routes := make([]string, 0, len(routeHandlers))
|
||||||
|
for route := range routeHandlers {
|
||||||
|
routes = append(routes, route)
|
||||||
|
}
|
||||||
|
return routes
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run is the scheduler's main run loop. By the time it returns, all inference
|
// Run is the scheduler's main run loop. By the time it returns, all inference
|
||||||
|
|
@ -196,6 +211,15 @@ func (s *Scheduler) handleOpenAIInference(w http.ResponseWriter, r *http.Request
|
||||||
runner.ServeHTTP(w, upstreamRequest)
|
runner.ServeHTTP(w, upstreamRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Scheduler) GetBackendStatus(w http.ResponseWriter, r *http.Request) {
|
||||||
|
status := make(map[string]string)
|
||||||
|
for backendName, backend := range s.backends {
|
||||||
|
status[backendName] = backend.Status()
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(status)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Scheduler) ResetInstaller(httpClient *http.Client) {
|
func (s *Scheduler) ResetInstaller(httpClient *http.Client) {
|
||||||
s.installer = newInstaller(s.log, s.backends, httpClient)
|
s.installer = newInstaller(s.log, s.backends, httpClient)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue