# SPDX-License-Identifier: Apache-2.0 """ Test deepep dispatch-combine logic """ import dataclasses import importlib from typing import Optional, Union import pytest import torch.distributed from torch.distributed import ProcessGroup from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import TritonExperts from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedTritonExperts) from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) from vllm.platforms import current_platform from .deepep_utils import ProcessGroupInfo, parallel_launch has_deep_ep = importlib.util.find_spec("deep_ep") is not None if has_deep_ep: from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 DeepEPHTPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 DeepEPLLPrepareAndFinalize) from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a requires_deep_ep = pytest.mark.skipif( not has_deep_ep, reason="Requires deep_ep kernels", ) MAX_TOKENS_PER_RANK = 64 def make_weights( e, n, k, dtype ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Return weights w1, w2, w1_scale, w2_scale """ if dtype in [torch.float16, torch.bfloat16]: w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 return w1, w2, None, None # per-out-channel weight quantization assert dtype == torch.float8_e4m3fn w1 = torch.empty((e, 2 * n, k), device="cuda", dtype=torch.float16) w2 = torch.empty((e, k, n), device="cuda", dtype=torch.float16) n_b_scales = 2 * n k_b_scales = k w1_q = torch.empty_like(w1, dtype=dtype) w2_q = torch.empty_like(w2, dtype=dtype) w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32) w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32) for expert in range(e): w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( w1[expert], use_per_token_if_dynamic=True) w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( w2[expert], use_per_token_if_dynamic=True) return w1_q, w2_q, w1_scale, w2_scale @dataclasses.dataclass class TestConfig: dtype: torch.dtype topk: int m: int k: int n: int num_experts: int @dataclasses.dataclass class TestTensors: rank_tokens: torch.Tensor # all ranks make this many tokens rank_token_scales: Optional[torch.Tensor] topk: torch.Tensor topk_weights: torch.Tensor config: TestConfig @staticmethod def make(config: TestConfig, low_latency_mode: bool) -> "TestTensors": # TODO (varun) - check that float16 works ? assert config.dtype in [torch.bfloat16, torch.float8_e4m3fn] token_dtype = (torch.bfloat16 if config.dtype == torch.float8_e4m3fn else config.dtype) rank_tokens = torch.randn( (config.m, config.k), device="cuda", dtype=token_dtype) / 10 rank_token_scales = None if config.dtype == torch.float8_e4m3fn: # low_latency_mode kernels dont support per-token quant. _, rank_token_scales = ops.scaled_fp8_quant( rank_tokens, use_per_token_if_dynamic=not low_latency_mode) topk = torch.randint(low=0, high=config.num_experts, size=(config.m, config.topk), device="cuda").to(dtype=torch.int64) topk_weights = torch.randn(topk.shape, dtype=torch.float32, device="cuda") return TestTensors(rank_tokens=rank_tokens, rank_token_scales=rank_token_scales, topk=topk, topk_weights=topk_weights, config=config) def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, low_latency_mode: bool, hidden_size: int, dp_size: int, num_experts: int, num_local_experts: int, q_dtype: Optional[torch.dtype], use_fp8_dispatch: bool) -> FusedMoEModularKernel: is_quantized = q_dtype is not None ht_args: Optional[DeepEPHTArgs] = None ll_args: Optional[DeepEPLLArgs] = None if low_latency_mode: ll_args = DeepEPLLArgs(max_tokens_per_rank=MAX_TOKENS_PER_RANK, hidden_size=hidden_size, num_experts=num_experts, use_fp8_dispatch=use_fp8_dispatch) else: assert not use_fp8_dispatch, ( "FP8 Dispatch is valid only for low-latency kernels") ht_args = DeepEPHTArgs(num_local_experts=num_local_experts) a2a : Union[DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize] = \ make_deepep_a2a(pg = pg, pgi = pgi, dp_size = dp_size, q_dtype = q_dtype, block_shape = None, deepep_ht_args = ht_args, deepep_ll_args = ll_args) if low_latency_mode: fused_experts = BatchedTritonExperts( max_num_tokens=MAX_TOKENS_PER_RANK, world_size=pgi.world_size, dp_size=dp_size, use_fp8_w8a8=is_quantized, use_int8_w8a8=False, use_int8_w8a16=False, use_int4_w4a16=False) else: fused_experts = TritonExperts(use_fp8_w8a8=is_quantized, use_int8_w8a8=False, use_int8_w8a16=False, use_int4_w4a16=False, per_channel_quant=False) mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts) return mk def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, low_latency_mode: bool, dp_size: int, test_tensors: TestTensors, w1: torch.Tensor, w2: torch.Tensor, w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], num_experts: int, use_fp8_dispatch: bool) -> torch.Tensor: num_local_experts = w1.size(0) def build_expert_map(): num_local_experts = w1.size(0) expert_map = torch.full((num_experts, ), fill_value=-1, dtype=torch.int32) s = pgi.rank * num_local_experts e = s + num_local_experts expert_map[s:e] = torch.tensor(list(range(num_local_experts))) return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32) hidden_size = test_tensors.rank_tokens.size(1) is_quantized = w1.dtype == torch.float8_e4m3fn q_dtype = None if is_quantized: q_dtype = torch.float8_e4m3fn # Make modular kernel mk: FusedMoEModularKernel = make_modular_kernel(pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts, num_local_experts, q_dtype, use_fp8_dispatch) out_hidden_states = torch.empty_like(test_tensors.rank_tokens) total_num_tokens = test_tensors.rank_tokens.size(0) def process_chunk(chunk_start, chunk_end, skip_result_store=False): rank_tokens_chunk = test_tensors.rank_tokens[chunk_start:chunk_end] topk_weights_chunk = test_tensors.topk_weights[chunk_start:chunk_end] topk_chunk = test_tensors.topk[chunk_start:chunk_end] rank_token_scales_chunk = test_tensors.rank_token_scales if rank_token_scales_chunk is not None and rank_token_scales_chunk.size( 0) == total_num_tokens: # per act token rank_token_scales_chunk = rank_token_scales_chunk[ chunk_start:chunk_end] out = mk.forward(hidden_states=rank_tokens_chunk, w1=w1, w2=w2, topk_weights=topk_weights_chunk, topk_ids=topk_chunk, inplace=False, activation="silu", global_num_experts=num_experts, expert_map=build_expert_map(), w1_scale=w1_scale, w2_scale=w2_scale, w1_zp=None, w2_zp=None, a1_scale=rank_token_scales_chunk, a2_scale=None, apply_router_weight_on_input=False) if not skip_result_store: out_hidden_states[chunk_start:chunk_end, :].copy_( out, non_blocking=True) max_num_tokens_per_dp = (MAX_TOKENS_PER_RANK if low_latency_mode else total_num_tokens) for chunk_start_ in range(0, total_num_tokens, max_num_tokens_per_dp): chunk_start = chunk_start_ chunk_end = min(chunk_start + max_num_tokens_per_dp, total_num_tokens) # clamp start and end chunk_start = min(chunk_start, total_num_tokens - 1) chunk_end = min(chunk_end, total_num_tokens) process_chunk(chunk_start, chunk_end, skip_result_store=chunk_start_ >= total_num_tokens) return out_hidden_states def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor, w2: torch.Tensor, w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], using_fp8_dispatch: bool): a, topk_ids, topk_weights = (test_tensors.rank_tokens, test_tensors.topk, test_tensors.topk_weights) if using_fp8_dispatch: # The DeepEP implementation is requested to dispatch using FP8. # For numerical stability for testing, emulate the fp8 dispatch by # blockwise quant and de-quant. a = test_tensors.rank_tokens aq, aq_scale = per_token_group_quant_fp8(a, 128) a = (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)).view( a.shape).to(a.dtype) is_quantized = w1.dtype == torch.float8_e4m3fn a_dtype = a.dtype if is_quantized: w1 = w1.to(dtype=torch.float32) * w1_scale w2 = w2.to(dtype=torch.float32) * w2_scale a = a.to(dtype=torch.float32) m, _ = a.shape topk = topk_ids.size(1) out = torch.zeros_like(a) for i in range(m): a_i = a[i] o_i = out[i] for j in range(topk): e = topk_ids[i][j] e_w = topk_weights[i][j] w1_e = w1[e] w2_e = w2[e] o_i += (SiluAndMul() (a_i @ w1_e.transpose(0, 1)) @ w2_e.transpose(0, 1)) * e_w if is_quantized: out = out.to(dtype=a_dtype) return out def _deep_ep_moe( pgi: ProcessGroupInfo, low_latency_mode: bool, dp_size: int, config: TestConfig, w1: torch.Tensor, w2: torch.Tensor, w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], use_fp8_dispatch: bool, ): if not low_latency_mode: assert not use_fp8_dispatch, ( "FP8 dispatch interface is available only in low-latency mode") is_quantized = w1.dtype == torch.float8_e4m3fn w1 = w1.to(device=torch.cuda.current_device()) w2 = w2.to(device=torch.cuda.current_device()) if is_quantized: w1_scale = w1_scale.to( # type: ignore device=torch.cuda.current_device()) w2_scale = w2_scale.to( # type: ignore device=torch.cuda.current_device()) pg = torch.distributed.new_group(list(range(pgi.world_size))) test_tensors = TestTensors.make(config, low_latency_mode) with set_current_vllm_config(VllmConfig()): # Reference torch_combined = torch_moe_impl(test_tensors, w1, w2, w1_scale, w2_scale, use_fp8_dispatch) # Splice experts for this rank. num_local_experts = config.num_experts // pgi.world_size e_start = num_local_experts * pgi.rank e_end = e_start + num_local_experts w1_ep = w1[e_start:e_end] w2_ep = w2[e_start:e_end] w1_scale_ep, w2_scale_ep = None, None if is_quantized: w1_scale_ep = w1_scale[e_start:e_end] # type: ignore w2_scale_ep = w2_scale[e_start:e_end] # type: ignore deepep_combined = deep_ep_moe_impl( pg, pgi, low_latency_mode, dp_size, test_tensors, w1_ep, w2_ep, w1_scale_ep, w2_scale_ep, config.num_experts, use_fp8_dispatch, ) torch.testing.assert_close( torch_combined, deepep_combined, atol=6e-2, rtol=6e-2, ) MNKs = [ (1, 128, 128), (2, 128, 512), (3, 1024, 2048), (32, 128, 1024), (45, 512, 2048), (64, 1024, 1024), (222, 1024, 2048), ] DTYPES = [torch.bfloat16, torch.float8_e4m3fn] @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("mnk", MNKs) @pytest.mark.parametrize("num_experts", [32]) @pytest.mark.parametrize("topk", [6]) @pytest.mark.parametrize("world_dp_size", [(2, 1)]) @requires_deep_ep def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int], num_experts: int, topk: int, world_dp_size: tuple[int, int]): low_latency_mode = False use_fp8_dispatch = False m, n, k = mnk current_platform.seed_everything(7) world_size, dp_size = world_dp_size config = TestConfig(dtype=dtype, topk=topk, m=m, k=k, n=n, num_experts=num_experts) w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype) parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size, config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch) MNKs = [ (1, 128, 2560), (2, 128, 2560), (3, 1024, 2560), (32, 128, 2560), (45, 512, 2560), (64, 1024, 2560), (222, 1024, 2560), ] DTYPES = [torch.float8_e4m3fn, torch.bfloat16] USE_FP8_DISPATCH = [True, False] @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("mnk", MNKs) @pytest.mark.parametrize("num_experts", [32]) @pytest.mark.parametrize("topk", [6]) @pytest.mark.parametrize("world_dp_size", [(2, 1)]) @pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH) @requires_deep_ep def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int], num_experts: int, topk: int, world_dp_size: tuple[int, int], use_fp8_dispatch: bool): low_latency_mode = True m, n, k = mnk if (low_latency_mode and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES): pytest.skip( f"Skipping test as hidden size {k} is not in list of supported " f"hidden sizes {DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES}" ) current_platform.seed_everything(7) world_size, dp_size = world_dp_size config = TestConfig(dtype=dtype, topk=topk, m=m, k=k, n=n, num_experts=num_experts) w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype) parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size, config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch)