# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for the Pallas MOE implementation. Run `pytest tests/kernels/moe/test_moe_pallas.py`. """ import pytest import torch # yapf conflicts with isort for this block # yapf: disable from vllm.model_executor.layers.fused_moe.moe_pallas import ( fused_moe as pallas_moe) from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( fused_moe as torch_moe) # yapf: enable from vllm.platforms import current_platform if not current_platform.is_tpu(): pytest.skip("This test needs a TPU.", allow_module_level=True) NUM_EXPERTS = [8, 64] EP_SIZE = [1] TOP_KS = [2, 6] # The Pallas GMM kernel requires num_tokens * topk to be a multiple of 16 @pytest.mark.parametrize("m", [8, 16, 64, 2048]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("dtype", [torch.bfloat16]) def test_pallas_moe( m: int, n: int, k: int, e: int, topk: int, ep_size: int, dtype: torch.dtype, ): import torch_xla.core.xla_model as xm with torch.device(xm.xla_device()): a = torch.randn((m, k), dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), dtype=dtype) / 10 w2 = torch.randn((e, k, n), dtype=dtype) / 10 score = torch.randn((m, e), dtype=dtype) # TODO: Support ep if ep_size > 1: pytest.skip("No support for ep_size > 1 yet") else: e_map = None # Run both implementations torch_output = torch_moe( hidden_states=a, w1=w1, w2=w2, gating_output=score, topk=topk, global_num_experts=e, expert_map=e_map, renormalize=False, ) pallas_output = pallas_moe( hidden_states=a, w1=w1, w2=w2, gating_output=score, topk=topk, global_num_experts=e, expert_map=e_map, renormalize=False, ) xm.mark_step() # Compare outputs torch.testing.assert_close( pallas_output.cpu(), torch_output.cpu(), atol=2e-2, rtol=0, )