vllm/tests/kernels/moe/test_batched_moe.py

299 lines
9.2 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional
import pytest
import torch
import triton.language as tl
from tests.kernels.moe.utils import (batched_moe,
make_quantized_test_activations,
make_test_weights, triton_moe)
from tests.kernels.quant_utils import native_batched_masked_quant_matmul
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
invoke_moe_batched_triton_kernel)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform
MNK_FACTORS = [
(1, 128, 128),
(1, 128, 2048),
(1, 512, 512),
(1, 1024, 128),
(1, 1024, 2048),
(32, 128, 128),
(32, 512, 512),
(32, 1024, 2048),
(45, 128, 128),
(45, 128, 2048),
(45, 512, 512),
(45, 1024, 128),
(45, 1024, 2048),
(64, 128, 128),
(64, 512, 512),
(64, 1024, 2048),
(222, 128, 128),
(222, 128, 2048),
(222, 512, 512),
(222, 1024, 128),
(222, 1024, 2048),
]
NUM_EXPERTS = [8, 64]
TOP_KS = [1, 2, 6]
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
@dataclass
class BatchedMMConfig:
in_dtype: torch.dtype
quant_dtype: Optional[torch.dtype]
out_dtype: torch.dtype
num_experts: int
max_tokens_per_expert: int
K: int
N: int
@dataclass
class BatchedMMTensors:
A: torch.Tensor # [E, max_tokens, K]
B: torch.Tensor # [E, K, N] - column major
C: torch.Tensor # [E, max_tokens, N]
num_expert_tokens: torch.Tensor # [E]
@staticmethod
def make_tensors(config: BatchedMMConfig):
A = torch.randn(
(config.num_experts, config.max_tokens_per_expert, config.K),
device="cuda",
dtype=config.in_dtype) / 10
B = torch.randn((config.num_experts, config.N, config.K),
device="cuda",
dtype=config.in_dtype)
C = torch.zeros(
(config.num_experts, config.max_tokens_per_expert, config.N),
device="cuda",
dtype=config.out_dtype)
num_expert_tokens = torch.randint(low=0,
high=config.max_tokens_per_expert,
size=(config.num_experts, ),
device="cuda",
dtype=torch.int32)
return BatchedMMTensors(A, B, C, num_expert_tokens)
@pytest.mark.parametrize("num_experts", [8, 16, 32])
@pytest.mark.parametrize("max_tokens_per_expert",
[32, 64, 128, 192, 224, 256, 512])
@pytest.mark.parametrize("K", [128, 256, 1024])
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("block_shape", [None])
@pytest.mark.parametrize("per_act_token_quant", [False])
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
N: int, dtype: torch.dtype,
block_shape: Optional[list[int]],
per_act_token_quant: bool):
current_platform.seed_everything(7)
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
if (per_act_token_quant or block_shape is not None) and not use_fp8_w8a8:
pytest.skip("Don't test blocking for non-quantized types.")
if per_act_token_quant and block_shape is not None:
pytest.skip("Skip illegal quantization test.")
if dtype.itemsize == 1:
act_dtype = torch.bfloat16
quant_dtype = dtype
else:
act_dtype = dtype
quant_dtype = None
num_expert_tokens = torch.randint(low=0,
high=max_tokens_per_expert,
size=(num_experts, ),
device="cuda",
dtype=torch.int32)
A, A_q, A_scale = make_quantized_test_activations(
num_experts,
max_tokens_per_expert,
K,
in_dtype=act_dtype,
quant_dtype=quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant)
B, B_q, B_scale, _, _, _ = make_test_weights(
num_experts,
N // 2,
K,
in_dtype=act_dtype,
quant_dtype=quant_dtype,
block_shape=block_shape,
)
out_shape = (num_experts, max_tokens_per_expert, N)
test_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
q_ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
compute_tl_dtype = {
torch.float16: tl.float16,
torch.bfloat16: tl.bfloat16,
torch.float32: tl.float32
}[test_output.dtype]
assert A_q.dtype == B_q.dtype
invoke_moe_batched_triton_kernel(
A_q,
B_q,
test_output,
num_expert_tokens,
compute_tl_dtype,
# Quantization data
A_scale,
B_scale,
None,
# Quantization schemes
use_fp8_w8a8,
False,
False,
config={
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32
},
block_shape=block_shape,
)
ref_output = native_batched_masked_quant_matmul(
A,
B,
ref_output,
num_expert_tokens,
None,
None,
None,
)
q_ref_output = native_batched_masked_quant_matmul(A_q, B_q, q_ref_output,
num_expert_tokens,
A_scale, B_scale,
block_shape)
rtol, atol = {
torch.float16: (6e-2, 6e-2),
torch.bfloat16: (6e-2, 6e-2),
torch.float32: (1e-2, 1e-2),
}[test_output.dtype]
torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol)
torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)
@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("per_act_token_quant", [False])
@pytest.mark.parametrize("block_shape", [None])
def test_fused_moe_batched_experts(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
per_act_token_quant: bool,
block_shape: Optional[list[int]],
):
current_platform.seed_everything(7)
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
pytest.skip("Skip quantization test for non-quantized type")
if per_act_token_quant and block_shape is not None or topk > e:
pytest.skip("Skip illegal quantization test.")
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
if dtype.itemsize == 1:
act_dtype = torch.bfloat16
quant_dtype = dtype
else:
act_dtype = dtype
quant_dtype = None
_, w1, w1_s, _, w2, w2_s = make_test_weights(e,
n,
k,
block_shape=block_shape,
in_dtype=act_dtype,
quant_dtype=quant_dtype)
with set_current_vllm_config(vllm_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
batched_output = batched_moe(
a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
baseline_output = torch_experts(
a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape)
triton_output = triton_moe(
a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
torch.testing.assert_close(triton_output,
baseline_output,
atol=2e-2,
rtol=2e-2)
torch.testing.assert_close(triton_output,
batched_output,
atol=2e-2,
rtol=2e-2)