mirror of https://github.com/vllm-project/vllm.git
651 lines
21 KiB
Python
651 lines
21 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Tests for the MOE layers.
|
|
|
|
Run `pytest tests/kernels/test_pplx_moe.py`.
|
|
"""
|
|
from typing import Optional
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
try:
|
|
from pplx_kernels import AllToAll
|
|
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
|
|
nvshmem_finalize, nvshmem_get_unique_id,
|
|
nvshmem_init)
|
|
has_pplx = True
|
|
except ImportError:
|
|
has_pplx = False
|
|
|
|
from tests.kernels.moe.utils import make_test_weights, naive_batched_moe
|
|
from tests.kernels.utils import torch_experts
|
|
from vllm.config import VllmConfig, set_current_vllm_config
|
|
from vllm.model_executor.layers.fused_moe import fused_topk, override_config
|
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
|
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
|
|
from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config
|
|
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
|
FusedMoEModularKernel)
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils import round_up
|
|
|
|
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
|
|
|
requires_pplx = pytest.mark.skipif(
|
|
not has_pplx,
|
|
reason="Requires PPLX kernels",
|
|
)
|
|
|
|
PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512),
|
|
(222, 2048, 1024)]
|
|
|
|
PPLX_MOE_COMBOS = [
|
|
(1, 128, 128),
|
|
(2, 128, 512),
|
|
(3, 1024, 2048),
|
|
(32, 128, 1024),
|
|
(45, 512, 2048),
|
|
(64, 1024, 1024),
|
|
(222, 1024, 2048),
|
|
]
|
|
|
|
NUM_EXPERTS = [8, 64]
|
|
EP_SIZE = [1, 4]
|
|
TOP_KS = [1, 2, 6]
|
|
|
|
vllm_config = VllmConfig()
|
|
vllm_config.scheduler_config.max_num_seqs = 128
|
|
vllm_config.scheduler_config.max_model_len = 8192
|
|
|
|
|
|
def torch_prepare(
|
|
a: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
num_experts: int,
|
|
max_num_tokens: Optional[int] = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
assert topk_ids.dim() == 2
|
|
assert topk_ids.shape[0] == a.shape[0]
|
|
|
|
num_tokens, hidden_dim = a.shape
|
|
topk = topk_ids.shape[1]
|
|
|
|
tokens_per_expert = torch.bincount(topk_ids.view(-1),
|
|
minlength=num_experts)
|
|
|
|
assert tokens_per_expert.numel() == num_experts
|
|
|
|
if max_num_tokens is None:
|
|
max_num_tokens = int(tokens_per_expert.max().item())
|
|
|
|
b_a = torch.zeros((num_experts, max_num_tokens, hidden_dim),
|
|
dtype=a.dtype,
|
|
device=a.device)
|
|
|
|
token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device)
|
|
|
|
for token in range(num_tokens):
|
|
for j in range(topk):
|
|
expert_id = topk_ids[token, j]
|
|
idx = token_counts[expert_id]
|
|
b_a[expert_id, idx:idx + 1, :] = a[token, :]
|
|
token_counts[expert_id] = token_counts[expert_id] + 1
|
|
|
|
return b_a, tokens_per_expert
|
|
|
|
|
|
def torch_finalize(b_out: torch.Tensor, topk_weight: torch.Tensor,
|
|
topk_ids: torch.Tensor) -> torch.Tensor:
|
|
num_tokens = topk_ids.shape[0]
|
|
num_experts = b_out.shape[0]
|
|
K = b_out.shape[-1]
|
|
out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device)
|
|
expert_counts = torch.zeros(num_experts,
|
|
dtype=torch.int,
|
|
device=b_out.device)
|
|
for token in range(num_tokens):
|
|
expert_ids = topk_ids[token]
|
|
for i in range(expert_ids.numel()):
|
|
expert_id = expert_ids[i]
|
|
idx = expert_counts[expert_id]
|
|
out[token, :] = out[token, :] + b_out[expert_id, idx:idx +
|
|
1, :] * topk_weight[token, i]
|
|
expert_counts[expert_id] = expert_counts[expert_id] + 1
|
|
|
|
return out
|
|
|
|
|
|
def torch_batched_moe(
|
|
a: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weight: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
num_experts = w1.shape[0]
|
|
b_a, tokens_per_expert = torch_prepare(a, topk_ids, num_experts)
|
|
assert b_a.dim() == 3
|
|
num_tokens, topk = topk_ids.shape
|
|
_, max_num_tokens, K = b_a.shape
|
|
assert num_experts == b_a.shape[0] and w2.shape[1] == K
|
|
out = torch.zeros((num_experts, max_num_tokens, K),
|
|
dtype=b_a.dtype,
|
|
device=b_a.device)
|
|
tmp = torch.empty((max_num_tokens, w1.shape[1] // 2),
|
|
dtype=b_a.dtype,
|
|
device=b_a.device)
|
|
for expert in range(num_experts):
|
|
num = tokens_per_expert[expert]
|
|
if num > 0:
|
|
torch.ops._C.silu_and_mul(
|
|
tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1))
|
|
out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1)
|
|
|
|
return torch_finalize(out, topk_weight, topk_ids)
|
|
|
|
|
|
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
|
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
|
@pytest.mark.parametrize("k", [128, 512, 1024])
|
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
|
@pytest.mark.parametrize("topk", TOP_KS)
|
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
|
def test_fused_moe_batched_experts(
|
|
m: int,
|
|
n: int,
|
|
k: int,
|
|
e: int,
|
|
topk: int,
|
|
dtype: torch.dtype,
|
|
):
|
|
current_platform.seed_everything(7)
|
|
|
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
|
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
|
|
|
with set_current_vllm_config(vllm_config):
|
|
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
|
baseline_output = torch_experts(a, w1, w2, topk_weight, topk_ids)
|
|
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
|
|
batched_output = naive_batched_moe(a, w1, w2, topk_weight, topk_ids)
|
|
|
|
torch.testing.assert_close(baseline_output,
|
|
torch_output,
|
|
atol=2e-2,
|
|
rtol=0)
|
|
torch.testing.assert_close(baseline_output,
|
|
batched_output,
|
|
atol=2e-2,
|
|
rtol=0)
|
|
|
|
|
|
def rank_chunk(num: int, r: int, w: int) -> int:
|
|
rem = num % w
|
|
return (num // w) + (1 if r < rem else 0)
|
|
|
|
|
|
def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
|
|
chunk = rank_chunk(t.shape[0], r, w)
|
|
return t[(r * chunk):(r + 1) * chunk]
|
|
|
|
|
|
def pplx_prepare_finalize(
|
|
pgi: ProcessGroupInfo,
|
|
dp_size: int,
|
|
a: torch.Tensor,
|
|
topk_weight: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
num_experts: int,
|
|
group_name: Optional[str],
|
|
) -> torch.Tensor:
|
|
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
|
PplxPrepareAndFinalize)
|
|
|
|
assert torch.cuda.current_device() == pgi.local_rank
|
|
|
|
topk = topk_ids.shape[1]
|
|
num_tokens, hidden_dim = a.shape
|
|
device = pgi.device
|
|
rank = pgi.rank
|
|
world_size = pgi.world_size
|
|
max_num_tokens = rank_chunk(num_tokens, 0, world_size)
|
|
|
|
args = dict(
|
|
max_num_tokens=max_num_tokens,
|
|
num_experts=num_experts,
|
|
experts_per_token=topk,
|
|
rank=rank,
|
|
world_size=world_size,
|
|
dp_size=dp_size,
|
|
hidden_dim=hidden_dim,
|
|
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
|
|
hidden_dim_scale_bytes=0,
|
|
)
|
|
|
|
if group_name is None:
|
|
ata = AllToAll.internode(**args)
|
|
else:
|
|
args["group_name"] = group_name
|
|
ata = AllToAll.intranode(**args)
|
|
|
|
topk_ids = topk_ids.to(dtype=torch.uint32)
|
|
|
|
prepare_finalize = PplxPrepareAndFinalize(
|
|
ata,
|
|
max_num_tokens,
|
|
world_size,
|
|
rank,
|
|
dp_size,
|
|
)
|
|
|
|
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
|
|
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
|
|
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
|
|
|
|
b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare(
|
|
a_chunk,
|
|
None,
|
|
None,
|
|
chunk_topk_weight,
|
|
chunk_topk_ids,
|
|
num_experts,
|
|
None,
|
|
False,
|
|
FusedMoEQuantConfig(),
|
|
)
|
|
|
|
b_a = b_a * 1.5
|
|
|
|
out = torch.full(
|
|
(max_num_tokens, hidden_dim),
|
|
torch.nan,
|
|
dtype=a.dtype,
|
|
device=device,
|
|
)
|
|
|
|
prepare_finalize.finalize(
|
|
out,
|
|
b_a,
|
|
chunk_topk_weight,
|
|
chunk_topk_ids,
|
|
False,
|
|
)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
ata.destroy()
|
|
|
|
num_tokens = a_chunk.shape[0]
|
|
|
|
return out[:num_tokens]
|
|
|
|
|
|
def _pplx_prepare_finalize(
|
|
pgi: ProcessGroupInfo,
|
|
dp_size: int,
|
|
a: torch.Tensor,
|
|
score: torch.Tensor,
|
|
topk: torch.Tensor,
|
|
num_experts: int,
|
|
use_internode: bool,
|
|
):
|
|
if use_internode:
|
|
uid = nvshmem_get_unique_id(
|
|
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
|
|
torch.distributed.broadcast(uid, src=0)
|
|
nvshmem_init(uid, pgi.rank, pgi.world_size)
|
|
group_name = None
|
|
else:
|
|
group_ranks = list(range(pgi.world_size))
|
|
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
|
|
group_name = cpu_group.group_name
|
|
|
|
device = pgi.device
|
|
|
|
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
|
k = a.shape[1]
|
|
|
|
a_rep = torch.repeat_interleave(a, topk, dim=0).to(device)
|
|
|
|
torch_output = (a_rep.view(-1, topk, k) * 1.5 *
|
|
topk_weight.view(-1, topk, 1).to(device)).sum(dim=1).to(
|
|
a.dtype)
|
|
|
|
pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight, topk_ids,
|
|
num_experts, group_name)
|
|
|
|
torch_output = chunk_by_rank(torch_output, pgi.rank,
|
|
pgi.world_size).to(pplx_output.device)
|
|
|
|
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
|
|
|
|
if use_internode:
|
|
nvshmem_finalize()
|
|
|
|
|
|
# TODO (bnell): this test point does not work for odd M due to how the test is
|
|
# written, not due to limitations of the pplx kernels. The pplx_moe
|
|
# test below is able to deal with odd M.
|
|
# TODO (bnell) add fp8 tests
|
|
@pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS)
|
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
|
@pytest.mark.parametrize("topk", TOP_KS)
|
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
|
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
|
|
@pytest.mark.parametrize("use_internode", [False])
|
|
@requires_pplx
|
|
def test_pplx_prepare_finalize(
|
|
mnk: tuple[int, int, int],
|
|
e: int,
|
|
topk: int,
|
|
dtype: torch.dtype,
|
|
world_dp_size: tuple[int, int],
|
|
use_internode: bool,
|
|
):
|
|
current_platform.seed_everything(7)
|
|
m, n, k = mnk
|
|
world_size, dp_size = world_dp_size
|
|
device = "cuda"
|
|
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
|
score = torch.randn((m, e), device=device, dtype=dtype)
|
|
|
|
parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score,
|
|
topk, e, use_internode)
|
|
|
|
|
|
def pplx_moe(
|
|
group_name: Optional[str],
|
|
rank: int,
|
|
world_size: int,
|
|
dp_size: int,
|
|
a: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weight: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
w1_scale: Optional[torch.Tensor] = None,
|
|
w2_scale: Optional[torch.Tensor] = None,
|
|
qtype: Optional[torch.dtype] = None,
|
|
per_act_token_quant=False,
|
|
block_shape: Optional[list[int]] = None,
|
|
use_compile: bool = False,
|
|
use_cudagraphs: bool = True,
|
|
) -> torch.Tensor:
|
|
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
|
PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes)
|
|
|
|
device = torch.device("cuda", rank)
|
|
hidden_dim = a.shape[1]
|
|
num_experts = w1.shape[0]
|
|
topk = topk_ids.shape[1]
|
|
max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), 64)
|
|
|
|
hidden_dim_bytes, scale_bytes = pplx_hidden_dim_scale_bytes(
|
|
max_num_tokens,
|
|
hidden_dim,
|
|
a.dtype,
|
|
qtype,
|
|
per_act_token_quant=per_act_token_quant,
|
|
block_shape=block_shape,
|
|
)
|
|
|
|
args = dict(
|
|
max_num_tokens=max_num_tokens,
|
|
num_experts=num_experts,
|
|
experts_per_token=topk,
|
|
rank=rank,
|
|
world_size=world_size,
|
|
dp_size=dp_size,
|
|
hidden_dim=hidden_dim,
|
|
hidden_dim_bytes=hidden_dim_bytes,
|
|
hidden_dim_scale_bytes=scale_bytes,
|
|
)
|
|
|
|
if group_name is None:
|
|
ata = AllToAll.internode(**args)
|
|
else:
|
|
args["group_name"] = group_name
|
|
ata = AllToAll.intranode(**args)
|
|
|
|
topk_ids = topk_ids.to(dtype=torch.uint32)
|
|
|
|
prepare_finalize = PplxPrepareAndFinalize(
|
|
ata,
|
|
max_num_tokens,
|
|
world_size,
|
|
rank,
|
|
dp_size,
|
|
)
|
|
|
|
experts = BatchedTritonExperts(max_num_tokens=max_num_tokens,
|
|
world_size=world_size,
|
|
dp_size=dp_size,
|
|
use_fp8_w8a8=qtype == torch.float8_e4m3fn,
|
|
block_shape=block_shape)
|
|
|
|
fused_experts = FusedMoEModularKernel(
|
|
prepare_finalize,
|
|
experts,
|
|
)
|
|
|
|
# Note: workers with the same dp_rank must use the exact same inputs.
|
|
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
|
|
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
|
|
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
|
|
|
|
# Chunking weights like this only works for batched format
|
|
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
|
|
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
|
|
|
|
if w1_scale is not None:
|
|
w1_scale_chunk = chunk_by_rank(w1_scale, rank, world_size).to(device)
|
|
w2_scale_chunk = chunk_by_rank(w2_scale, rank, world_size).to(device)
|
|
else:
|
|
w1_scale_chunk = None
|
|
w2_scale_chunk = None
|
|
|
|
# Note: for now use_compile will error out if the problem size is
|
|
# large enough to trigger chunking. I'm leaving the flag and
|
|
# setup code in case we are able to revisit this later.
|
|
if use_compile:
|
|
_fused_experts = torch.compile(fused_experts,
|
|
backend='inductor',
|
|
fullgraph=True)
|
|
torch._dynamo.mark_dynamic(a_chunk, 0)
|
|
torch._dynamo.mark_dynamic(chunk_topk_weight, 0)
|
|
torch._dynamo.mark_dynamic(chunk_topk_ids, 0)
|
|
else:
|
|
_fused_experts = fused_experts
|
|
|
|
out = _fused_experts(a_chunk,
|
|
w1_chunk,
|
|
w2_chunk,
|
|
chunk_topk_weight,
|
|
chunk_topk_ids,
|
|
w1_scale=w1_scale_chunk,
|
|
w2_scale=w2_scale_chunk,
|
|
global_num_experts=num_experts)
|
|
|
|
if use_cudagraphs:
|
|
out.fill_(0)
|
|
stream = torch.cuda.Stream()
|
|
graph = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(graph, stream=stream):
|
|
out = _fused_experts(a_chunk,
|
|
w1_chunk,
|
|
w2_chunk,
|
|
chunk_topk_weight,
|
|
chunk_topk_ids,
|
|
w1_scale=w1_scale_chunk,
|
|
w2_scale=w2_scale_chunk,
|
|
global_num_experts=num_experts)
|
|
|
|
torch.cuda.synchronize()
|
|
graph.replay()
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
ata.destroy()
|
|
|
|
return out
|
|
|
|
|
|
def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
|
|
assert torch.cuda.current_device() == pgi.local_rank
|
|
|
|
num_experts = w1.shape[0]
|
|
device = pgi.device
|
|
rank = pgi.rank
|
|
world_size = pgi.world_size
|
|
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
|
|
|
|
prepare_finalize = BatchedPrepareAndFinalize(
|
|
max_num_tokens=max_num_tokens,
|
|
world_size=world_size,
|
|
dp_size=dp_size,
|
|
rank=rank,
|
|
)
|
|
|
|
experts = NaiveBatchedExperts(max_num_tokens=a.shape[0],
|
|
world_size=1,
|
|
dp_size=1)
|
|
|
|
fused_experts = FusedMoEModularKernel(
|
|
prepare_finalize,
|
|
experts,
|
|
)
|
|
|
|
# Note: workers with the same dp_rank must use the exact same inputs.
|
|
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
|
|
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
|
|
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
|
|
|
|
out = fused_experts(
|
|
a_chunk,
|
|
# Chunking weights like this only works for batched format
|
|
chunk_by_rank(w1, rank, world_size).to(device),
|
|
chunk_by_rank(w2, rank, world_size).to(device),
|
|
chunk_topk_weight,
|
|
chunk_topk_ids,
|
|
global_num_experts=num_experts)
|
|
|
|
return out
|
|
|
|
|
|
def _pplx_moe(
|
|
pgi: ProcessGroupInfo,
|
|
dp_size: int,
|
|
a: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
score: torch.Tensor,
|
|
topk: int,
|
|
w1_s: Optional[torch.Tensor] = None,
|
|
w2_s: Optional[torch.Tensor] = None,
|
|
qtype: Optional[torch.dtype] = None,
|
|
per_act_token_quant: bool = False,
|
|
block_shape: Optional[list[int]] = None,
|
|
use_internode: bool = False,
|
|
):
|
|
if use_internode:
|
|
uid = nvshmem_get_unique_id(
|
|
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
|
|
torch.distributed.broadcast(uid, src=0)
|
|
nvshmem_init(uid, pgi.rank, pgi.world_size)
|
|
group_name = None
|
|
else:
|
|
group_ranks = list(range(pgi.world_size))
|
|
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
|
|
group_name = cpu_group.group_name
|
|
|
|
m, k = a.shape
|
|
e, _, n = w2.shape
|
|
|
|
moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)
|
|
|
|
device = torch.device("cuda", pgi.rank)
|
|
a = a.to(device)
|
|
w1 = w1.to(device)
|
|
w2 = w2.to(device)
|
|
w1_s = w1_s.to(device) if w1_s is not None else None
|
|
w2_s = w2_s.to(device) if w2_s is not None else None
|
|
|
|
with set_current_vllm_config(vllm_config), override_config(moe_config):
|
|
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
|
torch_output = torch_experts(a,
|
|
w1,
|
|
w2,
|
|
topk_weight,
|
|
topk_ids,
|
|
w1_scale=w1_s,
|
|
w2_scale=w2_s,
|
|
quant_dtype=qtype,
|
|
per_act_token_quant=per_act_token_quant,
|
|
block_shape=block_shape)
|
|
pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size,
|
|
a, w1, w2, topk_weight, topk_ids, w1_s, w2_s,
|
|
qtype, per_act_token_quant, block_shape)
|
|
# TODO (bnell): fix + re-enable
|
|
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
|
|
# topk_ids)
|
|
|
|
torch_output = chunk_by_rank(torch_output, pgi.rank,
|
|
pgi.world_size).to(pplx_output.device)
|
|
|
|
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
|
|
#torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0)
|
|
|
|
if use_internode:
|
|
nvshmem_finalize()
|
|
|
|
|
|
@pytest.mark.parametrize("mnk", PPLX_MOE_COMBOS)
|
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
|
@pytest.mark.parametrize("topk", TOP_KS)
|
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
|
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
|
|
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
|
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
|
|
@pytest.mark.parametrize("use_internode", [False])
|
|
@requires_pplx
|
|
def test_pplx_moe(
|
|
mnk: tuple[int, int, int],
|
|
e: int,
|
|
topk: int,
|
|
dtype: torch.dtype,
|
|
world_dp_size: tuple[int, int],
|
|
per_act_token_quant: bool,
|
|
block_shape: Optional[list[int]],
|
|
use_internode: bool,
|
|
):
|
|
current_platform.seed_everything(7)
|
|
m, n, k = mnk
|
|
world_size, dp_size = world_dp_size
|
|
|
|
if dtype == torch.float8_e4m3fn:
|
|
use_fp8_w8a8 = True
|
|
quant_dtype = dtype
|
|
else:
|
|
use_fp8_w8a8 = False
|
|
quant_dtype = None
|
|
|
|
if not use_fp8_w8a8 and per_act_token_quant and block_shape is not None:
|
|
pytest.skip("Skip quantization test for non-quantized type")
|
|
|
|
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
|
|
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
|
|
|
_, w1, w1_s, _, w2, w2_s = make_test_weights(e,
|
|
n,
|
|
k,
|
|
quant_dtype=quant_dtype,
|
|
block_shape=block_shape)
|
|
|
|
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk,
|
|
w1_s, w2_s, quant_dtype, per_act_token_quant, block_shape,
|
|
use_internode)
|