[Core] Support full cuda graph in v1 (#16072)

Signed-off-by: Chanh Nguyen <cnguyen@linkedin.com>
Co-authored-by: Chanh Nguyen <cnguyen@linkedin.com>
This commit is contained in:
Chanh Nguyen 2025-05-07 22:30:15 -07:00 committed by GitHub
parent 3d13ca0e24
commit 7ea2adb802
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 190 additions and 13 deletions

View File

@ -137,3 +137,9 @@ By default, vLLM will try to determine a set of sizes to capture cudagraph. You
`vllm serve meta-llama/Llama-3.2-1B --compilation-config "{'cudagraph_capture_sizes': [1, 2, 4, 8]}"`
Then it will only capture cudagraph for the specified sizes. It can be useful to have fine-grained control over the cudagraph capture.
### Full Cudagraph capture
It is possible to include attention as part of the cudagraph if using an attention backend that is cudagraph compatible. This can improve performance in some cases such as decode speed for smaller models. Enable this using `--compilation-config "{'full_cuda_graph': True}"`
Currently only FlashAttention 3 is compatible, and only when cascade attention is disabled.

View File

@ -0,0 +1,97 @@
# SPDX-License-Identifier: Apache-2.0
import contextlib
import os
import pytest
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig
MODEL = "Qwen/Qwen2-1.5B-Instruct"
@contextlib.contextmanager
def temporary_environ(env_vars):
"""
Temporarily set environment variables and restore them afterward.
We have to do this vs monkeypatch because monkeypatch doesn't work
with "module" scoped fixtures.
"""
original_env = {k: os.environ.get(k) for k in env_vars}
try:
os.environ.update(env_vars)
yield
finally:
for k, v in original_env.items():
if v is None:
os.environ.pop(k, None)
else:
os.environ[k] = v
@pytest.fixture(scope="module")
def full_cudagraph_llm():
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION": "3"
}):
return LLM(model=MODEL,
gpu_memory_utilization=0.2,
compilation_config=CompilationConfig(full_cuda_graph=True))
@pytest.fixture(scope="module")
def piecewise_llm():
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION": "3"
}):
return LLM(model=MODEL,
gpu_memory_utilization=0.5,
compilation_config=CompilationConfig())
def generate_text(llm: LLM, batch_size: int, max_tokens: int):
prompts = ["Hi my name is"] * batch_size
sampling_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
top_p=0.95)
return llm.generate(prompts, sampling_params)
@pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10),
(16, 10), (25, 10),
(32, 10), (45, 10),
(64, 10), (8, 5),
(8, 20), (8, 200)])
def test_full_cudagraph(batch_size, max_tokens, full_cudagraph_llm,
piecewise_llm):
"""
Load full cudagraph model and piecewise model once, and at the same time to
reuse them across various test cases.
Test various batch sizes and max_tokens to ensure that the full cudagraph
compilation works for padded cases too.
"""
piecewise_responses = generate_text(piecewise_llm,
batch_size=batch_size,
max_tokens=max_tokens)
full_cudagraph_responses = generate_text(full_cudagraph_llm,
batch_size=batch_size,
max_tokens=max_tokens)
# Check that all responses are the same
for i in range(len(piecewise_responses)):
assert piecewise_responses[i].outputs[
0].text == full_cudagraph_responses[i].outputs[0].text
def test_full_cudagraph_with_invalid_backend():
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION":
"2" #FA2 not supported with full_cuda_graph
}), pytest.raises(RuntimeError):
LLM(model=MODEL,
compilation_config=CompilationConfig(full_cuda_graph=True))

View File

