# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest import torch from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import (GeluAndMul, ReLUSquaredActivation, SiluAndMul) from vllm.model_executor.layers.fused_moe.fused_moe import (dispatch_topk_func, vllm_topk_softmax) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( is_rocm_aiter_moe_enabled) from vllm.model_executor.layers.layernorm import ( RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm, rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul) from vllm.platforms import current_platform # Registered subclass for test @CustomOp.register("relu3") class Relu3(ReLUSquaredActivation): pass @pytest.mark.parametrize( "env, torch_level, use_inductor, ops_enabled, default_on", [ # Default values based on compile level # - All by default (no Inductor compilation) ("", 0, False, [True] * 4, True), ("", 1, True, [True] * 4, True), ("", 2, False, [True] * 4, True), # - None by default (with Inductor) ("", 3, True, [False] * 4, False), ("", 4, True, [False] * 4, False), # - All by default (without Inductor) ("", 3, False, [True] * 4, True), ("", 4, False, [True] * 4, True), # Explicitly enabling/disabling # # Default: all # # All but SiluAndMul ("+rms_norm,-silu_and_mul", 0, True, [1, 0, 1, 1], True), # Only ReLU3 ("none,-rms_norm,+relu3", 1, False, [0, 0, 0, 1], False), # All but SiluAndMul ("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True), # All but ReLU3 (even if ReLU2 is on) ("-relu3,relu2", 3, False, [1, 1, 1, 0], True), # RMSNorm and SiluAndMul ("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False), # All but RMSNorm ("-rms_norm", 3, False, [0, 1, 1, 1], True), # # Default: none # # Only ReLU3 ("-silu_and_mul,+relu3", 3, True, [0, 0, 0, 1], False), # All but RMSNorm ("all,-rms_norm", 4, True, [0, 1, 1, 1], True), ]) def test_enabled_ops(env: str, torch_level: int, use_inductor: bool, ops_enabled: list[int], default_on: bool): vllm_config = VllmConfig( compilation_config=CompilationConfig(use_inductor=bool(use_inductor), level=torch_level, custom_ops=env.split(","))) with set_current_vllm_config(vllm_config): assert CustomOp.default_on() == default_on ops_enabled = [bool(x) for x in ops_enabled] assert RMSNorm(1024).enabled() == ops_enabled[0] assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0] assert SiluAndMul().enabled() == ops_enabled[1] assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1] assert GeluAndMul().enabled() == ops_enabled[2] assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2] # If registered, subclasses should follow their own name assert Relu3().enabled() == ops_enabled[3] assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3] # Unregistered subclass class SiluAndMul2(SiluAndMul): pass # Subclasses should not require registration assert SiluAndMul2().enabled() == SiluAndMul().enabled() @pytest.mark.parametrize( "env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"]) def test_enabled_ops_invalid(env: str): with pytest.raises(Exception): # noqa vllm_config = VllmConfig(compilation_config=CompilationConfig( custom_ops=env.split(","))) with set_current_vllm_config(vllm_config): RMSNorm(1024).enabled() @pytest.mark.skipif( not current_platform.is_rocm() or not current_platform.is_fp8_fnuz(), reason="AITER is a feature exclusive for ROCm and FP8_FNUZ") @pytest.mark.parametrize("use_cutlass", [True, False]) @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) @pytest.mark.parametrize("use_rocm_aiter_gemm_w8a8_blockscale", ["0", "1"]) def test_w8a8_blockscale_dispatch(use_cutlass: bool, use_rocm_aiter: str, use_rocm_aiter_gemm_w8a8_blockscale: str, monkeypatch): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", use_rocm_aiter_gemm_w8a8_blockscale) use_aiter_and_is_supported = (bool(int(use_rocm_aiter)) and bool( int(use_rocm_aiter_gemm_w8a8_blockscale))) block_scale_func = dispatch_w8a8_blockscale_func( use_cutlass, use_aiter_and_is_supported=use_aiter_and_is_supported) if use_cutlass: assert block_scale_func == cutlass_scaled_mm elif current_platform.is_rocm() and int(use_rocm_aiter) and int( use_rocm_aiter_gemm_w8a8_blockscale): assert block_scale_func == ( torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale) else: assert block_scale_func == w8a8_block_fp8_matmul @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) topk_func = dispatch_topk_func() is_rocm_aiter_moe_enabled.cache_clear() if current_platform.is_rocm() and int(use_rocm_aiter): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( rocm_aiter_topk_softmax) assert topk_func == rocm_aiter_topk_softmax else: assert topk_func == vllm_topk_softmax @pytest.mark.parametrize("add_residual", [True, False]) @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) @pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"]) @pytest.mark.skipif(not current_platform.is_rocm(), reason="AITER is a feature exclusive for ROCm") def test_rms_norm_dispatch(add_residual: bool, use_rocm_aiter: str, use_rocm_aiter_norm: str, monkeypatch): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm) rms_norm_func = dispatch_cuda_rmsnorm_func(add_residual) if not add_residual: if current_platform.is_rocm() and int(use_rocm_aiter) and int( use_rocm_aiter_norm): assert rms_norm_func == rocm_aiter_rms_norm else: assert rms_norm_func == rms_norm elif current_platform.is_rocm() and int(use_rocm_aiter) and int( use_rocm_aiter_norm): assert rms_norm_func == rocm_aiter_fused_add_rms_norm else: assert rms_norm_func == fused_add_rms_norm