mirror of https://github.com/vllm-project/vllm.git
505 lines
18 KiB
Python
505 lines
18 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import multiprocessing
|
|
import os
|
|
import random
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.distributed
|
|
|
|
from vllm.distributed.eplb.rebalance_execute import (
|
|
rearrange_expert_weights_inplace)
|
|
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
|
get_tp_group,
|
|
init_distributed_environment)
|
|
from vllm.utils import update_environment_variables
|
|
|
|
|
|
def distributed_run(fn, world_size):
|
|
number_of_processes = world_size
|
|
processes: list[multiprocessing.Process] = []
|
|
for i in range(number_of_processes):
|
|
env: dict[str, str] = {}
|
|
env['RANK'] = str(i)
|
|
env['LOCAL_RANK'] = str(i)
|
|
env['WORLD_SIZE'] = str(number_of_processes)
|
|
env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
|
|
env['MASTER_ADDR'] = 'localhost'
|
|
env['MASTER_PORT'] = '12345'
|
|
p = multiprocessing.Process(target=fn, args=(env, ))
|
|
processes.append(p)
|
|
p.start()
|
|
|
|
for p in processes:
|
|
p.join()
|
|
|
|
for p in processes:
|
|
assert p.exitcode == 0
|
|
|
|
|
|
def worker_fn_wrapper(fn):
|
|
# `multiprocessing.Process` cannot accept environment variables directly
|
|
# so we need to pass the environment variables as arguments
|
|
# and update the environment variables in the function
|
|
def wrapped_fn(env):
|
|
update_environment_variables(env)
|
|
local_rank = os.environ['LOCAL_RANK']
|
|
device = torch.device(f"cuda:{local_rank}")
|
|
torch.cuda.set_device(device)
|
|
init_distributed_environment()
|
|
|
|
# Ensure each worker process has the same random seed
|
|
random.seed(42)
|
|
torch.manual_seed(42)
|
|
|
|
fn()
|
|
|
|
return wrapped_fn
|
|
|
|
|
|
def create_expert_indices_with_redundancy(
|
|
num_layers: int,
|
|
num_logical_experts: int,
|
|
total_physical_experts: int,
|
|
redundancy_config: list[int], # redundancy for each logical expert
|
|
) -> torch.Tensor:
|
|
"""
|
|
Create expert indices with redundancy.
|
|
|
|
Args:
|
|
num_layers: number of layers
|
|
num_logical_experts: number of logical experts
|
|
total_physical_experts: total number of physical experts
|
|
redundancy_config: redundancy for each logical expert
|
|
|
|
Returns:
|
|
indices: Shape (num_layers, total_physical_experts)
|
|
"""
|
|
assert sum(redundancy_config) == total_physical_experts
|
|
assert len(redundancy_config) == num_logical_experts
|
|
|
|
indices = torch.zeros(num_layers, total_physical_experts, dtype=torch.long)
|
|
|
|
for layer in range(num_layers):
|
|
physical_pos = 0
|
|
for logical_expert_id, redundancy in enumerate(redundancy_config):
|
|
for _ in range(redundancy):
|
|
indices[layer, physical_pos] = logical_expert_id
|
|
physical_pos += 1
|
|
|
|
# Shuffle the indices at dim 1
|
|
for layer in range(num_layers):
|
|
indices[layer] = indices[layer][torch.randperm(indices.shape[1])]
|
|
|
|
return indices
|
|
|
|
|
|
def create_expert_weights(
|
|
num_layers: int,
|
|
num_local_experts: int,
|
|
hidden_sizes: list[int],
|
|
rank: int,
|
|
device: torch.device,
|
|
physical_to_logical_mapping: torch.Tensor,
|
|
) -> list[list[torch.Tensor]]:
|
|
"""
|
|
Create fake expert weights tensor for testing.
|
|
|
|
Use `arange` to generate predictable weights values, based on logical
|
|
expert ID.
|
|
All replicas of the same logical expert should have the same weights.
|
|
|
|
Args:
|
|
physical_to_logical_mapping: Shape (num_layers, num_local_experts)
|
|
mapping[layer, physical_pos] = logical_expert_id
|
|
"""
|
|
expert_weights = []
|
|
|
|
for layer in range(num_layers):
|
|
layer_weights = []
|
|
for weight_idx, hidden_size in enumerate(hidden_sizes):
|
|
weight_tensor = torch.zeros(num_local_experts,
|
|
hidden_size,
|
|
device=device,
|
|
dtype=torch.float32)
|
|
|
|
for local_expert in range(num_local_experts):
|
|
# Get the logical expert ID for this physical expert
|
|
global_pos = rank * num_local_experts + local_expert
|
|
logical_expert_id = physical_to_logical_mapping[
|
|
layer, global_pos].item()
|
|
|
|
# Generate weights based on logical expert ID
|
|
# (so that all replicas of the same logical expert have the
|
|
# same weights)
|
|
base_value = (logical_expert_id * 1000 + layer * 100 +
|
|
weight_idx * 10)
|
|
weight_tensor[local_expert] = torch.arange(base_value,
|
|
base_value +
|
|
hidden_size,
|
|
device=device,
|
|
dtype=torch.float32)
|
|
|
|
layer_weights.append(weight_tensor)
|
|
expert_weights.append(layer_weights)
|
|
|
|
return expert_weights
|
|
|
|
|
|
def create_redundancy_config(
|
|
num_logical_experts: int,
|
|
num_physical_experts: int,
|
|
) -> list[int]:
|
|
"""Create a redundancy configuration."""
|
|
redundancy_config = [1] * num_logical_experts
|
|
remaining = num_physical_experts - num_logical_experts
|
|
# Randomly assign the remaining physical experts to the logical experts
|
|
for _ in range(remaining):
|
|
redundancy_config[random.choice(range(num_logical_experts))] += 1
|
|
return redundancy_config
|
|
|
|
|
|
def verify_expert_weights_after_shuffle(
|
|
expert_weights: list[list[torch.Tensor]],
|
|
new_indices: torch.Tensor,
|
|
hidden_sizes: list[int],
|
|
ep_rank: int,
|
|
num_local_experts: int,
|
|
):
|
|
"""Verify the weights after shuffling are correct."""
|
|
num_layers = len(expert_weights)
|
|
|
|
for layer in range(num_layers):
|
|
for weight_idx, hidden_size in enumerate(hidden_sizes):
|
|
weight_tensor = expert_weights[layer][weight_idx]
|
|
|
|
for local_expert in range(num_local_experts):
|
|
# Calculate the global expert ID for this local expert
|
|
global_pos = ep_rank * num_local_experts + local_expert
|
|
expected_logical_expert = new_indices[layer, global_pos].item()
|
|
|
|
# Check if the weights are correct
|
|
actual_weights = weight_tensor[local_expert]
|
|
expected_base = (expected_logical_expert * 1000 + layer * 100 +
|
|
weight_idx * 10)
|
|
expected_weights = torch.arange(expected_base,
|
|
expected_base + hidden_size,
|
|
device=actual_weights.device,
|
|
dtype=actual_weights.dtype)
|
|
|
|
torch.testing.assert_close(
|
|
actual_weights,
|
|
expected_weights,
|
|
msg=f"Layer {layer}, weight {weight_idx},"
|
|
f"local expert {local_expert}: "
|
|
f"weights do not match. "
|
|
f"Expected logical expert {expected_logical_expert}")
|
|
|
|
|
|
def verify_redundant_experts_have_same_weights(
|
|
expert_weights: list[list[torch.Tensor]],
|
|
indices: torch.Tensor,
|
|
hidden_sizes: list[int],
|
|
world_size: int,
|
|
num_local_experts: int,
|
|
):
|
|
"""
|
|
Verify that all replicas of the same logical expert have the same weights.
|
|
"""
|
|
num_layers = len(expert_weights)
|
|
total_physical_experts = world_size * num_local_experts
|
|
|
|
for layer in range(num_layers):
|
|
# Collect weights for all physical experts for each weight matrix
|
|
all_weights: list[torch.Tensor] = []
|
|
|
|
for weight_idx, hidden_size in enumerate(hidden_sizes):
|
|
# Create tensor to store all expert weights
|
|
# Shape: [total_physical_experts, hidden_size]
|
|
gathered_weights = torch.zeros(
|
|
total_physical_experts,
|
|
hidden_size,
|
|
device=expert_weights[layer][weight_idx].device,
|
|
dtype=expert_weights[layer][weight_idx].dtype)
|
|
|
|
# Use all_gather to collect expert weights from current node
|
|
# expert_weights[layer][weight_idx] shape:
|
|
# [num_local_experts, hidden_size]
|
|
local_weights = expert_weights[layer][
|
|
weight_idx] # [num_local_experts, hidden_size]
|
|
|
|
# Split tensor along dim 0 into a list for all_gather
|
|
gathered_weights_list = torch.chunk(gathered_weights,
|
|
world_size,
|
|
dim=0)
|
|
|
|
torch.distributed.all_gather(
|
|
# Output list: each element corresponds to one rank's weights
|
|
list(gathered_weights_list),
|
|
local_weights # Input: current rank's local weights
|
|
)
|
|
|
|
all_weights.append(gathered_weights)
|
|
|
|
# Verify that all replicas of the same logical expert have the same
|
|
# weights
|
|
logical_expert_weights: dict[int, dict[int, torch.Tensor]] = {}
|
|
|
|
for physical_pos in range(total_physical_experts):
|
|
logical_expert_id = int(indices[layer, physical_pos].item())
|
|
|
|
if logical_expert_id not in logical_expert_weights:
|
|
# First time encountering this logical expert, save its weights
|
|
logical_expert_weights[logical_expert_id] = {
|
|
weight_idx: all_weights[weight_idx][physical_pos]
|
|
for weight_idx in range(len(hidden_sizes))
|
|
}
|
|
else:
|
|
# Verify that current physical expert's weights match the
|
|
# previously saved logical expert weights
|
|
for weight_idx in range(len(hidden_sizes)):
|
|
torch.testing.assert_close(
|
|
all_weights[weight_idx][physical_pos],
|
|
logical_expert_weights[logical_expert_id][weight_idx],
|
|
msg=f"Layer {layer}, weight {weight_idx},"
|
|
f"logical expert {logical_expert_id}: "
|
|
f"Physical expert {physical_pos} has different weights"
|
|
f"than expected")
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"world_size,num_layers,num_local_experts,num_logical_experts",
|
|
[
|
|
# 2 GPU, 2 experts per GPU
|
|
# 3 logical experts, 4 physical experts, 1 redundant experts
|
|
(2, 1, 2, 3),
|
|
# 2 GPU, 3 experts per GPU
|
|
# 4 logical experts, 6 physical experts, 2 redundant experts
|
|
(2, 2, 3, 4),
|
|
# 2 GPU, 8 experts per GPU
|
|
# 16 logical experts, 16 physical experts, 0 redundant experts
|
|
(2, 4, 8, 16),
|
|
# 4 GPU, 2 experts per GPU
|
|
# 6 logical experts, 8 physical experts, 2 redundant experts
|
|
(4, 1, 2, 6),
|
|
# 4 GPU, 2 experts per GPU
|
|
# 5 logical experts, 8 physical experts, 3 redundant experts
|
|
(4, 2, 2, 5),
|
|
# 4 GPU, 8 experts per GPU
|
|
# 16 logical experts, 32 physical experts, 16 redundant experts
|
|
(4, 8, 8, 16),
|
|
])
|
|
def test_rearrange_expert_weights_with_redundancy(world_size, num_layers,
|
|
num_local_experts,
|
|
num_logical_experts):
|
|
"""Test the functionality of rearranging expert weights with redundancy."""
|
|
|
|
if torch.cuda.device_count() < world_size:
|
|
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
|
|
|
@worker_fn_wrapper
|
|
def worker_fn():
|
|
# Initialize model parallel (using tensor parallel as an entrypoint
|
|
# to expert parallel)
|
|
ensure_model_parallel_initialized(
|
|
tensor_model_parallel_size=world_size,
|
|
pipeline_model_parallel_size=1)
|
|
|
|
ep_group = get_tp_group().cpu_group
|
|
ep_rank = torch.distributed.get_rank()
|
|
device = torch.device(f"cuda:{ep_rank}")
|
|
|
|
# Test parameters
|
|
total_physical_experts = world_size * num_local_experts
|
|
hidden_sizes = [32, 64] # Two different weight matrices
|
|
|
|
# Create old expert indices (with redundancy)
|
|
redundancy_config = create_redundancy_config(num_logical_experts,
|
|
total_physical_experts)
|
|
|
|
old_indices = create_expert_indices_with_redundancy(
|
|
num_layers,
|
|
num_logical_experts,
|
|
total_physical_experts,
|
|
redundancy_config,
|
|
)
|
|
|
|
# Create new expert indices (with redundancy)
|
|
new_redundancy_config = create_redundancy_config(
|
|
num_logical_experts, total_physical_experts)
|
|
new_indices = create_expert_indices_with_redundancy(
|
|
num_layers,
|
|
num_logical_experts,
|
|
total_physical_experts,
|
|
new_redundancy_config,
|
|
)
|
|
|
|
# Create expert weights
|
|
expert_weights = create_expert_weights(num_layers, num_local_experts,
|
|
hidden_sizes, ep_rank, device,
|
|
old_indices)
|
|
|
|
# Execute weight rearrangement
|
|
rearrange_expert_weights_inplace(
|
|
old_indices,
|
|
new_indices,
|
|
expert_weights,
|
|
ep_group,
|
|
is_profile=False,
|
|
)
|
|
|
|
# Verify the rearrangement result
|
|
verify_expert_weights_after_shuffle(
|
|
expert_weights,
|
|
new_indices,
|
|
hidden_sizes,
|
|
ep_rank,
|
|
num_local_experts,
|
|
)
|
|
|
|
verify_redundant_experts_have_same_weights(
|
|
expert_weights,
|
|
new_indices,
|
|
hidden_sizes,
|
|
world_size,
|
|
num_local_experts,
|
|
)
|
|
|
|
distributed_run(worker_fn, world_size)
|
|
|
|
|
|
@pytest.mark.parametrize("world_size", [2, 4])
|
|
def test_rearrange_expert_weights_no_change(world_size):
|
|
"""
|
|
Test that when the indices do not change, the weights should remain
|
|
unchanged.
|
|
"""
|
|
|
|
if torch.cuda.device_count() < world_size:
|
|
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
|
|
|
@worker_fn_wrapper
|
|
def worker_fn():
|
|
ensure_model_parallel_initialized(
|
|
tensor_model_parallel_size=world_size,
|
|
pipeline_model_parallel_size=1)
|
|
|
|
ep_group = get_tp_group().cpu_group
|
|
ep_rank = torch.distributed.get_rank()
|
|
device = torch.device(f"cuda:{ep_rank}")
|
|
|
|
num_layers = 2
|
|
num_local_experts = 2
|
|
total_physical_experts = world_size * num_local_experts
|
|
num_logical_experts = total_physical_experts // 2 # Some redundancy
|
|
hidden_sizes = [32, 64]
|
|
|
|
# Create redundancy configuration
|
|
redundancy_config = [2] * num_logical_experts
|
|
|
|
# Same indices - no change
|
|
indices = create_expert_indices_with_redundancy(
|
|
num_layers, num_logical_experts, total_physical_experts,
|
|
redundancy_config)
|
|
|
|
expert_weights = create_expert_weights(num_layers, num_local_experts,
|
|
hidden_sizes, ep_rank, device,
|
|
indices)
|
|
|
|
# Save original weights
|
|
original_weights = []
|
|
for layer_weights in expert_weights:
|
|
layer_copy = []
|
|
for weight in layer_weights:
|
|
layer_copy.append(weight.clone())
|
|
original_weights.append(layer_copy)
|
|
|
|
# Execute rearrangement (should be no change)
|
|
rearrange_expert_weights_inplace(
|
|
indices,
|
|
indices, # Same indices
|
|
expert_weights,
|
|
ep_group,
|
|
is_profile=False)
|
|
|
|
# Verify that the weights have not changed
|
|
for layer in range(num_layers):
|
|
for weight_idx in range(len(hidden_sizes)):
|
|
torch.testing.assert_close(
|
|
expert_weights[layer][weight_idx],
|
|
original_weights[layer][weight_idx],
|
|
msg=f"Layer {layer}, weight {weight_idx} should remain "
|
|
f"unchanged")
|
|
|
|
distributed_run(worker_fn, world_size)
|
|
|
|
|
|
@pytest.mark.parametrize("world_size", [2, 4])
|
|
def test_rearrange_expert_weights_profile_mode(world_size):
|
|
"""Test profile mode (should not copy actual weights)"""
|
|
|
|
if torch.cuda.device_count() < world_size:
|
|
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
|
|
|
@worker_fn_wrapper
|
|
def worker_fn():
|
|
ensure_model_parallel_initialized(
|
|
tensor_model_parallel_size=world_size,
|
|
pipeline_model_parallel_size=1)
|
|
|
|
ep_group = get_tp_group().cpu_group
|
|
ep_rank = torch.distributed.get_rank()
|
|
device = torch.device(f"cuda:{ep_rank}")
|
|
|
|
num_layers = 1
|
|
num_local_experts = 2
|
|
total_physical_experts = world_size * num_local_experts
|
|
num_logical_experts = total_physical_experts // 2
|
|
hidden_sizes = [32]
|
|
|
|
# Create different index distributions
|
|
old_redundancy = create_redundancy_config(num_logical_experts,
|
|
total_physical_experts)
|
|
new_redundancy = create_redundancy_config(num_logical_experts,
|
|
total_physical_experts)
|
|
|
|
old_indices = create_expert_indices_with_redundancy(
|
|
num_layers, num_logical_experts, total_physical_experts,
|
|
old_redundancy)
|
|
new_indices = create_expert_indices_with_redundancy(
|
|
num_layers, num_logical_experts, total_physical_experts,
|
|
new_redundancy)
|
|
|
|
expert_weights = create_expert_weights(num_layers, num_local_experts,
|
|
hidden_sizes, ep_rank, device,
|
|
old_indices)
|
|
|
|
# Save original weights
|
|
original_weights = []
|
|
for layer_weights in expert_weights:
|
|
layer_copy = []
|
|
for weight in layer_weights:
|
|
layer_copy.append(weight.clone())
|
|
original_weights.append(layer_copy)
|
|
|
|
# Execute profile mode rearrangement
|
|
rearrange_expert_weights_inplace(
|
|
old_indices,
|
|
new_indices,
|
|
expert_weights,
|
|
ep_group,
|
|
is_profile=True # Profile mode
|
|
)
|
|
|
|
# In profile mode, the weights should remain unchanged
|
|
for layer in range(num_layers):
|
|
for weight_idx in range(len(hidden_sizes)):
|
|
torch.testing.assert_close(
|
|
expert_weights[layer][weight_idx],
|
|
original_weights[layer][weight_idx],
|
|
msg="In profile mode, the weights should remain unchanged")
|
|
|
|
distributed_run(worker_fn, world_size)
|