vllm/tests/tpu/test_moe_pallas.py

89 lines
2.4 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the Pallas MOE implementation.
Run `pytest tests/kernels/moe/test_moe_pallas.py`.
"""
import pytest
import torch
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.layers.fused_moe.moe_pallas import (
fused_moe as pallas_moe)
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as torch_moe)
# yapf: enable
from vllm.platforms import current_platform
if not current_platform.is_tpu():
pytest.skip("This test needs a TPU.", allow_module_level=True)
NUM_EXPERTS = [8, 64]
EP_SIZE = [1]
TOP_KS = [2, 6]
# The Pallas GMM kernel requires num_tokens * topk to be a multiple of 16
@pytest.mark.parametrize("m", [8, 16, 64, 2048])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_pallas_moe(
m: int,
n: int,
k: int,
e: int,
topk: int,
ep_size: int,
dtype: torch.dtype,
):
import torch_xla.core.xla_model as xm
with torch.device(xm.xla_device()):
a = torch.randn((m, k), dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), dtype=dtype) / 10
w2 = torch.randn((e, k, n), dtype=dtype) / 10
score = torch.randn((m, e), dtype=dtype)
# TODO: Support ep
if ep_size > 1:
pytest.skip("No support for ep_size > 1 yet")
else:
e_map = None
# Run both implementations
torch_output = torch_moe(
hidden_states=a,
w1=w1,
w2=w2,
gating_output=score,
topk=topk,
global_num_experts=e,
expert_map=e_map,
renormalize=False,
)
pallas_output = pallas_moe(
hidden_states=a,
w1=w1,
w2=w2,
gating_output=score,
topk=topk,
global_num_experts=e,
expert_map=e_map,
renormalize=False,
)
xm.mark_step()
# Compare outputs
torch.testing.assert_close(
pallas_output.cpu(),
torch_output.cpu(),
atol=2e-2,
rtol=0,
)