# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # This is a test for the AITER ops. # It tests if the AITER ops are # 1. correctly registered as custom ops # 2. correctly defined the relationship between # implementation and fake function # 3. can be used with torch.compile # This file will be skipped if AITER is not installed # and the platform is not ROCm. import importlib.util import pytest import torch # this import statement is needed to ensure the ops are registered import vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe # noqa: F401 from vllm.platforms import current_platform # need to import once to ensure the ops are registered # Check if aiter package is installed aiter_available = importlib.util.find_spec("aiter") is not None pytestmark = pytest.mark.skipif( not (current_platform.is_rocm() and aiter_available), reason="AITER ops are only available on ROCm with aiter package installed") def test_rocm_aiter_biased_grouped_topk_custom_op_registration(): """Test that the custom op is correctly registered.""" # Check if the op exists in torch.ops.vllm assert hasattr(torch.ops.vllm, 'rocm_aiter_biased_grouped_topk') # Check if the op is callable assert callable(torch.ops.vllm.rocm_aiter_biased_grouped_topk) def test_rocm_aiter_grouped_topk_custom_op_registration(): """Test that the custom op is correctly registered.""" # Check if the op exists in torch.ops.vllm assert hasattr(torch.ops.vllm, 'rocm_aiter_grouped_topk') # Check if the op is callable assert callable(torch.ops.vllm.rocm_aiter_grouped_topk) def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility(): """Test that the op can be used with torch.compile.""" # Create test tensors token = 64 expert = 256 num_expert_group = 8 topk = 8 topk_group = 4 renormalize = True scale_factor = 1.0 gating_output = torch.randn((token, expert), dtype=torch.bfloat16, device="cuda") e_score_correction_bias = torch.randn((expert, ), dtype=torch.bfloat16, device="cuda") device = gating_output.device topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device) # Define a function that uses the op def biased_grouped_topk_fn(gating_output, e_score_correction_bias, topk_weights, topk_ids): return torch.ops.vllm.rocm_aiter_biased_grouped_topk( gating_output, e_score_correction_bias, topk_weights, topk_ids, num_expert_group, topk_group, renormalize, scale_factor) # Verify the op's fake implementation torch.library.opcheck( torch.ops.vllm.rocm_aiter_biased_grouped_topk, (gating_output, e_score_correction_bias, topk_weights, topk_ids), kwargs={ "num_expert_group": num_expert_group, "topk_group": topk_group, "need_renorm": renormalize, "routed_scaling_factor": scale_factor }, test_utils=("test_faketensor")) # Compile the function with appropriate settings compiled_fn = torch.compile(biased_grouped_topk_fn, fullgraph=True, backend="inductor", mode="reduce-overhead", dynamic=False) topk_weights_original = torch.empty((token, topk), dtype=torch.float32, device=device) topk_ids_original = torch.empty((token, topk), dtype=torch.int32, device=device) topk_weights_compiled = torch.empty((token, topk), dtype=torch.float32, device=device) topk_ids_compiled = torch.empty((token, topk), dtype=torch.int32, device=device) # Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode) biased_grouped_topk_fn(gating_output, e_score_correction_bias, topk_weights_original, topk_ids_original) compiled_fn(gating_output, e_score_correction_bias, topk_weights_compiled, topk_ids_compiled) # Sort the results for comparison since the order might not be deterministic topk_ids_original, indices_original = torch.sort(topk_ids_original) topk_weights_original = torch.gather(topk_weights_original, 1, indices_original) topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled) topk_weights_compiled = torch.gather(topk_weights_compiled, 1, indices_compiled) # Verify results match assert torch.allclose(topk_weights_original, topk_weights_compiled, rtol=1e-2, atol=1e-2) assert torch.allclose(topk_ids_original, topk_ids_compiled) def test_rocm_aiter_grouped_topk_torch_compile_compatibility(): """Test that the op can be used with torch.compile.""" # Create test tensors token = 64 expert = 256 num_expert_group = 8 topk = 8 topk_group = 4 renormalize = True scoring_func = "softmax" scale_factor = 1.0 gating_output = torch.randn((token, expert), dtype=torch.bfloat16, device="cuda") device = gating_output.device topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device) # Define a function that uses the op def grouped_topk_fn(gating_output, topk_weights, topk_ids, scoring_func): return torch.ops.vllm.rocm_aiter_grouped_topk( gating_output, topk_weights, topk_ids, num_expert_group, topk_group, renormalize, scoring_func, scale_factor) # Verify the op's fake implementation torch.library.opcheck(torch.ops.vllm.rocm_aiter_grouped_topk, (gating_output, topk_weights, topk_ids), kwargs={ "num_expert_group": num_expert_group, "topk_group": topk_group, "need_renorm": renormalize, "scoring_func": scoring_func, "routed_scaling_factor": scale_factor }, test_utils=("test_faketensor")) # Compile the function with appropriate settings compiled_fn = torch.compile(grouped_topk_fn, fullgraph=True, backend="inductor", mode="reduce-overhead", dynamic=False) topk_weights_original = torch.empty((token, topk), dtype=torch.float32, device=device) topk_ids_original = torch.empty((token, topk), dtype=torch.int32, device=device) topk_weights_compiled = torch.empty((token, topk), dtype=torch.float32, device=device) topk_ids_compiled = torch.empty((token, topk), dtype=torch.int32, device=device) # Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode) grouped_topk_fn(gating_output, topk_weights_original, topk_ids_original, scoring_func) compiled_fn(gating_output, topk_weights_compiled, topk_ids_compiled, scoring_func) # Sort the results for comparison since the order might not be deterministic topk_ids_original, indices_original = torch.sort(topk_ids_original) topk_weights_original = torch.gather(topk_weights_original, 1, indices_original) topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled) topk_weights_compiled = torch.gather(topk_weights_compiled, 1, indices_compiled) # Verify results match assert torch.allclose(topk_weights_original, topk_weights_compiled, rtol=1e-2, atol=1e-2) assert torch.allclose(topk_ids_original, topk_ids_compiled)