vllm/tests/kernels/moe/utils.py

250 lines
8.1 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import vllm._custom_ops as ops
from tests.kernels.quant_utils import (per_block_cast_to_fp8,
per_block_cast_to_int8)
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
from vllm.utils import round_up
def triton_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
per_act_token_quant=False,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
per_channel_quant=per_act_token_quant,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
block_shape=block_shape)
def batched_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64)
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens,
world_size=1,
dp_size=1,
rank=0),
BatchedTritonExperts(
max_num_tokens=max_num_tokens,
world_size=1,
dp_size=1,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
),
)
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
def naive_batched_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64)
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens,
world_size=1,
dp_size=1,
rank=0),
NaiveBatchedExperts(
max_num_tokens=max_num_tokens,
dp_size=1,
world_size=1,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
),
)
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
def chunk_scales(scales: Optional[torch.Tensor], start: int,
end: int) -> Optional[torch.Tensor]:
if scales is not None:
if scales.numel() == 1:
return scales
else:
return scales[start:end]
return None
def make_quantized_test_activations(
E: int,
m: int,
k: int,
in_dtype: torch.dtype,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
a = torch.randn((E, m, k), device="cuda", dtype=in_dtype) / 10
a_q = a
a_scale = None
if quant_dtype is not None:
assert (quant_dtype == torch.float8_e4m3fn
or quant_dtype == torch.int8), "only fp8/int8 supported"
a_q = torch.zeros_like(a, dtype=quant_dtype)
a_scale_l = [None] * E
for e in range(E):
a_q[e], a_scale_l[e] = moe_kernel_quantize_input(
a[e], None, quant_dtype, per_act_token_quant, block_shape)
a_scale = torch.stack(a_scale_l)
if not per_act_token_quant and block_shape is None:
a_scale = a_scale.view(E, 1, 1)
return a, a_q, a_scale
def moe_quantize_weights(
w: torch.Tensor,
w_s: Optional[torch.Tensor],
quant_dtype: Optional[torch.dtype],
per_token_quant: bool,
block_shape: Optional[list[int]],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert (quant_dtype == torch.float8_e4m3fn
or quant_dtype == torch.int8), "only fp8/int8 supported"
if block_shape is not None:
assert not per_token_quant
if quant_dtype == torch.int8:
w, w_s = per_block_cast_to_int8(w, block_shape)
else:
w, w_s = per_block_cast_to_fp8(w, block_shape)
else:
if quant_dtype == torch.int8:
w, w_s = ops.scaled_int8_quant(
w, w_s, use_per_token_if_dynamic=per_token_quant)
else:
w, w_s = ops.scaled_fp8_quant(
w, w_s, use_per_token_if_dynamic=per_token_quant)
return w, w_s
def make_test_weight(
e: int,
rows: int,
cols: int,
in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
if quant_dtype is not None:
w_l = [None] * e
w_s_l = [None] * e
for idx in range(e):
w_l[idx], w_s_l[idx] = moe_quantize_weights(
w_16[idx], None, quant_dtype, per_act_token_quant, block_shape)
w = torch.stack(w_l)
w_s = torch.stack(w_s_l)
if w_s.ndim == 2:
assert w_s.shape[-1] == 1
w_s = w_s.view(-1, 1, 1)
if block_shape is not None:
block_n, block_k = block_shape
n_tiles = (rows + block_n - 1) // block_n
k_tiles = (cols + block_k - 1) // block_k
assert w_s.shape == (e, n_tiles, k_tiles)
else:
w = w_16
w_s = None
return w_16, w, w_s
def make_test_weights(
e: int,
n: int,
k: int,
in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor,
torch.Tensor, Optional[torch.Tensor]]:
return (
*make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
per_act_token_quant),
*make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
per_act_token_quant),
)