# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses from typing import Optional import pytest import torch from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8 from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, fused_topk) from vllm.platforms import current_platform NUM_EXPERTS = [40, 64] TOP_KS = [6, 8] MNK_FACTORS = [ (2, 1024, 1024), (2, 1024, 1536), (2, 3072, 1024), (2, 3072, 1536), (64, 1024, 1024), (64, 1024, 1536), (64, 3072, 1024), (64, 3072, 1536), (224, 1024, 1024), (224, 1024, 1536), (224, 3072, 1024), (224, 3072, 1536), (32768, 1024, 1024), # These sizes trigger wrong answers. #(7232, 2048, 5120), #(40000, 2048, 5120), ] vllm_config = VllmConfig(parallel_config=ParallelConfig( pipeline_parallel_size=1)) vllm_config.scheduler_config.max_num_seqs = 128 vllm_config.scheduler_config.max_model_len = 8192 @dataclasses.dataclass class MOETensors: a: torch.Tensor w1: torch.Tensor w2: torch.Tensor ab_strides1: torch.Tensor c_strides1: torch.Tensor ab_strides2: torch.Tensor c_strides2: torch.Tensor @staticmethod def make_moe_tensors(m: int, k: int, n: int, e: int, dtype: torch.dtype) -> "MOETensors": a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) return MOETensors(a=a, w1=w1, w2=w2, ab_strides1=ab_strides1, c_strides1=c_strides1, ab_strides2=ab_strides2, c_strides2=c_strides2) @dataclasses.dataclass class MOETensors8Bit(MOETensors): # quantized a_q: Optional[torch.Tensor] = None # a -> a_q w1_q: Optional[torch.Tensor] = None # w1 -> w1_q w2_q: Optional[torch.Tensor] = None # w2 -> w2_q a_scale: Optional[torch.Tensor] = None w1_scale: Optional[torch.Tensor] = None w2_scale: Optional[torch.Tensor] = None # dequantized a_d: Optional[torch.Tensor] = None # a -> a_q -> a_d w1_d: Optional[torch.Tensor] = None # w1 -> w1_q -> w1_d w2_d: Optional[torch.Tensor] = None # w2 -> w2_q -> w2_d @staticmethod def make_moe_tensors_8bit(m: int, k: int, n: int, e: int, per_act_token: bool, per_out_channel: bool) -> "MOETensors8Bit": dtype = torch.half q_dtype = torch.float8_e4m3fn moe_tensors_fp16 = MOETensors.make_moe_tensors(m, k, n, e, dtype) # a -> a_q, w1 -> w1_q, w2 -> w2_q n_b_scales = 2 * n if per_out_channel else 1 k_b_scales = k if per_out_channel else 1 # Get the right scale for tests. a_q, a_scale = ops.scaled_fp8_quant( moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token) w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype) w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype) w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32) w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32) for expert in range(e): w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( moe_tensors_fp16.w1[expert], use_per_token_if_dynamic=per_out_channel) w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( moe_tensors_fp16.w2[expert], use_per_token_if_dynamic=per_out_channel) # a_q -> a_d, w1_q -> w1_d, w2_q -> w2_d a_d = a_q.float().mul(a_scale).to(dtype) w1_d = torch.empty_like(moe_tensors_fp16.w1) w2_d = torch.empty_like(moe_tensors_fp16.w2) for expert in range(e): w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half() w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half() return MOETensors8Bit(a=moe_tensors_fp16.a, w1=moe_tensors_fp16.w1, w2=moe_tensors_fp16.w2, ab_strides1=moe_tensors_fp16.ab_strides1, c_strides1=moe_tensors_fp16.c_strides1, ab_strides2=moe_tensors_fp16.ab_strides2, c_strides2=moe_tensors_fp16.c_strides2, a_q=a_q, w1_q=w1_q, w2_q=w2_q, a_scale=a_scale, w1_scale=w1_scale, w2_scale=w2_scale, a_d=a_d, w1_d=w1_d, w2_d=w2_d) def run_with_expert_maps(num_experts: int, num_local_experts: int, **cutlass_moe_kwargs): def slice_experts(): slice_params = [ "w1_q", "w2_q", "ab_strides1", "ab_strides2", "c_strides1", "c_strides2", "w1_scale", "w2_scale" ] full_tensors = { k: v for k, v in cutlass_moe_kwargs.items() if k in slice_params and k in cutlass_moe_kwargs } for i in range(0, num_experts, num_local_experts): s, e = i, i + num_local_experts # make expert map expert_map = [-1] * num_experts expert_map[s:e] = list(range(num_local_experts)) expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda") # update cutlass moe arg with expert_map cutlass_moe_kwargs["expert_map"] = expert_map # update cutlass moe arg tensors for k, t in full_tensors.items(): cutlass_moe_kwargs[k] = t[s:e] yield cutlass_moe_kwargs out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"]) for kwargs in slice_experts(): out_tensor = out_tensor + cutlass_moe_fp8(**kwargs) return out_tensor def run_8_bit(moe_tensors: MOETensors8Bit, topk_weights: torch.Tensor, topk_ids: torch.Tensor, per_act_token: bool, num_local_experts: Optional[int] = None) -> torch.Tensor: assert not any([ t is None for t in [ moe_tensors.w1_q, moe_tensors.w2_q, moe_tensors.w1_scale, moe_tensors.w2_scale, moe_tensors.a_scale ] ]) kwargs = { 'a': moe_tensors.a, 'w1_q': moe_tensors.w1_q, # type: ignore[union-attr] 'w2_q': moe_tensors.w2_q, # type: ignore[union-attr] 'topk_weights': topk_weights, 'topk_ids': topk_ids, 'w1_scale': moe_tensors.w1_scale, 'w2_scale': moe_tensors.w2_scale, 'per_act_token': per_act_token, 'a1_scale': None #moe_tensors.a_scale } num_experts = moe_tensors.w1.size(0) with_ep = num_local_experts is not None or num_local_experts == num_experts if not with_ep: return cutlass_moe_fp8(**kwargs) assert num_local_experts is not None return run_with_expert_maps( num_experts, num_local_experts, # type: ignore[arg-type] **kwargs) @pytest.mark.parametrize("m,n,k", MNK_FACTORS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( current_platform.get_device_capability()), reason="Grouped gemm is not supported on this GPU type.") def test_cutlass_moe_8_bit_no_graph( m: int, n: int, k: int, e: int, topk: int, per_act_token: bool, per_out_ch: bool, monkeypatch, ): current_platform.seed_everything(7) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch) score = torch.randn((m, e), device="cuda", dtype=torch.half) topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False) # Note that we are using the dequantized versions of the tensors. # Using a, w1 and w2 directly results in minor output differences. triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids) cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token) # Note 5.5 only needed for larger problem sizes, 5 works ok for # the rest. torch.testing.assert_close(triton_output, cutlass_output, atol=5.5e-2, rtol=1e-2) @pytest.mark.parametrize("m,n,k", MNK_FACTORS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( current_platform.get_device_capability()), reason="Grouped gemm is not supported on this GPU type.") def test_cutlass_moe_8_bit_cuda_graph( m: int, n: int, k: int, e: int, topk: int, per_act_token: bool, per_out_ch: bool, monkeypatch, ): current_platform.seed_everything(7) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): dtype = torch.half mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch) score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False) # Note that we are using the dequantized versions of the tensors. # Using a, w1 and w2 directly results in minor output differences. triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids) stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token) torch.cuda.synchronize() graph.replay() torch.cuda.synchronize() torch.testing.assert_close(triton_output, cutlass_output, atol=9e-2, rtol=1e-2) @pytest.mark.parametrize("m", [64]) @pytest.mark.parametrize("n", [1024]) @pytest.mark.parametrize("k", [4096]) @pytest.mark.parametrize("e", [16]) @pytest.mark.parametrize("topk", [1, 8]) @pytest.mark.parametrize("per_act_token", [True]) @pytest.mark.parametrize("per_out_channel", [True]) @pytest.mark.parametrize("ep_size", [1, 2, 4, 8, 16]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( current_platform.get_device_capability()), reason="Grouped gemm is not supported on this GPU type.") def test_cutlass_moe_8_bit_EP( m: int, n: int, k: int, e: int, topk: int, per_act_token: bool, per_out_channel: bool, ep_size: int, monkeypatch, ): current_platform.seed_everything(7) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_channel) score = torch.randn((m, e), device="cuda", dtype=torch.half) topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False) # Note that we are using the dequantized versions of the tensors. # Using a, w1 and w2 directly results in minor output differences. triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids) assert e % ep_size == 0, "Cannot distribute experts evenly" cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token, num_local_experts=e // ep_size) torch.testing.assert_close(triton_output, cutlass_output, atol=5e-2, rtol=1e-2)