diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0d69b20..bbb1dfa 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,6 +13,9 @@ jobs: - name: Checkout code uses: actions/checkout@v4 + - name: Verify vendor/ is not present + run: stat vendor && exit 1 || exit 0 + - name: Set up Go uses: actions/setup-go@v5 with: diff --git a/.gitignore b/.gitignore index 15ddf27..1c7c4bc 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ model-runner.sock models-store/ # Directory where we store the updated llama.cpp updated-inference/ +vendor/ diff --git a/main.go b/main.go index 287b64b..f024039 100644 --- a/main.go +++ b/main.go @@ -14,6 +14,7 @@ import ( "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/inference/backends/llamacpp" "github.com/docker/model-runner/pkg/inference/config" + "github.com/docker/model-runner/pkg/inference/memory" "github.com/docker/model-runner/pkg/inference/models" "github.com/docker/model-runner/pkg/inference/scheduling" "github.com/docker/model-runner/pkg/metrics" @@ -54,6 +55,20 @@ func main() { llamacpp.SetDesiredServerVersion(desiredSeverVersion) } + llamaServerPath := os.Getenv("LLAMA_SERVER_PATH") + if llamaServerPath == "" { + llamaServerPath = "/Applications/Docker.app/Contents/Resources/model-runner/bin" + } + + gpuInfo := gpuinfo.New(llamaServerPath) + + sysMemInfo, err := memory.NewSystemMemoryInfo(log, gpuInfo) + if err != nil { + log.Fatalf("unable to initialize system memory info: %v", err) + } + + memEstimator := memory.NewEstimator(sysMemInfo) + modelManager := models.NewManager( log, models.ClientConfig{ @@ -61,13 +76,9 @@ func main() { Logger: log.WithFields(logrus.Fields{"component": "model-manager"}), }, nil, + memEstimator, ) - llamaServerPath := os.Getenv("LLAMA_SERVER_PATH") - if llamaServerPath == "" { - llamaServerPath = "/Applications/Docker.app/Contents/Resources/model-runner/bin" - } - log.Infof("LLAMA_SERVER_PATH: %s", llamaServerPath) // Create llama.cpp configuration from environment variables @@ -90,7 +101,7 @@ func main() { log.Fatalf("unable to initialize %s backend: %v", llamacpp.Name, err) } - gpuInfo := gpuinfo.New(llamaServerPath) + memEstimator.SetDefaultBackend(llamaCppBackend) scheduler := scheduling.NewScheduler( log, @@ -105,7 +116,7 @@ func main() { "", false, ), - gpuInfo, + sysMemInfo, ) router := routing.NewNormalizedServeMux() diff --git a/pkg/inference/backend.go b/pkg/inference/backend.go index 3d857f8..26bd3fd 100644 --- a/pkg/inference/backend.go +++ b/pkg/inference/backend.go @@ -2,7 +2,6 @@ package inference import ( "context" - "errors" "net/http" ) @@ -18,9 +17,13 @@ const ( BackendModeEmbedding ) -var ( - ErrGGUFParse = errors.New("failed to parse GGUF file") -) +type ErrGGUFParse struct { + Err error +} + +func (e *ErrGGUFParse) Error() string { + return "failed to parse GGUF: " + e.Err.Error() +} // String implements Stringer.String for BackendMode. func (m BackendMode) String() string { @@ -88,5 +91,5 @@ type Backend interface { GetDiskUsage() (int64, error) // GetRequiredMemoryForModel returns the required working memory for a given // model. - GetRequiredMemoryForModel(model string, config *BackendConfiguration) (*RequiredMemory, error) + GetRequiredMemoryForModel(ctx context.Context, model string, config *BackendConfiguration) (*RequiredMemory, error) } diff --git a/pkg/inference/backends/llamacpp/llamacpp.go b/pkg/inference/backends/llamacpp/llamacpp.go index 15c6625..6f05750 100644 --- a/pkg/inference/backends/llamacpp/llamacpp.go +++ b/pkg/inference/backends/llamacpp/llamacpp.go @@ -15,6 +15,8 @@ import ( "runtime" "strings" + "github.com/docker/model-distribution/types" + v1 "github.com/google/go-containerregistry/pkg/v1" parser "github.com/gpustack/gguf-parser-go" "github.com/docker/model-runner/pkg/diskusage" @@ -223,23 +225,30 @@ func (l *llamaCpp) GetDiskUsage() (int64, error) { return size, nil } -func (l *llamaCpp) GetRequiredMemoryForModel(model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) { - bundle, err := l.modelManager.GetBundle(model) +func (l *llamaCpp) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) { + var mdlGguf *parser.GGUFFile + var mdlConfig types.Config + inStore, err := l.modelManager.IsModelInStore(model) if err != nil { - return nil, fmt.Errorf("getting model(%s): %w", model, err) + return nil, fmt.Errorf("checking if model is in local store: %w", err) + } + if inStore { + mdlGguf, mdlConfig, err = l.parseLocalModel(model) + if err != nil { + return nil, &inference.ErrGGUFParse{Err: err} + } + } else { + mdlGguf, mdlConfig, err = l.parseRemoteModel(ctx, model) + if err != nil { + return nil, &inference.ErrGGUFParse{Err: err} + } } - mdlGGUF, err := parser.ParseGGUFFile(bundle.GGUFPath()) - if err != nil { - l.log.Warnf("Failed to parse gguf(%s): %s", bundle.GGUFPath(), err) - return nil, inference.ErrGGUFParse - } - - contextSize := GetContextSize(bundle.RuntimeConfig(), config) + contextSize := GetContextSize(mdlConfig, config) ngl := uint64(0) if l.gpuSupported { - if runtime.GOOS == "windows" && runtime.GOARCH == "arm64" && strings.TrimSpace(mdlGGUF.Metadata().FileType.String()) != "Q4_0" { + if runtime.GOOS == "windows" && runtime.GOARCH == "arm64" && mdlConfig.Quantization != "Q4_0" { ngl = 0 // only Q4_0 models can be accelerated on Adreno } ngl = 100 @@ -248,7 +257,7 @@ func (l *llamaCpp) GetRequiredMemoryForModel(model string, config *inference.Bac // TODO(p1-0tr): for now assume we are running on GPU (single one) - Devices[1]; // sum up weights + kv cache + context for an estimate of total GPU memory needed // while running inference with the given model - estimate := mdlGGUF.EstimateLLaMACppRun(parser.WithLLaMACppContextSize(int32(contextSize)), + estimate := mdlGguf.EstimateLLaMACppRun(parser.WithLLaMACppContextSize(int32(contextSize)), // TODO(p1-0tr): add logic for resolving other param values, instead of hardcoding them parser.WithLLaMACppLogicalBatchSize(2048), parser.WithLLaMACppOffloadLayers(ngl)) @@ -270,6 +279,63 @@ func (l *llamaCpp) GetRequiredMemoryForModel(model string, config *inference.Bac }, nil } +func (l *llamaCpp) parseLocalModel(model string) (*parser.GGUFFile, types.Config, error) { + bundle, err := l.modelManager.GetBundle(model) + if err != nil { + return nil, types.Config{}, fmt.Errorf("getting model(%s): %w", model, err) + } + modelGGUF, err := parser.ParseGGUFFile(bundle.GGUFPath()) + if err != nil { + return nil, types.Config{}, fmt.Errorf("parsing gguf(%s): %w", bundle.GGUFPath(), err) + } + return modelGGUF, bundle.RuntimeConfig(), nil +} + +func (l *llamaCpp) parseRemoteModel(ctx context.Context, model string) (*parser.GGUFFile, types.Config, error) { + mdl, err := l.modelManager.GetRemoteModel(ctx, model) + if err != nil { + return nil, types.Config{}, fmt.Errorf("getting remote model(%s): %w", model, err) + } + layers, err := mdl.Layers() + if err != nil { + return nil, types.Config{}, fmt.Errorf("getting layers of model(%s): %w", model, err) + } + var ggufDigest v1.Hash + for _, layer := range layers { + mt, err := layer.MediaType() + if err != nil { + return nil, types.Config{}, fmt.Errorf("getting media type of model(%s) layer: %w", model, err) + } + if mt == types.MediaTypeGGUF { + ggufDigest, err = layer.Digest() + if err != nil { + return nil, types.Config{}, fmt.Errorf("getting digest of GGUF layer for model(%s): %w", model, err) + } + break + } + } + if ggufDigest.String() == "" { + return nil, types.Config{}, fmt.Errorf("model(%s) has no GGUF layer", model) + } + blobURL, err := l.modelManager.GetRemoteModelBlobURL(model, ggufDigest) + if err != nil { + return nil, types.Config{}, fmt.Errorf("getting GGUF blob URL for model(%s): %w", model, err) + } + tok, err := l.modelManager.BearerTokenForModel(ctx, model) + if err != nil { + return nil, types.Config{}, fmt.Errorf("getting bearer token for model(%s): %w", model, err) + } + mdlGguf, err := parser.ParseGGUFFileRemote(ctx, blobURL, parser.UseBearerAuth(tok)) + if err != nil { + return nil, types.Config{}, fmt.Errorf("parsing GGUF for model(%s): %w", model, err) + } + config, err := mdl.Config() + if err != nil { + return nil, types.Config{}, fmt.Errorf("getting config for model(%s): %w", model, err) + } + return mdlGguf, config, nil +} + func (l *llamaCpp) checkGPUSupport(ctx context.Context) bool { binPath := l.vendoredServerStoragePath if l.updatedLlamaCpp { diff --git a/pkg/inference/backends/mlx/mlx.go b/pkg/inference/backends/mlx/mlx.go index 2bae367..267a9c8 100644 --- a/pkg/inference/backends/mlx/mlx.go +++ b/pkg/inference/backends/mlx/mlx.go @@ -63,6 +63,6 @@ func (m *mlx) GetDiskUsage() (int64, error) { return 0, nil } -func (m *mlx) GetRequiredMemoryForModel(model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) { +func (m *mlx) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) { return nil, errors.New("not implemented") } diff --git a/pkg/inference/backends/vllm/vllm.go b/pkg/inference/backends/vllm/vllm.go index 86334d4..8bafb8e 100644 --- a/pkg/inference/backends/vllm/vllm.go +++ b/pkg/inference/backends/vllm/vllm.go @@ -63,6 +63,6 @@ func (v *vLLM) GetDiskUsage() (int64, error) { return 0, nil } -func (v *vLLM) GetRequiredMemoryForModel(model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) { +func (v *vLLM) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) { return nil, errors.New("not implemented") } diff --git a/pkg/inference/memory/estimator.go b/pkg/inference/memory/estimator.go new file mode 100644 index 0000000..6d66e7f --- /dev/null +++ b/pkg/inference/memory/estimator.go @@ -0,0 +1,48 @@ +package memory + +import ( + "context" + "errors" + "fmt" + + "github.com/docker/model-runner/pkg/inference" +) + +type MemoryEstimator interface { + SetDefaultBackend(MemoryEstimatorBackend) + GetRequiredMemoryForModel(context.Context, string, *inference.BackendConfiguration) (*inference.RequiredMemory, error) + HaveSufficientMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (bool, error) +} + +type MemoryEstimatorBackend interface { + GetRequiredMemoryForModel(context.Context, string, *inference.BackendConfiguration) (*inference.RequiredMemory, error) +} + +type memoryEstimator struct { + systemMemoryInfo SystemMemoryInfo + defaultBackend MemoryEstimatorBackend +} + +func NewEstimator(systemMemoryInfo SystemMemoryInfo) MemoryEstimator { + return &memoryEstimator{systemMemoryInfo: systemMemoryInfo} +} + +func (m *memoryEstimator) SetDefaultBackend(backend MemoryEstimatorBackend) { + m.defaultBackend = backend +} + +func (m *memoryEstimator) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) { + if m.defaultBackend == nil { + return nil, errors.New("default backend not configured") + } + + return m.defaultBackend.GetRequiredMemoryForModel(ctx, model, config) +} + +func (m *memoryEstimator) HaveSufficientMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (bool, error) { + req, err := m.GetRequiredMemoryForModel(ctx, model, config) + if err != nil { + return false, fmt.Errorf("estimating required memory for model: %w", err) + } + return m.systemMemoryInfo.HaveSufficientMemory(*req), nil +} diff --git a/pkg/inference/memory/system.go b/pkg/inference/memory/system.go new file mode 100644 index 0000000..fb0a384 --- /dev/null +++ b/pkg/inference/memory/system.go @@ -0,0 +1,55 @@ +package memory + +import ( + "github.com/docker/model-runner/pkg/gpuinfo" + "github.com/docker/model-runner/pkg/inference" + "github.com/docker/model-runner/pkg/logging" + "github.com/elastic/go-sysinfo" +) + +type SystemMemoryInfo interface { + HaveSufficientMemory(inference.RequiredMemory) bool + GetTotalMemory() inference.RequiredMemory +} + +type systemMemoryInfo struct { + log logging.Logger + totalMemory inference.RequiredMemory +} + +func NewSystemMemoryInfo(log logging.Logger, gpuInfo *gpuinfo.GPUInfo) (SystemMemoryInfo, error) { + // Compute the amount of available memory. + // TODO(p1-0tr): improve error handling + vramSize, err := gpuInfo.GetVRAMSize() + if err != nil { + vramSize = 1 + log.Warnf("Could not read VRAM size: %s", err) + } else { + log.Infof("Running on system with %d MB VRAM", vramSize/1024/1024) + } + ramSize := uint64(1) + hostInfo, err := sysinfo.Host() + if err != nil { + log.Warnf("Could not read host info: %s", err) + } else { + ram, err := hostInfo.Memory() + if err != nil { + log.Warnf("Could not read host RAM size: %s", err) + } else { + ramSize = ram.Total + log.Infof("Running on system with %d MB RAM", ramSize/1024/1024) + } + } + return &systemMemoryInfo{ + log: log, + totalMemory: inference.RequiredMemory{RAM: ramSize, VRAM: vramSize}, + }, nil +} + +func (s *systemMemoryInfo) HaveSufficientMemory(req inference.RequiredMemory) bool { + return req.RAM <= s.totalMemory.RAM && req.VRAM <= s.totalMemory.VRAM +} + +func (s *systemMemoryInfo) GetTotalMemory() inference.RequiredMemory { + return s.totalMemory +} diff --git a/pkg/inference/models/api.go b/pkg/inference/models/api.go index a21864d..196d64f 100644 --- a/pkg/inference/models/api.go +++ b/pkg/inference/models/api.go @@ -14,6 +14,9 @@ import ( type ModelCreateRequest struct { // From is the name of the model to pull. From string `json:"from"` + // IgnoreRuntimeMemoryCheck indicates whether the server should check if it has sufficient + // memory to run the given model (assuming default configuration). + IgnoreRuntimeMemoryCheck bool `json:"ignore-runtime-memory-check,omitempty"` } // ToOpenAIList converts the model list to its OpenAI API representation. This function never diff --git a/pkg/inference/models/manager.go b/pkg/inference/models/manager.go index 9ea8011..7f84c34 100644 --- a/pkg/inference/models/manager.go +++ b/pkg/inference/models/manager.go @@ -15,11 +15,12 @@ import ( "github.com/docker/model-distribution/distribution" "github.com/docker/model-distribution/registry" "github.com/docker/model-distribution/types" - "github.com/sirupsen/logrus" - "github.com/docker/model-runner/pkg/diskusage" "github.com/docker/model-runner/pkg/inference" + "github.com/docker/model-runner/pkg/inference/memory" "github.com/docker/model-runner/pkg/logging" + v1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/sirupsen/logrus" ) const ( @@ -43,6 +44,8 @@ type Manager struct { registryClient *registry.Client // lock is used to synchronize access to the models manager's router. lock sync.RWMutex + // memoryEstimator is used to calculate runtime memory requirements for models. + memoryEstimator memory.MemoryEstimator } type ClientConfig struct { @@ -57,7 +60,7 @@ type ClientConfig struct { } // NewManager creates a new model's manager. -func NewManager(log logging.Logger, c ClientConfig, allowedOrigins []string) *Manager { +func NewManager(log logging.Logger, c ClientConfig, allowedOrigins []string, memoryEstimator memory.MemoryEstimator) *Manager { // Create the model distribution client. distributionClient, err := distribution.NewClient( distribution.WithStoreRootPath(c.StoreRootPath), @@ -84,6 +87,7 @@ func NewManager(log logging.Logger, c ClientConfig, allowedOrigins []string) *Ma router: http.NewServeMux(), distributionClient: distributionClient, registryClient: registryClient, + memoryEstimator: memoryEstimator, } // Register routes. @@ -164,6 +168,20 @@ func (m *Manager) handleCreateModel(w http.ResponseWriter, r *http.Request) { // Pull the model. In the future, we may support additional operations here // besides pulling (such as model building). + if !request.IgnoreRuntimeMemoryCheck { + m.log.Infof("Will estimate memory required for %q", request.From) + proceed, err := m.memoryEstimator.HaveSufficientMemoryForModel(r.Context(), request.From, nil) + if err != nil { + m.log.Warnf("Failed to calculate memory required for model %q: %s", request.From, err) + // Prefer staying functional in case of unexpected estimation errors. + proceed = true + } + if !proceed { + m.log.Warnf("Runtime memory requirement for model %q exceeds total system memory", request.From) + http.Error(w, "Runtime memory requirement for model exceeds total system memory", http.StatusInsufficientStorage) + return + } + } if err := m.PullModel(request.From, r, w); err != nil { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { m.log.Infof("Request canceled/timed out while pulling model %q", request.From) @@ -563,6 +581,11 @@ func (m *Manager) ServeHTTP(w http.ResponseWriter, r *http.Request) { m.router.ServeHTTP(w, r) } +// IsModelInStore checks if a given model is in the local store. +func (m *Manager) IsModelInStore(ref string) (bool, error) { + return m.distributionClient.IsModelInStore(ref) +} + // GetModel returns a single model. func (m *Manager) GetModel(ref string) (types.Model, error) { model, err := m.distributionClient.GetModel(ref) @@ -572,6 +595,33 @@ func (m *Manager) GetModel(ref string) (types.Model, error) { return model, err } +// GetRemoteModel returns a single remote model. +func (m *Manager) GetRemoteModel(ctx context.Context, ref string) (types.ModelArtifact, error) { + model, err := m.registryClient.Model(ctx, ref) + if err != nil { + return nil, fmt.Errorf("error while getting remote model: %w", err) + } + return model, nil +} + +// GetRemoteModelBlobURL returns the URL of a given model blob. +func (m *Manager) GetRemoteModelBlobURL(ref string, digest v1.Hash) (string, error) { + blobURL, err := m.registryClient.BlobURL(ref, digest) + if err != nil { + return "", fmt.Errorf("error while getting remote model blob URL: %w", err) + } + return blobURL, nil +} + +// BearerTokenForModel returns the bearer token needed to pull a given model. +func (m *Manager) BearerTokenForModel(ctx context.Context, ref string) (string, error) { + tok, err := m.registryClient.BearerToken(ctx, ref) + if err != nil { + return "", fmt.Errorf("error while getting bearer token for model: %w", err) + } + return tok, nil +} + // GetBundle returns model bundle. func (m *Manager) GetBundle(ref string) (types.ModelBundle, error) { bundle, err := m.distributionClient.GetBundle(ref) diff --git a/pkg/inference/models/manager_test.go b/pkg/inference/models/manager_test.go index 7edd357..8115b3a 100644 --- a/pkg/inference/models/manager_test.go +++ b/pkg/inference/models/manager_test.go @@ -16,10 +16,23 @@ import ( "github.com/docker/model-distribution/builder" reg "github.com/docker/model-distribution/registry" "github.com/docker/model-runner/pkg/inference" + "github.com/docker/model-runner/pkg/inference/memory" "github.com/sirupsen/logrus" ) +type mockMemoryEstimator struct{} + +func (me *mockMemoryEstimator) SetDefaultBackend(_ memory.MemoryEstimatorBackend) {} + +func (me *mockMemoryEstimator) GetRequiredMemoryForModel(_ context.Context, _ string, _ *inference.BackendConfiguration) (*inference.RequiredMemory, error) { + return &inference.RequiredMemory{RAM: 0, VRAM: 0}, nil +} + +func (me *mockMemoryEstimator) HaveSufficientMemoryForModel(_ context.Context, _ string, _ *inference.BackendConfiguration) (bool, error) { + return true, nil +} + // getProjectRoot returns the absolute path to the project root directory func getProjectRoot(t *testing.T) string { // Start from the current test file's directory @@ -109,10 +122,11 @@ func TestPullModel(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { log := logrus.NewEntry(logrus.StandardLogger()) + memEstimator := &mockMemoryEstimator{} m := NewManager(log, ClientConfig{ StoreRootPath: tempDir, Logger: log.WithFields(logrus.Fields{"component": "model-manager"}), - }, nil) + }, nil, memEstimator) r := httptest.NewRequest("POST", "/models/create", strings.NewReader(`{"from": "`+tag+`"}`)) if tt.acceptHeader != "" { @@ -219,12 +233,13 @@ func TestHandleGetModel(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { log := logrus.NewEntry(logrus.StandardLogger()) + memEstimator := &mockMemoryEstimator{} m := NewManager(log, ClientConfig{ StoreRootPath: tempDir, Logger: log.WithFields(logrus.Fields{"component": "model-manager"}), Transport: http.DefaultTransport, UserAgent: "test-agent", - }, nil) + }, nil, memEstimator) // First pull the model if we're testing local access if !tt.remote && !strings.Contains(tt.modelName, "nonexistent") { diff --git a/pkg/inference/scheduling/loader.go b/pkg/inference/scheduling/loader.go index ebbdd33..69166e0 100644 --- a/pkg/inference/scheduling/loader.go +++ b/pkg/inference/scheduling/loader.go @@ -10,12 +10,11 @@ import ( "time" "github.com/docker/model-runner/pkg/environment" - "github.com/docker/model-runner/pkg/gpuinfo" "github.com/docker/model-runner/pkg/inference" + "github.com/docker/model-runner/pkg/inference/memory" "github.com/docker/model-runner/pkg/inference/models" "github.com/docker/model-runner/pkg/logging" "github.com/docker/model-runner/pkg/metrics" - "github.com/elastic/go-sysinfo" ) const ( @@ -113,7 +112,7 @@ func newLoader( backends map[string]inference.Backend, modelManager *models.Manager, openAIRecorder *metrics.OpenAIRecorder, - gpuInfo *gpuinfo.GPUInfo, + sysMemInfo memory.SystemMemoryInfo, ) *loader { // Compute the number of runner slots to allocate. Because of RAM and VRAM // limitations, it's unlikely that we'll ever be able to fully populate @@ -135,32 +134,7 @@ func newLoader( } // Compute the amount of available memory. - // TODO(p1-0tr): improve error handling - vramSize, err := gpuInfo.GetVRAMSize() - if err != nil { - vramSize = 1 - log.Warnf("Could not read VRAM size: %s", err) - } else { - log.Infof("Running on system with %dMB VRAM", vramSize/1024/1024) - } - ramSize := uint64(1) - hostInfo, err := sysinfo.Host() - if err != nil { - log.Warnf("Could not read host info: %s", err) - } else { - ram, err := hostInfo.Memory() - if err != nil { - log.Warnf("Could not read host RAM size: %s", err) - } else { - ramSize = ram.Total - log.Infof("Running on system with %dMB RAM", ramSize/1024/1024) - } - } - - totalMemory := inference.RequiredMemory{ - RAM: ramSize, - VRAM: vramSize, - } + totalMemory := sysMemInfo.GetTotalMemory() // Create the loader. l := &loader{ @@ -420,12 +394,13 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string if rc, ok := l.runnerConfigs[runnerKey{backendName, modelID, mode}]; ok { runnerConfig = &rc } - memory, err := backend.GetRequiredMemoryForModel(modelID, runnerConfig) - if errors.Is(err, inference.ErrGGUFParse) { + memory, err := backend.GetRequiredMemoryForModel(ctx, modelID, runnerConfig) + var parseErr *inference.ErrGGUFParse + if errors.As(err, &parseErr) { // TODO(p1-0tr): For now override memory checks in case model can't be parsed // e.g. model is too new for gguf-parser-go to know. We should provide a cleaner // way to bypass these checks. - l.log.Warnf("Could not parse model(%s), memory checks will be ignored for it.", modelID) + l.log.Warnf("Could not parse model(%s), memory checks will be ignored for it. Error: %s", modelID, parseErr) memory = &inference.RequiredMemory{ RAM: 0, VRAM: 0, @@ -433,7 +408,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string } else if err != nil { return nil, err } - l.log.Infof("Loading %s, which will require %dMB RAM and %dMB VRAM", modelID, memory.RAM/1024/1024, memory.VRAM/1024/1024) + l.log.Infof("Loading %s, which will require %d MB RAM and %d MB VRAM on a system with %d MB RAM and %d MB VRAM", modelID, memory.RAM/1024/1024, memory.VRAM/1024/1024, l.totalMemory.RAM/1024/1024, l.totalMemory.VRAM/1024/1024) if l.totalMemory.RAM == 1 { l.log.Warnf("RAM size unknown. Assume model will fit, but only one.") memory.RAM = 1 diff --git a/pkg/inference/scheduling/scheduler.go b/pkg/inference/scheduling/scheduler.go index 3f716e9..b99db11 100644 --- a/pkg/inference/scheduling/scheduler.go +++ b/pkg/inference/scheduling/scheduler.go @@ -13,8 +13,8 @@ import ( "time" "github.com/docker/model-distribution/distribution" - "github.com/docker/model-runner/pkg/gpuinfo" "github.com/docker/model-runner/pkg/inference" + "github.com/docker/model-runner/pkg/inference/memory" "github.com/docker/model-runner/pkg/inference/models" "github.com/docker/model-runner/pkg/logging" "github.com/docker/model-runner/pkg/metrics" @@ -56,7 +56,7 @@ func NewScheduler( httpClient *http.Client, allowedOrigins []string, tracker *metrics.Tracker, - gpuInfo *gpuinfo.GPUInfo, + sysMemInfo memory.SystemMemoryInfo, ) *Scheduler { openAIRecorder := metrics.NewOpenAIRecorder(log.WithField("component", "openai-recorder"), modelManager) @@ -67,7 +67,7 @@ func NewScheduler( defaultBackend: defaultBackend, modelManager: modelManager, installer: newInstaller(log, backends, httpClient), - loader: newLoader(log, backends, modelManager, openAIRecorder, gpuInfo), + loader: newLoader(log, backends, modelManager, openAIRecorder, sysMemInfo), router: http.NewServeMux(), tracker: tracker, openAIRecorder: openAIRecorder,