[Kernel] Use flash-attn for decoding (#3648)

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
Stephen Krider 2024-05-13 15:50:33 -07:00 committed by GitHub
parent ce532ff45c
commit 1356df53bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 313 additions and 65 deletions

View File

@ -0,0 +1,209 @@
from typing import List, Optional, Tuple
import pytest
import torch
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
NUM_HEADS = [(16, 16), (32, 8), (64, 8)]
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16]
def ref_paged_attn(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
query_lens: List[int],
kv_lens: List[int],
block_tables: torch.Tensor,
scale: float,
sliding_window: Optional[int] = None,
) -> torch.Tensor:
num_seqs = len(query_lens)
block_tables = block_tables.cpu().numpy()
_, block_size, num_kv_heads, head_size = key_cache.shape
outputs = []
start_idx = 0
for i in range(num_seqs):
query_len = query_lens[i]
kv_len = kv_lens[i]
q = query[start_idx:start_idx + query_len]
q *= scale
num_kv_blocks = (kv_len + block_size - 1) // block_size
block_indices = block_tables[i, :num_kv_blocks]
k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
k = k[:kv_len]
v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
v = v[:kv_len]
if q.shape[1] != k.shape[1]:
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
attn = torch.einsum("qhd,khd->hqk", q, k).float()
empty_mask = torch.ones(query_len, kv_len)
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
if sliding_window is not None:
sliding_window_mask = torch.triu(empty_mask,
diagonal=kv_len -
(query_len + sliding_window) +
1).bool().logical_not()
mask |= sliding_window_mask
attn.masked_fill_(mask, float("-inf"))
attn = torch.softmax(attn, dim=-1).to(v.dtype)
out = torch.einsum("hqk,khd->qhd", attn, v)
outputs.append(out)
start_idx += query_len
return torch.cat(outputs, dim=0)
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode
def test_flash_attn_with_paged_kv(
kv_lens: List[Tuple[int, int]],
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,
block_size: int,
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_blocks = 128
num_seqs = len(kv_lens)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
key_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
value_cache = torch.randn_like(key_cache)
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
output = flash_attn_with_kvcache(
q=query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
softmax_scale=scale,
causal=True,
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
).squeeze(1)
ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=[1] * num_seqs,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("sliding_window", [None])
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode
def test_varlen_with_paged_kv(
seq_lens: List[Tuple[int, int]],
num_heads: Tuple[int, int],
head_size: int,
sliding_window: Optional[int],
dtype: torch.dtype,
block_size: int,
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_blocks = 128
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)
window_size = ((sliding_window,
sliding_window) if sliding_window is not None else
(-1, -1))
scale = head_size**-0.5
query = torch.randn(sum(query_lens),
num_query_heads,
head_size,
dtype=dtype)
key_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
value_cache = torch.randn_like(key_cache)
# Normalize the scale of the key and value caches to mitigate
# numerical instability.
key_cache /= head_size**0.5
value_cache /= head_size**0.5
cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
cu_kv_lens = torch.tensor([0] + kv_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
output = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=window_size,
block_table=block_tables,
)
ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=query_lens,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
sliding_window=sliding_window,
)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"

View File

@ -12,7 +12,7 @@ MODELS = [
# "Deci/DeciLM-7b", # Broken # "Deci/DeciLM-7b", # Broken
# "tiiuae/falcon-7b", # Broken # "tiiuae/falcon-7b", # Broken
"EleutherAI/gpt-j-6b", "EleutherAI/gpt-j-6b",
"mosaicml/mpt-7b", # "mosaicml/mpt-7b", # Broken
# "Qwen/Qwen1.5-0.5B" # Broken, # "Qwen/Qwen1.5-0.5B" # Broken,
] ]

View File

