mirror of https://github.com/vllm-project/vllm.git
97 lines
3.4 KiB
Python
97 lines
3.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import pytest
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
|
|
import vllm._custom_ops as ops
|
|
from vllm.platforms import current_platform
|
|
|
|
if not current_platform.has_device_capability(100):
|
|
pytest.skip(
|
|
reason="Cutlass MLA Requires compute capability of 10 or above.",
|
|
allow_module_level=True)
|
|
|
|
|
|
def ref_mla(
|
|
out: Tensor, # (bs, num_heads, v_head_dim)
|
|
query: Tensor, # (bs, num_heads, head_dim)
|
|
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
|
|
scale: float,
|
|
block_tables: Tensor, # (bs, max_num_blocks)
|
|
seq_lens: Tensor, # (bs,)
|
|
):
|
|
bs, num_heads, v_head_dim = out.shape
|
|
head_dim = query.shape[2]
|
|
|
|
for i in range(bs):
|
|
# gather and flatten KV-cache
|
|
kv = kv_cache[
|
|
block_tables[i]] # (max_num_blocks, block_size, head_dim)
|
|
kv = kv.view(1, -1,
|
|
head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim)
|
|
v = kv[:, :, :v_head_dim]
|
|
|
|
q = query[i].view(num_heads, 1, head_dim)
|
|
o = F.scaled_dot_product_attention(q,
|
|
kv,
|
|
v,
|
|
scale=scale,
|
|
enable_gqa=True)
|
|
out[i] = o.view(num_heads, v_head_dim)
|
|
|
|
return out
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
|
@pytest.mark.parametrize("mean_seq_len", [128, 1024, 4096])
|
|
@pytest.mark.parametrize("bs", [1, 2, 4])
|
|
@pytest.mark.parametrize("varlen", [False, True])
|
|
@pytest.mark.parametrize("block_size", [16, 64, 128])
|
|
def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int,
|
|
varlen: bool, block_size: int):
|
|
torch.set_default_dtype(dtype)
|
|
torch.set_default_device('cuda')
|
|
torch.manual_seed(42)
|
|
|
|
d = 576
|
|
h_q = 128
|
|
dv = 512
|
|
|
|
q_nope_dim = 128
|
|
q_pe_dim = 64
|
|
scale = (q_nope_dim + q_pe_dim)**(-0.5)
|
|
if varlen:
|
|
seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2)
|
|
seq_lens = seq_lens.clip(2).to(torch.int32)
|
|
else:
|
|
seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32)
|
|
max_seq_len = seq_lens.max().item()
|
|
block_num = (max_seq_len + block_size - 1) // block_size
|
|
|
|
# Pad block_num so that small blocks can be packed into full 128-sized
|
|
# CUTLASS tiles. One 128-wide tile can hold (128 // block_size) small
|
|
# blocks.
|
|
pack_factor = 128 // block_size
|
|
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor
|
|
|
|
# Amplify input values to ensure test coverage of edge cases where CUTLASS
|
|
# kernel errors occur with split_k settings.
|
|
q = torch.randn(bs, h_q, d) * 100
|
|
block_table = torch.randint(0,
|
|
bs * block_num, (bs, block_num),
|
|
dtype=torch.int32)
|
|
|
|
kv_cache = torch.randn(block_table.numel(), block_size, d)
|
|
|
|
out_ref = q.new_zeros(bs, h_q, dv)
|
|
ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)
|
|
out_ans = torch.zeros_like(out_ref)
|
|
q_nope = q[:, :, :dv].clone()
|
|
q_pe = q[:, :, dv:].clone()
|
|
ops.cutlass_mla_decode(out_ans, q_nope, q_pe, kv_cache, seq_lens,
|
|
block_table, scale)
|
|
|
|
torch.testing.assert_close(out_ans, out_ref, atol=1e-2, rtol=1e-2)
|