Add Status to Backend interface
Signed-off-by: Dorin Geman <dorin.geman@docker.com>
This commit is contained in:
parent
5e4719501a
commit
e5d5ccf2dd
|
|
@ -67,4 +67,6 @@ type Backend interface {
|
|||
// instead load only the specified model. Backends should still respond to
|
||||
// OpenAI API requests for other models with a 421 error code.
|
||||
Run(ctx context.Context, socket, model string, mode BackendMode) error
|
||||
// Status returns in which state the backend is in.
|
||||
Status() string
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,9 +23,12 @@ const (
|
|||
hubRepo = "docker-model-backend-llamacpp"
|
||||
)
|
||||
|
||||
var ShouldUseGPUVariant bool
|
||||
var (
|
||||
ShouldUseGPUVariant bool
|
||||
errLlamaCppUpToDate = errors.New("bundled llama.cpp version is up to date, no need to update")
|
||||
)
|
||||
|
||||
func downloadLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
|
||||
func (l *llamaCpp) downloadLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
|
||||
llamaCppPath, vendoredServerStoragePath, desiredVersion, desiredVariant string,
|
||||
) error {
|
||||
url := fmt.Sprintf("https://hub.docker.com/v2/namespaces/%s/repositories/%s/tags", hubNamespace, hubRepo)
|
||||
|
|
@ -71,7 +74,9 @@ func downloadLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient
|
|||
if err != nil {
|
||||
return fmt.Errorf("failed to read bundled llama.cpp version: %w", err)
|
||||
} else if strings.TrimSpace(string(data)) == latest {
|
||||
return errors.New("bundled llama.cpp version is up to date, no need to update")
|
||||
l.status = fmt.Sprintf("running llama.cpp %s (%s) version: %s",
|
||||
desiredTag, latest, getLlamaCppVersion(log, filepath.Join(vendoredServerStoragePath, "com.docker.llama-server")))
|
||||
return errLlamaCppUpToDate
|
||||
}
|
||||
|
||||
data, err = os.ReadFile(currentVersionFile)
|
||||
|
|
@ -81,6 +86,8 @@ func downloadLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient
|
|||
} else if strings.TrimSpace(string(data)) == latest {
|
||||
log.Infoln("current llama.cpp version is already up to date")
|
||||
if _, err := os.Stat(llamaCppPath); err == nil {
|
||||
l.status = fmt.Sprintf("running llama.cpp %s (%s) version: %s",
|
||||
desiredTag, latest, getLlamaCppVersion(log, llamaCppPath))
|
||||
return nil
|
||||
}
|
||||
log.Infoln("llama.cpp binary must be updated, proceeding to update it")
|
||||
|
|
@ -95,6 +102,7 @@ func downloadLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient
|
|||
}
|
||||
defer os.RemoveAll(downloadDir)
|
||||
|
||||
l.status = fmt.Sprintf("downloading %s (%s) variant of llama.cpp", desiredTag, latest)
|
||||
if err := extractFromImage(ctx, log, image, runtime.GOOS, runtime.GOARCH, downloadDir); err != nil {
|
||||
return fmt.Errorf("could not extract image: %w", err)
|
||||
}
|
||||
|
|
@ -130,7 +138,8 @@ func downloadLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient
|
|||
}
|
||||
|
||||
log.Infoln("successfully updated llama.cpp binary")
|
||||
log.Infoln("running llama.cpp version:", getLlamaCppVersion(log, llamaCppPath))
|
||||
l.status = fmt.Sprintf("running llama.cpp %s (%s) version: %s", desiredTag, latest, getLlamaCppVersion(log, llamaCppPath))
|
||||
log.Infoln(l.status)
|
||||
|
||||
if err := os.WriteFile(currentVersionFile, []byte(latest), 0o644); err != nil {
|
||||
log.Warnf("failed to save llama.cpp version: %v", err)
|
||||
|
|
|
|||
|
|
@ -7,11 +7,11 @@ import (
|
|||
"github.com/docker/model-runner/pkg/logging"
|
||||
)
|
||||
|
||||
func ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
|
||||
func (l *llamaCpp) ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
|
||||
llamaCppPath, vendoredServerStoragePath string,
|
||||
) error {
|
||||
desiredVersion := "latest"
|
||||
desiredVariant := "metal"
|
||||
return downloadLatestLlamaCpp(ctx, log, httpClient, llamaCppPath, vendoredServerStoragePath, desiredVersion,
|
||||
return l.downloadLatestLlamaCpp(ctx, log, httpClient, llamaCppPath, vendoredServerStoragePath, desiredVersion,
|
||||
desiredVariant)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import (
|
|||
"github.com/docker/model-runner/pkg/logging"
|
||||
)
|
||||
|
||||
func ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
|
||||
func (l *llamaCpp) ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
|
||||
llamaCppPath, vendoredServerStoragePath string,
|
||||
) error {
|
||||
return errors.New("platform is not supported")
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import (
|
|||
"github.com/docker/model-runner/pkg/logging"
|
||||
)
|
||||
|
||||
func ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
|
||||
func (l *llamaCpp) ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
|
||||
llamaCppPath, vendoredServerStoragePath string,
|
||||
) error {
|
||||
nvGPUInfoBin := filepath.Join(vendoredServerStoragePath, "com.docker.nv-gpu-info.exe")
|
||||
|
|
@ -18,6 +18,7 @@ func ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *h
|
|||
if ShouldUseGPUVariant {
|
||||
canUseCUDA11, err = hasCUDA11CapableGPU(ctx, nvGPUInfoBin)
|
||||
if err != nil {
|
||||
l.status = fmt.Sprintf("failed to check CUDA 11 capability: %v", err)
|
||||
return fmt.Errorf("failed to check CUDA 11 capability: %w", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -26,6 +27,7 @@ func ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *h
|
|||
if canUseCUDA11 {
|
||||
desiredVariant = "cuda"
|
||||
}
|
||||
return downloadLatestLlamaCpp(ctx, log, httpClient, llamaCppPath, vendoredServerStoragePath, desiredVersion,
|
||||
l.status = fmt.Sprintf("looking for updates for %s variant", desiredVariant)
|
||||
return l.downloadLatestLlamaCpp(ctx, log, httpClient, llamaCppPath, vendoredServerStoragePath, desiredVersion,
|
||||
desiredVariant)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -34,6 +34,8 @@ type llamaCpp struct {
|
|||
// updatedServerStoragePath is the parent path of the updated version of com.docker.llama-server.
|
||||
// It is also where updates will be stored when downloaded.
|
||||
updatedServerStoragePath string
|
||||
// status is the state in which the llama.cpp backend is in.
|
||||
status string
|
||||
}
|
||||
|
||||
// New creates a new llama.cpp-based backend.
|
||||
|
|
@ -81,13 +83,18 @@ func (l *llamaCpp) Install(ctx context.Context, httpClient *http.Client) error {
|
|||
llamaServerBin = "com.docker.llama-server.exe"
|
||||
}
|
||||
|
||||
l.status = "installing"
|
||||
|
||||
// Temporary workaround for dynamically downloading llama.cpp from Docker Hub.
|
||||
// Internet access and an available docker/docker-model-backend-llamacpp:latest on Docker Hub are required.
|
||||
// Even if docker/docker-model-backend-llamacpp:latest has been downloaded before, we still require its
|
||||
// digest to be equal to the one on Docker Hub.
|
||||
llamaCppPath := filepath.Join(l.updatedServerStoragePath, llamaServerBin)
|
||||
if err := ensureLatestLlamaCpp(ctx, l.log, httpClient, llamaCppPath, l.vendoredServerStoragePath); err != nil {
|
||||
if err := l.ensureLatestLlamaCpp(ctx, l.log, httpClient, llamaCppPath, l.vendoredServerStoragePath); err != nil {
|
||||
l.log.Infof("failed to ensure latest llama.cpp: %v\n", err)
|
||||
if !errors.Is(err, errLlamaCppUpToDate) {
|
||||
l.status = fmt.Sprintf("failed to install llama.cpp: %v", err)
|
||||
}
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return err
|
||||
}
|
||||
|
|
@ -167,3 +174,7 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference
|
|||
return fmt.Errorf("llama.cpp terminated unexpectedly: %w", llamaCppErr)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *llamaCpp) Status() string {
|
||||
return l.status
|
||||
}
|
||||
|
|
|
|||
|
|
@ -54,3 +54,7 @@ func (m *mlx) Run(ctx context.Context, socket, model string, mode inference.Back
|
|||
m.log.Warn("MLX backend is not yet supported")
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (m *mlx) Status() string {
|
||||
return "not running"
|
||||
}
|
||||
|
|
@ -54,3 +54,7 @@ func (v *vLLM) Run(ctx context.Context, socket, model string, mode inference.Bac
|
|||
v.log.Warn("vLLM backend is not yet supported")
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (v *vLLM) Status() string {
|
||||
return "not running"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -59,16 +59,16 @@ func NewScheduler(
|
|||
http.Error(w, "not found", http.StatusNotFound)
|
||||
})
|
||||
|
||||
for _, route := range s.GetRoutes() {
|
||||
s.router.HandleFunc(route, s.handleOpenAIInference)
|
||||
for route, handler := range s.routeHandlers() {
|
||||
s.router.HandleFunc(route, handler)
|
||||
}
|
||||
|
||||
// Scheduler successfully initialized.
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Scheduler) GetRoutes() []string {
|
||||
return []string{
|
||||
func (s *Scheduler) routeHandlers() map[string]http.HandlerFunc {
|
||||
openAIRoutes := []string{
|
||||
"POST " + inference.InferencePrefix + "/{backend}/v1/chat/completions",
|
||||
"POST " + inference.InferencePrefix + "/{backend}/v1/completions",
|
||||
"POST " + inference.InferencePrefix + "/{backend}/v1/embeddings",
|
||||
|
|
@ -76,6 +76,21 @@ func (s *Scheduler) GetRoutes() []string {
|
|||
"POST " + inference.InferencePrefix + "/v1/completions",
|
||||
"POST " + inference.InferencePrefix + "/v1/embeddings",
|
||||
}
|
||||
m := make(map[string]http.HandlerFunc)
|
||||
for _, route := range openAIRoutes {
|
||||
m[route] = s.handleOpenAIInference
|
||||
}
|
||||
m["GET "+inference.InferencePrefix+"/status"] = s.GetBackendStatus
|
||||
return m
|
||||
}
|
||||
|
||||
func (s *Scheduler) GetRoutes() []string {
|
||||
routeHandlers := s.routeHandlers()
|
||||
routes := make([]string, 0, len(routeHandlers))
|
||||
for route := range routeHandlers {
|
||||
routes = append(routes, route)
|
||||
}
|
||||
return routes
|
||||
}
|
||||
|
||||
// Run is the scheduler's main run loop. By the time it returns, all inference
|
||||
|
|
@ -196,6 +211,15 @@ func (s *Scheduler) handleOpenAIInference(w http.ResponseWriter, r *http.Request
|
|||
runner.ServeHTTP(w, upstreamRequest)
|
||||
}
|
||||
|
||||
func (s *Scheduler) GetBackendStatus(w http.ResponseWriter, r *http.Request) {
|
||||
status := make(map[string]string)
|
||||
for backendName, backend := range s.backends {
|
||||
status[backendName] = backend.Status()
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(status)
|
||||
}
|
||||
|
||||
func (s *Scheduler) ResetInstaller(httpClient *http.Client) {
|
||||
s.installer = newInstaller(s.log, s.backends, httpClient)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue