Merge remote-tracking branch 'origin/main' into shards
This commit is contained in:
commit
8d5f251df7
|
|
@ -13,6 +13,9 @@ jobs:
|
|||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Verify vendor/ is not present
|
||||
run: stat vendor && exit 1 || exit 0
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
|
|
|
|||
|
|
@ -5,3 +5,4 @@ model-runner.sock
|
|||
models-store/
|
||||
# Directory where we store the updated llama.cpp
|
||||
updated-inference/
|
||||
vendor/
|
||||
|
|
|
|||
25
main.go
25
main.go
|
|
@ -14,6 +14,7 @@ import (
|
|||
"github.com/docker/model-runner/pkg/inference"
|
||||
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
|
||||
"github.com/docker/model-runner/pkg/inference/config"
|
||||
"github.com/docker/model-runner/pkg/inference/memory"
|
||||
"github.com/docker/model-runner/pkg/inference/models"
|
||||
"github.com/docker/model-runner/pkg/inference/scheduling"
|
||||
"github.com/docker/model-runner/pkg/metrics"
|
||||
|
|
@ -54,6 +55,20 @@ func main() {
|
|||
llamacpp.SetDesiredServerVersion(desiredSeverVersion)
|
||||
}
|
||||
|
||||
llamaServerPath := os.Getenv("LLAMA_SERVER_PATH")
|
||||
if llamaServerPath == "" {
|
||||
llamaServerPath = "/Applications/Docker.app/Contents/Resources/model-runner/bin"
|
||||
}
|
||||
|
||||
gpuInfo := gpuinfo.New(llamaServerPath)
|
||||
|
||||
sysMemInfo, err := memory.NewSystemMemoryInfo(log, gpuInfo)
|
||||
if err != nil {
|
||||
log.Fatalf("unable to initialize system memory info: %v", err)
|
||||
}
|
||||
|
||||
memEstimator := memory.NewEstimator(sysMemInfo)
|
||||
|
||||
modelManager := models.NewManager(
|
||||
log,
|
||||
models.ClientConfig{
|
||||
|
|
@ -61,13 +76,9 @@ func main() {
|
|||
Logger: log.WithFields(logrus.Fields{"component": "model-manager"}),
|
||||
},
|
||||
nil,
|
||||
memEstimator,
|
||||
)
|
||||
|
||||
llamaServerPath := os.Getenv("LLAMA_SERVER_PATH")
|
||||
if llamaServerPath == "" {
|
||||
llamaServerPath = "/Applications/Docker.app/Contents/Resources/model-runner/bin"
|
||||
}
|
||||
|
||||
log.Infof("LLAMA_SERVER_PATH: %s", llamaServerPath)
|
||||
|
||||
// Create llama.cpp configuration from environment variables
|
||||
|
|
@ -90,7 +101,7 @@ func main() {
|
|||
log.Fatalf("unable to initialize %s backend: %v", llamacpp.Name, err)
|
||||
}
|
||||
|
||||
gpuInfo := gpuinfo.New(llamaServerPath)
|
||||
memEstimator.SetDefaultBackend(llamaCppBackend)
|
||||
|
||||
scheduler := scheduling.NewScheduler(
|
||||
log,
|
||||
|
|
@ -105,7 +116,7 @@ func main() {
|
|||
"",
|
||||
false,
|
||||
),
|
||||
gpuInfo,
|
||||
sysMemInfo,
|
||||
)
|
||||
|
||||
router := routing.NewNormalizedServeMux()
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package inference
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
|
|
@ -18,9 +17,13 @@ const (
|
|||
BackendModeEmbedding
|
||||
)
|
||||
|
||||
var (
|
||||
ErrGGUFParse = errors.New("failed to parse GGUF file")
|
||||
)
|
||||
type ErrGGUFParse struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *ErrGGUFParse) Error() string {
|
||||
return "failed to parse GGUF: " + e.Err.Error()
|
||||
}
|
||||
|
||||
// String implements Stringer.String for BackendMode.
|
||||
func (m BackendMode) String() string {
|
||||
|
|
@ -88,5 +91,5 @@ type Backend interface {
|
|||
GetDiskUsage() (int64, error)
|
||||
// GetRequiredMemoryForModel returns the required working memory for a given
|
||||
// model.
|
||||
GetRequiredMemoryForModel(model string, config *BackendConfiguration) (*RequiredMemory, error)
|
||||
GetRequiredMemoryForModel(ctx context.Context, model string, config *BackendConfiguration) (*RequiredMemory, error)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ import (
|
|||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/docker/model-distribution/types"
|
||||
v1 "github.com/google/go-containerregistry/pkg/v1"
|
||||
parser "github.com/gpustack/gguf-parser-go"
|
||||
|
||||
"github.com/docker/model-runner/pkg/diskusage"
|
||||
|
|
@ -223,23 +225,30 @@ func (l *llamaCpp) GetDiskUsage() (int64, error) {
|
|||
return size, nil
|
||||
}
|
||||
|
||||
func (l *llamaCpp) GetRequiredMemoryForModel(model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
|
||||
bundle, err := l.modelManager.GetBundle(model)
|
||||
func (l *llamaCpp) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
|
||||
var mdlGguf *parser.GGUFFile
|
||||
var mdlConfig types.Config
|
||||
inStore, err := l.modelManager.IsModelInStore(model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting model(%s): %w", model, err)
|
||||
return nil, fmt.Errorf("checking if model is in local store: %w", err)
|
||||
}
|
||||
if inStore {
|
||||
mdlGguf, mdlConfig, err = l.parseLocalModel(model)
|
||||
if err != nil {
|
||||
return nil, &inference.ErrGGUFParse{Err: err}
|
||||
}
|
||||
} else {
|
||||
mdlGguf, mdlConfig, err = l.parseRemoteModel(ctx, model)
|
||||
if err != nil {
|
||||
return nil, &inference.ErrGGUFParse{Err: err}
|
||||
}
|
||||
}
|
||||
|
||||
mdlGGUF, err := parser.ParseGGUFFile(bundle.GGUFPath())
|
||||
if err != nil {
|
||||
l.log.Warnf("Failed to parse gguf(%s): %s", bundle.GGUFPath(), err)
|
||||
return nil, inference.ErrGGUFParse
|
||||
}
|
||||
|
||||
contextSize := GetContextSize(bundle.RuntimeConfig(), config)
|
||||
contextSize := GetContextSize(mdlConfig, config)
|
||||
|
||||
ngl := uint64(0)
|
||||
if l.gpuSupported {
|
||||
if runtime.GOOS == "windows" && runtime.GOARCH == "arm64" && strings.TrimSpace(mdlGGUF.Metadata().FileType.String()) != "Q4_0" {
|
||||
if runtime.GOOS == "windows" && runtime.GOARCH == "arm64" && mdlConfig.Quantization != "Q4_0" {
|
||||
ngl = 0 // only Q4_0 models can be accelerated on Adreno
|
||||
}
|
||||
ngl = 100
|
||||
|
|
@ -248,7 +257,7 @@ func (l *llamaCpp) GetRequiredMemoryForModel(model string, config *inference.Bac
|
|||
// TODO(p1-0tr): for now assume we are running on GPU (single one) - Devices[1];
|
||||
// sum up weights + kv cache + context for an estimate of total GPU memory needed
|
||||
// while running inference with the given model
|
||||
estimate := mdlGGUF.EstimateLLaMACppRun(parser.WithLLaMACppContextSize(int32(contextSize)),
|
||||
estimate := mdlGguf.EstimateLLaMACppRun(parser.WithLLaMACppContextSize(int32(contextSize)),
|
||||
// TODO(p1-0tr): add logic for resolving other param values, instead of hardcoding them
|
||||
parser.WithLLaMACppLogicalBatchSize(2048),
|
||||
parser.WithLLaMACppOffloadLayers(ngl))
|
||||
|
|
@ -270,6 +279,63 @@ func (l *llamaCpp) GetRequiredMemoryForModel(model string, config *inference.Bac
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (l *llamaCpp) parseLocalModel(model string) (*parser.GGUFFile, types.Config, error) {
|
||||
bundle, err := l.modelManager.GetBundle(model)
|
||||
if err != nil {
|
||||
return nil, types.Config{}, fmt.Errorf("getting model(%s): %w", model, err)
|
||||
}
|
||||
modelGGUF, err := parser.ParseGGUFFile(bundle.GGUFPath())
|
||||
if err != nil {
|
||||
return nil, types.Config{}, fmt.Errorf("parsing gguf(%s): %w", bundle.GGUFPath(), err)
|
||||
}
|
||||
return modelGGUF, bundle.RuntimeConfig(), nil
|
||||
}
|
||||
|
||||
func (l *llamaCpp) parseRemoteModel(ctx context.Context, model string) (*parser.GGUFFile, types.Config, error) {
|
||||
mdl, err := l.modelManager.GetRemoteModel(ctx, model)
|
||||
if err != nil {
|
||||
return nil, types.Config{}, fmt.Errorf("getting remote model(%s): %w", model, err)
|
||||
}
|
||||
layers, err := mdl.Layers()
|
||||
if err != nil {
|
||||
return nil, types.Config{}, fmt.Errorf("getting layers of model(%s): %w", model, err)
|
||||
}
|
||||
var ggufDigest v1.Hash
|
||||
for _, layer := range layers {
|
||||
mt, err := layer.MediaType()
|
||||
if err != nil {
|
||||
return nil, types.Config{}, fmt.Errorf("getting media type of model(%s) layer: %w", model, err)
|
||||
}
|
||||
if mt == types.MediaTypeGGUF {
|
||||
ggufDigest, err = layer.Digest()
|
||||
if err != nil {
|
||||
return nil, types.Config{}, fmt.Errorf("getting digest of GGUF layer for model(%s): %w", model, err)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if ggufDigest.String() == "" {
|
||||
return nil, types.Config{}, fmt.Errorf("model(%s) has no GGUF layer", model)
|
||||
}
|
||||
blobURL, err := l.modelManager.GetRemoteModelBlobURL(model, ggufDigest)
|
||||
if err != nil {
|
||||
return nil, types.Config{}, fmt.Errorf("getting GGUF blob URL for model(%s): %w", model, err)
|
||||
}
|
||||
tok, err := l.modelManager.BearerTokenForModel(ctx, model)
|
||||
if err != nil {
|
||||
return nil, types.Config{}, fmt.Errorf("getting bearer token for model(%s): %w", model, err)
|
||||
}
|
||||
mdlGguf, err := parser.ParseGGUFFileRemote(ctx, blobURL, parser.UseBearerAuth(tok))
|
||||
if err != nil {
|
||||
return nil, types.Config{}, fmt.Errorf("parsing GGUF for model(%s): %w", model, err)
|
||||
}
|
||||
config, err := mdl.Config()
|
||||
if err != nil {
|
||||
return nil, types.Config{}, fmt.Errorf("getting config for model(%s): %w", model, err)
|
||||
}
|
||||
return mdlGguf, config, nil
|
||||
}
|
||||
|
||||
func (l *llamaCpp) checkGPUSupport(ctx context.Context) bool {
|
||||
binPath := l.vendoredServerStoragePath
|
||||
if l.updatedLlamaCpp {
|
||||
|
|
|
|||
|
|
@ -63,6 +63,6 @@ func (m *mlx) GetDiskUsage() (int64, error) {
|
|||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mlx) GetRequiredMemoryForModel(model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
|
||||
func (m *mlx) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -63,6 +63,6 @@ func (v *vLLM) GetDiskUsage() (int64, error) {
|
|||
return 0, nil
|
||||
}
|
||||
|
||||
func (v *vLLM) GetRequiredMemoryForModel(model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
|
||||
func (v *vLLM) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,48 @@
|
|||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/docker/model-runner/pkg/inference"
|
||||
)
|
||||
|
||||
type MemoryEstimator interface {
|
||||
SetDefaultBackend(MemoryEstimatorBackend)
|
||||
GetRequiredMemoryForModel(context.Context, string, *inference.BackendConfiguration) (*inference.RequiredMemory, error)
|
||||
HaveSufficientMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (bool, error)
|
||||
}
|
||||
|
||||
type MemoryEstimatorBackend interface {
|
||||
GetRequiredMemoryForModel(context.Context, string, *inference.BackendConfiguration) (*inference.RequiredMemory, error)
|
||||
}
|
||||
|
||||
type memoryEstimator struct {
|
||||
systemMemoryInfo SystemMemoryInfo
|
||||
defaultBackend MemoryEstimatorBackend
|
||||
}
|
||||
|
||||
func NewEstimator(systemMemoryInfo SystemMemoryInfo) MemoryEstimator {
|
||||
return &memoryEstimator{systemMemoryInfo: systemMemoryInfo}
|
||||
}
|
||||
|
||||
func (m *memoryEstimator) SetDefaultBackend(backend MemoryEstimatorBackend) {
|
||||
m.defaultBackend = backend
|
||||
}
|
||||
|
||||
func (m *memoryEstimator) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
|
||||
if m.defaultBackend == nil {
|
||||
return nil, errors.New("default backend not configured")
|
||||
}
|
||||
|
||||
return m.defaultBackend.GetRequiredMemoryForModel(ctx, model, config)
|
||||
}
|
||||
|
||||
func (m *memoryEstimator) HaveSufficientMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (bool, error) {
|
||||
req, err := m.GetRequiredMemoryForModel(ctx, model, config)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("estimating required memory for model: %w", err)
|
||||
}
|
||||
return m.systemMemoryInfo.HaveSufficientMemory(*req), nil
|
||||
}
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
package memory
|
||||
|
||||
import (
|
||||
"github.com/docker/model-runner/pkg/gpuinfo"
|
||||
"github.com/docker/model-runner/pkg/inference"
|
||||
"github.com/docker/model-runner/pkg/logging"
|
||||
"github.com/elastic/go-sysinfo"
|
||||
)
|
||||
|
||||
type SystemMemoryInfo interface {
|
||||
HaveSufficientMemory(inference.RequiredMemory) bool
|
||||
GetTotalMemory() inference.RequiredMemory
|
||||
}
|
||||
|
||||
type systemMemoryInfo struct {
|
||||
log logging.Logger
|
||||
totalMemory inference.RequiredMemory
|
||||
}
|
||||
|
||||
func NewSystemMemoryInfo(log logging.Logger, gpuInfo *gpuinfo.GPUInfo) (SystemMemoryInfo, error) {
|
||||
// Compute the amount of available memory.
|
||||
// TODO(p1-0tr): improve error handling
|
||||
vramSize, err := gpuInfo.GetVRAMSize()
|
||||
if err != nil {
|
||||
vramSize = 1
|
||||
log.Warnf("Could not read VRAM size: %s", err)
|
||||
} else {
|
||||
log.Infof("Running on system with %d MB VRAM", vramSize/1024/1024)
|
||||
}
|
||||
ramSize := uint64(1)
|
||||
hostInfo, err := sysinfo.Host()
|
||||
if err != nil {
|
||||
log.Warnf("Could not read host info: %s", err)
|
||||
} else {
|
||||
ram, err := hostInfo.Memory()
|
||||
if err != nil {
|
||||
log.Warnf("Could not read host RAM size: %s", err)
|
||||
} else {
|
||||
ramSize = ram.Total
|
||||
log.Infof("Running on system with %d MB RAM", ramSize/1024/1024)
|
||||
}
|
||||
}
|
||||
return &systemMemoryInfo{
|
||||
log: log,
|
||||
totalMemory: inference.RequiredMemory{RAM: ramSize, VRAM: vramSize},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *systemMemoryInfo) HaveSufficientMemory(req inference.RequiredMemory) bool {
|
||||
return req.RAM <= s.totalMemory.RAM && req.VRAM <= s.totalMemory.VRAM
|
||||
}
|
||||
|
||||
func (s *systemMemoryInfo) GetTotalMemory() inference.RequiredMemory {
|
||||
return s.totalMemory
|
||||
}
|
||||
|
|
@ -14,6 +14,9 @@ import (
|
|||
type ModelCreateRequest struct {
|
||||
// From is the name of the model to pull.
|
||||
From string `json:"from"`
|
||||
// IgnoreRuntimeMemoryCheck indicates whether the server should check if it has sufficient
|
||||
// memory to run the given model (assuming default configuration).
|
||||
IgnoreRuntimeMemoryCheck bool `json:"ignore-runtime-memory-check,omitempty"`
|
||||
}
|
||||
|
||||
// ToOpenAIList converts the model list to its OpenAI API representation. This function never
|
||||
|
|
|
|||
|
|
@ -15,11 +15,12 @@ import (
|
|||
"github.com/docker/model-distribution/distribution"
|
||||
"github.com/docker/model-distribution/registry"
|
||||
"github.com/docker/model-distribution/types"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/docker/model-runner/pkg/diskusage"
|
||||
"github.com/docker/model-runner/pkg/inference"
|
||||
"github.com/docker/model-runner/pkg/inference/memory"
|
||||
"github.com/docker/model-runner/pkg/logging"
|
||||
v1 "github.com/google/go-containerregistry/pkg/v1"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -43,6 +44,8 @@ type Manager struct {
|
|||
registryClient *registry.Client
|
||||
// lock is used to synchronize access to the models manager's router.
|
||||
lock sync.RWMutex
|
||||
// memoryEstimator is used to calculate runtime memory requirements for models.
|
||||
memoryEstimator memory.MemoryEstimator
|
||||
}
|
||||
|
||||
type ClientConfig struct {
|
||||
|
|
@ -57,7 +60,7 @@ type ClientConfig struct {
|
|||
}
|
||||
|
||||
// NewManager creates a new model's manager.
|
||||
func NewManager(log logging.Logger, c ClientConfig, allowedOrigins []string) *Manager {
|
||||
func NewManager(log logging.Logger, c ClientConfig, allowedOrigins []string, memoryEstimator memory.MemoryEstimator) *Manager {
|
||||
// Create the model distribution client.
|
||||
distributionClient, err := distribution.NewClient(
|
||||
distribution.WithStoreRootPath(c.StoreRootPath),
|
||||
|
|
@ -84,6 +87,7 @@ func NewManager(log logging.Logger, c ClientConfig, allowedOrigins []string) *Ma
|
|||
router: http.NewServeMux(),
|
||||
distributionClient: distributionClient,
|
||||
registryClient: registryClient,
|
||||
memoryEstimator: memoryEstimator,
|
||||
}
|
||||
|
||||
// Register routes.
|
||||
|
|
@ -164,6 +168,20 @@ func (m *Manager) handleCreateModel(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// Pull the model. In the future, we may support additional operations here
|
||||
// besides pulling (such as model building).
|
||||
if !request.IgnoreRuntimeMemoryCheck {
|
||||
m.log.Infof("Will estimate memory required for %q", request.From)
|
||||
proceed, err := m.memoryEstimator.HaveSufficientMemoryForModel(r.Context(), request.From, nil)
|
||||
if err != nil {
|
||||
m.log.Warnf("Failed to calculate memory required for model %q: %s", request.From, err)
|
||||
// Prefer staying functional in case of unexpected estimation errors.
|
||||
proceed = true
|
||||
}
|
||||
if !proceed {
|
||||
m.log.Warnf("Runtime memory requirement for model %q exceeds total system memory", request.From)
|
||||
http.Error(w, "Runtime memory requirement for model exceeds total system memory", http.StatusInsufficientStorage)
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := m.PullModel(request.From, r, w); err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
m.log.Infof("Request canceled/timed out while pulling model %q", request.From)
|
||||
|
|
@ -563,6 +581,11 @@ func (m *Manager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
m.router.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// IsModelInStore checks if a given model is in the local store.
|
||||
func (m *Manager) IsModelInStore(ref string) (bool, error) {
|
||||
return m.distributionClient.IsModelInStore(ref)
|
||||
}
|
||||
|
||||
// GetModel returns a single model.
|
||||
func (m *Manager) GetModel(ref string) (types.Model, error) {
|
||||
model, err := m.distributionClient.GetModel(ref)
|
||||
|
|
@ -572,6 +595,33 @@ func (m *Manager) GetModel(ref string) (types.Model, error) {
|
|||
return model, err
|
||||
}
|
||||
|
||||
// GetRemoteModel returns a single remote model.
|
||||
func (m *Manager) GetRemoteModel(ctx context.Context, ref string) (types.ModelArtifact, error) {
|
||||
model, err := m.registryClient.Model(ctx, ref)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error while getting remote model: %w", err)
|
||||
}
|
||||
return model, nil
|
||||
}
|
||||
|
||||
// GetRemoteModelBlobURL returns the URL of a given model blob.
|
||||
func (m *Manager) GetRemoteModelBlobURL(ref string, digest v1.Hash) (string, error) {
|
||||
blobURL, err := m.registryClient.BlobURL(ref, digest)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error while getting remote model blob URL: %w", err)
|
||||
}
|
||||
return blobURL, nil
|
||||
}
|
||||
|
||||
// BearerTokenForModel returns the bearer token needed to pull a given model.
|
||||
func (m *Manager) BearerTokenForModel(ctx context.Context, ref string) (string, error) {
|
||||
tok, err := m.registryClient.BearerToken(ctx, ref)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error while getting bearer token for model: %w", err)
|
||||
}
|
||||
return tok, nil
|
||||
}
|
||||
|
||||
// GetBundle returns model bundle.
|
||||
func (m *Manager) GetBundle(ref string) (types.ModelBundle, error) {
|
||||
bundle, err := m.distributionClient.GetBundle(ref)
|
||||
|
|
|
|||
|
|
@ -16,10 +16,23 @@ import (
|
|||
"github.com/docker/model-distribution/builder"
|
||||
reg "github.com/docker/model-distribution/registry"
|
||||
"github.com/docker/model-runner/pkg/inference"
|
||||
"github.com/docker/model-runner/pkg/inference/memory"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type mockMemoryEstimator struct{}
|
||||
|
||||
func (me *mockMemoryEstimator) SetDefaultBackend(_ memory.MemoryEstimatorBackend) {}
|
||||
|
||||
func (me *mockMemoryEstimator) GetRequiredMemoryForModel(_ context.Context, _ string, _ *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
|
||||
return &inference.RequiredMemory{RAM: 0, VRAM: 0}, nil
|
||||
}
|
||||
|
||||
func (me *mockMemoryEstimator) HaveSufficientMemoryForModel(_ context.Context, _ string, _ *inference.BackendConfiguration) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// getProjectRoot returns the absolute path to the project root directory
|
||||
func getProjectRoot(t *testing.T) string {
|
||||
// Start from the current test file's directory
|
||||
|
|
@ -109,10 +122,11 @@ func TestPullModel(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
log := logrus.NewEntry(logrus.StandardLogger())
|
||||
memEstimator := &mockMemoryEstimator{}
|
||||
m := NewManager(log, ClientConfig{
|
||||
StoreRootPath: tempDir,
|
||||
Logger: log.WithFields(logrus.Fields{"component": "model-manager"}),
|
||||
}, nil)
|
||||
}, nil, memEstimator)
|
||||
|
||||
r := httptest.NewRequest("POST", "/models/create", strings.NewReader(`{"from": "`+tag+`"}`))
|
||||
if tt.acceptHeader != "" {
|
||||
|
|
@ -219,12 +233,13 @@ func TestHandleGetModel(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
log := logrus.NewEntry(logrus.StandardLogger())
|
||||
memEstimator := &mockMemoryEstimator{}
|
||||
m := NewManager(log, ClientConfig{
|
||||
StoreRootPath: tempDir,
|
||||
Logger: log.WithFields(logrus.Fields{"component": "model-manager"}),
|
||||
Transport: http.DefaultTransport,
|
||||
UserAgent: "test-agent",
|
||||
}, nil)
|
||||
}, nil, memEstimator)
|
||||
|
||||
// First pull the model if we're testing local access
|
||||
if !tt.remote && !strings.Contains(tt.modelName, "nonexistent") {
|
||||
|
|
|
|||
|
|
@ -10,12 +10,11 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/docker/model-runner/pkg/environment"
|
||||
"github.com/docker/model-runner/pkg/gpuinfo"
|
||||
"github.com/docker/model-runner/pkg/inference"
|
||||
"github.com/docker/model-runner/pkg/inference/memory"
|
||||
"github.com/docker/model-runner/pkg/inference/models"
|
||||
"github.com/docker/model-runner/pkg/logging"
|
||||
"github.com/docker/model-runner/pkg/metrics"
|
||||
"github.com/elastic/go-sysinfo"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -113,7 +112,7 @@ func newLoader(
|
|||
backends map[string]inference.Backend,
|
||||
modelManager *models.Manager,
|
||||
openAIRecorder *metrics.OpenAIRecorder,
|
||||
gpuInfo *gpuinfo.GPUInfo,
|
||||
sysMemInfo memory.SystemMemoryInfo,
|
||||
) *loader {
|
||||
// Compute the number of runner slots to allocate. Because of RAM and VRAM
|
||||
// limitations, it's unlikely that we'll ever be able to fully populate
|
||||
|
|
@ -135,32 +134,7 @@ func newLoader(
|
|||
}
|
||||
|
||||
// Compute the amount of available memory.
|
||||
// TODO(p1-0tr): improve error handling
|
||||
vramSize, err := gpuInfo.GetVRAMSize()
|
||||
if err != nil {
|
||||
vramSize = 1
|
||||
log.Warnf("Could not read VRAM size: %s", err)
|
||||
} else {
|
||||
log.Infof("Running on system with %dMB VRAM", vramSize/1024/1024)
|
||||
}
|
||||
ramSize := uint64(1)
|
||||
hostInfo, err := sysinfo.Host()
|
||||
if err != nil {
|
||||
log.Warnf("Could not read host info: %s", err)
|
||||
} else {
|
||||
ram, err := hostInfo.Memory()
|
||||
if err != nil {
|
||||
log.Warnf("Could not read host RAM size: %s", err)
|
||||
} else {
|
||||
ramSize = ram.Total
|
||||
log.Infof("Running on system with %dMB RAM", ramSize/1024/1024)
|
||||
}
|
||||
}
|
||||
|
||||
totalMemory := inference.RequiredMemory{
|
||||
RAM: ramSize,
|
||||
VRAM: vramSize,
|
||||
}
|
||||
totalMemory := sysMemInfo.GetTotalMemory()
|
||||
|
||||
// Create the loader.
|
||||
l := &loader{
|
||||
|
|
@ -420,12 +394,13 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string
|
|||
if rc, ok := l.runnerConfigs[runnerKey{backendName, modelID, mode}]; ok {
|
||||
runnerConfig = &rc
|
||||
}
|
||||
memory, err := backend.GetRequiredMemoryForModel(modelID, runnerConfig)
|
||||
if errors.Is(err, inference.ErrGGUFParse) {
|
||||
memory, err := backend.GetRequiredMemoryForModel(ctx, modelID, runnerConfig)
|
||||
var parseErr *inference.ErrGGUFParse
|
||||
if errors.As(err, &parseErr) {
|
||||
// TODO(p1-0tr): For now override memory checks in case model can't be parsed
|
||||
// e.g. model is too new for gguf-parser-go to know. We should provide a cleaner
|
||||
// way to bypass these checks.
|
||||
l.log.Warnf("Could not parse model(%s), memory checks will be ignored for it.", modelID)
|
||||
l.log.Warnf("Could not parse model(%s), memory checks will be ignored for it. Error: %s", modelID, parseErr)
|
||||
memory = &inference.RequiredMemory{
|
||||
RAM: 0,
|
||||
VRAM: 0,
|
||||
|
|
@ -433,7 +408,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string
|
|||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l.log.Infof("Loading %s, which will require %dMB RAM and %dMB VRAM", modelID, memory.RAM/1024/1024, memory.VRAM/1024/1024)
|
||||
l.log.Infof("Loading %s, which will require %d MB RAM and %d MB VRAM on a system with %d MB RAM and %d MB VRAM", modelID, memory.RAM/1024/1024, memory.VRAM/1024/1024, l.totalMemory.RAM/1024/1024, l.totalMemory.VRAM/1024/1024)
|
||||
if l.totalMemory.RAM == 1 {
|
||||
l.log.Warnf("RAM size unknown. Assume model will fit, but only one.")
|
||||
memory.RAM = 1
|
||||
|
|
|
|||
|
|
@ -13,8 +13,8 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/docker/model-distribution/distribution"
|
||||
"github.com/docker/model-runner/pkg/gpuinfo"
|
||||
"github.com/docker/model-runner/pkg/inference"
|
||||
"github.com/docker/model-runner/pkg/inference/memory"
|
||||
"github.com/docker/model-runner/pkg/inference/models"
|
||||
"github.com/docker/model-runner/pkg/logging"
|
||||
"github.com/docker/model-runner/pkg/metrics"
|
||||
|
|
@ -56,7 +56,7 @@ func NewScheduler(
|
|||
httpClient *http.Client,
|
||||
allowedOrigins []string,
|
||||
tracker *metrics.Tracker,
|
||||
gpuInfo *gpuinfo.GPUInfo,
|
||||
sysMemInfo memory.SystemMemoryInfo,
|
||||
) *Scheduler {
|
||||
openAIRecorder := metrics.NewOpenAIRecorder(log.WithField("component", "openai-recorder"), modelManager)
|
||||
|
||||
|
|
@ -67,7 +67,7 @@ func NewScheduler(
|
|||
defaultBackend: defaultBackend,
|
||||
modelManager: modelManager,
|
||||
installer: newInstaller(log, backends, httpClient),
|
||||
loader: newLoader(log, backends, modelManager, openAIRecorder, gpuInfo),
|
||||
loader: newLoader(log, backends, modelManager, openAIRecorder, sysMemInfo),
|
||||
router: http.NewServeMux(),
|
||||
tracker: tracker,
|
||||
openAIRecorder: openAIRecorder,
|
||||
|
|
|
|||
Loading…
Reference in New Issue