Merge remote-tracking branch 'origin/main' into shards

This commit is contained in:
Emily Casey 2025-08-22 09:27:07 -06:00
commit 8d5f251df7
14 changed files with 297 additions and 67 deletions

View File

@ -13,6 +13,9 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4
- name: Verify vendor/ is not present
run: stat vendor && exit 1 || exit 0
- name: Set up Go
uses: actions/setup-go@v5
with:

1
.gitignore vendored
View File

@ -5,3 +5,4 @@ model-runner.sock
models-store/
# Directory where we store the updated llama.cpp
updated-inference/
vendor/

25
main.go
View File

@ -14,6 +14,7 @@ import (
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
"github.com/docker/model-runner/pkg/inference/config"
"github.com/docker/model-runner/pkg/inference/memory"
"github.com/docker/model-runner/pkg/inference/models"
"github.com/docker/model-runner/pkg/inference/scheduling"
"github.com/docker/model-runner/pkg/metrics"
@ -54,6 +55,20 @@ func main() {
llamacpp.SetDesiredServerVersion(desiredSeverVersion)
}
llamaServerPath := os.Getenv("LLAMA_SERVER_PATH")
if llamaServerPath == "" {
llamaServerPath = "/Applications/Docker.app/Contents/Resources/model-runner/bin"
}
gpuInfo := gpuinfo.New(llamaServerPath)
sysMemInfo, err := memory.NewSystemMemoryInfo(log, gpuInfo)
if err != nil {
log.Fatalf("unable to initialize system memory info: %v", err)
}
memEstimator := memory.NewEstimator(sysMemInfo)
modelManager := models.NewManager(
log,
models.ClientConfig{
@ -61,13 +76,9 @@ func main() {
Logger: log.WithFields(logrus.Fields{"component": "model-manager"}),
},
nil,
memEstimator,
)
llamaServerPath := os.Getenv("LLAMA_SERVER_PATH")
if llamaServerPath == "" {
llamaServerPath = "/Applications/Docker.app/Contents/Resources/model-runner/bin"
}
log.Infof("LLAMA_SERVER_PATH: %s", llamaServerPath)
// Create llama.cpp configuration from environment variables
@ -90,7 +101,7 @@ func main() {
log.Fatalf("unable to initialize %s backend: %v", llamacpp.Name, err)
}
gpuInfo := gpuinfo.New(llamaServerPath)
memEstimator.SetDefaultBackend(llamaCppBackend)
scheduler := scheduling.NewScheduler(
log,
@ -105,7 +116,7 @@ func main() {
"",
false,
),
gpuInfo,
sysMemInfo,
)
router := routing.NewNormalizedServeMux()

View File

@ -2,7 +2,6 @@ package inference
import (
"context"
"errors"
"net/http"
)
@ -18,9 +17,13 @@ const (
BackendModeEmbedding
)
var (
ErrGGUFParse = errors.New("failed to parse GGUF file")
)
type ErrGGUFParse struct {
Err error
}
func (e *ErrGGUFParse) Error() string {
return "failed to parse GGUF: " + e.Err.Error()
}
// String implements Stringer.String for BackendMode.
func (m BackendMode) String() string {
@ -88,5 +91,5 @@ type Backend interface {
GetDiskUsage() (int64, error)
// GetRequiredMemoryForModel returns the required working memory for a given
// model.
GetRequiredMemoryForModel(model string, config *BackendConfiguration) (*RequiredMemory, error)
GetRequiredMemoryForModel(ctx context.Context, model string, config *BackendConfiguration) (*RequiredMemory, error)
}

View File

@ -15,6 +15,8 @@ import (
"runtime"
"strings"
"github.com/docker/model-distribution/types"
v1 "github.com/google/go-containerregistry/pkg/v1"
parser "github.com/gpustack/gguf-parser-go"
"github.com/docker/model-runner/pkg/diskusage"
@ -223,23 +225,30 @@ func (l *llamaCpp) GetDiskUsage() (int64, error) {
return size, nil
}
func (l *llamaCpp) GetRequiredMemoryForModel(model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
bundle, err := l.modelManager.GetBundle(model)
func (l *llamaCpp) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
var mdlGguf *parser.GGUFFile
var mdlConfig types.Config
inStore, err := l.modelManager.IsModelInStore(model)
if err != nil {
return nil, fmt.Errorf("getting model(%s): %w", model, err)
return nil, fmt.Errorf("checking if model is in local store: %w", err)
}
if inStore {
mdlGguf, mdlConfig, err = l.parseLocalModel(model)
if err != nil {
return nil, &inference.ErrGGUFParse{Err: err}
}
} else {
mdlGguf, mdlConfig, err = l.parseRemoteModel(ctx, model)
if err != nil {
return nil, &inference.ErrGGUFParse{Err: err}
}
}
mdlGGUF, err := parser.ParseGGUFFile(bundle.GGUFPath())
if err != nil {
l.log.Warnf("Failed to parse gguf(%s): %s", bundle.GGUFPath(), err)
return nil, inference.ErrGGUFParse
}
contextSize := GetContextSize(bundle.RuntimeConfig(), config)
contextSize := GetContextSize(mdlConfig, config)
ngl := uint64(0)
if l.gpuSupported {
if runtime.GOOS == "windows" && runtime.GOARCH == "arm64" && strings.TrimSpace(mdlGGUF.Metadata().FileType.String()) != "Q4_0" {
if runtime.GOOS == "windows" && runtime.GOARCH == "arm64" && mdlConfig.Quantization != "Q4_0" {
ngl = 0 // only Q4_0 models can be accelerated on Adreno
}
ngl = 100
@ -248,7 +257,7 @@ func (l *llamaCpp) GetRequiredMemoryForModel(model string, config *inference.Bac
// TODO(p1-0tr): for now assume we are running on GPU (single one) - Devices[1];
// sum up weights + kv cache + context for an estimate of total GPU memory needed
// while running inference with the given model
estimate := mdlGGUF.EstimateLLaMACppRun(parser.WithLLaMACppContextSize(int32(contextSize)),
estimate := mdlGguf.EstimateLLaMACppRun(parser.WithLLaMACppContextSize(int32(contextSize)),
// TODO(p1-0tr): add logic for resolving other param values, instead of hardcoding them
parser.WithLLaMACppLogicalBatchSize(2048),
parser.WithLLaMACppOffloadLayers(ngl))
@ -270,6 +279,63 @@ func (l *llamaCpp) GetRequiredMemoryForModel(model string, config *inference.Bac
}, nil
}
func (l *llamaCpp) parseLocalModel(model string) (*parser.GGUFFile, types.Config, error) {
bundle, err := l.modelManager.GetBundle(model)
if err != nil {
return nil, types.Config{}, fmt.Errorf("getting model(%s): %w", model, err)
}
modelGGUF, err := parser.ParseGGUFFile(bundle.GGUFPath())
if err != nil {
return nil, types.Config{}, fmt.Errorf("parsing gguf(%s): %w", bundle.GGUFPath(), err)
}
return modelGGUF, bundle.RuntimeConfig(), nil
}
func (l *llamaCpp) parseRemoteModel(ctx context.Context, model string) (*parser.GGUFFile, types.Config, error) {
mdl, err := l.modelManager.GetRemoteModel(ctx, model)
if err != nil {
return nil, types.Config{}, fmt.Errorf("getting remote model(%s): %w", model, err)
}
layers, err := mdl.Layers()
if err != nil {
return nil, types.Config{}, fmt.Errorf("getting layers of model(%s): %w", model, err)
}
var ggufDigest v1.Hash
for _, layer := range layers {
mt, err := layer.MediaType()
if err != nil {
return nil, types.Config{}, fmt.Errorf("getting media type of model(%s) layer: %w", model, err)
}
if mt == types.MediaTypeGGUF {
ggufDigest, err = layer.Digest()
if err != nil {
return nil, types.Config{}, fmt.Errorf("getting digest of GGUF layer for model(%s): %w", model, err)
}
break
}
}
if ggufDigest.String() == "" {
return nil, types.Config{}, fmt.Errorf("model(%s) has no GGUF layer", model)
}
blobURL, err := l.modelManager.GetRemoteModelBlobURL(model, ggufDigest)
if err != nil {
return nil, types.Config{}, fmt.Errorf("getting GGUF blob URL for model(%s): %w", model, err)
}
tok, err := l.modelManager.BearerTokenForModel(ctx, model)
if err != nil {
return nil, types.Config{}, fmt.Errorf("getting bearer token for model(%s): %w", model, err)
}
mdlGguf, err := parser.ParseGGUFFileRemote(ctx, blobURL, parser.UseBearerAuth(tok))
if err != nil {
return nil, types.Config{}, fmt.Errorf("parsing GGUF for model(%s): %w", model, err)
}
config, err := mdl.Config()
if err != nil {
return nil, types.Config{}, fmt.Errorf("getting config for model(%s): %w", model, err)
}
return mdlGguf, config, nil
}
func (l *llamaCpp) checkGPUSupport(ctx context.Context) bool {
binPath := l.vendoredServerStoragePath
if l.updatedLlamaCpp {

View File

@ -63,6 +63,6 @@ func (m *mlx) GetDiskUsage() (int64, error) {
return 0, nil
}
func (m *mlx) GetRequiredMemoryForModel(model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
func (m *mlx) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
return nil, errors.New("not implemented")
}

View File

@ -63,6 +63,6 @@ func (v *vLLM) GetDiskUsage() (int64, error) {
return 0, nil
}
func (v *vLLM) GetRequiredMemoryForModel(model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
func (v *vLLM) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
return nil, errors.New("not implemented")
}

View File

@ -0,0 +1,48 @@
package memory
import (
"context"
"errors"
"fmt"
"github.com/docker/model-runner/pkg/inference"
)
type MemoryEstimator interface {
SetDefaultBackend(MemoryEstimatorBackend)
GetRequiredMemoryForModel(context.Context, string, *inference.BackendConfiguration) (*inference.RequiredMemory, error)
HaveSufficientMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (bool, error)
}
type MemoryEstimatorBackend interface {
GetRequiredMemoryForModel(context.Context, string, *inference.BackendConfiguration) (*inference.RequiredMemory, error)
}
type memoryEstimator struct {
systemMemoryInfo SystemMemoryInfo
defaultBackend MemoryEstimatorBackend
}
func NewEstimator(systemMemoryInfo SystemMemoryInfo) MemoryEstimator {
return &memoryEstimator{systemMemoryInfo: systemMemoryInfo}
}
func (m *memoryEstimator) SetDefaultBackend(backend MemoryEstimatorBackend) {
m.defaultBackend = backend
}
func (m *memoryEstimator) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
if m.defaultBackend == nil {
return nil, errors.New("default backend not configured")
}
return m.defaultBackend.GetRequiredMemoryForModel(ctx, model, config)
}
func (m *memoryEstimator) HaveSufficientMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (bool, error) {
req, err := m.GetRequiredMemoryForModel(ctx, model, config)
if err != nil {
return false, fmt.Errorf("estimating required memory for model: %w", err)
}
return m.systemMemoryInfo.HaveSufficientMemory(*req), nil
}

View File

@ -0,0 +1,55 @@
package memory
import (
"github.com/docker/model-runner/pkg/gpuinfo"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/logging"
"github.com/elastic/go-sysinfo"
)
type SystemMemoryInfo interface {
HaveSufficientMemory(inference.RequiredMemory) bool
GetTotalMemory() inference.RequiredMemory
}
type systemMemoryInfo struct {
log logging.Logger
totalMemory inference.RequiredMemory
}
func NewSystemMemoryInfo(log logging.Logger, gpuInfo *gpuinfo.GPUInfo) (SystemMemoryInfo, error) {
// Compute the amount of available memory.
// TODO(p1-0tr): improve error handling
vramSize, err := gpuInfo.GetVRAMSize()
if err != nil {
vramSize = 1
log.Warnf("Could not read VRAM size: %s", err)
} else {
log.Infof("Running on system with %d MB VRAM", vramSize/1024/1024)
}
ramSize := uint64(1)
hostInfo, err := sysinfo.Host()
if err != nil {
log.Warnf("Could not read host info: %s", err)
} else {
ram, err := hostInfo.Memory()
if err != nil {
log.Warnf("Could not read host RAM size: %s", err)
} else {
ramSize = ram.Total
log.Infof("Running on system with %d MB RAM", ramSize/1024/1024)
}
}
return &systemMemoryInfo{
log: log,
totalMemory: inference.RequiredMemory{RAM: ramSize, VRAM: vramSize},
}, nil
}
func (s *systemMemoryInfo) HaveSufficientMemory(req inference.RequiredMemory) bool {
return req.RAM <= s.totalMemory.RAM && req.VRAM <= s.totalMemory.VRAM
}
func (s *systemMemoryInfo) GetTotalMemory() inference.RequiredMemory {
return s.totalMemory
}

View File

@ -14,6 +14,9 @@ import (
type ModelCreateRequest struct {
// From is the name of the model to pull.
From string `json:"from"`
// IgnoreRuntimeMemoryCheck indicates whether the server should check if it has sufficient
// memory to run the given model (assuming default configuration).
IgnoreRuntimeMemoryCheck bool `json:"ignore-runtime-memory-check,omitempty"`
}
// ToOpenAIList converts the model list to its OpenAI API representation. This function never

View File

@ -15,11 +15,12 @@ import (
"github.com/docker/model-distribution/distribution"
"github.com/docker/model-distribution/registry"
"github.com/docker/model-distribution/types"
"github.com/sirupsen/logrus"
"github.com/docker/model-runner/pkg/diskusage"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/memory"
"github.com/docker/model-runner/pkg/logging"
v1 "github.com/google/go-containerregistry/pkg/v1"
"github.com/sirupsen/logrus"
)
const (
@ -43,6 +44,8 @@ type Manager struct {
registryClient *registry.Client
// lock is used to synchronize access to the models manager's router.
lock sync.RWMutex
// memoryEstimator is used to calculate runtime memory requirements for models.
memoryEstimator memory.MemoryEstimator
}
type ClientConfig struct {
@ -57,7 +60,7 @@ type ClientConfig struct {
}
// NewManager creates a new model's manager.
func NewManager(log logging.Logger, c ClientConfig, allowedOrigins []string) *Manager {
func NewManager(log logging.Logger, c ClientConfig, allowedOrigins []string, memoryEstimator memory.MemoryEstimator) *Manager {
// Create the model distribution client.
distributionClient, err := distribution.NewClient(
distribution.WithStoreRootPath(c.StoreRootPath),
@ -84,6 +87,7 @@ func NewManager(log logging.Logger, c ClientConfig, allowedOrigins []string) *Ma
router: http.NewServeMux(),
distributionClient: distributionClient,
registryClient: registryClient,
memoryEstimator: memoryEstimator,
}
// Register routes.
@ -164,6 +168,20 @@ func (m *Manager) handleCreateModel(w http.ResponseWriter, r *http.Request) {
// Pull the model. In the future, we may support additional operations here
// besides pulling (such as model building).
if !request.IgnoreRuntimeMemoryCheck {
m.log.Infof("Will estimate memory required for %q", request.From)
proceed, err := m.memoryEstimator.HaveSufficientMemoryForModel(r.Context(), request.From, nil)
if err != nil {
m.log.Warnf("Failed to calculate memory required for model %q: %s", request.From, err)
// Prefer staying functional in case of unexpected estimation errors.
proceed = true
}
if !proceed {
m.log.Warnf("Runtime memory requirement for model %q exceeds total system memory", request.From)
http.Error(w, "Runtime memory requirement for model exceeds total system memory", http.StatusInsufficientStorage)
return
}
}
if err := m.PullModel(request.From, r, w); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
m.log.Infof("Request canceled/timed out while pulling model %q", request.From)
@ -563,6 +581,11 @@ func (m *Manager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
m.router.ServeHTTP(w, r)
}
// IsModelInStore checks if a given model is in the local store.
func (m *Manager) IsModelInStore(ref string) (bool, error) {
return m.distributionClient.IsModelInStore(ref)
}
// GetModel returns a single model.
func (m *Manager) GetModel(ref string) (types.Model, error) {
model, err := m.distributionClient.GetModel(ref)
@ -572,6 +595,33 @@ func (m *Manager) GetModel(ref string) (types.Model, error) {
return model, err
}
// GetRemoteModel returns a single remote model.
func (m *Manager) GetRemoteModel(ctx context.Context, ref string) (types.ModelArtifact, error) {
model, err := m.registryClient.Model(ctx, ref)
if err != nil {
return nil, fmt.Errorf("error while getting remote model: %w", err)
}
return model, nil
}
// GetRemoteModelBlobURL returns the URL of a given model blob.
func (m *Manager) GetRemoteModelBlobURL(ref string, digest v1.Hash) (string, error) {
blobURL, err := m.registryClient.BlobURL(ref, digest)
if err != nil {
return "", fmt.Errorf("error while getting remote model blob URL: %w", err)
}
return blobURL, nil
}
// BearerTokenForModel returns the bearer token needed to pull a given model.
func (m *Manager) BearerTokenForModel(ctx context.Context, ref string) (string, error) {
tok, err := m.registryClient.BearerToken(ctx, ref)
if err != nil {
return "", fmt.Errorf("error while getting bearer token for model: %w", err)
}
return tok, nil
}
// GetBundle returns model bundle.
func (m *Manager) GetBundle(ref string) (types.ModelBundle, error) {
bundle, err := m.distributionClient.GetBundle(ref)

View File

@ -16,10 +16,23 @@ import (
"github.com/docker/model-distribution/builder"
reg "github.com/docker/model-distribution/registry"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/memory"
"github.com/sirupsen/logrus"
)
type mockMemoryEstimator struct{}
func (me *mockMemoryEstimator) SetDefaultBackend(_ memory.MemoryEstimatorBackend) {}
func (me *mockMemoryEstimator) GetRequiredMemoryForModel(_ context.Context, _ string, _ *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
return &inference.RequiredMemory{RAM: 0, VRAM: 0}, nil
}
func (me *mockMemoryEstimator) HaveSufficientMemoryForModel(_ context.Context, _ string, _ *inference.BackendConfiguration) (bool, error) {
return true, nil
}
// getProjectRoot returns the absolute path to the project root directory
func getProjectRoot(t *testing.T) string {
// Start from the current test file's directory
@ -109,10 +122,11 @@ func TestPullModel(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
log := logrus.NewEntry(logrus.StandardLogger())
memEstimator := &mockMemoryEstimator{}
m := NewManager(log, ClientConfig{
StoreRootPath: tempDir,
Logger: log.WithFields(logrus.Fields{"component": "model-manager"}),
}, nil)
}, nil, memEstimator)
r := httptest.NewRequest("POST", "/models/create", strings.NewReader(`{"from": "`+tag+`"}`))
if tt.acceptHeader != "" {
@ -219,12 +233,13 @@ func TestHandleGetModel(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
log := logrus.NewEntry(logrus.StandardLogger())
memEstimator := &mockMemoryEstimator{}
m := NewManager(log, ClientConfig{
StoreRootPath: tempDir,
Logger: log.WithFields(logrus.Fields{"component": "model-manager"}),
Transport: http.DefaultTransport,
UserAgent: "test-agent",
}, nil)
}, nil, memEstimator)
// First pull the model if we're testing local access
if !tt.remote && !strings.Contains(tt.modelName, "nonexistent") {

View File

@ -10,12 +10,11 @@ import (
"time"
"github.com/docker/model-runner/pkg/environment"
"github.com/docker/model-runner/pkg/gpuinfo"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/memory"
"github.com/docker/model-runner/pkg/inference/models"
"github.com/docker/model-runner/pkg/logging"
"github.com/docker/model-runner/pkg/metrics"
"github.com/elastic/go-sysinfo"
)
const (
@ -113,7 +112,7 @@ func newLoader(
backends map[string]inference.Backend,
modelManager *models.Manager,
openAIRecorder *metrics.OpenAIRecorder,
gpuInfo *gpuinfo.GPUInfo,
sysMemInfo memory.SystemMemoryInfo,
) *loader {
// Compute the number of runner slots to allocate. Because of RAM and VRAM
// limitations, it's unlikely that we'll ever be able to fully populate
@ -135,32 +134,7 @@ func newLoader(
}
// Compute the amount of available memory.
// TODO(p1-0tr): improve error handling
vramSize, err := gpuInfo.GetVRAMSize()
if err != nil {
vramSize = 1
log.Warnf("Could not read VRAM size: %s", err)
} else {
log.Infof("Running on system with %dMB VRAM", vramSize/1024/1024)
}
ramSize := uint64(1)
hostInfo, err := sysinfo.Host()
if err != nil {
log.Warnf("Could not read host info: %s", err)
} else {
ram, err := hostInfo.Memory()
if err != nil {
log.Warnf("Could not read host RAM size: %s", err)
} else {
ramSize = ram.Total
log.Infof("Running on system with %dMB RAM", ramSize/1024/1024)
}
}
totalMemory := inference.RequiredMemory{
RAM: ramSize,
VRAM: vramSize,
}
totalMemory := sysMemInfo.GetTotalMemory()
// Create the loader.
l := &loader{
@ -420,12 +394,13 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string
if rc, ok := l.runnerConfigs[runnerKey{backendName, modelID, mode}]; ok {
runnerConfig = &rc
}
memory, err := backend.GetRequiredMemoryForModel(modelID, runnerConfig)
if errors.Is(err, inference.ErrGGUFParse) {
memory, err := backend.GetRequiredMemoryForModel(ctx, modelID, runnerConfig)
var parseErr *inference.ErrGGUFParse
if errors.As(err, &parseErr) {
// TODO(p1-0tr): For now override memory checks in case model can't be parsed
// e.g. model is too new for gguf-parser-go to know. We should provide a cleaner
// way to bypass these checks.
l.log.Warnf("Could not parse model(%s), memory checks will be ignored for it.", modelID)
l.log.Warnf("Could not parse model(%s), memory checks will be ignored for it. Error: %s", modelID, parseErr)
memory = &inference.RequiredMemory{
RAM: 0,
VRAM: 0,
@ -433,7 +408,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string
} else if err != nil {
return nil, err
}
l.log.Infof("Loading %s, which will require %dMB RAM and %dMB VRAM", modelID, memory.RAM/1024/1024, memory.VRAM/1024/1024)
l.log.Infof("Loading %s, which will require %d MB RAM and %d MB VRAM on a system with %d MB RAM and %d MB VRAM", modelID, memory.RAM/1024/1024, memory.VRAM/1024/1024, l.totalMemory.RAM/1024/1024, l.totalMemory.VRAM/1024/1024)
if l.totalMemory.RAM == 1 {
l.log.Warnf("RAM size unknown. Assume model will fit, but only one.")
memory.RAM = 1

View File

@ -13,8 +13,8 @@ import (
"time"
"github.com/docker/model-distribution/distribution"
"github.com/docker/model-runner/pkg/gpuinfo"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/memory"
"github.com/docker/model-runner/pkg/inference/models"
"github.com/docker/model-runner/pkg/logging"
"github.com/docker/model-runner/pkg/metrics"
@ -56,7 +56,7 @@ func NewScheduler(
httpClient *http.Client,
allowedOrigins []string,
tracker *metrics.Tracker,
gpuInfo *gpuinfo.GPUInfo,
sysMemInfo memory.SystemMemoryInfo,
) *Scheduler {
openAIRecorder := metrics.NewOpenAIRecorder(log.WithField("component", "openai-recorder"), modelManager)
@ -67,7 +67,7 @@ func NewScheduler(
defaultBackend: defaultBackend,
modelManager: modelManager,
installer: newInstaller(log, backends, httpClient),
loader: newLoader(log, backends, modelManager, openAIRecorder, gpuInfo),
loader: newLoader(log, backends, modelManager, openAIRecorder, sysMemInfo),
router: http.NewServeMux(),
tracker: tracker,
openAIRecorder: openAIRecorder,