# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import OrderedDict from typing import NamedTuple, Optional from unittest.mock import patch import pytest from huggingface_hub.utils import HfHubHTTPError from torch import nn from vllm.lora.utils import (get_adapter_absolute_path, parse_fine_tuned_lora_name, replace_submodule) from vllm.model_executor.models.utils import WeightsMapper class LoRANameParserTestConfig(NamedTuple): name: str module_name: str is_lora_a: bool is_bias: bool weights_mapper: Optional[WeightsMapper] = None def test_parse_fine_tuned_lora_name_valid(): fixture = [ LoRANameParserTestConfig("base_model.model.lm_head.lora_A.weight", "lm_head", True, False), LoRANameParserTestConfig("base_model.model.lm_head.lora_B.weight", "lm_head", False, False), LoRANameParserTestConfig( "base_model.model.model.embed_tokens.lora_embedding_A", "model.embed_tokens", True, False, ), LoRANameParserTestConfig( "base_model.model.model.embed_tokens.lora_embedding_B", "model.embed_tokens", False, False, ), LoRANameParserTestConfig( "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", "model.layers.9.mlp.down_proj", True, False, ), LoRANameParserTestConfig( "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", "model.layers.9.mlp.down_proj", False, False, ), LoRANameParserTestConfig( "language_model.layers.9.mlp.down_proj.lora_A.weight", "language_model.layers.9.mlp.down_proj", True, False, ), LoRANameParserTestConfig( "language_model.layers.9.mlp.down_proj.lora_B.weight", "language_model.layers.9.mlp.down_proj", False, False, ), # Test with WeightsMapper LoRANameParserTestConfig( "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", "language_model.model.layers.9.mlp.down_proj", True, False, weights_mapper=WeightsMapper( orig_to_new_prefix={"model.": "language_model.model."}), ), LoRANameParserTestConfig( "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", "language_model.model.layers.9.mlp.down_proj", False, False, weights_mapper=WeightsMapper( orig_to_new_prefix={"model.": "language_model.model."}), ), LoRANameParserTestConfig( "model.layers.9.mlp.down_proj.lora_A.weight", "language_model.model.layers.9.mlp.down_proj", True, False, weights_mapper=WeightsMapper( orig_to_new_prefix={"model.": "language_model.model."}), ), LoRANameParserTestConfig( "model.layers.9.mlp.down_proj.lora_B.weight", "language_model.model.layers.9.mlp.down_proj", False, False, weights_mapper=WeightsMapper( orig_to_new_prefix={"model.": "language_model.model."}), ), ] for name, module_name, is_lora_a, is_bias, weights_mapper in fixture: assert (module_name, is_lora_a, is_bias) == parse_fine_tuned_lora_name(name, weights_mapper) def test_parse_fine_tuned_lora_name_invalid(): fixture = { "base_model.weight", "base_model.model.weight", } for name in fixture: with pytest.raises(ValueError, match="unsupported LoRA weight"): parse_fine_tuned_lora_name(name) def test_replace_submodule(): model = nn.Sequential( OrderedDict([ ("dense1", nn.Linear(764, 100)), ("act1", nn.ReLU()), ("dense2", nn.Linear(100, 50)), ( "seq1", nn.Sequential( OrderedDict([ ("dense1", nn.Linear(100, 10)), ("dense2", nn.Linear(10, 50)), ])), ), ("act2", nn.ReLU()), ("output", nn.Linear(50, 10)), ("outact", nn.Sigmoid()), ])) sigmoid = nn.Sigmoid() replace_submodule(model, "act1", sigmoid) assert dict(model.named_modules())["act1"] == sigmoid dense2 = nn.Linear(1, 5) replace_submodule(model, "seq1.dense2", dense2) assert dict(model.named_modules())["seq1.dense2"] == dense2 # Unit tests for get_adapter_absolute_path @patch('os.path.isabs') def test_get_adapter_absolute_path_absolute(mock_isabs): path = '/absolute/path/to/lora' mock_isabs.return_value = True assert get_adapter_absolute_path(path) == path @patch('os.path.expanduser') def test_get_adapter_absolute_path_expanduser(mock_expanduser): # Path with ~ that needs to be expanded path = '~/relative/path/to/lora' absolute_path = '/home/user/relative/path/to/lora' mock_expanduser.return_value = absolute_path assert get_adapter_absolute_path(path) == absolute_path @patch('os.path.exists') @patch('os.path.abspath') def test_get_adapter_absolute_path_local_existing(mock_abspath, mock_exist): # Relative path that exists locally path = 'relative/path/to/lora' absolute_path = '/absolute/path/to/lora' mock_exist.return_value = True mock_abspath.return_value = absolute_path assert get_adapter_absolute_path(path) == absolute_path @patch('huggingface_hub.snapshot_download') @patch('os.path.exists') def test_get_adapter_absolute_path_huggingface(mock_exist, mock_snapshot_download): # Hugging Face model identifier path = 'org/repo' absolute_path = '/mock/snapshot/path' mock_exist.return_value = False mock_snapshot_download.return_value = absolute_path assert get_adapter_absolute_path(path) == absolute_path @patch('huggingface_hub.snapshot_download') @patch('os.path.exists') def test_get_adapter_absolute_path_huggingface_error(mock_exist, mock_snapshot_download): # Hugging Face model identifier with download error path = 'org/repo' mock_exist.return_value = False mock_snapshot_download.side_effect = HfHubHTTPError( "failed to query model info") assert get_adapter_absolute_path(path) == path