Add `pt_load_map_location` to allow loading to cuda (#16869)

Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
Jerry Zhang 2025-05-01 23:23:42 -07:00 committed by GitHub
parent f192ca90e6
commit 109e15a335
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 74 additions and 3 deletions

View File

@ -3,6 +3,7 @@ import importlib.metadata
import importlib.util
import pytest
import torch
DTYPE = ["bfloat16"]
@ -21,5 +22,30 @@ def test_pre_quantized_model(vllm_runner):
print(output)
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
@pytest.mark.parametrize(
"pt_load_map_location",
[
"cuda:0",
# {"": "cuda"},
])
def test_opt_125m_int4wo_model_loading_with_params(vllm_runner,
pt_load_map_location):
"""
Test loading roberta-base model with no lm_head.
"""
torch._dynamo.reset()
model_name = "jerryzh168/opt-125m-int4wo"
with vllm_runner(model_name=model_name,
quantization="torchao",
dtype="bfloat16",
pt_load_map_location=pt_load_map_location) as llm:
output = llm.generate_greedy(["The capital of France is"],
max_tokens=32)
assert output
print(output)
if __name__ == "__main__":
pytest.main([__file__])

View File

@ -5,7 +5,8 @@ from typing import Literal, Union
import pytest
from vllm.config import ModelConfig, PoolerConfig, config, get_field
from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig,
config, get_field)
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform
@ -410,3 +411,16 @@ def test_generation_config_loading():
override_generation_config=override_generation_config)
assert model_config.get_diff_sampling_param() == override_generation_config
@pytest.mark.parametrize("pt_load_map_location", [
"cuda",
{
"": "cuda"
},
])
def test_load_config_pt_load_map_location(pt_load_map_location):
load_config = LoadConfig(pt_load_map_location=pt_load_map_location)
config = VllmConfig(load_config=load_config)
assert config.load_config.pt_load_map_location == pt_load_map_location

View File

@ -1564,6 +1564,16 @@ class LoadConfig:
use_tqdm_on_load: bool = True
"""Whether to enable tqdm for showing progress bar when loading model
weights."""
pt_load_map_location: Union[str, dict[str, str]] = "cpu"
"""
pt_load_map_location: the map location for loading pytorch checkpoint, to
support loading checkpoints can only be loaded on certain devices like
"cuda", this is equivalent to {"": "cuda"}. Another supported format is
mapping from different devices like from GPU 1 to GPU 0:
{"cuda:1": "cuda:0"}. Note that when passed from command line, the strings
in dictionary needs to be double quoted for json parsing. For more details,
see original doc for `map_location` in https://pytorch.org/docs/stable/generated/torch.load.html
"""
def compute_hash(self) -> str:
"""

View File

@ -64,6 +64,13 @@ def optional_type(
return _optional_type
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
if not re.match("^{.*}$", val):
return str(val)
else:
return optional_type(json.loads)(val)
@deprecated(
"Passing a JSON argument as a string containing comma separated key=value "
"pairs is deprecated. This will be removed in v0.10.0. Please use a JSON "
@ -187,6 +194,10 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
kwargs[name]["type"] = human_readable_int
elif contains_type(type_hints, float):
kwargs[name]["type"] = float
elif contains_type(type_hints,
dict) and (contains_type(type_hints, str) or any(
is_not_builtin(th) for th in type_hints)):
kwargs[name]["type"] = union_dict_and_str
elif contains_type(type_hints, dict):
# Dict arguments will always be optional
kwargs[name]["type"] = optional_type(json.loads)
@ -371,6 +382,7 @@ class EngineArgs:
reasoning_parser: str = DecodingConfig.reasoning_backend
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
pt_load_map_location: str = LoadConfig.pt_load_map_location
def __post_init__(self):
# support `EngineArgs(compilation_config={...})`
@ -491,6 +503,8 @@ class EngineArgs:
type=str,
default=None,
help='Name or path of the QLoRA adapter.')
load_group.add_argument('--pt-load-map-location',
**load_kwargs["pt_load_map_location"])
# Guided decoding arguments
guided_decoding_kwargs = get_kwargs(DecodingConfig)
@ -883,12 +897,14 @@ class EngineArgs:
if self.quantization == "bitsandbytes":
self.load_format = "bitsandbytes"
return LoadConfig(
load_format=self.load_format,
download_dir=self.download_dir,
model_loader_extra_config=self.model_loader_extra_config,
ignore_patterns=self.ignore_patterns,
use_tqdm_on_load=self.use_tqdm_on_load,
pt_load_map_location=self.pt_load_map_location,
)
def create_speculative_config(
@ -1513,7 +1529,7 @@ def _warn_or_fallback(feature_name: str) -> bool:
def human_readable_int(value):
"""Parse human-readable integers like '1k', '2M', etc.
Including decimal values with decimal multipliers.
Examples:
- '1k' -> 1,000
- '1K' -> 1,024

View File

@ -384,6 +384,7 @@ class DefaultModelLoader(BaseModelLoader):
weights_iterator = pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.pt_load_map_location,
)
if current_platform.is_tpu():
@ -890,6 +891,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
iterator = pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.pt_load_map_location,
)
for org_name, param in iterator:
# mapping weight names from transformers to vllm while preserving

View File

@ -502,6 +502,7 @@ def fastsafetensors_weights_iterator(
def pt_weights_iterator(
hf_weights_files: List[str],
use_tqdm_on_load: bool,
pt_load_map_location: Union[str, dict[str, str]] = "cpu",
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model bin/pt files."""
for bin_file in tqdm(
@ -510,7 +511,9 @@ def pt_weights_iterator(
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file, map_location="cpu", weights_only=True)
state = torch.load(bin_file,
map_location=pt_load_map_location,
weights_only=True)
yield from state.items()
del state