mirror of https://github.com/vllm-project/vllm.git
191 lines
5.6 KiB
Python
191 lines
5.6 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
"""
|
|
DeepEP test utilities
|
|
"""
|
|
import dataclasses
|
|
import importlib
|
|
import os
|
|
import traceback
|
|
from typing import Callable, Optional
|
|
|
|
import torch
|
|
from torch.distributed import ProcessGroup
|
|
from torch.multiprocessing import (
|
|
spawn) # pyright: ignore[reportPrivateImportUsage]
|
|
from typing_extensions import Concatenate, ParamSpec
|
|
|
|
from vllm.utils import get_open_port
|
|
|
|
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
|
|
if has_deep_ep:
|
|
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
|
DeepEPHTPrepareAndFinalize)
|
|
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
|
DeepEPLLPrepareAndFinalize)
|
|
|
|
## Parallel Processes Utils
|
|
|
|
P = ParamSpec("P")
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ProcessGroupInfo:
|
|
world_size: int
|
|
world_local_size: int
|
|
rank: int
|
|
node_rank: int
|
|
local_rank: int
|
|
device: torch.device
|
|
|
|
|
|
def _worker_parallel_launch(
|
|
local_rank: int,
|
|
world_size: int,
|
|
world_local_size: int,
|
|
node_rank: int,
|
|
init_method: str,
|
|
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
|
*args: P.args,
|
|
**kwargs: P.kwargs,
|
|
) -> None:
|
|
rank = node_rank * world_local_size + local_rank
|
|
torch.cuda.set_device(local_rank)
|
|
device = torch.device("cuda", local_rank)
|
|
torch.distributed.init_process_group(
|
|
backend="cpu:gloo,cuda:nccl",
|
|
init_method=init_method,
|
|
rank=rank,
|
|
world_size=world_size,
|
|
device_id=device,
|
|
)
|
|
barrier = torch.tensor([rank], device=device)
|
|
torch.distributed.all_reduce(barrier)
|
|
|
|
try:
|
|
worker(
|
|
ProcessGroupInfo(
|
|
world_size=world_size,
|
|
world_local_size=world_local_size,
|
|
rank=rank,
|
|
node_rank=node_rank,
|
|
local_rank=local_rank,
|
|
device=device,
|
|
),
|
|
*args,
|
|
**kwargs,
|
|
)
|
|
except Exception as ex:
|
|
print(ex)
|
|
traceback.print_exc()
|
|
raise
|
|
finally:
|
|
torch.distributed.destroy_process_group()
|
|
|
|
|
|
def parallel_launch(
|
|
world_size: int,
|
|
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
|
*args: P.args,
|
|
**kwargs: P.kwargs,
|
|
) -> None:
|
|
assert not kwargs
|
|
spawn(
|
|
_worker_parallel_launch,
|
|
args=(
|
|
world_size,
|
|
world_size,
|
|
0,
|
|
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
|
|
worker,
|
|
) + args,
|
|
nprocs=world_size,
|
|
join=True,
|
|
)
|
|
|
|
|
|
## DeepEP specific utils
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class DeepEPHTArgs:
|
|
num_local_experts: int
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class DeepEPLLArgs:
|
|
max_tokens_per_rank: int
|
|
hidden_size: int
|
|
num_experts: int
|
|
use_fp8_dispatch: bool
|
|
|
|
|
|
def make_deepep_ht_a2a(pg: ProcessGroup,
|
|
pgi: ProcessGroupInfo,
|
|
dp_size: int,
|
|
ht_args: DeepEPHTArgs,
|
|
q_dtype: Optional[torch.dtype] = None,
|
|
block_shape: Optional[list[int]] = None):
|
|
|
|
import deep_ep
|
|
|
|
# high throughput a2a
|
|
num_nvl_bytes = 1024 * 1024 * 1024 # 1GB
|
|
num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1
|
|
buffer = deep_ep.Buffer(group=pg,
|
|
num_nvl_bytes=num_nvl_bytes,
|
|
num_rdma_bytes=num_rdma_bytes,
|
|
low_latency_mode=low_latency_mode,
|
|
num_qps_per_rank=num_qps_per_rank)
|
|
return DeepEPHTPrepareAndFinalize(buffer=buffer,
|
|
world_size=pgi.world_size,
|
|
rank=pgi.rank,
|
|
dp_size=dp_size,
|
|
rank_expert_offset=pgi.rank *
|
|
ht_args.num_local_experts)
|
|
|
|
|
|
def make_deepep_ll_a2a(pg: ProcessGroup,
|
|
pgi: ProcessGroupInfo,
|
|
dp_size: int,
|
|
deepep_ll_args: DeepEPLLArgs,
|
|
q_dtype: Optional[torch.dtype] = None,
|
|
block_shape: Optional[list[int]] = None):
|
|
|
|
import deep_ep
|
|
|
|
# low-latency a2a
|
|
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
|
deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size,
|
|
pgi.world_size, deepep_ll_args.num_experts)
|
|
|
|
buffer = deep_ep.Buffer(group=pg,
|
|
num_rdma_bytes=num_rdma_bytes,
|
|
low_latency_mode=True,
|
|
num_qps_per_rank=deepep_ll_args.num_experts //
|
|
pgi.world_size)
|
|
|
|
return DeepEPLLPrepareAndFinalize(
|
|
buffer=buffer,
|
|
world_size=pgi.world_size,
|
|
dp_size=dp_size,
|
|
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
|
|
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
|
|
)
|
|
|
|
|
|
def make_deepep_a2a(pg: ProcessGroup,
|
|
pgi: ProcessGroupInfo,
|
|
dp_size: int,
|
|
deepep_ht_args: Optional[DeepEPHTArgs],
|
|
deepep_ll_args: Optional[DeepEPLLArgs],
|
|
q_dtype: Optional[torch.dtype] = None,
|
|
block_shape: Optional[list[int]] = None):
|
|
if deepep_ht_args is not None:
|
|
assert deepep_ll_args is None
|
|
return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype,
|
|
block_shape)
|
|
|
|
assert deepep_ll_args is not None
|
|
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype,
|
|
block_shape)
|