inference: Support memory estimation for remote models
Signed-off-by: Piotr Stankiewicz <piotr.stankiewicz@docker.com>
This commit is contained in:
parent
59da65a365
commit
880818f741
|
|
@ -88,5 +88,5 @@ type Backend interface {
|
|||
GetDiskUsage() (int64, error)
|
||||
// GetRequiredMemoryForModel returns the required working memory for a given
|
||||
// model.
|
||||
GetRequiredMemoryForModel(model string, config *BackendConfiguration) (*RequiredMemory, error)
|
||||
GetRequiredMemoryForModel(ctx context.Context, model string, config *BackendConfiguration) (*RequiredMemory, error)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,8 +15,10 @@ import (
|
|||
"runtime"
|
||||
"strings"
|
||||
|
||||
v1 "github.com/google/go-containerregistry/pkg/v1"
|
||||
parser "github.com/gpustack/gguf-parser-go"
|
||||
|
||||
"github.com/docker/model-distribution/types"
|
||||
"github.com/docker/model-runner/pkg/diskusage"
|
||||
"github.com/docker/model-runner/pkg/inference"
|
||||
"github.com/docker/model-runner/pkg/inference/config"
|
||||
|
|
@ -223,23 +225,23 @@ func (l *llamaCpp) GetDiskUsage() (int64, error) {
|
|||
return size, nil
|
||||
}
|
||||
|
||||
func (l *llamaCpp) GetRequiredMemoryForModel(model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
|
||||
mdl, err := l.modelManager.GetModel(model)
|
||||
func (l *llamaCpp) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
|
||||
var mdlGguf *parser.GGUFFile
|
||||
var mdlConfig types.Config
|
||||
inStore, err := l.modelManager.IsModelInStore(model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting model(%s): %w", model, err)
|
||||
return nil, fmt.Errorf("checking if model is in local store: %w", err)
|
||||
}
|
||||
mdlPath, err := mdl.GGUFPath()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting gguf path for model(%s): %w", model, err)
|
||||
}
|
||||
mdlGguf, err := parser.ParseGGUFFile(mdlPath)
|
||||
if err != nil {
|
||||
l.log.Warnf("Failed to parse gguf(%s): %s", mdlPath, err)
|
||||
return nil, inference.ErrGGUFParse
|
||||
}
|
||||
mdlConfig, err := mdl.Config()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("accessing model(%s) config: %w", model, err)
|
||||
if inStore {
|
||||
mdlGguf, mdlConfig, err = l.parseLocalModel(model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing local gguf: %w", err)
|
||||
}
|
||||
} else {
|
||||
mdlGguf, mdlConfig, err = l.parseRemoteModel(ctx, model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing remote model: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
contextSize := GetContextSize(&mdlConfig, config)
|
||||
|
|
@ -277,6 +279,71 @@ func (l *llamaCpp) GetRequiredMemoryForModel(model string, config *inference.Bac
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (l *llamaCpp) parseLocalModel(model string) (*parser.GGUFFile, types.Config, error) {
|
||||
mdl, err := l.modelManager.GetModel(model)
|
||||
if err != nil {
|
||||
return nil, types.Config{}, fmt.Errorf("getting model(%s): %w", model, err)
|
||||
}
|
||||
mdlPath, err := mdl.GGUFPath()
|
||||
if err != nil {
|
||||
return nil, types.Config{}, fmt.Errorf("getting gguf path for model(%s): %w", model, err)
|
||||
}
|
||||
mdlGguf, err := parser.ParseGGUFFile(mdlPath)
|
||||
if err != nil {
|
||||
return nil, types.Config{}, fmt.Errorf("parsing gguf(%s): %w", mdlPath, err)
|
||||
}
|
||||
mdlConfig, err := mdl.Config()
|
||||
if err != nil {
|
||||
return nil, types.Config{}, fmt.Errorf("accessing model(%s) config: %w", model, err)
|
||||
}
|
||||
return mdlGguf, mdlConfig, nil
|
||||
}
|
||||
|
||||
func (l *llamaCpp) parseRemoteModel(ctx context.Context, model string) (*parser.GGUFFile, types.Config, error) {
|
||||
mdl, err := l.modelManager.GetRemoteModel(ctx, model)
|
||||
if err != nil {
|
||||
return nil, types.Config{}, fmt.Errorf("getting remote model(%s): %w", model, err)
|
||||
}
|
||||
layers, err := mdl.Layers()
|
||||
if err != nil {
|
||||
return nil, types.Config{}, fmt.Errorf("getting layers of model(%s): %w", model, err)
|
||||
}
|
||||
var ggufDigest v1.Hash
|
||||
for _, layer := range layers {
|
||||
mt, err := layer.MediaType()
|
||||
if err != nil {
|
||||
return nil, types.Config{}, fmt.Errorf("getting media type of model(%s) layer: %w", model, err)
|
||||
}
|
||||
if mt == types.MediaTypeGGUF {
|
||||
ggufDigest, err = layer.Digest()
|
||||
if err != nil {
|
||||
return nil, types.Config{}, fmt.Errorf("getting digest of GGUF layer for model(%s): %w", model, err)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if ggufDigest.String() == "" {
|
||||
return nil, types.Config{}, fmt.Errorf("model(%s) has no GGUF layer", model)
|
||||
}
|
||||
blobURL, err := l.modelManager.GetRemoteModelBlobURL(model, ggufDigest)
|
||||
if err != nil {
|
||||
return nil, types.Config{}, fmt.Errorf("getting GGUF blob URL for model(%s): %w", model, err)
|
||||
}
|
||||
tok, err := l.modelManager.BearerTokenForModel(ctx, model)
|
||||
if err != nil {
|
||||
return nil, types.Config{}, fmt.Errorf("getting bearer token for model(%s): %w", model, err)
|
||||
}
|
||||
mdlGguf, err := parser.ParseGGUFFileRemote(ctx, blobURL, parser.UseBearerAuth(tok))
|
||||
if err != nil {
|
||||
return nil, types.Config{}, fmt.Errorf("parsing GGUF for model(%s): %w", model, err)
|
||||
}
|
||||
config, err := mdl.Config()
|
||||
if err != nil {
|
||||
return nil, types.Config{}, fmt.Errorf("getting config for model(%s): %w", model, err)
|
||||
}
|
||||
return mdlGguf, config, nil
|
||||
}
|
||||
|
||||
func (l *llamaCpp) checkGPUSupport(ctx context.Context) bool {
|
||||
binPath := l.vendoredServerStoragePath
|
||||
if l.updatedLlamaCpp {
|
||||
|
|
|
|||
|
|
@ -63,6 +63,6 @@ func (m *mlx) GetDiskUsage() (int64, error) {
|
|||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mlx) GetRequiredMemoryForModel(model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
|
||||
func (m *mlx) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -63,6 +63,6 @@ func (v *vLLM) GetDiskUsage() (int64, error) {
|
|||
return 0, nil
|
||||
}
|
||||
|
||||
func (v *vLLM) GetRequiredMemoryForModel(model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
|
||||
func (v *vLLM) GetRequiredMemoryForModel(ctx context.Context, model string, config *inference.BackendConfiguration) (*inference.RequiredMemory, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import (
|
|||
"github.com/docker/model-runner/pkg/diskusage"
|
||||
"github.com/docker/model-runner/pkg/inference"
|
||||
"github.com/docker/model-runner/pkg/logging"
|
||||
v1 "github.com/google/go-containerregistry/pkg/v1"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
|
@ -562,6 +563,11 @@ func (m *Manager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
m.router.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// IsModelInStore checks if a given model is in the local store.
|
||||
func (m *Manager) IsModelInStore(ref string) (bool, error) {
|
||||
return m.distributionClient.IsModelInStore(ref)
|
||||
}
|
||||
|
||||
// GetModel returns a single model.
|
||||
func (m *Manager) GetModel(ref string) (types.Model, error) {
|
||||
model, err := m.distributionClient.GetModel(ref)
|
||||
|
|
@ -571,6 +577,33 @@ func (m *Manager) GetModel(ref string) (types.Model, error) {
|
|||
return model, err
|
||||
}
|
||||
|
||||
// GetRemoteModel returns a single remote model.
|
||||
func (m *Manager) GetRemoteModel(ctx context.Context, ref string) (types.ModelArtifact, error) {
|
||||
model, err := m.registryClient.Model(ctx, ref)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error while getting remote model: %w", err)
|
||||
}
|
||||
return model, nil
|
||||
}
|
||||
|
||||
// GetRemoteModelBlobURL returns the URL of a given model blob.
|
||||
func (m *Manager) GetRemoteModelBlobURL(ref string, digest v1.Hash) (string, error) {
|
||||
blobURL, err := m.registryClient.BlobURL(ref, digest)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error while getting remote model blob URL: %w", err)
|
||||
}
|
||||
return blobURL, nil
|
||||
}
|
||||
|
||||
// BearerTokenForModel returns the bearer token needed to pull a given model.
|
||||
func (m *Manager) BearerTokenForModel(ctx context.Context, ref string) (string, error) {
|
||||
tok, err := m.registryClient.BearerToken(ctx, ref)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error while getting bearer token for model: %w", err)
|
||||
}
|
||||
return tok, nil
|
||||
}
|
||||
|
||||
// GetModelPath returns the path to a model's files.
|
||||
func (m *Manager) GetModelPath(ref string) (string, error) {
|
||||
model, err := m.GetModel(ref)
|
||||
|
|
|
|||
|
|
@ -420,7 +420,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string
|
|||
if rc, ok := l.runnerConfigs[runnerKey{backendName, modelID, mode}]; ok {
|
||||
runnerConfig = &rc
|
||||
}
|
||||
memory, err := backend.GetRequiredMemoryForModel(modelID, runnerConfig)
|
||||
memory, err := backend.GetRequiredMemoryForModel(ctx, modelID, runnerConfig)
|
||||
if errors.Is(err, inference.ErrGGUFParse) {
|
||||
// TODO(p1-0tr): For now override memory checks in case model can't be parsed
|
||||
// e.g. model is too new for gguf-parser-go to know. We should provide a cleaner
|
||||
|
|
|
|||
Loading…
Reference in New Issue