@ -3605,6 +3605,10 @@ class CompilationConfig(BaseModel):
are always used, it can set this to False. Otherwise, it should
set this to True, and the compiler will copy the input to an
internally managed buffer. Default is False.
- full_cuda_graph: whether to use a full cuda graph for the entire forward
pass rather than splitting certain operations such as attention into subgraphs.
Thus this flag cannot be used together with splitting_ops. This may provide
performance benefits for smaller models.
- Inductor compilation:
- use_inductor: whether to use inductor compilation.
- False: inductor compilation is not used. graph runs in eager.
@ -3649,6 +3653,7 @@ class CompilationConfig(BaseModel):
cudagraph_num_of_warmups: int = 0
cudagraph_capture_sizes: Optional[list[int]] = None
cudagraph_copy_inputs: bool = False
full_cuda_graph: bool = False
class PassConfig(BaseModel):
"""
@ -3871,10 +3876,14 @@ class CompilationConfig(BaseModel):
self.max_capture_size] = self.max_capture_size
def set_splitting_ops_for_v1(self):
# If default, override splitting ops for piecewise cudagraph on V1.
# NOTE: this function needs to be called
if self.splitting_ops and self.full_cuda_graph:
raise ValueError("full_cuda_graph cannot be used together with "
"splitting_ops, as Full CUDA graph will override "
f"the splitting_ops: {self.splitting_ops}")
if not self.splitting_ops:
self.splitting_ops = [
self.splitting_ops = [] if self.full_cuda_graph else [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
]
@ -4151,6 +4160,12 @@ class VllmConfig:
"Disabling `torch.compile`.")
self.compilation_config.level = CompilationLevel.NO_COMPILATION
if self.compilation_config.full_cuda_graph and \
not self.model_config.disable_cascade_attn:
logger.warning_once(
"full_cuda_graph is not supported with "
"cascade attention. Disabling cascade attention.")
self.model_config.disable_cascade_attn = True
if self.model_config and self.model_config.use_mla and \
not (current_platform.is_cuda() or current_platform.is_rocm()):

View File

@ -291,6 +291,7 @@ class FlashAttentionMetadataBuilder:
def __init__(self, runner: "GPUModelRunner"):
model_config = runner.model_config
compilation_config = runner.vllm_config.compilation_config
self.runner = runner
self.num_heads_q = model_config.get_num_attention_heads(
@ -300,7 +301,14 @@ class FlashAttentionMetadataBuilder:
self.headdim = model_config.get_head_size()
self.page_size = self.runner.block_size
self.aot_schedule = (get_flash_attn_version() == 3)
if get_flash_attn_version() == 3:
self.aot_schedule = not compilation_config.full_cuda_graph
if not self.aot_schedule:
logger.warning(
"AOT Schedule is disabled when using full_cuda_graph")
else:
self.aot_schedule = False
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
self.aot_sliding_window: Optional[tuple[int, int]] = None
@ -317,8 +325,7 @@ class FlashAttentionMetadataBuilder:
seq_lens = common_attn_metadata.seq_lens
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
self.runner.device, non_blocking=True).long()
slot_mapping = self.runner.slot_mapping[:num_actual_tokens]
if self.aot_sliding_window is None:
self.aot_sliding_window = (-1, -1)

View File

@ -12,6 +12,7 @@ import torch.nn as nn
from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
@ -139,6 +140,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
raise NotImplementedError(
"Non-Attention backend is not supported by V1 GPUModelRunner.")
if self.vllm_config.compilation_config.full_cuda_graph:
attn_backend_name = self.attn_backend.__name__
flash_attn_version = get_flash_attn_version()
if attn_backend_name != "FlashAttentionBackend" or \
flash_attn_version != 3:
raise ValueError(
f"full_cuda_graph is only supported with "
f"FA3. Current attention backend is {attn_backend_name}, "
f"FlashAttention version is {flash_attn_version}.")
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self))
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
@ -219,6 +230,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=self.device)
self.query_start_loc = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device=self.device)
self.seq_lens = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device=self.device)
self.slot_mapping = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=self.device)
# None in the first PP rank. The rest are set after load_model.
self.intermediate_tensors: Optional[IntermediateTensors] = None
@ -271,7 +292,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
pin_memory=self.pin_memory)
self.positions_np = self.positions_cpu.numpy()
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
@ -589,10 +610,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.positions_cpu[:total_num_scheduled_tokens],
non_blocking=True)
query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
self.device, non_blocking=True)
seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device,
non_blocking=True)
self.query_start_loc[:num_reqs + 1].copy_(
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True)
self.slot_mapping[:total_num_scheduled_tokens].copy_(
self.slot_mapping_cpu[:total_num_scheduled_tokens],
non_blocking=True)
# Fill unused with -1. Needed for reshape_and_cache
self.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
self.seq_lens[num_reqs:].fill_(0)
self.query_start_loc[num_reqs + 1:].fill_(-1)
query_start_loc = self.query_start_loc[:num_reqs + 1]
seq_lens = self.seq_lens[:num_reqs]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, seq_lens=seq_lens)
@ -1478,6 +1511,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _dummy_run(
self,
num_tokens: int,
skip_attn: bool = True,
) -> torch.Tensor:
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
@ -1494,6 +1528,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32)
if skip_attn:
attn_metadata = None
else:
query_start_loc = self.query_start_loc[:num_reqs + 1]
seq_lens = self.seq_lens[:num_reqs]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, seq_lens=seq_lens)
attn_metadata = self.attn_metadata_builder.build(
num_reqs=num_tokens,
num_actual_tokens=num_tokens,
max_query_len=num_tokens,
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
model = self.model
@ -1522,7 +1573,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for k, v in self.intermediate_tensors.items()
})
with set_forward_context(None,
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_tokens):
outputs = model(
@ -1708,11 +1759,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with graph_capture(device=self.device):
skip_attn = not self.vllm_config.compilation_config.full_cuda_graph
for num_tokens in reversed(self.cudagraph_batch_sizes):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens)
self._dummy_run(num_tokens)
self._dummy_run(num_tokens, skip_attn=skip_attn)
self._dummy_run(num_tokens, skip_attn=skip_attn)
end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]