vllm/tests/models/multimodal/test_mapping.py

87 lines
3.2 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import pytest
import torch
import transformers
from transformers import AutoConfig, PreTrainedModel
from vllm.config import ModelConfig
from vllm.model_executor.models.utils import WeightsMapper
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.transformers_utils.config import try_get_safetensors_metadata
from ..registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS
def create_repo_dummy_weights(repo: str) -> Iterable[tuple[str, torch.Tensor]]:
"""Create weights from safetensors checkpoint metadata"""
metadata = try_get_safetensors_metadata(repo)
weight_names = list(metadata.weight_map.keys())
with torch.device('meta'):
return ((name, torch.empty(0)) for name in weight_names)
def create_model_dummy_weights(
repo: str,
model_arch: str,
) -> Iterable[tuple[str, torch.Tensor]]:
"""
Create weights from a dummy meta deserialized hf model with name conversion
"""
model_cls: PreTrainedModel = getattr(transformers, model_arch)
config = AutoConfig.from_pretrained(repo)
with torch.device("meta"):
model: PreTrainedModel = model_cls._from_config(config)
return model.named_parameters()
def model_architectures_for_test() -> list[str]:
arch_to_test = list[str]()
for model_arch, info in _MULTIMODAL_EXAMPLE_MODELS.items():
if not info.trust_remote_code and hasattr(transformers, model_arch):
model_cls: PreTrainedModel = getattr(transformers, model_arch)
if getattr(model_cls, "_checkpoint_conversion_mapping", None):
arch_to_test.append(model_arch)
return arch_to_test
@pytest.mark.core_model
@pytest.mark.parametrize("model_arch", model_architectures_for_test())
def test_hf_model_weights_mapper(model_arch: str):
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
model_id = model_info.default
model_config = ModelConfig(
model_id,
task="auto",
tokenizer=model_info.tokenizer or model_id,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
seed=0,
dtype="auto",
revision=None,
hf_overrides=model_info.hf_overrides,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
original_weights = create_repo_dummy_weights(model_id)
hf_converted_weights = create_model_dummy_weights(model_id, model_arch)
mapper: WeightsMapper = model_cls.hf_to_vllm_mapper
mapped_original_weights = mapper.apply(original_weights)
mapped_hf_converted_weights = mapper.apply(hf_converted_weights)
ref_weight_names = set(map(lambda x: x[0], mapped_original_weights))
weight_names = set(map(lambda x: x[0], mapped_hf_converted_weights))
weights_missing = ref_weight_names - weight_names
weights_unmapped = weight_names - ref_weight_names
assert (not weights_missing and not weights_unmapped), (
f"Following weights are not mapped correctly: {weights_unmapped}, "
f"Missing expected weights: {weights_missing}.")