mirror of https://github.com/vllm-project/vllm.git
[Core] Support full cuda graph in v1 (#16072)
Signed-off-by: Chanh Nguyen <cnguyen@linkedin.com> Co-authored-by: Chanh Nguyen <cnguyen@linkedin.com>
This commit is contained in:
parent
3d13ca0e24
commit
7ea2adb802
|
@ -137,3 +137,9 @@ By default, vLLM will try to determine a set of sizes to capture cudagraph. You
|
|||
`vllm serve meta-llama/Llama-3.2-1B --compilation-config "{'cudagraph_capture_sizes': [1, 2, 4, 8]}"`
|
||||
|
||||
Then it will only capture cudagraph for the specified sizes. It can be useful to have fine-grained control over the cudagraph capture.
|
||||
|
||||
### Full Cudagraph capture
|
||||
|
||||
It is possible to include attention as part of the cudagraph if using an attention backend that is cudagraph compatible. This can improve performance in some cases such as decode speed for smaller models. Enable this using `--compilation-config "{'full_cuda_graph': True}"`
|
||||
|
||||
Currently only FlashAttention 3 is compatible, and only when cascade attention is disabled.
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
import contextlib
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig
|
||||
|
||||
MODEL = "Qwen/Qwen2-1.5B-Instruct"
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def temporary_environ(env_vars):
|
||||
"""
|
||||
Temporarily set environment variables and restore them afterward.
|
||||
We have to do this vs monkeypatch because monkeypatch doesn't work
|
||||
with "module" scoped fixtures.
|
||||
"""
|
||||
original_env = {k: os.environ.get(k) for k in env_vars}
|
||||
try:
|
||||
os.environ.update(env_vars)
|
||||
yield
|
||||
finally:
|
||||
for k, v in original_env.items():
|
||||
if v is None:
|
||||
os.environ.pop(k, None)
|
||||
else:
|
||||
os.environ[k] = v
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def full_cudagraph_llm():
|
||||
with temporary_environ({
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_FLASH_ATTN_VERSION": "3"
|
||||
}):
|
||||
return LLM(model=MODEL,
|
||||
gpu_memory_utilization=0.2,
|
||||
compilation_config=CompilationConfig(full_cuda_graph=True))
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def piecewise_llm():
|
||||
with temporary_environ({
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_FLASH_ATTN_VERSION": "3"
|
||||
}):
|
||||
return LLM(model=MODEL,
|
||||
gpu_memory_utilization=0.5,
|
||||
compilation_config=CompilationConfig())
|
||||
|
||||
|
||||
def generate_text(llm: LLM, batch_size: int, max_tokens: int):
|
||||
prompts = ["Hi my name is"] * batch_size
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=max_tokens,
|
||||
top_p=0.95)
|
||||
|
||||
return llm.generate(prompts, sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10),
|
||||
(16, 10), (25, 10),
|
||||
(32, 10), (45, 10),
|
||||
(64, 10), (8, 5),
|
||||
(8, 20), (8, 200)])
|
||||
def test_full_cudagraph(batch_size, max_tokens, full_cudagraph_llm,
|
||||
piecewise_llm):
|
||||
"""
|
||||
Load full cudagraph model and piecewise model once, and at the same time to
|
||||
reuse them across various test cases.
|
||||
|
||||
Test various batch sizes and max_tokens to ensure that the full cudagraph
|
||||
compilation works for padded cases too.
|
||||
"""
|
||||
piecewise_responses = generate_text(piecewise_llm,
|
||||
batch_size=batch_size,
|
||||
max_tokens=max_tokens)
|
||||
full_cudagraph_responses = generate_text(full_cudagraph_llm,
|
||||
batch_size=batch_size,
|
||||
max_tokens=max_tokens)
|
||||
|
||||
# Check that all responses are the same
|
||||
for i in range(len(piecewise_responses)):
|
||||
assert piecewise_responses[i].outputs[
|
||||
0].text == full_cudagraph_responses[i].outputs[0].text
|
||||
|
||||
|
||||
def test_full_cudagraph_with_invalid_backend():
|
||||
with temporary_environ({
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_FLASH_ATTN_VERSION":
|
||||
"2" #FA2 not supported with full_cuda_graph
|
||||
}), pytest.raises(RuntimeError):
|
||||
LLM(model=MODEL,
|
||||
compilation_config=CompilationConfig(full_cuda_graph=True))
|
|
@ -3605,6 +3605,10 @@ class CompilationConfig(BaseModel):
|
|||
are always used, it can set this to False. Otherwise, it should
|
||||
set this to True, and the compiler will copy the input to an
|
||||
internally managed buffer. Default is False.
|
||||
- full_cuda_graph: whether to use a full cuda graph for the entire forward
|
||||
pass rather than splitting certain operations such as attention into subgraphs.
|
||||
Thus this flag cannot be used together with splitting_ops. This may provide
|
||||
performance benefits for smaller models.
|
||||
- Inductor compilation:
|
||||
- use_inductor: whether to use inductor compilation.
|
||||
- False: inductor compilation is not used. graph runs in eager.
|
||||
|
@ -3649,6 +3653,7 @@ class CompilationConfig(BaseModel):
|
|||
cudagraph_num_of_warmups: int = 0
|
||||
cudagraph_capture_sizes: Optional[list[int]] = None
|
||||
cudagraph_copy_inputs: bool = False
|
||||
full_cuda_graph: bool = False
|
||||
|
||||
class PassConfig(BaseModel):
|
||||
"""
|
||||
|
@ -3871,10 +3876,14 @@ class CompilationConfig(BaseModel):
|
|||
self.max_capture_size] = self.max_capture_size
|
||||
|
||||
def set_splitting_ops_for_v1(self):
|
||||
# If default, override splitting ops for piecewise cudagraph on V1.
|
||||
# NOTE: this function needs to be called
|
||||
if self.splitting_ops and self.full_cuda_graph:
|
||||
raise ValueError("full_cuda_graph cannot be used together with "
|
||||
"splitting_ops, as Full CUDA graph will override "
|
||||
f"the splitting_ops: {self.splitting_ops}")
|
||||
|
||||
if not self.splitting_ops:
|
||||
self.splitting_ops = [
|
||||
self.splitting_ops = [] if self.full_cuda_graph else [
|
||||
"vllm.unified_attention",
|
||||
"vllm.unified_attention_with_output",
|
||||
]
|
||||
|
@ -4151,6 +4160,12 @@ class VllmConfig:
|
|||
"Disabling `torch.compile`.")
|
||||
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
|
||||
if self.compilation_config.full_cuda_graph and \
|
||||
not self.model_config.disable_cascade_attn:
|
||||
logger.warning_once(
|
||||
"full_cuda_graph is not supported with "
|
||||
"cascade attention. Disabling cascade attention.")
|
||||
self.model_config.disable_cascade_attn = True
|
||||
|
||||
if self.model_config and self.model_config.use_mla and \
|
||||
not (current_platform.is_cuda() or current_platform.is_rocm()):
|
||||
|
|
|
@ -291,6 +291,7 @@ class FlashAttentionMetadataBuilder:
|
|||
|
||||
def __init__(self, runner: "GPUModelRunner"):
|
||||
model_config = runner.model_config
|
||||
compilation_config = runner.vllm_config.compilation_config
|
||||
|
||||
self.runner = runner
|
||||
self.num_heads_q = model_config.get_num_attention_heads(
|
||||
|
@ -300,7 +301,14 @@ class FlashAttentionMetadataBuilder:
|
|||
self.headdim = model_config.get_head_size()
|
||||
self.page_size = self.runner.block_size
|
||||
|
||||
self.aot_schedule = (get_flash_attn_version() == 3)
|
||||
if get_flash_attn_version() == 3:
|
||||
self.aot_schedule = not compilation_config.full_cuda_graph
|
||||
if not self.aot_schedule:
|
||||
logger.warning(
|
||||
"AOT Schedule is disabled when using full_cuda_graph")
|
||||
else:
|
||||
self.aot_schedule = False
|
||||
|
||||
# Sliding window size to be used with the AOT scheduler will be
|
||||
# populated on first build() call.
|
||||
self.aot_sliding_window: Optional[tuple[int, int]] = None
|
||||
|
@ -317,8 +325,7 @@ class FlashAttentionMetadataBuilder:
|
|||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table = (
|
||||
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
|
||||
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
self.runner.device, non_blocking=True).long()
|
||||
slot_mapping = self.runner.slot_mapping[:num_actual_tokens]
|
||||
|
||||
if self.aot_sliding_window is None:
|
||||
self.aot_sliding_window = (-1, -1)
|
||||
|
|
|
@ -12,6 +12,7 @@ import torch.nn as nn
|
|||
|
||||
from vllm.attention import AttentionType, get_attn_backend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
from vllm.config import (CompilationLevel, VllmConfig,
|
||||
get_layers_from_vllm_config)
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
|
@ -139,6 +140,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||
raise NotImplementedError(
|
||||
"Non-Attention backend is not supported by V1 GPUModelRunner.")
|
||||
|
||||
if self.vllm_config.compilation_config.full_cuda_graph:
|
||||
attn_backend_name = self.attn_backend.__name__
|
||||
flash_attn_version = get_flash_attn_version()
|
||||
if attn_backend_name != "FlashAttentionBackend" or \
|
||||
flash_attn_version != 3:
|
||||
raise ValueError(
|
||||
f"full_cuda_graph is only supported with "
|
||||
f"FA3. Current attention backend is {attn_backend_name}, "
|
||||
f"FlashAttention version is {flash_attn_version}.")
|
||||
|
||||
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
||||
weakref.proxy(self))
|
||||
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
|
||||
|
@ -219,6 +230,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||
self.positions = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
self.query_start_loc = torch.zeros(self.max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.seq_lens = torch.zeros(self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.slot_mapping = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
|
||||
# None in the first PP rank. The rest are set after load_model.
|
||||
self.intermediate_tensors: Optional[IntermediateTensors] = None
|
||||
|
||||
|
@ -271,7 +292,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||
pin_memory=self.pin_memory)
|
||||
self.positions_np = self.positions_cpu.numpy()
|
||||
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||
|
@ -589,10 +610,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||
self.positions_cpu[:total_num_scheduled_tokens],
|
||||
non_blocking=True)
|
||||
|
||||
query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
|
||||
self.device, non_blocking=True)
|
||||
seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device,
|
||||
non_blocking=True)
|
||||
self.query_start_loc[:num_reqs + 1].copy_(
|
||||
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
|
||||
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
|
||||
non_blocking=True)
|
||||
self.slot_mapping[:total_num_scheduled_tokens].copy_(
|
||||
self.slot_mapping_cpu[:total_num_scheduled_tokens],
|
||||
non_blocking=True)
|
||||
|
||||
# Fill unused with -1. Needed for reshape_and_cache
|
||||
self.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
|
||||
self.seq_lens[num_reqs:].fill_(0)
|
||||
self.query_start_loc[num_reqs + 1:].fill_(-1)
|
||||
|
||||
query_start_loc = self.query_start_loc[:num_reqs + 1]
|
||||
seq_lens = self.seq_lens[:num_reqs]
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc, seq_lens=seq_lens)
|
||||
|
||||
|
@ -1478,6 +1511,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||
def _dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
skip_attn: bool = True,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
||||
|
@ -1494,6 +1528,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
||||
dtype=np.int32)
|
||||
|
||||
if skip_attn:
|
||||
attn_metadata = None
|
||||
else:
|
||||
query_start_loc = self.query_start_loc[:num_reqs + 1]
|
||||
seq_lens = self.seq_lens[:num_reqs]
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc, seq_lens=seq_lens)
|
||||
|
||||
attn_metadata = self.attn_metadata_builder.build(
|
||||
num_reqs=num_tokens,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=num_tokens,
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
|
||||
with self.maybe_dummy_run_with_lora(self.lora_config,
|
||||
num_scheduled_tokens):
|
||||
model = self.model
|
||||
|
@ -1522,7 +1573,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||
for k, v in self.intermediate_tensors.items()
|
||||
})
|
||||
|
||||
with set_forward_context(None,
|
||||
with set_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens):
|
||||
outputs = model(
|
||||
|
@ -1708,11 +1759,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||
# Capture the large shapes first so that the smaller shapes
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
with graph_capture(device=self.device):
|
||||
skip_attn = not self.vllm_config.compilation_config.full_cuda_graph
|
||||
for num_tokens in reversed(self.cudagraph_batch_sizes):
|
||||
for _ in range(self.vllm_config.compilation_config.
|
||||
cudagraph_num_of_warmups):
|
||||
self._dummy_run(num_tokens)
|
||||
self._dummy_run(num_tokens)
|
||||
self._dummy_run(num_tokens, skip_attn=skip_attn)
|
||||
self._dummy_run(num_tokens, skip_attn=skip_attn)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
|
Loading…
Reference in New Issue