mirror of https://github.com/vllm-project/vllm.git
[BugFix] llama4 fa3 fix - RuntimeError: scheduler_metadata must have shape (metadata_size) (#16998)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
b2f195c429
commit
d0da99fb70
|
@ -105,6 +105,7 @@ class FlashAttentionMetadata:
|
|||
local_block_table: torch.Tensor
|
||||
local_max_query_len: int
|
||||
local_max_seq_len: int
|
||||
local_scheduler_metadata: Optional[torch.Tensor]
|
||||
|
||||
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
||||
|
||||
|
@ -282,7 +283,9 @@ class FlashAttentionMetadataBuilder:
|
|||
|
||||
self.runner = runner
|
||||
self.aot_schedule = (get_flash_attn_version() == 3)
|
||||
self.num_heads = model_config.get_num_attention_heads(
|
||||
self.num_heads_q = model_config.get_num_attention_heads(
|
||||
runner.parallel_config)
|
||||
self.num_heads_kv = model_config.get_num_kv_heads(
|
||||
runner.parallel_config)
|
||||
self.headdim = model_config.get_head_size()
|
||||
self.page_size = self.runner.block_size
|
||||
|
@ -304,6 +307,23 @@ class FlashAttentionMetadataBuilder:
|
|||
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
self.runner.device, non_blocking=True).long()
|
||||
|
||||
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
|
||||
max_seq_len, causal):
|
||||
if self.aot_schedule:
|
||||
return get_scheduler_metadata(
|
||||
batch_size=batch_size,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_seq_len,
|
||||
cache_seqlens=seqlens,
|
||||
num_heads_q=self.num_heads_q,
|
||||
num_heads_kv=self.num_heads_kv,
|
||||
headdim=self.headdim,
|
||||
page_size=self.page_size,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
causal=causal,
|
||||
)
|
||||
return None
|
||||
|
||||
# for local attention
|
||||
local_attn_metadata = None
|
||||
if self.runner.attention_chunk_size is not None:
|
||||
|
@ -315,36 +335,31 @@ class FlashAttentionMetadataBuilder:
|
|||
block_table,
|
||||
self.runner.block_size,
|
||||
)
|
||||
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
|
||||
self.runner.device, non_blocking=True)
|
||||
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
|
||||
self.runner.device, non_blocking=True)
|
||||
local_max_query_len = seqlens_q_local_np.max()
|
||||
local_max_seq_len = virt_k_seqlens_np.max()
|
||||
local_scheduler_metadata = schedule(
|
||||
batch_size=local_query_start_loc.shape[0] - 1,
|
||||
cu_query_lens=local_query_start_loc,
|
||||
max_query_len=local_max_query_len,
|
||||
seqlens=local_seqused_k,
|
||||
max_seq_len=local_max_seq_len,
|
||||
causal=True)
|
||||
|
||||
local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
||||
local_query_start_loc=torch.from_numpy(
|
||||
virt_q_cu_seqlens_np).to(self.runner.device,
|
||||
non_blocking=True),
|
||||
local_seqused_k=torch.from_numpy(virt_k_seqlens_np).to(
|
||||
self.runner.device, non_blocking=True),
|
||||
local_query_start_loc=local_query_start_loc,
|
||||
local_seqused_k=local_seqused_k,
|
||||
local_block_table=virt_block_table,
|
||||
local_max_query_len=seqlens_q_local_np.max(),
|
||||
local_max_seq_len=virt_k_seqlens_np.max(),
|
||||
local_max_query_len=local_max_query_len,
|
||||
local_max_seq_len=local_max_seq_len,
|
||||
local_scheduler_metadata=local_scheduler_metadata,
|
||||
)
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
|
||||
def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len,
|
||||
causal):
|
||||
if self.aot_schedule:
|
||||
return get_scheduler_metadata(
|
||||
batch_size=num_reqs,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_seq_len,
|
||||
cache_seqlens=seqlens,
|
||||
num_heads_q=self.num_heads,
|
||||
num_heads_kv=self.num_heads,
|
||||
headdim=self.headdim,
|
||||
page_size=self.page_size,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
causal=causal,
|
||||
)
|
||||
return None
|
||||
|
||||
if use_cascade:
|
||||
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
|
||||
dtype=torch.int32,
|
||||
|
@ -357,12 +372,14 @@ class FlashAttentionMetadataBuilder:
|
|||
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
|
||||
self.runner.device)
|
||||
prefix_scheduler_metadata = schedule(
|
||||
batch_size=num_reqs,
|
||||
cu_query_lens=cu_prefix_query_lens,
|
||||
max_query_len=num_actual_tokens,
|
||||
seqlens=prefix_kv_lens,
|
||||
max_seq_len=common_prefix_len,
|
||||
causal=False)
|
||||
scheduler_metadata = schedule(cu_query_lens=query_start_loc,
|
||||
scheduler_metadata = schedule(batch_size=num_reqs,
|
||||
cu_query_lens=query_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
seqlens=suffix_kv_lens,
|
||||
max_seq_len=max_seq_len -
|
||||
|
@ -373,7 +390,8 @@ class FlashAttentionMetadataBuilder:
|
|||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
prefix_scheduler_metadata = None
|
||||
scheduler_metadata = schedule(cu_query_lens=query_start_loc,
|
||||
scheduler_metadata = schedule(batch_size=num_reqs,
|
||||
cu_query_lens=query_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
seqlens=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
|
@ -540,12 +558,14 @@ class FlashAttentionImpl(AttentionImpl):
|
|||
max_seqlen_q = local_metadata.local_max_query_len
|
||||
max_seqlen_k = local_metadata.local_max_seq_len
|
||||
block_table = local_metadata.local_block_table
|
||||
scheduler_metadata = local_metadata.local_scheduler_metadata
|
||||
else:
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
scheduler_metadata = attn_metadata.scheduler_metadata
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
||||
|
||||
|
@ -564,7 +584,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
scheduler_metadata=attn_metadata.scheduler_metadata,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
|
|
Loading…
Reference in New Issue