inference: Use common system memory size getter in the loader

Signed-off-by: Piotr Stankiewicz <piotr.stankiewicz@docker.com>
This commit is contained in:
Piotr Stankiewicz 2025-08-22 15:05:51 +02:00 committed by Piotr
parent 03f7adc077
commit d8ed374455
4 changed files with 13 additions and 34 deletions

View File

@ -116,7 +116,7 @@ func main() {
"",
false,
),
gpuInfo,
sysMemInfo,
)
router := routing.NewNormalizedServeMux()

View File

@ -9,6 +9,7 @@ import (
type SystemMemoryInfo interface {
HaveSufficientMemory(inference.RequiredMemory) bool
GetTotalMemory() inference.RequiredMemory
}
type systemMemoryInfo struct {
@ -48,3 +49,7 @@ func NewSystemMemoryInfo(log logging.Logger, gpuInfo *gpuinfo.GPUInfo) (SystemMe
func (s *systemMemoryInfo) HaveSufficientMemory(req inference.RequiredMemory) bool {
return req.RAM <= s.totalMemory.RAM && req.VRAM <= s.totalMemory.VRAM
}
func (s *systemMemoryInfo) GetTotalMemory() inference.RequiredMemory {
return s.totalMemory
}

View File

@ -10,12 +10,11 @@ import (
"time"
"github.com/docker/model-runner/pkg/environment"
"github.com/docker/model-runner/pkg/gpuinfo"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/memory"
"github.com/docker/model-runner/pkg/inference/models"
"github.com/docker/model-runner/pkg/logging"
"github.com/docker/model-runner/pkg/metrics"
"github.com/elastic/go-sysinfo"
)
const (
@ -113,7 +112,7 @@ func newLoader(
backends map[string]inference.Backend,
modelManager *models.Manager,
openAIRecorder *metrics.OpenAIRecorder,
gpuInfo *gpuinfo.GPUInfo,
sysMemInfo memory.SystemMemoryInfo,
) *loader {
// Compute the number of runner slots to allocate. Because of RAM and VRAM
// limitations, it's unlikely that we'll ever be able to fully populate
@ -135,32 +134,7 @@ func newLoader(
}
// Compute the amount of available memory.
// TODO(p1-0tr): improve error handling
vramSize, err := gpuInfo.GetVRAMSize()
if err != nil {
vramSize = 1
log.Warnf("Could not read VRAM size: %s", err)
} else {
log.Infof("Running on system with %dMB VRAM", vramSize/1024/1024)
}
ramSize := uint64(1)
hostInfo, err := sysinfo.Host()
if err != nil {
log.Warnf("Could not read host info: %s", err)
} else {
ram, err := hostInfo.Memory()
if err != nil {
log.Warnf("Could not read host RAM size: %s", err)
} else {
ramSize = ram.Total
log.Infof("Running on system with %dMB RAM", ramSize/1024/1024)
}
}
totalMemory := inference.RequiredMemory{
RAM: ramSize,
VRAM: vramSize,
}
totalMemory := sysMemInfo.GetTotalMemory()
// Create the loader.
l := &loader{
@ -434,7 +408,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string
} else if err != nil {
return nil, err
}
l.log.Infof("Loading %s, which will require %dMB RAM and %dMB VRAM", modelID, memory.RAM/1024/1024, memory.VRAM/1024/1024)
l.log.Infof("Loading %s, which will require %d MB RAM and %d MB VRAM on a system with %d MB RAM and %d MB VRAM", modelID, memory.RAM/1024/1024, memory.VRAM/1024/1024, l.totalMemory.RAM/1024/1024, l.totalMemory.VRAM/1024/1024)
if l.totalMemory.RAM == 1 {
l.log.Warnf("RAM size unknown. Assume model will fit, but only one.")
memory.RAM = 1

View File

@ -13,8 +13,8 @@ import (
"time"
"github.com/docker/model-distribution/distribution"
"github.com/docker/model-runner/pkg/gpuinfo"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/memory"
"github.com/docker/model-runner/pkg/inference/models"
"github.com/docker/model-runner/pkg/logging"
"github.com/docker/model-runner/pkg/metrics"
@ -56,7 +56,7 @@ func NewScheduler(
httpClient *http.Client,
allowedOrigins []string,
tracker *metrics.Tracker,
gpuInfo *gpuinfo.GPUInfo,
sysMemInfo memory.SystemMemoryInfo,
) *Scheduler {
openAIRecorder := metrics.NewOpenAIRecorder(log.WithField("component", "openai-recorder"), modelManager)
@ -67,7 +67,7 @@ func NewScheduler(
defaultBackend: defaultBackend,
modelManager: modelManager,
installer: newInstaller(log, backends, httpClient),
loader: newLoader(log, backends, modelManager, openAIRecorder, gpuInfo),
loader: newLoader(log, backends, modelManager, openAIRecorder, sysMemInfo),
router: http.NewServeMux(),
tracker: tracker,
openAIRecorder: openAIRecorder,