# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import multiprocessing import os import random import pytest import torch import torch.distributed from vllm.distributed.eplb.rebalance_execute import ( rearrange_expert_weights_inplace) from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, get_tp_group, init_distributed_environment) from vllm.utils import update_environment_variables def distributed_run(fn, world_size): number_of_processes = world_size processes: list[multiprocessing.Process] = [] for i in range(number_of_processes): env: dict[str, str] = {} env['RANK'] = str(i) env['LOCAL_RANK'] = str(i) env['WORLD_SIZE'] = str(number_of_processes) env['LOCAL_WORLD_SIZE'] = str(number_of_processes) env['MASTER_ADDR'] = 'localhost' env['MASTER_PORT'] = '12345' p = multiprocessing.Process(target=fn, args=(env, )) processes.append(p) p.start() for p in processes: p.join() for p in processes: assert p.exitcode == 0 def worker_fn_wrapper(fn): # `multiprocessing.Process` cannot accept environment variables directly # so we need to pass the environment variables as arguments # and update the environment variables in the function def wrapped_fn(env): update_environment_variables(env) local_rank = os.environ['LOCAL_RANK'] device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) init_distributed_environment() # Ensure each worker process has the same random seed random.seed(42) torch.manual_seed(42) fn() return wrapped_fn def create_expert_indices_with_redundancy( num_layers: int, num_logical_experts: int, total_physical_experts: int, redundancy_config: list[int], # redundancy for each logical expert ) -> torch.Tensor: """ Create expert indices with redundancy. Args: num_layers: number of layers num_logical_experts: number of logical experts total_physical_experts: total number of physical experts redundancy_config: redundancy for each logical expert Returns: indices: Shape (num_layers, total_physical_experts) """ assert sum(redundancy_config) == total_physical_experts assert len(redundancy_config) == num_logical_experts indices = torch.zeros(num_layers, total_physical_experts, dtype=torch.long) for layer in range(num_layers): physical_pos = 0 for logical_expert_id, redundancy in enumerate(redundancy_config): for _ in range(redundancy): indices[layer, physical_pos] = logical_expert_id physical_pos += 1 # Shuffle the indices at dim 1 for layer in range(num_layers): indices[layer] = indices[layer][torch.randperm(indices.shape[1])] return indices def create_expert_weights( num_layers: int, num_local_experts: int, hidden_sizes: list[int], rank: int, device: torch.device, physical_to_logical_mapping: torch.Tensor, ) -> list[list[torch.Tensor]]: """ Create fake expert weights tensor for testing. Use `arange` to generate predictable weights values, based on logical expert ID. All replicas of the same logical expert should have the same weights. Args: physical_to_logical_mapping: Shape (num_layers, num_local_experts) mapping[layer, physical_pos] = logical_expert_id """ expert_weights = [] for layer in range(num_layers): layer_weights = [] for weight_idx, hidden_size in enumerate(hidden_sizes): weight_tensor = torch.zeros(num_local_experts, hidden_size, device=device, dtype=torch.float32) for local_expert in range(num_local_experts): # Get the logical expert ID for this physical expert global_pos = rank * num_local_experts + local_expert logical_expert_id = physical_to_logical_mapping[ layer, global_pos].item() # Generate weights based on logical expert ID # (so that all replicas of the same logical expert have the # same weights) base_value = (logical_expert_id * 1000 + layer * 100 + weight_idx * 10) weight_tensor[local_expert] = torch.arange(base_value, base_value + hidden_size, device=device, dtype=torch.float32) layer_weights.append(weight_tensor) expert_weights.append(layer_weights) return expert_weights def create_redundancy_config( num_logical_experts: int, num_physical_experts: int, ) -> list[int]: """Create a redundancy configuration.""" redundancy_config = [1] * num_logical_experts remaining = num_physical_experts - num_logical_experts # Randomly assign the remaining physical experts to the logical experts for _ in range(remaining): redundancy_config[random.choice(range(num_logical_experts))] += 1 return redundancy_config def verify_expert_weights_after_shuffle( expert_weights: list[list[torch.Tensor]], new_indices: torch.Tensor, hidden_sizes: list[int], ep_rank: int, num_local_experts: int, ): """Verify the weights after shuffling are correct.""" num_layers = len(expert_weights) for layer in range(num_layers): for weight_idx, hidden_size in enumerate(hidden_sizes): weight_tensor = expert_weights[layer][weight_idx] for local_expert in range(num_local_experts): # Calculate the global expert ID for this local expert global_pos = ep_rank * num_local_experts + local_expert expected_logical_expert = new_indices[layer, global_pos].item() # Check if the weights are correct actual_weights = weight_tensor[local_expert] expected_base = (expected_logical_expert * 1000 + layer * 100 + weight_idx * 10) expected_weights = torch.arange(expected_base, expected_base + hidden_size, device=actual_weights.device, dtype=actual_weights.dtype) torch.testing.assert_close( actual_weights, expected_weights, msg=f"Layer {layer}, weight {weight_idx}," f"local expert {local_expert}: " f"weights do not match. " f"Expected logical expert {expected_logical_expert}") def verify_redundant_experts_have_same_weights( expert_weights: list[list[torch.Tensor]], indices: torch.Tensor, hidden_sizes: list[int], world_size: int, num_local_experts: int, ): """ Verify that all replicas of the same logical expert have the same weights. """ num_layers = len(expert_weights) total_physical_experts = world_size * num_local_experts for layer in range(num_layers): # Collect weights for all physical experts for each weight matrix all_weights: list[torch.Tensor] = [] for weight_idx, hidden_size in enumerate(hidden_sizes): # Create tensor to store all expert weights # Shape: [total_physical_experts, hidden_size] gathered_weights = torch.zeros( total_physical_experts, hidden_size, device=expert_weights[layer][weight_idx].device, dtype=expert_weights[layer][weight_idx].dtype) # Use all_gather to collect expert weights from current node # expert_weights[layer][weight_idx] shape: # [num_local_experts, hidden_size] local_weights = expert_weights[layer][ weight_idx] # [num_local_experts, hidden_size] # Split tensor along dim 0 into a list for all_gather gathered_weights_list = torch.chunk(gathered_weights, world_size, dim=0) torch.distributed.all_gather( # Output list: each element corresponds to one rank's weights list(gathered_weights_list), local_weights # Input: current rank's local weights ) all_weights.append(gathered_weights) # Verify that all replicas of the same logical expert have the same # weights logical_expert_weights: dict[int, dict[int, torch.Tensor]] = {} for physical_pos in range(total_physical_experts): logical_expert_id = int(indices[layer, physical_pos].item()) if logical_expert_id not in logical_expert_weights: # First time encountering this logical expert, save its weights logical_expert_weights[logical_expert_id] = { weight_idx: all_weights[weight_idx][physical_pos] for weight_idx in range(len(hidden_sizes)) } else: # Verify that current physical expert's weights match the # previously saved logical expert weights for weight_idx in range(len(hidden_sizes)): torch.testing.assert_close( all_weights[weight_idx][physical_pos], logical_expert_weights[logical_expert_id][weight_idx], msg=f"Layer {layer}, weight {weight_idx}," f"logical expert {logical_expert_id}: " f"Physical expert {physical_pos} has different weights" f"than expected") @pytest.mark.parametrize( "world_size,num_layers,num_local_experts,num_logical_experts", [ # 2 GPU, 2 experts per GPU # 3 logical experts, 4 physical experts, 1 redundant experts (2, 1, 2, 3), # 2 GPU, 3 experts per GPU # 4 logical experts, 6 physical experts, 2 redundant experts (2, 2, 3, 4), # 2 GPU, 8 experts per GPU # 16 logical experts, 16 physical experts, 0 redundant experts (2, 4, 8, 16), # 4 GPU, 2 experts per GPU # 6 logical experts, 8 physical experts, 2 redundant experts (4, 1, 2, 6), # 4 GPU, 2 experts per GPU # 5 logical experts, 8 physical experts, 3 redundant experts (4, 2, 2, 5), # 4 GPU, 8 experts per GPU # 16 logical experts, 32 physical experts, 16 redundant experts (4, 8, 8, 16), ]) def test_rearrange_expert_weights_with_redundancy(world_size, num_layers, num_local_experts, num_logical_experts): """Test the functionality of rearranging expert weights with redundancy.""" if torch.cuda.device_count() < world_size: pytest.skip(f"Need at least {world_size} GPUs to run the test") @worker_fn_wrapper def worker_fn(): # Initialize model parallel (using tensor parallel as an entrypoint # to expert parallel) ensure_model_parallel_initialized( tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1) ep_group = get_tp_group().cpu_group ep_rank = torch.distributed.get_rank() device = torch.device(f"cuda:{ep_rank}") # Test parameters total_physical_experts = world_size * num_local_experts hidden_sizes = [32, 64] # Two different weight matrices # Create old expert indices (with redundancy) redundancy_config = create_redundancy_config(num_logical_experts, total_physical_experts) old_indices = create_expert_indices_with_redundancy( num_layers, num_logical_experts, total_physical_experts, redundancy_config, ) # Create new expert indices (with redundancy) new_redundancy_config = create_redundancy_config( num_logical_experts, total_physical_experts) new_indices = create_expert_indices_with_redundancy( num_layers, num_logical_experts, total_physical_experts, new_redundancy_config, ) # Create expert weights expert_weights = create_expert_weights(num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices) # Execute weight rearrangement rearrange_expert_weights_inplace( old_indices, new_indices, expert_weights, ep_group, is_profile=False, ) # Verify the rearrangement result verify_expert_weights_after_shuffle( expert_weights, new_indices, hidden_sizes, ep_rank, num_local_experts, ) verify_redundant_experts_have_same_weights( expert_weights, new_indices, hidden_sizes, world_size, num_local_experts, ) distributed_run(worker_fn, world_size) @pytest.mark.parametrize("world_size", [2, 4]) def test_rearrange_expert_weights_no_change(world_size): """ Test that when the indices do not change, the weights should remain unchanged. """ if torch.cuda.device_count() < world_size: pytest.skip(f"Need at least {world_size} GPUs to run the test") @worker_fn_wrapper def worker_fn(): ensure_model_parallel_initialized( tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1) ep_group = get_tp_group().cpu_group ep_rank = torch.distributed.get_rank() device = torch.device(f"cuda:{ep_rank}") num_layers = 2 num_local_experts = 2 total_physical_experts = world_size * num_local_experts num_logical_experts = total_physical_experts // 2 # Some redundancy hidden_sizes = [32, 64] # Create redundancy configuration redundancy_config = [2] * num_logical_experts # Same indices - no change indices = create_expert_indices_with_redundancy( num_layers, num_logical_experts, total_physical_experts, redundancy_config) expert_weights = create_expert_weights(num_layers, num_local_experts, hidden_sizes, ep_rank, device, indices) # Save original weights original_weights = [] for layer_weights in expert_weights: layer_copy = [] for weight in layer_weights: layer_copy.append(weight.clone()) original_weights.append(layer_copy) # Execute rearrangement (should be no change) rearrange_expert_weights_inplace( indices, indices, # Same indices expert_weights, ep_group, is_profile=False) # Verify that the weights have not changed for layer in range(num_layers): for weight_idx in range(len(hidden_sizes)): torch.testing.assert_close( expert_weights[layer][weight_idx], original_weights[layer][weight_idx], msg=f"Layer {layer}, weight {weight_idx} should remain " f"unchanged") distributed_run(worker_fn, world_size) @pytest.mark.parametrize("world_size", [2, 4]) def test_rearrange_expert_weights_profile_mode(world_size): """Test profile mode (should not copy actual weights)""" if torch.cuda.device_count() < world_size: pytest.skip(f"Need at least {world_size} GPUs to run the test") @worker_fn_wrapper def worker_fn(): ensure_model_parallel_initialized( tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1) ep_group = get_tp_group().cpu_group ep_rank = torch.distributed.get_rank() device = torch.device(f"cuda:{ep_rank}") num_layers = 1 num_local_experts = 2 total_physical_experts = world_size * num_local_experts num_logical_experts = total_physical_experts // 2 hidden_sizes = [32] # Create different index distributions old_redundancy = create_redundancy_config(num_logical_experts, total_physical_experts) new_redundancy = create_redundancy_config(num_logical_experts, total_physical_experts) old_indices = create_expert_indices_with_redundancy( num_layers, num_logical_experts, total_physical_experts, old_redundancy) new_indices = create_expert_indices_with_redundancy( num_layers, num_logical_experts, total_physical_experts, new_redundancy) expert_weights = create_expert_weights(num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices) # Save original weights original_weights = [] for layer_weights in expert_weights: layer_copy = [] for weight in layer_weights: layer_copy.append(weight.clone()) original_weights.append(layer_copy) # Execute profile mode rearrangement rearrange_expert_weights_inplace( old_indices, new_indices, expert_weights, ep_group, is_profile=True # Profile mode ) # In profile mode, the weights should remain unchanged for layer in range(num_layers): for weight_idx in range(len(hidden_sizes)): torch.testing.assert_close( expert_weights[layer][weight_idx], original_weights[layer][weight_idx], msg="In profile mode, the weights should remain unchanged") distributed_run(worker_fn, world_size)