@ -25,18 +25,18 @@ EXPECTED_STRS_MAP = {
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep', 'Zeta-5, a highly advanced robot designed for menial labor, whirred to a',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here', 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya tori, nemuri nemuri)\n\n**' 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o',
], ],
"meta-llama/Meta-Llama-3-8B-Instruct": [ "meta-llama/Meta-Llama-3-8B-Instruct": [
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
'In the year 2154, the robotics lab at NeuroSpark Industries was on the cusp of', 'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu' 'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu'

View File

@ -1,20 +1,16 @@
"""Attention layer with Flash and PagedAttention. """Attention layer with FlashAttention."""
NOTE(woosuk): At the moment, this file includes a lot of duplicated code from
XFormers backend. The duplicated code will be removed once we use flash-attn or
flashinfer for all the attention operations.
"""
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Type from typing import List, Optional, Tuple, Type
import torch import torch
from vllm_flash_attn import flash_attn_varlen_func from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from vllm._C import cache_ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionMetadata,
AttentionMetadataPerStage) AttentionMetadataPerStage)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) _SUPPORTED_HEAD_SIZES = [32, 64, 96, 128, 160, 192, 224, 256]
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
@ -38,8 +34,9 @@ class FlashAttentionBackend(AttentionBackend):
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
) -> Tuple[int, ...]: ) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size, if block_size % 16 != 0:
num_kv_heads, head_size) raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod @staticmethod
def swap_blocks( def swap_blocks(
@ -47,19 +44,26 @@ class FlashAttentionBackend(AttentionBackend):
dst_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor, src_to_dst: torch.Tensor,
) -> None: ) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1]
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
@staticmethod @staticmethod
def copy_blocks( def copy_blocks(
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor, src_to_dists: torch.Tensor,
) -> None: ) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists) key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)
@dataclass @dataclass
class FlashAttentionMetadata(AttentionMetadataPerStage, class FlashAttentionMetadata(AttentionMetadataPerStage):
PagedAttentionMetadata):
"""Metadata for FlashAttentionBackend. """Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is NOTE: Any python object stored here is not updated when it is
@ -105,6 +109,14 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool use_cuda_graph: bool
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
class FlashAttentionImpl(AttentionImpl): class FlashAttentionImpl(AttentionImpl):
""" """
@ -156,11 +168,15 @@ class FlashAttentionImpl(AttentionImpl):
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
suppored_head_sizes = PagedAttention.get_supported_head_sizes() if sliding_window is not None:
if head_size not in suppored_head_sizes: # NOTE(woosuk): flash-attn's sliding window does not work with
# paged KV cache.
raise ValueError( raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. " "Sliding window is not supported in FlashAttention.")
f"Supported head sizes are: {suppored_head_sizes}.") if head_size not in _SUPPORTED_HEAD_SIZES:
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {_SUPPORTED_HEAD_SIZES}.")
def forward( def forward(
self, self,
@ -171,17 +187,20 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata: AttentionMetadata[FlashAttentionMetadata], attn_metadata: AttentionMetadata[FlashAttentionMetadata],
kv_scale: float = 1.0, kv_scale: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention.
Args: Args:
query: shape = [num_tokens, num_heads * head_size] query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention."
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
@ -189,16 +208,20 @@ class FlashAttentionImpl(AttentionImpl):
value = value.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size)
if kv_cache is not None: if kv_cache is not None:
key_cache, value_cache = PagedAttention.split_kv_cache( key_cache = kv_cache[0]
kv_cache, self.num_kv_heads, self.head_size) value_cache = kv_cache[1]
# Reshape the input keys and values and store them in the cache. # Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are # If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run. # not cached. This happens during the initial memory profiling run.
PagedAttention.write_to_paged_cache(key, value, key_cache, cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache, value_cache,
attn_metadata.slot_mapping, attn_metadata.slot_mapping.flatten(),
self.kv_cache_dtype, kv_scale) self.kv_cache_dtype,
)
num_prefill_tokens = attn_metadata.num_prefill_tokens num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
@ -218,7 +241,8 @@ class FlashAttentionImpl(AttentionImpl):
if prefill_meta := attn_metadata.prefill_metadata: if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run. # Prompt run.
if kv_cache is None or prefill_meta.block_tables.numel() == 0: if (kv_cache is None or prefill_meta.block_tables is None
or prefill_meta.block_tables.numel() == 0):
# normal attention # normal attention
# When block_tables are not filled, it means q and k are the # When block_tables are not filled, it means q and k are the
# prompt, and they have the same length. # prompt, and they have the same length.
@ -239,38 +263,32 @@ class FlashAttentionImpl(AttentionImpl):
output[:num_prefill_tokens] = out output[:num_prefill_tokens] = out
else: else:
# prefix-enabled attention # prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to output[:num_prefill_tokens] = flash_attn_varlen_func(
# deal with different data types between KV and FP8 KV cache, q=query,
# to be addressed separately. k=key_cache,
output[:num_prefill_tokens] = PagedAttention.forward_prefix( v=value_cache,
query, cu_seqlens_q=prefill_meta.subquery_start_loc,
key, max_seqlen_q=prefill_meta.max_query_len,
value, cu_seqlens_k=prefill_meta.seq_start_loc,
key_cache, max_seqlen_k=prefill_meta.max_seq_len,
value_cache, softmax_scale=self.scale,
prefill_meta.block_tables, causal=True,
prefill_meta.subquery_start_loc, alibi_slopes=self.alibi_slopes,
prefill_meta.seq_lens_tensor, block_table=prefill_meta.block_tables,
prefill_meta.context_lens_tensor,
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
) )
if decode_meta := attn_metadata.decode_metadata: if decode_meta := attn_metadata.decode_metadata:
# Decoding run. # Decoding run.
output[num_prefill_tokens:] = PagedAttention.forward_decode( output[num_prefill_tokens:] = flash_attn_with_kvcache(
decode_query, decode_query.unsqueeze(1),
key_cache, key_cache,
value_cache, value_cache,
decode_meta.block_tables, block_table=decode_meta.block_tables,
decode_meta.seq_lens_tensor, cache_seqlens=decode_meta.seq_lens_tensor,
decode_meta.max_seq_len, softmax_scale=self.scale,
self.kv_cache_dtype, causal=True,
self.num_kv_heads, alibi_slopes=self.alibi_slopes,
self.scale, ).squeeze(1)
self.alibi_slopes,
kv_scale,
)
# Reshape the output tensor. # Reshape the output tensor.
return output.view(num_tokens, hidden_size) return output.view(num_tokens, hidden_size)

