# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import random from copy import deepcopy from dataclasses import dataclass from typing import Optional from unittest.mock import patch import pytest import torch import torch.nn.functional as F from vllm.config import LoRAConfig from vllm.lora.fully_sharded_layers import ( ColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithShardedLoRA, MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA) # yapf conflicts with isort for this block # yapf: disable from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, LogitsProcessorWithLoRA, LoRAMapping, MergedColumnParallelLinearWithLoRA, MergedQKVParallelLinearWithLoRA, QKVParallelLinearWithLoRA, ReplicatedLinearWithLoRA, RowParallelLinearWithLoRA, VocabParallelEmbeddingWithLoRA) # yapf: enable from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.punica_wrapper import get_punica_wrapper from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask) from vllm.model_executor.utils import set_random_seed from vllm.platforms import current_platform from .utils import DummyLoRAManager TOLERANCES = { torch.float16: (5e-3, 5e-3), torch.float32: (5e-3, 5e-3), torch.bfloat16: (3e-2, 2e-2), } pytestmark = pytest.mark.skipif( not (current_platform.is_cuda_alike() or current_platform.is_cpu()), reason="Backend not supported") DEVICES = ([ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] if current_platform.is_cuda_alike() else ["cpu"]) # prefill stage(True) or decode stage(False) STAGES = [True, False] NUM_RANDOM_SEEDS = 6 VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128 @pytest.fixture(autouse=True) def clean_cache_reset_device(reset_default_device): # Release any memory we might be holding on to. CI runs OOMs otherwise. from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT, _LORA_B_PTR_DICT) _LORA_B_PTR_DICT.clear() _LORA_A_PTR_DICT.clear() yield @pytest.fixture(autouse=True) def skip_cuda_with_stage_false(request): """ On cuda-like platforms, we use the same kernels for prefill and decode stage, and 'stage' is generally ignored, so we only need to test once. """ if current_platform.is_cuda_alike(): try: if hasattr(request.node, "callspec") and hasattr( request.node.callspec, "params"): params = request.node.callspec.params if "stage" in params and params["stage"] is False: pytest.skip("Skip test when stage=False") except Exception: pass yield def get_random_id_to_index(num_loras: int, num_slots: int, log: bool = True) -> list[Optional[int]]: """Creates a random lora_id_to_index mapping. Args: num_loras: The number of active loras in the mapping. num_slots: The number of slots in the mapping. Must be larger than num_loras. log: Whether to log the output. """ if num_loras > num_slots: raise ValueError( f"num_loras is higher than num_slots: {num_loras} > {num_slots}. " "num_loras must be less than or equal to num_slots.") slots: list[Optional[int]] = [None] * num_slots random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist() for lora_id, slot_idx in enumerate(random_slot_selections, start=1): slots[slot_idx] = lora_id if log: print(f"Created lora_id_to_index mapping: {slots}.") return slots def populate_loras( id_to_index: list[Optional[int]], layer: BaseLayerWithLoRA, layer_weights: torch.Tensor, generate_embeddings_tensor: int = 0, repeats: int = 1, ) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]: """This method populates the lora layers with lora weights. Args: id_to_index: a list of lora ids. The index of the lora id represents which memory slot the lora matrices are stored in. A None value indicates a free slot. layer: the LoRAlayer to populate. layer_weights: the PyTorch tensor containing the layer's weights. generate_embeddings_tensor: whether to generate an embeddings tensor for each LoRA. repeats: must only be set for column parallel packed layers. Indicates the number of loras to compose together to create a single lora layer. """ # Dictionary that maps the lora ID to the # corresponding lora weights. lora_dict: dict[int, LoRALayerWeights] = dict() # Dictionary that maps the lora ID to the # corresponding subloras. sublora_dict: dict[int, list[LoRALayerWeights]] = dict() for slot_idx, lora_id in enumerate(id_to_index): if lora_id is not None: subloras: list[LoRALayerWeights] = [] sublora_len = layer_weights.shape[0] // repeats for i in range(repeats): sublora = DummyLoRAManager( layer_weights.device).init_random_lora( module_name=f"fake_{i}", weight=layer_weights, generate_embeddings_tensor=generate_embeddings_tensor, ) sublora.lora_b = sublora.lora_b[:, (sublora_len * i):(sublora_len * (i + 1))] sublora.optimize() subloras.append(sublora) lora = PackedLoRALayerWeights.pack( subloras) if repeats > 1 else subloras[0] layer.set_lora( slot_idx, lora_a=lora.lora_a, lora_b=lora.lora_b, embeddings_tensor=lora.embeddings_tensor, ) lora_dict[lora_id] = lora sublora_dict[lora_id] = subloras return lora_dict, sublora_dict def create_random_inputs( active_lora_ids: list[int], num_inputs: int, input_size: tuple[int, ...], input_range: tuple[float, float], input_type: torch.dtype = torch.int, device: torch.device = "cuda" ) -> tuple[list[torch.Tensor], list[int], list[int]]: """Creates random inputs. Args: active_lora_ids: lora IDs of active lora weights. num_inputs: the number of inputs to create. input_size: the size of each individual input. input_range: the range of values to include in the input. input_range[0] <= possible input values < input_range[1] input_type: the type of values in the input. """ low, high = input_range inputs: list[torch.Tensor] = [] index_mapping: list[int] = [] prompt_mapping: list[int] = [] for _ in range(num_inputs): if input_type == torch.int: inputs.append( torch.randint(low=int(low), high=int(high), size=input_size, device=device)) else: inputs.append( torch.rand(size=input_size, dtype=input_type, device=device) * high + low) lora_id = random.choice(active_lora_ids) index_mapping += [lora_id] * input_size[0] prompt_mapping += [lora_id] return inputs, index_mapping, prompt_mapping def check_punica_wrapper(punica_wrapper) -> bool: if current_platform.is_cuda_alike(): from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU return type(punica_wrapper) is PunicaWrapperGPU elif current_platform.is_cpu(): from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU return type(punica_wrapper) is PunicaWrapperCPU else: return False @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("stage", STAGES) def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: # For multi-GPU testing of Triton kernel, we must explicitly set the CUDA # device, see: https://github.com/triton-lang/triton/issues/2925 # Same below. if current_platform.is_cuda_alike(): torch.cuda.set_device(device) torch.set_default_device(device) max_loras = 8 punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16) def create_random_embedding_layer(): embedding = VocabParallelEmbedding(vocab_size, 256) embedding.weight.data = torch.rand_like(embedding.weight.data) embedding.weight.data[vocab_size:, :] = 0 lora_embedding = VocabParallelEmbeddingWithLoRA(embedding) lora_embedding.create_lora_weights(max_loras, lora_config) return embedding, lora_embedding for i in range(NUM_RANDOM_SEEDS): set_random_seed(i) id_to_index = get_random_id_to_index(num_loras, max_loras) embedding, lora_embedding = create_random_embedding_layer() lora_embedding.set_mapping(punica_wrapper) lora_dict, _ = populate_loras( id_to_index, layer=lora_embedding, layer_weights=embedding.weight.T, ) inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=list(lora_dict.keys()), num_inputs=num_loras * 3, input_size=(200, ), input_range=(1, vocab_size), device=device) lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, vocab_size, lora_config.lora_extra_vocab_size) lora_result = lora_embedding(torch.cat(inputs)) expected_results: list[torch.Tensor] = [] for input_, lora_id in zip(inputs, prompt_mapping): lora = lora_dict[lora_id] result = embedding(input_) after_a = F.embedding( input_, lora.lora_a, ) result += (after_a @ lora.lora_b) expected_results.append(result) expected_result = torch.cat(expected_results) rtol, atol = TOLERANCES[lora_result.dtype] torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) # Check that resetting the lora weights succeeds for slot_idx in range(max_loras): lora_embedding.reset_lora(slot_idx) inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=[0], num_inputs=num_loras * 3, input_size=(200, ), input_range=(1, vocab_size), device=device) lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, vocab_size, lora_config.lora_extra_vocab_size) lora_result = lora_embedding(torch.cat(inputs)) expected_result = embedding(torch.cat(inputs)) rtol, atol = TOLERANCES[lora_result.dtype] torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) @torch.inference_mode() # @pytest.mark.skip( # reason="Fails when loras are in any slot other than the first.") @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("stage", STAGES) def test_embeddings_with_new_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) torch.set_default_device(device) max_loras = 8 punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16) def create_random_embedding_layer(): embedding = VocabParallelEmbedding(vocab_size, 256) embedding_data = torch.rand_like(embedding.weight.data) embedding.weight.data = embedding_data embedding.weight.data[vocab_size:, :] = 0 expanded_embedding = VocabParallelEmbedding( vocab_size + lora_config.lora_extra_vocab_size * max_loras, 256, org_num_embeddings=vocab_size) expanded_embedding.weight.data[:vocab_size, :] = embedding_data # We need to deepcopy the embedding as it will be modified # in place lora_embedding = VocabParallelEmbeddingWithLoRA( deepcopy(expanded_embedding)) lora_embedding.create_lora_weights(max_loras, lora_config) return expanded_embedding, lora_embedding for i in range(NUM_RANDOM_SEEDS): set_random_seed(i) id_to_index = get_random_id_to_index(num_loras, max_loras) expanded_embedding, lora_embedding = create_random_embedding_layer() lora_dict, _ = populate_loras( id_to_index, layer=lora_embedding, layer_weights=torch.zeros( (256, vocab_size + lora_config.lora_extra_vocab_size)), generate_embeddings_tensor=256, ) lora_embedding.set_mapping(punica_wrapper) # All embeddings tensors have the same shape. embeddings_tensors = [ lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys()) ] embeddings_tensor_len = embeddings_tensors[0].shape[0] # Add empty embeddings_tensors for unoccupied lora slots. for _ in range(max_loras - len(embeddings_tensors)): embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape)) inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=list(lora_dict.keys()), num_inputs=num_loras * 3, input_size=(200, ), input_range=(1, vocab_size), device=device) lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, vocab_size, lora_config.lora_extra_vocab_size) original_inputs = deepcopy(inputs) # Force some of the inputs to be in the extended embeddings range # to guarantee that their behavior is tested. for input_, original_input_, lora_id in zip(inputs, original_inputs, prompt_mapping): embedding_id = lora_id - 1 input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len) original_input_[-1] = vocab_size input_[-2] = vocab_size + ( (embedding_id + 1) * embeddings_tensor_len - 1) original_input_[-2] = vocab_size + embeddings_tensor_len - 1 expanded_embedding.weight[vocab_size:vocab_size + (embeddings_tensor_len * max_loras)] = torch.cat(embeddings_tensors) lora_result = lora_embedding(torch.cat(original_inputs)) expected_results: list[torch.Tensor] = [] for input_, original_input_, lora_id in zip(inputs, original_inputs, prompt_mapping): lora = lora_dict[lora_id] result = expanded_embedding(input_) after_a = F.embedding( original_input_, lora.lora_a, ) result += (after_a @ lora.lora_b) expected_results.append(result) expected_result = torch.cat(expected_results) rtol, atol = TOLERANCES[lora_result.dtype] torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) # Check that resetting the lora weights succeeds for slot_idx in range(max_loras): lora_embedding.reset_lora(slot_idx) inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=[0], num_inputs=num_loras * 3, input_size=(200, ), input_range=(1, vocab_size), device=device) original_inputs = deepcopy(inputs) lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, vocab_size, lora_config.lora_extra_vocab_size) lora_result = lora_embedding(torch.cat(original_inputs)) expected_result = expanded_embedding(torch.cat(inputs)) rtol, atol = TOLERANCES[lora_result.dtype] torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512]) @pytest.mark.parametrize("stage", STAGES) def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, stage) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) torch.set_default_device(device) max_loras = 8 punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16) def _pretest(): linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size, 1024, vocab_size, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data[:, vocab_size:] = 0 logits_processor = LogitsProcessor( vocab_size + lora_config.lora_extra_vocab_size, vocab_size) lora_logits_processor = LogitsProcessorWithLoRA( logits_processor, 1024, linear.weight.dtype, linear.weight.device, None) lora_logits_processor.create_lora_weights(max_loras, lora_config) return linear, logits_processor, lora_logits_processor for i in range(NUM_RANDOM_SEEDS): set_random_seed(i) id_to_index = get_random_id_to_index(num_loras, max_loras) linear, logits_processor, lora_logits_processor = _pretest() lora_logits_processor.set_mapping(punica_wrapper) # NOTE: all the generated loras share the same embeddings tensor. lora_dict, _ = populate_loras( id_to_index, layer=lora_logits_processor, layer_weights=linear.weight, generate_embeddings_tensor=1024, ) embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor embeddings_tensor_len = embeddings_tensor.shape[0] inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=list(lora_dict.keys()), num_inputs=8 * num_loras, # * 3, input_size=(1, 1024), input_range=(0, 1), input_type=torch.float16, device=device) lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( lora_mapping, id_to_index, max_loras, vocab_size, lora_config.lora_extra_vocab_size, ) input_ = torch.rand(20, 1024) lora_result = lora_logits_processor._get_logits( hidden_states=torch.cat(inputs), lm_head=linear, embedding_bias=None) original_lm_head = deepcopy(linear) linear.weight[logits_processor. org_vocab_size:logits_processor.org_vocab_size + embeddings_tensor_len] = embeddings_tensor logits_processor.org_vocab_size = (vocab_size + lora_config.lora_extra_vocab_size) expected_results: list[torch.Tensor] = [] for input_, lora_id in zip(inputs, prompt_mapping): lora = lora_dict[lora_id] result = logits_processor._get_logits(hidden_states=input_, lm_head=linear, embedding_bias=None) result[:, vocab_size + embeddings_tensor_len:] = float("-inf") result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling expected_results.append(result) expected_result = torch.cat(expected_results) logits_processor.org_vocab_size = vocab_size # Check that resetting the lora weights succeeds for slot_idx in range(max_loras): lora_logits_processor.reset_lora(slot_idx) inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=[0], num_inputs=8 * num_loras * 3, input_size=(1, 1024), input_range=(0, 1), input_type=torch.float16, device=device) lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( lora_mapping, id_to_index, max_loras, vocab_size, lora_config.lora_extra_vocab_size, ) lora_result = lora_logits_processor._get_logits( hidden_states=torch.cat(inputs), lm_head=original_lm_head, embedding_bias=None)[:, :vocab_size] expected_result = logits_processor._get_logits( hidden_states=torch.cat(inputs), lm_head=original_lm_head, embedding_bias=None) rtol, atol = TOLERANCES[lora_result.dtype] torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) @pytest.mark.parametrize("bias_enabled", [True, False]) def test_linear_replicated(dist_init, num_loras, device, stage, bias_enabled) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) max_loras = 8 torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16, bias_enabled=bias_enabled) def create_random_linear_replicated_layer(): linear = ReplicatedLinear(4096, 4096, bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) lora_linear = ReplicatedLinearWithLoRA(linear) lora_linear.create_lora_weights(max_loras, lora_config) assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( lora_linear.lora_b_stacked) == 1) if bias_enabled: assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices else: assert lora_linear.lora_bias_stacked is None return linear, lora_linear for i in range(NUM_RANDOM_SEEDS): set_random_seed(i) id_to_index = get_random_id_to_index(num_loras, max_loras) linear, lora_linear = create_random_linear_replicated_layer() assert torch.equal(linear.weight, lora_linear.weight) lora_linear.set_mapping(punica_wrapper) lora_dict, _ = populate_loras( id_to_index, layer=lora_linear, layer_weights=linear.weight, ) inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=list(lora_dict.keys()), num_inputs=32 * num_loras, input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, device=device) lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size, ) lora_result = lora_linear(torch.cat(inputs))[0] expected_results: list[torch.Tensor] = [] for input_, lora_id in zip(inputs, prompt_mapping): lora = lora_dict[lora_id] result = linear(input_)[0] result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling expected_results.append(result) expected_result = torch.cat(expected_results) rtol, atol = TOLERANCES[lora_result.dtype] torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) # Check that resetting the lora weights succeeds for slot_idx in range(max_loras): lora_linear.reset_lora(slot_idx) inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=[0], num_inputs=32 * num_loras, input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, device=device) lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size) lora_result = lora_linear(torch.cat(inputs))[0] expected_result = linear(torch.cat(inputs))[0] rtol, atol = TOLERANCES[lora_result.dtype] torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("orientation", ["row", "column"]) @pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) @pytest.mark.parametrize("bias_enabled", [True, False]) def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, device, stage, bias_enabled) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) max_loras = 8 torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, fully_sharded_loras=fully_shard, lora_dtype=torch.float16, bias_enabled=bias_enabled) def create_random_linear_parallel_layer(): if orientation == "row": linear = RowParallelLinear(4096, 4096, bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) lora_linear = (RowParallelLinearWithLoRA(linear) if not fully_shard else RowParallelLinearWithShardedLoRA(linear)) else: linear = ColumnParallelLinear(4096, 4096, bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) lora_linear = (ColumnParallelLinearWithLoRA(linear) if not fully_shard else ColumnParallelLinearWithShardedLoRA(linear)) lora_linear.create_lora_weights(max_loras, lora_config) assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( lora_linear.lora_b_stacked) == 1) if bias_enabled: assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices else: assert lora_linear.lora_bias_stacked is None return linear, lora_linear for i in range(NUM_RANDOM_SEEDS): set_random_seed(i) id_to_index = get_random_id_to_index(num_loras, max_loras) linear, lora_linear = create_random_linear_parallel_layer() assert torch.equal(linear.weight, lora_linear.weight) lora_linear.set_mapping(punica_wrapper) lora_dict, _ = populate_loras( id_to_index, layer=lora_linear, layer_weights=linear.weight, ) inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=list(lora_dict.keys()), num_inputs=32 * num_loras, input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, device=device) lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size, ) lora_result = lora_linear(torch.cat(inputs))[0] expected_results: list[torch.Tensor] = [] for input_, lora_id in zip(inputs, prompt_mapping): lora = lora_dict[lora_id] result = linear(input_)[0] result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling expected_results.append(result) expected_result = torch.cat(expected_results) rtol, atol = TOLERANCES[lora_result.dtype] torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) # Check that resetting the lora weights succeeds for slot_idx in range(max_loras): lora_linear.reset_lora(slot_idx) inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=[0], num_inputs=32 * num_loras, input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, device=device) lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size) lora_result = lora_linear(torch.cat(inputs))[0] expected_result = linear(torch.cat(inputs))[0] rtol, atol = TOLERANCES[lora_result.dtype] torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("repeats", [1, 2, 3]) @pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) @pytest.mark.parametrize("bias_enabled", [True, False]) def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, device, stage, bias_enabled) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) max_loras = 8 torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, fully_sharded_loras=fully_shard, lora_dtype=torch.float16, bias_enabled=bias_enabled) def create_column_parallel_packed_layer(): if repeats == 2: linear = MergedColumnParallelLinear(4096, [4096] * repeats, bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) lora_linear = (MergedColumnParallelLinearWithLoRA(linear) if not fully_shard else MergedColumnParallelLinearWithShardedLoRA(linear)) elif repeats == 3: linear = QKVParallelLinear(4096, 64, 32, bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) lora_linear = (MergedQKVParallelLinearWithLoRA(linear) if not fully_shard else MergedQKVParallelLinearWithShardedLoRA(linear)) else: linear = QKVParallelLinear(4096, 64, 32, bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) lora_linear = QKVParallelLinearWithLoRA( linear ) if not fully_shard else QKVParallelLinearWithShardedLoRA(linear) @dataclass class FakeConfig: hidden_size = 4096 num_key_value_heads = 32 num_attention_heads = 32 n_slices = repeats lora_linear.create_lora_weights(max_loras, lora_config, model_config=FakeConfig()) assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( lora_linear.lora_b_stacked) == n_slices) if bias_enabled: assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices else: assert lora_linear.lora_bias_stacked is None return linear, lora_linear for i in range(NUM_RANDOM_SEEDS): set_random_seed(i) id_to_index = get_random_id_to_index(num_loras, max_loras) linear, lora_linear = create_column_parallel_packed_layer() assert torch.equal(linear.weight, lora_linear.weight) lora_linear.set_mapping(punica_wrapper) lora_dict, sublora_dict = populate_loras( id_to_index, layer=lora_linear, layer_weights=linear.weight, repeats=repeats, ) inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=list(lora_dict.keys()), num_inputs=32 * num_loras, input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, device=device) lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size, ) lora_result = lora_linear(torch.cat(inputs))[0] expected_results: list[torch.Tensor] = [] for input_, lora_id in zip(inputs, prompt_mapping): result = linear(input_)[0] subloras = sublora_dict[lora_id] for i, sublora in enumerate(subloras): result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] * (i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b * sublora.scaling) expected_results.append(result) expected_result = torch.cat(expected_results) rtol, atol = TOLERANCES[lora_result.dtype] torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) for slot_idx in range(max_loras): lora_linear.reset_lora(slot_idx) inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=[0], num_inputs=32 * num_loras, input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, device=device) lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size, ) lora_result = lora_linear(torch.cat(inputs))[0] expected_result = linear(torch.cat(inputs))[0] rtol, atol = TOLERANCES[lora_result.dtype] torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) @pytest.mark.parametrize("tp_size", [1, 2, 4, 8]) @pytest.mark.parametrize( "seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS))) def test_vocab_parallel_embedding_indices(tp_size, seed): random.seed(seed) vocab_size = random.randint(4000, 64000) added_vocab_size = random.randint(0, 1024) org_vocab_size = vocab_size - added_vocab_size last_org_vocab_end_index = 0 last_added_vocab_end_index = org_vocab_size computed_vocab_size = 0 computed_org_vocab_size = 0 computed_added_vocab_size = 0 vocab_size_padded = -1 all_org_tokens: list[int] = [] all_added_tokens: list[int] = [] token_ids: list[int] = [] for tp_rank in range(tp_size): with patch( "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", return_value=tp_rank ), patch( "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size", return_value=tp_size): vocab_embedding = VocabParallelEmbedding( vocab_size, 1, org_num_embeddings=org_vocab_size) vocab_size_padded = vocab_embedding.num_embeddings_padded shard_indices = vocab_embedding.shard_indices # Assert that the ranges are contiguous assert shard_indices.org_vocab_start_index == last_org_vocab_end_index assert (shard_indices.added_vocab_start_index == last_added_vocab_end_index) # Ensure that we are not exceeding the vocab size computed_vocab_size += shard_indices.num_elements_padded computed_org_vocab_size += shard_indices.num_org_elements computed_added_vocab_size += shard_indices.num_added_elements # Ensure that the ranges are not overlapping all_org_tokens.extend( range(shard_indices.org_vocab_start_index, shard_indices.org_vocab_end_index)) all_added_tokens.extend( range(shard_indices.added_vocab_start_index, shard_indices.added_vocab_end_index)) token_ids.extend( range(shard_indices.org_vocab_start_index, shard_indices.org_vocab_end_index)) token_ids.extend([-1] * (shard_indices.num_org_elements_padded - shard_indices.num_org_elements)) token_ids.extend( range(shard_indices.added_vocab_start_index, shard_indices.added_vocab_end_index)) token_ids.extend([-1] * (shard_indices.num_added_elements_padded - shard_indices.num_added_elements)) last_org_vocab_end_index = shard_indices.org_vocab_end_index last_added_vocab_end_index = shard_indices.added_vocab_end_index assert computed_vocab_size == vocab_size_padded assert computed_org_vocab_size == org_vocab_size assert computed_added_vocab_size == added_vocab_size # Ensure that the ranges are not overlapping assert len(all_org_tokens) == len(set(all_org_tokens)) assert len(all_added_tokens) == len(set(all_added_tokens)) assert not set(all_org_tokens).intersection(set(all_added_tokens)) token_ids_tensor = torch.tensor(token_ids, dtype=torch.long) reindex_mapping = vocab_embedding.get_sharded_to_full_mapping() assert reindex_mapping is not None or tp_size == 1 if reindex_mapping is not None: reindexed_token_ids = token_ids_tensor[reindex_mapping] expected = torch.tensor(list(range(0, vocab_size))) assert reindexed_token_ids[:vocab_size].equal(expected) assert torch.all(reindexed_token_ids[vocab_size:] == -1) def test_get_masked_input_and_mask(): x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) # base tp 1 case, no padding modified_x, _ = get_masked_input_and_mask(x, org_vocab_start_index=0, org_vocab_end_index=8, added_vocab_start_index=8, added_vocab_end_index=12, num_org_vocab_padding=0) assert torch.equal(x, modified_x) # tp 2 case, no padding modified_x_rank_0, _ = get_masked_input_and_mask(x, org_vocab_start_index=0, org_vocab_end_index=4, added_vocab_start_index=8, added_vocab_end_index=10, num_org_vocab_padding=0) modified_x_rank_1, _ = get_masked_input_and_mask( x, org_vocab_start_index=4, org_vocab_end_index=8, added_vocab_start_index=10, added_vocab_end_index=12, num_org_vocab_padding=0) assert torch.equal(modified_x_rank_0, torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0])) assert torch.equal(modified_x_rank_1, torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5])) # tp 4 case, no padding modified_x_rank_0, _ = get_masked_input_and_mask(x, org_vocab_start_index=0, org_vocab_end_index=2, added_vocab_start_index=8, added_vocab_end_index=9, num_org_vocab_padding=0) modified_x_rank_1, _ = get_masked_input_and_mask(x, org_vocab_start_index=2, org_vocab_end_index=4, added_vocab_start_index=9, added_vocab_end_index=10, num_org_vocab_padding=0) modified_x_rank_2, _ = get_masked_input_and_mask( x, org_vocab_start_index=4, org_vocab_end_index=6, added_vocab_start_index=10, added_vocab_end_index=11, num_org_vocab_padding=0) modified_x_rank_3, _ = get_masked_input_and_mask( x, org_vocab_start_index=6, org_vocab_end_index=8, added_vocab_start_index=11, added_vocab_end_index=12, num_org_vocab_padding=0) assert torch.equal(modified_x_rank_0, torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0])) assert torch.equal(modified_x_rank_1, torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0])) assert torch.equal(modified_x_rank_2, torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0])) assert torch.equal(modified_x_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2])) # base tp 1 case, with padding modified_x, _ = get_masked_input_and_mask(x, org_vocab_start_index=0, org_vocab_end_index=8, added_vocab_start_index=8, added_vocab_end_index=12, num_org_vocab_padding=2) assert torch.equal(modified_x, torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13])) # tp 2 case, with padding modified_x_rank_0, _ = get_masked_input_and_mask(x, org_vocab_start_index=0, org_vocab_end_index=4, added_vocab_start_index=8, added_vocab_end_index=10, num_org_vocab_padding=2) modified_x_rank_1, _ = get_masked_input_and_mask( x, org_vocab_start_index=4, org_vocab_end_index=8, added_vocab_start_index=10, added_vocab_end_index=12, num_org_vocab_padding=2) assert torch.equal(modified_x_rank_0, torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0])) assert torch.equal(modified_x_rank_1, torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7])) # tp 4 case, with padding modified_x_rank_0, _ = get_masked_input_and_mask(x, org_vocab_start_index=0, org_vocab_end_index=2, added_vocab_start_index=8, added_vocab_end_index=9, num_org_vocab_padding=2) modified_x_rank_1, _ = get_masked_input_and_mask(x, org_vocab_start_index=2, org_vocab_end_index=4, added_vocab_start_index=9, added_vocab_end_index=10, num_org_vocab_padding=2) modified_x_rank_2, _ = get_masked_input_and_mask( x, org_vocab_start_index=4, org_vocab_end_index=6, added_vocab_start_index=10, added_vocab_end_index=11, num_org_vocab_padding=2) modified_x_rank_3, _ = get_masked_input_and_mask( x, org_vocab_start_index=6, org_vocab_end_index=8, added_vocab_start_index=11, added_vocab_end_index=12, num_org_vocab_padding=2) assert torch.equal(modified_x_rank_0, torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0])) assert torch.equal(modified_x_rank_1, torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0])) assert torch.equal(modified_x_rank_2, torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0])) assert torch.equal(modified_x_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4]))