mirror of https://github.com/vllm-project/vllm.git
226 lines
6.5 KiB
Python
226 lines
6.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
"""
|
|
Unit-test DeepGEMM FP8 kernels (no DeepEP).
|
|
Compare DeepGEMM path against the Triton fallback inside vLLM's fused_experts.
|
|
"""
|
|
|
|
import importlib
|
|
import math
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
|
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
|
per_token_group_quant_fp8)
|
|
from vllm.utils import cdiv
|
|
|
|
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
|
|
|
if has_deep_gemm:
|
|
import deep_gemm
|
|
BLOCK_M = deep_gemm.get_m_alignment_for_contiguous_layout()
|
|
BLOCK_SIZE = [BLOCK_M, BLOCK_M]
|
|
|
|
requires_deep_gemm = pytest.mark.skipif(
|
|
not has_deep_gemm,
|
|
reason="Requires deep_gemm kernels",
|
|
)
|
|
|
|
|
|
def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
|
x, y = x.double(), y.double()
|
|
denominator = (x * x + y * y).sum()
|
|
sim = 2 * (x * y).sum() / denominator
|
|
return 1 - sim
|
|
|
|
|
|
def per_block_cast_to_fp8(
|
|
x: torch.Tensor,
|
|
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
|
|
assert x.dim() == 2
|
|
m, n = x.shape
|
|
x_padded = torch.zeros(
|
|
(cdiv(m, 128) * 128, cdiv(n, block_size_n) * block_size_n),
|
|
dtype=x.dtype,
|
|
device=x.device)
|
|
x_padded[:m, :n] = x
|
|
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
|
|
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
|
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
|
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
|
|
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
|
return x_scaled_sub, scales
|
|
|
|
|
|
def make_block_quant_fp8_weights(
|
|
e: int,
|
|
n: int,
|
|
k: int,
|
|
block_size: list[int],
|
|
):
|
|
"""
|
|
Generate (w1, w2) expert weights and their per-block scale tensors
|
|
in FP8 block-quantized format.
|
|
|
|
w1 shape: (E, 2N, K)
|
|
w2 shape: (E, K, N)
|
|
"""
|
|
dtype = torch.bfloat16
|
|
fp8_max, fp8_min = torch.finfo(torch.float8_e4m3fn).max, torch.finfo(
|
|
torch.float8_e4m3fn).min
|
|
|
|
# bf16 reference weights
|
|
w1_bf16 = torch.randn(e, 2 * n, k, device="cuda", dtype=dtype) / 10
|
|
w2_bf16 = torch.randn(e, k, n, device="cuda", dtype=dtype) / 10
|
|
w1_bf16.clamp_(fp8_min, fp8_max)
|
|
w2_bf16.clamp_(fp8_min, fp8_max)
|
|
|
|
block_n, block_k = block_size
|
|
n_tiles_w1 = math.ceil((2 * n) / block_n)
|
|
k_tiles_w1 = math.ceil(k / block_k)
|
|
n_tiles_w2 = math.ceil(k / block_n)
|
|
k_tiles_w2 = math.ceil(n / block_k)
|
|
|
|
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
|
|
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
|
|
w1_s = torch.empty(e,
|
|
n_tiles_w1,
|
|
k_tiles_w1,
|
|
device="cuda",
|
|
dtype=torch.float32)
|
|
w2_s = torch.empty(e,
|
|
n_tiles_w2,
|
|
k_tiles_w2,
|
|
device="cuda",
|
|
dtype=torch.float32)
|
|
|
|
for i in range(e):
|
|
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
|
|
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
|
|
|
|
return w1, w2, w1_s, w2_s
|
|
|
|
|
|
def run_single_case(m, n, k, topk, num_experts, block_size):
|
|
"""
|
|
Run one (M,N,K) configuration on a single GPU and assert DeepGEMM ==
|
|
Triton baseline within tolerance.
|
|
"""
|
|
tokens_bf16 = torch.randn(
|
|
m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1)
|
|
_, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1])
|
|
|
|
# expert weight tensors
|
|
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k,
|
|
block_size)
|
|
|
|
router_logits = torch.randn(m,
|
|
num_experts,
|
|
device="cuda",
|
|
dtype=torch.float32)
|
|
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
|
|
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
|
|
|
|
# triton referrence
|
|
out_triton = fused_experts(
|
|
hidden_states=tokens_bf16,
|
|
w1=w1,
|
|
w2=w2,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
inplace=False,
|
|
use_fp8_w8a8=True,
|
|
w1_scale=w1_s,
|
|
w2_scale=w2_s,
|
|
a1_scale=a1_scale,
|
|
block_shape=block_size,
|
|
allow_deep_gemm=False,
|
|
)
|
|
|
|
# DeepGemm
|
|
out_deepgemm = fused_experts(
|
|
hidden_states=tokens_bf16,
|
|
w1=w1,
|
|
w2=w2,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
inplace=False,
|
|
use_fp8_w8a8=True,
|
|
w1_scale=w1_s,
|
|
w2_scale=w2_s,
|
|
a1_scale=a1_scale,
|
|
block_shape=block_size,
|
|
allow_deep_gemm=True,
|
|
)
|
|
|
|
base = out_triton.abs().mean()
|
|
atol = 0.1 * base.clamp(min=1e-2) # 10% of mean, but not lower than 1e-3
|
|
rtol = 0.05
|
|
# ----- Compare -----
|
|
torch.testing.assert_close(
|
|
out_deepgemm.to(torch.float32),
|
|
out_triton.to(torch.float32),
|
|
rtol=rtol,
|
|
atol=float(atol),
|
|
)
|
|
|
|
|
|
# Note: W1 has shape (E, 2N, K), so N = 512
|
|
# can trigger the deepgemm path.
|
|
MNKs = [
|
|
(1024, 512, 128),
|
|
(1024, 512, 512),
|
|
(2048, 512, 512),
|
|
(512, 1024, 1024),
|
|
(512, 2048, 2048),
|
|
(4096, 4096, 1024),
|
|
]
|
|
|
|
TOPKS = [2, 6]
|
|
NUM_EXPERTS = [32]
|
|
|
|
|
|
@pytest.mark.parametrize("mnk", MNKs)
|
|
@pytest.mark.parametrize("topk", TOPKS)
|
|
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
|
|
@requires_deep_gemm
|
|
def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch):
|
|
|
|
with monkeypatch.context() as m:
|
|
m.setenv("VLLM_USE_DEEP_GEMM", "1")
|
|
|
|
_fused_moe_mod = importlib.import_module(
|
|
"vllm.model_executor.layers.fused_moe.fused_moe")
|
|
|
|
call_counter = {"cnt": 0}
|
|
|
|
orig_fn = _fused_moe_mod.deep_gemm_moe_fp8
|
|
|
|
def _spy_deep_gemm_moe_fp8(*args, **kwargs):
|
|
call_counter["cnt"] += 1
|
|
return orig_fn(*args, **kwargs)
|
|
|
|
monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8",
|
|
_spy_deep_gemm_moe_fp8)
|
|
|
|
m, n, k = mnk
|
|
|
|
if topk > num_experts:
|
|
pytest.skip(f"topk={topk} > num_experts={num_experts}")
|
|
|
|
run_single_case(
|
|
m=m,
|
|
n=n,
|
|
k=k,
|
|
topk=topk,
|
|
num_experts=num_experts,
|
|
block_size=BLOCK_SIZE,
|
|
)
|
|
|
|
# ensure that the DeepGEMM path was indeed taken.
|
|
assert call_counter["cnt"] == 1, \
|
|
f"DeepGEMM path was not executed during the test. " \
|
|
f"Call counter: {call_counter['cnt']}"
|