View File

@ -93,6 +93,20 @@ def _which_attn_to_use(
"torch.float16 or torch.bfloat16.") "torch.float16 or torch.bfloat16.")
return _Backend.XFORMERS return _Backend.XFORMERS
if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
logger.info("Cannot use FlashAttention-2 backend for FP8 KV cache.")
return _Backend.XFORMERS
if block_size % 16 != 0:
logger.info("Cannot use FlashAttention-2 backend for block size not "
"divisible by 16.")
return _Backend.XFORMERS
if sliding_window is not None:
logger.info(
"Cannot use FlashAttention-2 backend due to sliding window.")
return _Backend.XFORMERS
try: try:
import vllm_flash_attn # noqa: F401 import vllm_flash_attn # noqa: F401
except ImportError: except ImportError:

View File

@ -266,20 +266,27 @@ class ModelRunner:
# Prefix is not supported with sliding_window # Prefix is not supported with sliding_window
context_len = len(computed_block_nums) * self.block_size context_len = len(computed_block_nums) * self.block_size
prompt_tokens = prompt_tokens[context_len:] prompt_tokens = prompt_tokens[context_len:]
prefix_block_tables.append(computed_block_nums) if self.attn_backend.get_name() == "flash-attn":
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
# TODO(woosuk): This is a temporary fix. We should
# provide a unified interface for different backends.
block_table = seq_group_metadata.block_tables[seq_id]
else:
block_table = computed_block_nums
elif self.scheduler_config.chunked_prefill_enabled: elif self.scheduler_config.chunked_prefill_enabled:
if seq_group_metadata.block_tables is not None: if seq_group_metadata.block_tables is not None:
# Prefill has chunked before. # Prefill has chunked before.
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
prefix_block_tables.append(block_table)
else: else:
# The first prefill. # The first prefill.
prefix_block_tables.append([]) block_table = []
else: else:
prefix_block_tables.append([]) block_table = []
# Right now, prefill start is always 0. However, this # Right now, prefill start is always 0. However, this
# assumption can be changed once chunked prefill is introduced. # assumption can be changed once chunked prefill is introduced.
assert context_len == 0 assert context_len == 0
prefix_block_tables.append(block_table)
# actual prompt lens # actual prompt lens
context_lens.append(context_len) context_lens.append(context_len)