[BugFix] llama4 fa3 fix - RuntimeError: scheduler_metadata must have shape (metadata_size) (#16998)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson 2025-04-23 00:49:24 -04:00 committed by GitHub
parent b2f195c429
commit d0da99fb70
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 48 additions and 28 deletions

View File

@ -105,6 +105,7 @@ class FlashAttentionMetadata:
local_block_table: torch.Tensor
local_max_query_len: int
local_max_seq_len: int
local_scheduler_metadata: Optional[torch.Tensor]
local_attn_metadata: Optional[LocalAttentionMetadata] = None
@ -282,7 +283,9 @@ class FlashAttentionMetadataBuilder:
self.runner = runner
self.aot_schedule = (get_flash_attn_version() == 3)
self.num_heads = model_config.get_num_attention_heads(
self.num_heads_q = model_config.get_num_attention_heads(
runner.parallel_config)
self.num_heads_kv = model_config.get_num_kv_heads(
runner.parallel_config)
self.headdim = model_config.get_head_size()
self.page_size = self.runner.block_size
@ -304,6 +307,23 @@ class FlashAttentionMetadataBuilder:
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
self.runner.device, non_blocking=True).long()
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
if self.aot_schedule:
return get_scheduler_metadata(
batch_size=batch_size,
max_seqlen_q=max_query_len,
max_seqlen_k=max_seq_len,
cache_seqlens=seqlens,
num_heads_q=self.num_heads_q,
num_heads_kv=self.num_heads_kv,
headdim=self.headdim,
page_size=self.page_size,
cu_seqlens_q=cu_query_lens,
causal=causal,
)
return None
# for local attention
local_attn_metadata = None
if self.runner.attention_chunk_size is not None:
@ -315,36 +335,31 @@ class FlashAttentionMetadataBuilder:
block_table,
self.runner.block_size,
)
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
self.runner.device, non_blocking=True)
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
self.runner.device, non_blocking=True)
local_max_query_len = seqlens_q_local_np.max()
local_max_seq_len = virt_k_seqlens_np.max()
local_scheduler_metadata = schedule(
batch_size=local_query_start_loc.shape[0] - 1,
cu_query_lens=local_query_start_loc,
max_query_len=local_max_query_len,
seqlens=local_seqused_k,
max_seq_len=local_max_seq_len,
causal=True)
local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=torch.from_numpy(
virt_q_cu_seqlens_np).to(self.runner.device,
non_blocking=True),
local_seqused_k=torch.from_numpy(virt_k_seqlens_np).to(
self.runner.device, non_blocking=True),
local_query_start_loc=local_query_start_loc,
local_seqused_k=local_seqused_k,
local_block_table=virt_block_table,
local_max_query_len=seqlens_q_local_np.max(),
local_max_seq_len=virt_k_seqlens_np.max(),
local_max_query_len=local_max_query_len,
local_max_seq_len=local_max_seq_len,
local_scheduler_metadata=local_scheduler_metadata,
)
use_cascade = common_prefix_len > 0
def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len,
causal):
if self.aot_schedule:
return get_scheduler_metadata(
batch_size=num_reqs,
max_seqlen_q=max_query_len,
max_seqlen_k=max_seq_len,
cache_seqlens=seqlens,
num_heads_q=self.num_heads,
num_heads_kv=self.num_heads,
headdim=self.headdim,
page_size=self.page_size,
cu_seqlens_q=cu_query_lens,
causal=causal,
)
return None
if use_cascade:
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
dtype=torch.int32,
@ -357,12 +372,14 @@ class FlashAttentionMetadataBuilder:
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
self.runner.device)
prefix_scheduler_metadata = schedule(
batch_size=num_reqs,
cu_query_lens=cu_prefix_query_lens,
max_query_len=num_actual_tokens,
seqlens=prefix_kv_lens,
max_seq_len=common_prefix_len,
causal=False)
scheduler_metadata = schedule(cu_query_lens=query_start_loc,
scheduler_metadata = schedule(batch_size=num_reqs,
cu_query_lens=query_start_loc,
max_query_len=max_query_len,
seqlens=suffix_kv_lens,
max_seq_len=max_seq_len -
@ -373,7 +390,8 @@ class FlashAttentionMetadataBuilder:
prefix_kv_lens = None
suffix_kv_lens = None
prefix_scheduler_metadata = None
scheduler_metadata = schedule(cu_query_lens=query_start_loc,
scheduler_metadata = schedule(batch_size=num_reqs,
cu_query_lens=query_start_loc,
max_query_len=max_query_len,
seqlens=seq_lens,
max_seq_len=max_seq_len,
@ -540,12 +558,14 @@ class FlashAttentionImpl(AttentionImpl):
max_seqlen_q = local_metadata.local_max_query_len
max_seqlen_k = local_metadata.local_max_seq_len
block_table = local_metadata.local_block_table
scheduler_metadata = local_metadata.local_scheduler_metadata
else:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
scheduler_metadata = attn_metadata.scheduler_metadata
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
@ -564,7 +584,7 @@ class FlashAttentionImpl(AttentionImpl):
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=attn_metadata.scheduler_metadata,
scheduler_metadata=scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),