# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Benchmark the performance of the cutlass_moe_fp4 kernel vs the triton_moe kernel. The cutlass_moe_fp4 kernel takes in fp4 quantized weights and 16-bit activations. The triton_moe kernel takes in fp8 weights(tensor scaled to fp8) and 16-bit activations. """ import nvtx import torch import torch.utils.benchmark as benchmark from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk from vllm.scalar_type import scalar_types from vllm.utils import FlexibleArgumentParser WEIGHT_SHAPES_MOE = { "nvidia/DeepSeek-R1-FP4": [ [256, 8, 2048, 7168], ], } DEFAULT_MODELS = [ "nvidia/DeepSeek-R1-FP4", ] DEFAULT_BATCH_SIZES = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] DEFAULT_TP_SIZES = [1] PER_ACT_TOKEN_OPTS = [False] PER_OUT_CH_OPTS = [False] FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max def to_fp8(tensor: torch.Tensor): finfo = torch.finfo(torch.float8_e4m3fn) return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to( dtype=torch.float8_e4m3fn ) def bench_run( results: list[benchmark.Measurement], model: str, num_experts: int, topk: int, per_act_token: bool, per_out_ch: bool, mkn: tuple[int, int, int], ): label = "NVFP4 Blockscaled CUTLASS MOE vs FP8 Tensor Scaled Triton" sub_label = ( "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})".format( model, num_experts, topk, per_act_token, per_out_ch, mkn ) ) print(f"Testing: {sub_label}") (m, k, n) = mkn dtype = torch.half device = "cuda" a = torch.randn((m, k), device=device, dtype=dtype) / 10 w1 = torch.randn((num_experts, 2 * n, k), device=device, dtype=dtype) / 10 w2 = torch.randn((num_experts, k, n), device=device, dtype=dtype) / 10 _, a_fp8_scale = ops.scaled_fp8_quant(a) w1_fp8q = torch.empty( (num_experts, 2 * n, k), device=device, dtype=torch.float8_e4m3fn ) w2_fp8q = torch.empty((num_experts, k, n), device=device, dtype=torch.float8_e4m3fn) w1_fp8scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32) w2_fp8scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32) for expert in range(num_experts): w1_fp8q[expert], w1_fp8scale[expert] = ops.scaled_fp8_quant(w1[expert]) w2_fp8q[expert], w2_fp8scale[expert] = ops.scaled_fp8_quant(w2[expert]) w1_fp8q_notransp = w1_fp8q.clone() w2_fp8q_notransp = w2_fp8q.clone() w1_fp8q = w1_fp8q.transpose(1, 2) w2_fp8q = w2_fp8q.transpose(1, 2) score = torch.randn((m, num_experts), device=device, dtype=dtype) topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False) quant_blocksize = 16 w1_blockscale = torch.empty( (num_experts, 2 * n, k // quant_blocksize), device=device, dtype=torch.float8_e4m3fn, ) w2_blockscale = torch.empty( (num_experts, k, n // quant_blocksize), device=device, dtype=torch.float8_e4m3fn ) # n_b_scales = 2 * n if per_out_ch else 1 # k_b_scales = k if per_out_ch else 1 w1_fp4 = torch.empty((num_experts, 2 * n, k // 2), device=device, dtype=torch.uint8) w2_fp4 = torch.empty((num_experts, k, n // 2), device=device, dtype=torch.uint8) w1_gs = torch.empty((num_experts,), device=device, dtype=torch.float32) w2_gs = torch.empty((num_experts,), device=device, dtype=torch.float32) a1_gs = torch.ones((num_experts,), device=device, dtype=torch.float32) a2_gs = torch.ones((num_experts,), device=device, dtype=torch.float32) for expert in range(num_experts): w1_e = w1[expert] w2_e = w2[expert] w1_amax = torch.abs(w1_e).max().to(torch.float32) w2_amax = torch.abs(w2_e).max().to(torch.float32) w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax w1_fp4[expert], w1_blockscale[expert] = ops.scaled_fp4_quant( w1_e, w1_gs[expert] ) w2_fp4[expert], w2_blockscale[expert] = ops.scaled_fp4_quant( w2_e, w2_gs[expert] ) def run_triton_moe( a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, a_fp8_scale: torch.Tensor, num_repeats: int, ): for _ in range(num_repeats): fused_experts( a, w1, w2, topk_weights, topk_ids, use_fp8_w8a8=True, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a_fp8_scale, ) def run_cutlass_moe_fp4( a: torch.Tensor, w1_fp4: torch.Tensor, w2_fp4: torch.Tensor, w1_blockscale: torch.Tensor, w2_blockscale: torch.Tensor, w1_gs: torch.Tensor, w2_gs: torch.Tensor, a1_gs: torch.Tensor, a2_gs: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, m: int, n: int, k: int, e: int, device: torch.device, num_repeats: int, ): for _ in range(num_repeats): with nvtx.annotate("cutlass_moe_fp4", color="green"): cutlass_moe_fp4( a=a, a1_gscale=a1_gs, a2_gscale=a2_gs, w1_fp4=w1_fp4, w1_blockscale=w1_blockscale, w1_alphas=w1_gs, w2_fp4=w2_fp4, w2_blockscale=w2_blockscale, w2_alphas=w2_gs, topk_weights=topk_weights, topk_ids=topk_ids, m=m, n=n, k=k, e=num_experts, device=device, ) def run_cutlass_from_graph( a: torch.Tensor, a1_gscale: torch.Tensor, w1_fp4: torch.Tensor, w1_blockscale: torch.Tensor, w1_alphas: torch.Tensor, a2_gscale: torch.Tensor, w2_fp4: torch.Tensor, w2_blockscale: torch.Tensor, w2_alphas: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, m: int, n: int, k: int, e: int, device: torch.device, ): with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) ): return cutlass_moe_fp4( a=a, a1_gscale=a1_gs, w1_fp4=w1_fp4, w1_blockscale=w1_blockscale, w1_alphas=w1_alphas, a2_gscale=a2_gs, w2_fp4=w2_fp4, w2_blockscale=w2_blockscale, w2_alphas=w2_alphas, topk_weights=topk_weights, topk_ids=topk_ids, m=m, n=n, k=k, e=num_experts, device=device, ) def run_triton_from_graph( a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, a_fp8_scale: torch.Tensor, ): with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) ): return fused_experts( a, w1, w2, topk_weights, topk_ids, use_fp8_w8a8=True, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a_fp8_scale, ) def replay_graph(graph, num_repeats): for _ in range(num_repeats): graph.replay() torch.cuda.synchronize() cutlass_stream = torch.cuda.Stream() cutlass_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(cutlass_graph, stream=cutlass_stream): run_cutlass_from_graph( a=a, a1_gscale=a1_gs, w1_fp4=w1_fp4, w1_blockscale=w1_blockscale, w1_alphas=w1_gs, a2_gscale=a2_gs, w2_fp4=w2_fp4, w2_blockscale=w2_blockscale, w2_alphas=w2_gs, topk_weights=topk_weights, topk_ids=topk_ids, m=m, n=n, k=k, e=num_experts, device=device, ) torch.cuda.synchronize() triton_stream = torch.cuda.Stream() triton_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(triton_graph, stream=triton_stream): run_triton_from_graph( a, w1_fp8q_notransp, w2_fp8q_notransp, topk_weights, topk_ids, w1_fp8scale, w2_fp8scale, a_fp8_scale, ) torch.cuda.synchronize() min_run_time = 5 num_warmup = 5 num_runs = 25 globals = { # Baseline params "w1": w1, "w2": w2, "score": score, "topk": topk, "w1_fp8q_notransp": w1_fp8q_notransp, "w2_fp8q_notransp": w2_fp8q_notransp, "w1_fp8scale": w1_fp8scale, "w2_fp8scale": w2_fp8scale, "a_fp8_scale": a_fp8_scale, # Cutlass params "a": a, "a1_gscale": a1_gs, "w1_fp4": w1_fp4, "w1_blockscale": w1_blockscale, "w1_alphas": w1_gs, "a2_gscale": a2_gs, "w2_fp4": w2_fp4, "w2_blockscale": w2_blockscale, "w2_alphas": w2_gs, "topk_weights": topk_weights, "topk_ids": topk_ids, "m": m, "n": n, "k": k, "e": num_experts, "device": device, # cuda graph params "cutlass_graph": cutlass_graph, "triton_graph": triton_graph, # Gen params "num_runs": num_runs, # Kernels "run_triton_moe": run_triton_moe, "run_cutlass_moe_fp4": run_cutlass_moe_fp4, "replay_graph": replay_graph, } # Warmup run_triton_moe( a, w1_fp8q_notransp, w2_fp8q_notransp, topk_weights, topk_ids, w1_fp8scale, w2_fp8scale, a_fp8_scale, num_warmup, ) results.append( benchmark.Timer( stmt="run_triton_moe(a, w1_fp8q_notransp, w2_fp8q_notransp, topk_weights, topk_ids, w1_fp8scale, w2_fp8scale, a_fp8_scale, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="triton_moe", ).blocked_autorange(min_run_time=min_run_time) ) # Warmup replay_graph(triton_graph, num_warmup) results.append( benchmark.Timer( stmt="replay_graph(triton_graph, num_runs)", globals=globals, label=label, sub_label=sub_label, description="triton_moe_cuda_graphs", ).blocked_autorange(min_run_time=min_run_time) ) # Warmup run_cutlass_moe_fp4( a, w1_fp4, w2_fp4, w1_blockscale, w2_blockscale, w1_gs, w2_gs, a1_gs, a2_gs, topk_weights, topk_ids, m, n, k, num_experts, device, num_warmup, ) results.append( benchmark.Timer( stmt="run_cutlass_moe_fp4(a, w1_fp4, w2_fp4, w1_blockscale, w2_blockscale, w1_alphas, w2_alphas, a1_gscale, a2_gscale, topk_weights, topk_ids, m, n, k, e, device, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="cutlass_moe_fp4", ).blocked_autorange(min_run_time=min_run_time) ) # Warmup replay_graph(cutlass_graph, num_warmup) results.append( benchmark.Timer( stmt="replay_graph(cutlass_graph, num_runs)", globals=globals, label=label, sub_label=sub_label, description="cutlass_moe_fp4_cuda_graphs", ).blocked_autorange(min_run_time=min_run_time) ) def main(args): print("Benchmarking models:") for i, model in enumerate(args.models): print(f"[{i}] {model}") results: list[benchmark.Measurement] = [] for model in args.models: for tp in args.tp_sizes: for layer in WEIGHT_SHAPES_MOE[model]: num_experts = layer[0] topk = layer[1] size_k = layer[2] size_n = layer[3] // tp if len(args.limit_k) > 0 and size_k not in args.limit_k: continue if len(args.limit_n) > 0 and size_n not in args.limit_n: continue for per_act_token in PER_ACT_TOKEN_OPTS: for per_out_ch in PER_OUT_CH_OPTS: for size_m in args.batch_sizes: mkn = (size_m, size_k, size_n) bench_run( results, model, num_experts, topk, per_act_token, per_out_ch, mkn, ) compare = benchmark.Compare(results) compare.print() if __name__ == "__main__": parser = FlexibleArgumentParser( description="Benchmark NVFP4 CUTLASS MOE across specified models/shapes/batches" ) parser.add_argument( "--models", nargs="+", type=str, default=DEFAULT_MODELS, choices=WEIGHT_SHAPES_MOE.keys(), ) parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES) parser.add_argument( "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES ) parser.add_argument("--limit-k", nargs="+", type=int, default=[]) parser.add_argument("--limit-n", nargs="+", type=int, default=[]) parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[]) parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[]) parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[]) args = parser.parse_args() main(args)