Compare commits

...

1 Commits
main ... v0.9.2

Author SHA1 Message Date
simon-mo a5dd03c1eb Revert "[V0 deprecation] Remove V0 CPU/XPU/TPU backends (#20412)"
This reverts commit e202dd2736.
2025-07-06 14:02:36 -07:00
20 changed files with 5034 additions and 46 deletions

View File

@ -66,10 +66,10 @@ function cpu_tests() {
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token"
# Run AWQ test
# docker exec cpu-test-"$NUMA_NODE" bash -c "
# set -e
# VLLM_USE_V1=0 pytest -s -v \
# tests/quantization/test_ipex_quant.py"
docker exec cpu-test-"$NUMA_NODE" bash -c "
set -e
VLLM_USE_V1=0 pytest -s -v \
tests/quantization/test_ipex_quant.py"
# Run chunked-prefill and prefix-cache test
docker exec cpu-test-"$NUMA_NODE" bash -c "

View File

@ -26,5 +26,7 @@ docker run \
--name "${container_name}" \
"${image_name}" \
sh -c '
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
'

View File

@ -8,7 +8,7 @@ image:
# -- Image tag
tag: "latest"
# -- Container launch command
command: ["vllm", "serve", "/data/", "--served-model-name", "opt-125m", "--enforce-eager", "--dtype", "bfloat16", "--block-size", "16", "--host", "0.0.0.0", "--port", "8000"]
command: ["vllm", "serve", "/data/", "--served-model-name", "opt-125m", "--dtype", "float32", "--block-size", "16", "--host", "0.0.0.0", "--port", "8000"]
# -- Container port
containerPort: 8000

View File

@ -36,8 +36,7 @@ DEVICE_REGULAR_ATTN_BACKENDS = {
DEVICE_MLA_BLOCK_SIZES = {
"cuda": [16, 64], # CUDA supports both standard and extended block sizes
"hip": [16, 1], # HIP requires special handling for block_size=1
# "cpu": [16] # CPU uses fixed block size from test cases
"cpu": [] # FIXME(woosuk): Temporarily disable CPU tests
"cpu": [16] # CPU uses fixed block size from test cases
}
@ -82,14 +81,14 @@ def test_env(
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
if device == "cpu":
if not use_v1:
pytest.skip("CPU backend only supports V1")
with patch("vllm.attention.selector.current_platform",
CpuPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16,
block_size, False)
if use_v1:
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
else:
assert backend.get_name() == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.current_platform",
@ -205,14 +204,12 @@ def test_fp32_fallback(
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
if device == "cpu":
if not use_v1:
pytest.skip("CPU backend only supports V1")
with patch("vllm.attention.selector.current_platform",
CpuPlatform()):
backend = get_attn_backend(16, torch.float32, torch.float32,
16, False)
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
assert (backend.get_name() == "TORCH_SDPA_VLLM_V1"
if use_v1 else "TORCH_SDPA")
elif device == "cuda":
with patch("vllm.attention.selector.current_platform",

View File

@ -0,0 +1,307 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
import vllm._custom_ops as ops
from vllm._ipex_ops import ipex_ops
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadataBuilder,
AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.mla.common import MLACommonImpl, MLACommonState
from vllm.attention.backends.torch_sdpa import TorchSDPAMetadata
from vllm.utils import make_tensor_with_pad
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder
class CPUMLABackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "CPU_MLA"
@staticmethod
def get_metadata_cls() -> Type["CPUMLAMetadata"]:
return CPUMLAMetadata
@staticmethod
def get_builder_cls() -> Type["CPUMLAMetadataBuilder"]:
return CPUMLAMetadataBuilder
@staticmethod
def get_state_cls() -> Type["MLACommonState"]:
return MLACommonState
@staticmethod
def get_impl_cls() -> Type["CPUMLAImpl"]:
return CPUMLAImpl
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
head_size: int,
) -> Tuple[int, ...]:
return (num_blocks, block_size, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
ops.copy_blocks_mla(kv_caches, src_to_dists)
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [576]
@dataclass
class CPUMLAMetadata(TorchSDPAMetadata):
# New for MLA
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor = None
# required by MLACommonImpl
is_profile_run: bool = False
class CPUMLAMetadataBuilder(AttentionMetadataBuilder[CPUMLAMetadata]):
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
self.chunked_prefill = input_builder.chunked_prefill
self.input_builder = input_builder
assert not self.chunked_prefill, \
"chunked prefill is currently not supported"
def prepare(self):
self.input_data = self.input_builder.input_data
def build(self, seq_lens, query_lens, cuda_graph_pad_size, batch_size):
input_data = self.input_data
prefill_seq_lens = seq_lens[0:input_data.num_prefills]
prefill_query_lens = query_lens[0:input_data.num_prefills]
slot_mapping = torch.tensor(input_data.slot_mapping,
dtype=torch.long,
device="cpu")
# metadata for prefill
if input_data.num_prefills > 0:
query_lens_tensor = torch.tensor(prefill_query_lens,
dtype=torch.int32,
device="cpu")
kv_lens_tensor = torch.tensor(prefill_seq_lens,
dtype=torch.int32,
device="cpu")
query_start_loc = torch.zeros(input_data.num_prefills + 1,
dtype=torch.int32,
device="cpu")
kv_start_loc = torch.zeros(input_data.num_prefills + 1,
dtype=torch.int32,
device="cpu")
torch.cumsum(query_lens_tensor,
dim=0,
dtype=torch.int32,
out=query_start_loc[1:])
torch.cumsum(kv_lens_tensor,
dim=0,
dtype=torch.int32,
out=kv_start_loc[1:])
max_query_len = max(prefill_query_lens)
max_kv_len = max(prefill_seq_lens)
# for chunked-prefill
if self.chunked_prefill:
prefill_block_tables = make_tensor_with_pad(
self.input_data.prefill_block_tables,
pad=0,
dtype=torch.int32,
device="cpu",
)
else:
prefill_block_tables = None
else:
query_start_loc = None
kv_start_loc = None
max_query_len = None
max_kv_len = None
prefill_block_tables = None
# metadata for decode
if input_data.num_decode_tokens != 0:
seq_lens_tensor = torch.tensor(
input_data.seq_lens[input_data.num_prefills:],
dtype=torch.int32,
device="cpu",
)
block_tables = make_tensor_with_pad(
self.input_data.decode_block_tables,
pad=0,
dtype=torch.int32,
device="cpu",
)
else:
block_tables = torch.tensor([])
seq_lens_tensor = torch.tensor(
input_data.seq_lens[:input_data.num_prefills],
dtype=torch.int32,
device="cpu",
)
# For multi-modal models
placeholder_index_maps = None
if len(input_data.multi_modal_inputs_list) != 0:
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
input_data.multi_modal_placeholder_maps.items()
}
return CPUMLAMetadata(
chunked_prefill=self.chunked_prefill,
seq_lens=prefill_seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_kv_len=max_kv_len,
prefill_query_start_loc=query_start_loc,
kv_start_loc=kv_start_loc,
max_decode_seq_len=input_data.max_decode_seq_len,
num_prefills=input_data.num_prefills,
num_prefill_tokens=input_data.num_prefill_tokens,
num_decode_tokens=input_data.num_decode_tokens,
block_tables=block_tables,
prefill_block_tables=prefill_block_tables,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
input_positions=torch.tensor([self.input_data.input_positions]))
class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"CPUMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"CPUMLAImpl")
# states is implemented.
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"CPUMLAImpl with FP8 KV cache not yet supported")
def _forward_prefill(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: CPUMLAMetadata, # type: ignore[override]
) -> torch.Tensor:
prefill_metadata = attn_metadata.prefill_metadata
assert prefill_metadata is not None
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0)
output = torch.empty_like(q)
ipex_ops.varlen_attention(
query=q,
key=k,
value=v_padded,
out=output,
seqlen_q=prefill_metadata.prefill_query_start_loc,
seqlen_k=prefill_metadata.prefill_query_start_loc,
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.max_query_len,
pdropout=0.0,
softmax_scale=self.scale,
zero_tensors=False,
is_causal=True,
return_softmax=False,
gen_=None,
logits_soft_cap=0.0,
window_size_left=-1,
window_size_right=-1,
alibi_slopes=None,
)
# remove padding
output = output.view(-1, self.num_heads,
q.shape[-1])[..., :v.shape[-1]]
return output.reshape(-1, self.num_heads * v.shape[-1])
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: CPUMLAMetadata, # type: ignore[override]
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
decode_meta = attn_metadata.decode_metadata
assert decode_meta is not None
q = torch.cat([q_nope, q_pe], dim=-1)
o = q.new_empty(q.shape[0], self.num_heads, self.kv_lora_rank)
# Run MQA
ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale,
decode_meta.block_tables,
decode_meta.seq_lens_tensor)
return self._v_up_proj(o)

View File

@ -0,0 +1,403 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" Attention layer with torch scaled_dot_product_attention
and PagedAttention."""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from vllm._ipex_ops import ipex_ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
logger = init_logger(__name__)
_PARTITION_SIZE = 512
class IpexAttnBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "IPEX"
@staticmethod
def get_impl_cls() -> Type["IpexAttnBackendImpl"]:
return IpexAttnBackendImpl
@staticmethod
def get_metadata_cls() -> Type["IpexAttnMetadata"]:
return IpexAttnMetadata
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
from vllm._ipex_ops import ipex_ops as ops
ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
from vllm._ipex_ops import ipex_ops as ops
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)
@dataclass
class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata):
"""Metadata for IpexAttnBackend.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
slot_mapping: torch.Tensor
seq_lens: Optional[List[int]]
seqlen_q: Optional[torch.Tensor]
max_seqlen: Optional[int]
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[torch.Tensor]] = None
@property
def prefill_metadata(self) -> Optional["IpexAttnMetadata"]:
# Currently chunked prefill is not supported
if self.num_decode_tokens == 0:
assert self.num_prefills > 0
return self
return None
@property
def decode_metadata(self) -> Optional["IpexAttnMetadata"]:
# Currently chunked prefill is not supported
if self.num_prefills > 0:
assert self.num_decode_tokens == 0
return None
return self
class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
if use_irope:
logger.warning_once(
"Using irope in Ipex is not supported yet, it will fall"
" back to global attention for long context.")
if blocksparse_params is not None:
raise ValueError(
"IPEX backend does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = sliding_window
self.kv_cache_dtype = kv_cache_dtype
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.need_mask = (self.sliding_window is not None)
if logits_soft_cap is None:
logits_soft_cap = -1
self.logits_soft_cap = logits_soft_cap
supported_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in supported_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {supported_head_sizes}.")
if is_quantized_kv_cache(kv_cache_dtype):
raise NotImplementedError(
"IPEX backend does not support FP8 KV cache. "
"Please use xFormers backend instead.")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"IpexAttnBackendImpl")
def split_kv_cache(
self,
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
x = 1
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
-1, x)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
return key_cache, value_cache
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: IpexAttnMetadata, # type: ignore
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for IpexAttentionImpl")
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if kv_cache.numel() > 0:
key_cache, value_cache = self.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
ipex_ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping.flatten(),
self.kv_cache_dtype,
layer._k_scale_float,
layer._v_scale_float,
)
if attn_metadata.is_prompt:
assert attn_metadata.seq_lens is not None
if (kv_cache.numel() == 0
or attn_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv,
dim=1)
if attn_metadata.attn_bias is None:
if self.sliding_window is not None:
att_masks = _make_sliding_window_bias(
attn_metadata.seq_lens, self.sliding_window,
query.dtype) # type: ignore
else:
att_masks = _make_sliding_window_bias(
attn_metadata.seq_lens, None, dtype=query.dtype)
attn_metadata.attn_bias = att_masks
output = torch.empty(
(num_tokens, self.num_heads, self.head_size),
dtype=query.dtype,
device=query.device)
ipex_ops.varlen_attention(
query,
key,
value,
output,
attn_metadata.seqlen_q,
attn_metadata.seqlen_q,
self.alibi_slopes,
attn_metadata.max_seqlen,
attn_metadata.max_seqlen,
pdropout=0.0,
softmax_scale=self.scale,
zero_tensors=False,
is_causal=True,
return_softmax=False,
gen_=None,
window_size_left=-1,
window_size_right=-1,
logits_soft_cap=self.logits_soft_cap,
)
else:
# prefix-enabled attention
raise RuntimeError(
"IPEX backend doesn't support prefix decoding.")
else:
# Decoding run.
max_seq_len = attn_metadata.max_decode_seq_len
output = torch.empty_like(query)
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
_PARTITION_SIZE)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory
# shortage.
use_v1 = (max_seq_len <= 8192 and
(max_num_partitions == 1 or num_seqs * num_heads > 512))
if use_v1:
# Run PagedAttention V1.
ipex_ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
self.num_kv_heads,
self.scale,
attn_metadata.block_tables,
attn_metadata.seq_lens_tensor,
block_size,
max_seq_len,
self.alibi_slopes,
self.kv_cache_dtype,
layer._k_scale_float,
layer._v_scale_float,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ipex_ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
self.num_kv_heads,
self.scale,
attn_metadata.block_tables,
attn_metadata.seq_lens_tensor,
block_size,
max_seq_len,
self.alibi_slopes,
self.kv_cache_dtype,
layer._k_scale_float,
layer._v_scale_float,
)
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
dtype: torch.dtype,
seq_lens: List[int],
) -> List[torch.Tensor]:
attn_biases = []
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]
num_heads = alibi_slopes.shape[0]
bias = bias[None, :].repeat((num_heads, 1, 1))
bias.mul_(alibi_slopes[:, None, None])
inf_mask = torch.empty(
(1, seq_len, seq_len),
dtype=bias.dtype,
device=alibi_slopes.device).fill_(-torch.inf).triu_(diagonal=1)
attn_biases.append((bias + inf_mask).to(dtype))
return attn_biases
def _make_sliding_window_bias(
seq_lens: List[int],
window_size: Optional[int],
dtype: torch.dtype,
) -> List[torch.Tensor]:
attn_biases = []
for seq_len in seq_lens:
tensor = torch.full(
(1, seq_len, seq_len),
dtype=dtype,
fill_value=1,
)
shift = 0
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
if window_size is not None:
mask = torch.triu(mask, diagonal=shift - window_size + 1)
mask = torch.log(mask)
attn_biases.append(mask.to(dtype))
return attn_biases

View File

@ -0,0 +1,356 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
import torch_xla.experimental.custom_kernel # Required to register custom ops.
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.logger import init_logger
logger = init_logger(__name__)
class PallasAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "PALLAS"
@staticmethod
def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
return PallasAttentionBackendImpl
@staticmethod
def get_metadata_cls() -> Type["PallasMetadata"]:
return PallasMetadata
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (num_kv_heads, num_blocks, block_size, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
raise RuntimeError("swap_blocks is not used for the TPU backend.")
@torch.compile(backend="openxla")
@staticmethod
def copy_blocks(
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
src_to_dists: Tuple[torch.Tensor, torch.Tensor],
) -> None:
src_indices, dst_indices = src_to_dists
for k_cache, v_cache in kv_caches:
torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
k_cache[:, dst_indices] = k_cache[:, src_indices]
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
v_cache[:, dst_indices] = v_cache[:, src_indices]
@dataclass
class PallasMetadata(AttentionMetadata):
# Currently, input sequences can only contain all prefills
# or all decoding.
block_tables: Optional[torch.Tensor] = None
context_lens: Optional[torch.Tensor] = None
effective_query_lens: Optional[torch.Tensor] = None
@property
def prefill_metadata(self) -> Optional["PallasMetadata"]:
if self.num_prefills == 0:
return None
assert self.num_decode_tokens == 0
return self
@property
def decode_metadata(self) -> Optional["PallasMetadata"]:
if self.num_decode_tokens == 0:
return None
assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.block_tables is not None
assert self.context_lens is not None
return self
class PallasAttentionBackendImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
if use_irope:
logger.warning_once(
"Using irope in Pallas is not supported yet, it will fall back "
"to global attention for long context.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.logits_soft_cap = logits_soft_cap
if head_size % 128 != 0:
raise NotImplementedError(
f"Head size must be a multiple of 128, found {head_size}.")
if alibi_slopes is not None:
raise NotImplementedError("Alibi slopes is not supported.")
if sliding_window is not None:
raise NotImplementedError("Sliding window is not supported.")
if is_quantized_kv_cache(kv_cache_dtype):
raise NotImplementedError("FP8 KV cache dtype is not supported.")
if blocksparse_params is not None:
raise NotImplementedError("Blocksparse is not supported.")
if torch_xla.tpu.version() < 4:
raise NotImplementedError("TPU version must be 4 or higher.")
self.megacore_mode = None
tpu_env = torch_xla.tpu.get_tpu_env()
tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None)
or tpu_env.get("TYPE", None)
or tpu_env.get("TPU_ACCELERATOR_TYPE", None))
assert tpu_type is not None
tpu_type = tpu_type.lower()
if (("lite" not in tpu_type) and ("v6" not in tpu_type)):
if self.num_kv_heads % 2 == 0:
self.megacore_mode = "kv_head"
else:
# NOTE(woosuk): If the batch size is not a multiple of 2, the
# megacore mode will be None.
self.megacore_mode = "batch"
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
attn_metadata: PallasMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
with shape [0] for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for PallasAttentionImpl")
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
value = value.view(batch_size, seq_len, self.num_kv_heads,
self.head_size)
if kv_cache[0].numel() > 0:
slot_mapping = attn_metadata.slot_mapping
key_cache, value_cache = kv_cache
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
query = query * self.scale
if attn_metadata.num_prefills > 0:
if attn_metadata.block_tables is None:
# Prefill without paged KV cache.
assert seq_len % 16 == 0, (
"Pallas FlashAttention kernel requires seq_len to be a "
f"multiple of 16 but got {seq_len}")
# Handle GQA/MQA.
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv,
dim=-2)
key = key.view(batch_size, seq_len, self.num_heads,
self.head_size)
value = value.repeat_interleave(self.num_queries_per_kv,
dim=-2)
value = value.view(batch_size, seq_len, self.num_heads,
self.head_size)
# FlashAttention kernel requires the input shape to be
# [batch_size, num_heads, seq_len, d_model]
# while the input is [batch_size, seq_len, num_heads, d_model].
# Permute the input to match the required format.
output = torch.ops.xla.flash_attention(
query.permute(0, 2, 1, 3),
key.permute(0, 2, 1, 3),
value.permute(0, 2, 1, 3),
True,
)
output = output.permute(0, 2, 1, 3)
else:
# Prefill with paged KV cache.
# TODO(woosuk): Tune the below knobs.
num_kv_pages_per_compute_block = 16
num_queries_per_compute_block = 16
assert seq_len % num_queries_per_compute_block == 0
output = torch.ops.xla.multi_queries_paged_attention(
query,
key_cache,
value_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
attn_metadata.effective_query_lens,
num_kv_pages_per_compute_block,
num_queries_per_compute_block,
use_kernel=True,
attn_logits_soft_cap=self.logits_soft_cap,
)
else:
# Decoding run.
assert kv_cache[0].numel() > 0
query = query.squeeze(dim=1)
pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
assert attn_metadata.block_tables is not None
assert attn_metadata.context_lens is not None
# NOTE(woosuk): The PagedAttention Pallas kernel stores the entire
# block table in SMEM. Therefore, if the block table is too large,
# the kernel compilation will fail. To avoid this, we split the
# batch dimension into smaller chunks and run the kernel multiple
# times.
MAX_SMEM_USAGE = 512 * 1024
size_per_seq = 4 * attn_metadata.block_tables.shape[1]
max_num_seq = MAX_SMEM_USAGE // size_per_seq
if batch_size <= max_num_seq:
output = paged_attention(
query,
key_cache,
value_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
pages_per_compute_block,
self.megacore_mode,
attn_logits_soft_cap=self.logits_soft_cap,
)
else:
chunk_size = max_num_seq
# Make sure the chunk size is a multiple of 2.
chunk_size = chunk_size // 2 * 2
num_chunks = (batch_size + chunk_size - 1) // chunk_size
output = torch.empty_like(query)
for chunk_idx in range(num_chunks):
chunk_start = chunk_idx * chunk_size
chunk_end = chunk_start + chunk_size
# NOTE(woosuk): We skip this line because it causes Dynamo
# compilation error. Instead, we rely on the slice operation
# to handle the out-of-bound case.
# chunk_end = min(chunk_end, batch_size)
chunk_output = paged_attention(
query[chunk_start:chunk_end],
key_cache,
value_cache,
attn_metadata.context_lens[chunk_start:chunk_end],
attn_metadata.block_tables[chunk_start:chunk_end],
pages_per_compute_block,
self.megacore_mode,
attn_logits_soft_cap=self.logits_soft_cap,
)
output[chunk_start:chunk_end] = chunk_output
# Reshape the output tensor.
return output.reshape(batch_size, seq_len, hidden_size)
def write_to_kv_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
key = key.flatten(0, 2)
value = value.flatten(0, 2)
key_cache = key_cache.flatten(0, 2)
value_cache = value_cache.flatten(0, 2)
key_cache.index_copy_(0, slot_mapping, key)
value_cache.index_copy_(0, slot_mapping, value)
def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
pages_per_compute_block: int,
megacore_mode: Optional[str],
*,
attn_logits_soft_cap: Optional[float],
) -> torch.Tensor:
batch_size = query.shape[0]
if megacore_mode == "batch" and batch_size % 2 != 0:
megacore_mode = None
else:
megacore_mode = megacore_mode
return torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
megacore_mode=megacore_mode,
attn_logits_soft_cap=attn_logits_soft_cap,
)

View File

@ -3,24 +3,78 @@
""" Attention layer with torch scaled_dot_product_attention
and PagedAttention."""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from torch.nn.functional import scaled_dot_product_attention
# yapf conflicts with isort for this block
# yapf: disable
from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
AttentionMetadata, AttentionType,
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType,
is_quantized_kv_cache)
# yapf: enable
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.ipex_attn import PagedAttention, _use_ipex
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
from vllm.logger import init_logger
from vllm.utils import make_tensor_with_pad
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder
logger = init_logger(__name__)
class TorchSDPABackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "TORCH_SDPA"
@staticmethod
def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
return TorchSDPABackendImpl
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return TorchSDPAMetadata
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]:
return TorchSDPAMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
raise NotImplementedError("Swap is not supported in TorchSDPABackend.")
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@dataclass
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
"""Metadata for TorchSDPABackend.
@ -233,6 +287,113 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
raise AttributeError(f"Invalid attention type {str(attn_type)}")
class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
self.chunked_prefill = input_builder.chunked_prefill
self.input_builder = input_builder
def prepare(self):
self.input_data = self.input_builder.input_data
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata:
input_data = self.input_data
prefill_seq_lens = seq_lens[0:input_data.num_prefills]
prefill_query_lens = query_lens[0:input_data.num_prefills]
slot_mapping = torch.tensor(input_data.slot_mapping,
dtype=torch.long,
device="cpu")
# For chunked-prefill
if self.chunked_prefill and input_data.num_prefill_tokens != 0:
prefill_block_tables = make_tensor_with_pad(
self.input_data.prefill_block_tables,
pad=0,
dtype=torch.int32,
device="cpu",
)
query_lens_tensor = torch.tensor(prefill_query_lens,
dtype=torch.int32,
device="cpu")
kv_lens_tensor = torch.tensor(prefill_seq_lens,
dtype=torch.int32,
device="cpu")
query_start_loc = torch.zeros(input_data.num_prefills + 1,
dtype=torch.int32,
device="cpu")
kv_start_loc = torch.zeros(input_data.num_prefills + 1,
dtype=torch.int32,
device="cpu")
torch.cumsum(query_lens_tensor,
dim=0,
dtype=torch.int32,
out=query_start_loc[1:])
torch.cumsum(kv_lens_tensor,
dim=0,
dtype=torch.int32,
out=kv_start_loc[1:])
max_query_len = max(prefill_query_lens)
max_kv_len = max(prefill_seq_lens)
else:
prefill_block_tables = None
query_start_loc = None
kv_start_loc = None
max_query_len = None
max_kv_len = None
# For paged attention
if input_data.num_decode_tokens != 0:
seq_lens_tensor = torch.tensor(
input_data.seq_lens[input_data.num_prefills:],
dtype=torch.int32,
device="cpu",
)
block_tables = make_tensor_with_pad(
self.input_data.decode_block_tables,
pad=0,
dtype=torch.int32,
device="cpu",
)
else:
block_tables = torch.tensor([])
seq_lens_tensor = torch.tensor(
input_data.seq_lens[:input_data.num_prefills],
dtype=torch.int32,
device="cpu",
)
# For multi-modal models
placeholder_index_maps = None
if len(input_data.multi_modal_inputs_list) != 0:
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
input_data.multi_modal_placeholder_maps.items()
}
attn_metadata = TorchSDPAMetadata(
chunked_prefill=self.chunked_prefill,
seq_lens=prefill_seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_kv_len=max_kv_len,
prefill_query_start_loc=query_start_loc,
kv_start_loc=kv_start_loc,
max_decode_seq_len=input_data.max_decode_seq_len,
num_prefills=input_data.num_prefills,
num_prefill_tokens=input_data.num_prefill_tokens,
num_decode_tokens=input_data.num_decode_tokens,
block_tables=block_tables,
prefill_block_tables=prefill_block_tables,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
)
return attn_metadata
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
def __init__(

View File

@ -64,11 +64,13 @@ class CpuPlatform(Platform):
if selected_backend and selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla:
raise NotImplementedError("MLA is not supported on CPU.")
logger.info("Using CPU MLA backend.")
return "vllm.attention.backends.cpu_mla.CPUMLABackend"
logger.info("Using Torch SDPA backend.")
if not use_v1:
raise ValueError("CPU backend only supports V1.")
if use_v1:
return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
else:
return "vllm.attention.backends.torch_sdpa.TorchSDPABackend"
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
@ -145,14 +147,26 @@ class CpuPlatform(Platform):
parallel_config.distributed_executor_backend)
parallel_config.distributed_executor_backend = "mp"
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.v1.worker.cpu_worker.CPUWorker"
if vllm_config.speculative_config:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.sd_worker_cls = \
"vllm.worker.cpu_worker.CPUWorker"
else:
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm.v1.worker.cpu_worker.CPUWorker"
else:
parallel_config.worker_cls = \
"vllm.worker.cpu_worker.CPUWorker"
# Note: workaround for v1 gpu_model_runner
from vllm.config import CompilationLevel
vllm_config.compilation_config.cudagraph_capture_sizes = []
compilation_config = vllm_config.compilation_config
if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE:
if (envs.VLLM_USE_V1 and vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE):
# Note: vLLM V1 is using PIECEWISE level compilation, which will
# take time to compile kernels just-in-time with the inductor

View File

@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional, Union, cast
import torch
from tpu_info import device
import vllm.envs as envs
from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams, SamplingType
@ -49,10 +50,12 @@ class TpuPlatform(Platform):
and selected_backend != _Backend.PALLAS_VLLM_V1):
logger.info("Cannot use %s backend on TPU.", selected_backend)
if not use_v1:
raise ValueError("TPU backend only supports V1.")
if use_v1:
logger.info("Using Pallas V1 backend.")
return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
else:
logger.info("Using Pallas backend.")
return "vllm.attention.backends.pallas.PallasAttentionBackend"
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
@ -65,7 +68,7 @@ class TpuPlatform(Platform):
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return False
return not envs.VLLM_USE_V1
@classmethod
def get_punica_wrapper(cls) -> str:
@ -114,7 +117,9 @@ class TpuPlatform(Platform):
"Using bfloat16 instead.", vllm_config.model_config.dtype)
vllm_config.model_config.dtype = torch.bfloat16
from vllm.v1.attention.backends.pallas import PallasAttentionBackend
if envs.VLLM_USE_V1:
from vllm.v1.attention.backends.pallas import (
PallasAttentionBackend)
cache_config.block_size = PallasAttentionBackend.get_page_size(
vllm_config) # type: ignore[assignment]
@ -122,11 +127,21 @@ class TpuPlatform(Platform):
scheduler_config = vllm_config.scheduler_config
if parallel_config.worker_cls == "auto":
if scheduler_config.is_multi_step:
if envs.VLLM_USE_V1:
raise NotImplementedError(
"Multi-step scheduling is not supported (and not "
"needed) on vLLM V1. Please launch without "
"--num-scheduler-steps.")
parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker"
else:
parallel_config.worker_cls = \
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
else:
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm.v1.worker.tpu_worker.TPUWorker"
else:
parallel_config.worker_cls = \
"vllm.worker.tpu_worker.TPUWorker"
assert not vllm_config.speculative_config, (
"Speculative decoding is not yet supported for TPU backend")
@ -174,9 +189,13 @@ class TpuPlatform(Platform):
processed_inputs: ProcessorInputs,
) -> None:
"""Raises if this request is unsupported on this platform"""
if (isinstance(params, SamplingParams)
and params.sampling_type == SamplingType.RANDOM_SEED):
raise ValueError("Torch XLA does not support per-request seed.")
if isinstance(params, SamplingParams):
if params.guided_decoding is not None and not envs.VLLM_USE_V1:
raise ValueError("Structured output is not supported on "
f"{cls.device_name} V0.")
if params.sampling_type == SamplingType.RANDOM_SEED:
raise ValueError(
"Torch XLA does not support per-request seed.")
try:

View File

@ -39,10 +39,12 @@ class XPUPlatform(Platform):
if selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)
use_v1 = envs.VLLM_USE_V1
if not use_v1:
raise ValueError("XPU backend only supports V1.")
if use_v1:
logger.info("Using Flash Attention backend on V1 engine.")
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
else:
logger.info("Using IPEX attention backend.")
return "vllm.attention.backends.ipex_attn.IpexAttnBackend"
@classmethod
def get_device_capability(
@ -75,7 +77,10 @@ class XPUPlatform(Platform):
cache_config = vllm_config.cache_config
# in V1(or with ipex chunked prefill) block_size is 64
if cache_config and cache_config.block_size is None:
if envs.VLLM_USE_V1:
cache_config.block_size = 64
else:
cache_config.block_size = 16
# Instances created using VllmConfig() typically have model_config as
# None by default. The modification involves adding a check to prevent
@ -101,7 +106,11 @@ class XPUPlatform(Platform):
# check and update parallel config
parallel_config = vllm_config.parallel_config
parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
if envs.VLLM_USE_V1:
parallel_config.worker_cls =\
"vllm.v1.worker.xpu_worker.XPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker"
if parallel_config.distributed_executor_backend is None:
if parallel_config.world_size > 1:

View File

@ -0,0 +1,326 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast
import torch
from vllm.attention import AttentionMetadata
from vllm.forward_context import set_forward_context
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MultiModalKwargs
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad
from vllm.worker.cpu_model_runner import (CPUModelRunnerBase,
ModelInputForCPUBuilder,
ModelInputForCPUWithSamplingMetadata)
from vllm.worker.model_runner_base import (
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
@dataclasses.dataclass(frozen=True)
class EncoderDecoderModelInputForCPU(ModelInputForCPUWithSamplingMetadata):
"""
Used by the EncoderDecoderModelRunner.
"""
encoder_input_tokens: Optional[torch.Tensor] = None
encoder_input_positions: Optional[torch.Tensor] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"encoder_input_tokens": self.encoder_input_tokens,
"encoder_input_positions": self.encoder_input_positions,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "EncoderDecoderModelInputForCPU":
return cast(
EncoderDecoderModelInputForCPU,
super().from_broadcasted_tensor_dict(tensor_dict, attn_backend))
class CPUEncoderDecoderModelRunner(
CPUModelRunnerBase[EncoderDecoderModelInputForCPU]):
_model_input_cls: Type[EncoderDecoderModelInputForCPU] = (
EncoderDecoderModelInputForCPU)
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
def _list_to_int32_tensor(
self,
_list: List[int],
) -> torch.Tensor:
return torch.tensor(_list, dtype=torch.int32, device=self.device)
def _list_to_long_tensor(
self,
_list: List[int],
) -> torch.Tensor:
return torch.tensor(_list, dtype=torch.long, device=self.device)
def _empty_int32_tensor(self) -> torch.Tensor:
return self._list_to_int32_tensor([])
def _empty_long_tensor(self) -> torch.Tensor:
return self._list_to_long_tensor([])
def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str,
Any]) -> EncoderDecoderModelInputForCPU:
return EncoderDecoderModelInputForCPU.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
)
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> EncoderDecoderModelInputForCPU:
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
(
attn_metadata,
encoder_input_tokens_tensor,
encoder_input_positions_tensor,
) = self._prepare_encoder_model_input_tensors(seq_group_metadata_list,
model_input)
# Sampling metadata is only required for the final pp group
generators = self.get_generators(finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
model_input.seq_lens,
model_input.query_lens,
self.device,
pin_memory=False,
generators=generators)
return dataclasses.replace(
model_input,
sampling_metadata=sampling_metadata,
attn_metadata=attn_metadata,
encoder_input_tokens=encoder_input_tokens_tensor,
encoder_input_positions=encoder_input_positions_tensor,
virtual_engine=virtual_engine,
)
def _prepare_encoder_model_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
model_input: EncoderDecoderModelInputForCPU,
) -> Tuple[AttentionMetadata, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""Helper method to prepare the encoder- and cross-attn-related
model inputs based on a given sequence group. These additional inputs
are used to augment an already-computed `EncoderDecoderModelInput`
data structure which already has decoder-related model inputs
populated.
Sets the following attn_metadata fields:
* `num_encoder_tokens`
* `encoder_seq_lens`
* `encoder_seq_lens_tensor`
* `max_encoder_seq_len`
* `cross_slot_mapping`
* `cross_block_tables`
Constructs a new model inputs data structure, based on
(1) the existing fields in the `model_inputs` argument,
and (2) the following additional fields which are
computed (or in the case of `attn_metadata`, updated)
by this function:
* attn_metadata
* encoder_input_tokens
* encoder_input_positions
Arguments:
* seq_group_metadata_list: list of sequence groups for which to
compute inputs
* model_inputs: model inputs data structure with decoder-oriented
fields already computed.
Return:
* Updated model inputs data structure
"""
if len(seq_group_metadata_list) == 0:
return (model_input.attn_metadata, None, None)
# Since we are not supporting chunked prefill either the entire
# batch is prefill or it is decode
is_prompt = seq_group_metadata_list[0].is_prompt
# Build encoder inputs
encoder_seq_lens: List[int] = []
if is_prompt:
# Prefill phase.
cross_block_tables = self._empty_int32_tensor().view(
len(seq_group_metadata_list), -1)
# Extract input tokens/positions, cross-attention slot-mapping,
# & seq len from each sequence group metadata
(
encoder_input_tokens,
encoder_input_positions,
cross_slot_mapping,
) = (
[],
[],
[],
)
for seq_group_metadata in seq_group_metadata_list:
# Build seq lens
seq_len = seq_group_metadata.encoder_seq_data.get_len()
token_ids = seq_group_metadata.encoder_seq_data.get_token_ids()
encoder_seq_lens.append(seq_len)
# Build slot mapping
for i in range(0, seq_len):
block_number = seq_group_metadata.cross_block_table[
i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
cross_slot_mapping.append(slot)
# Build encoder input tokens
encoder_input_tokens.extend(token_ids)
encoder_input_positions.extend(list(range(0, seq_len)))
# Convert tokens/positions & cross-attention
# slot-mapping to encoder input tensors
encoder_input_tokens_tensor = self._list_to_long_tensor(
encoder_input_tokens)
encoder_input_positions_tensor = self._list_to_long_tensor(
encoder_input_positions)
cross_slot_mapping_tensor = self._list_to_long_tensor(
cross_slot_mapping)
else:
# Decode phase.
encoder_input_tokens_tensor = self._empty_long_tensor()
encoder_input_positions_tensor = self._empty_long_tensor()
cross_slot_mapping_tensor = self._empty_long_tensor()
# Extract cross-attention block tables &
# seq len from each sequence group metadata.
# Cross-attention block tables are empty
# during vLLM memory profiling.
cross_block_tables = []
for seq_group_metadata in seq_group_metadata_list:
for _ in range(len(seq_group_metadata.seq_data)):
encoder_seq_lens.append(
seq_group_metadata.encoder_seq_data.get_len())
cross_block_table = seq_group_metadata.cross_block_table
cross_block_tables.append([] if (
cross_block_table is None) else cross_block_table)
max_len_of_block_table = max(
len(block_table) for block_table in cross_block_tables)
cross_block_tables = make_tensor_with_pad(
cross_block_tables,
max_len=max_len_of_block_table,
pad=0,
dtype=torch.int32,
device=self.device,
)
# Compute encoder sequence lengths & encoder
# sequence starting offset tensors
max_encoder_seq_len = max(encoder_seq_lens, default=0)
encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens)
encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] +
1,
dtype=torch.int32,
device=self.device)
torch.cumsum(encoder_seq_lens_tensor,
dim=0,
dtype=encoder_seq_start_loc.dtype,
out=encoder_seq_start_loc[1:])
# Update attention metadata with encoder-oriented attributes
attn_metadata = model_input.attn_metadata
assert attn_metadata is not None
(
attn_metadata.num_encoder_tokens,
attn_metadata.encoder_seq_lens,
attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
attn_metadata.cross_slot_mapping,
attn_metadata.cross_block_tables,
) = (
sum(encoder_seq_lens),
encoder_seq_lens,
encoder_seq_lens_tensor,
max_encoder_seq_len,
cross_slot_mapping_tensor,
cross_block_tables,
)
return (attn_metadata, encoder_input_tokens_tensor,
encoder_input_positions_tensor)
@torch.no_grad()
def execute_model(
self,
model_input: EncoderDecoderModelInputForCPU,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
raise ValueError(
"CPU worker does not support multi-step execution.")
model_executable = self.model
execute_model_kwargs = {
"input_ids":
model_input.input_tokens,
"positions":
model_input.input_positions,
"encoder_input_ids":
model_input.encoder_input_tokens,
"encoder_positions":
model_input.encoder_input_positions,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
device=self.device,
),
"intermediate_tensors":
intermediate_tensors,
}
with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_states = model_executable(**execute_model_kwargs)
# Compute the logits.
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
return []
# Sample the next token.
output = self.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
return [output]

View File

@ -0,0 +1,671 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import weakref
from collections import defaultdict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Type,
TypeVar, Union)
import torch
from torch import nn
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.multimodal import (BatchedTensorInputs, MultiModalKwargs,
MultiModalPlaceholderMap)
from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__)
TModelInputForCPU = TypeVar('TModelInputForCPU', bound="ModelInputForCPU")
_PAD_SLOT_ID = -1
@dataclass(frozen=True)
class ModelInputForCPU(ModelRunnerInputBase):
"""
Base class contains metadata needed for the base model forward pass on CPU
"""
input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
token_type_ids: Optional[torch.Tensor] = None
attn_metadata: Optional["AttentionMetadata"] = None
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
virtual_engine: Optional[int] = None
seq_lens: Optional[List[int]] = None
query_lens: Optional[List[int]] = None
lora_mapping: Optional["LoRAMapping"] = None
lora_requests: Optional[Set[LoRARequest]] = None
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"token_type_ids": self.token_type_ids,
"multi_modal_kwargs": self.multi_modal_kwargs,
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls: Type[TModelInputForCPU],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None
) -> TModelInputForCPU:
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
@dataclass(frozen=True)
class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
"""
Used by the ModelRunner.
"""
sampling_metadata: Optional["SamplingMetadata"] = None
is_prompt: Optional[bool] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"token_type_ids": self.token_type_ids,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "ModelInputForCPUWithSamplingMetadata":
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
class ModelInputData:
def __init__(self, use_mrope: bool):
self.use_mrope = use_mrope
self.input_tokens: List[int] = []
self.input_positions: List[int] = []
self.token_type_ids: Optional[List[int]] = []
self.seq_lens: List[int] = []
self.query_lens: List[int] = []
self.prefill_block_tables: List[List[int]] = []
self.decode_block_tables: List[List[int]] = []
self.max_decode_seq_len: int = 0
self.num_prefills: int = 0
self.num_prefill_tokens: int = 0
self.num_decode_tokens: int = 0
self.slot_mapping: List[int] = []
self.multi_modal_inputs_list: List[MultiModalKwargs] = []
self.multi_modal_placeholder_maps: Dict[
str, MultiModalPlaceholderMap] = defaultdict(
MultiModalPlaceholderMap)
self.input_mrope_positions: List[List[int]] = [[]
for _ in range(3)]
def __init__(self,
runner: "CPUModelRunner",
finished_requests_ids: Optional[List[str]] = None) -> None:
super().__init__()
self.runner = runner
self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled
or runner.cache_config.enable_prefix_caching)
self.model_input_cls = self.runner._model_input_cls
self.attn_backend = self.runner.attn_backend
self.sliding_window = self.runner.sliding_window
self.block_size = self.runner.block_size
self.device = self.runner.device
self.enable_lora = self.runner.lora_config is not None
if self.runner.attn_backend is not None:
# spec decode (e.g. Medusa) does not have atten backend
attn_backend = self.runner.attn_backend
self.att_metadata_builder = attn_backend.get_builder_cls()(self)
def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
self.input_data = ModelInputForCPUBuilder.ModelInputData(
self.runner.model_config.uses_mrope)
self.att_metadata_builder.prepare()
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
self.seq_group_metadata_list.append(seq_group_metadata)
def set_seq_group_list(
self, seq_group_metadata_list: List[SequenceGroupMetadata]):
self.seq_group_metadata_list = seq_group_metadata_list
def build(self) -> ModelInputForCPU:
self._build_input_data()
input_data = self.input_data
input_tokens = torch.tensor(input_data.input_tokens,
dtype=torch.long,
device="cpu")
input_positions = torch.tensor(
input_data.input_positions
if not any(input_data.input_mrope_positions) else
input_data.input_mrope_positions,
dtype=torch.long,
device="cpu")
token_type_ids = torch.tensor(input_data.token_type_ids,
dtype=torch.long,
device="cpu") \
if input_data.token_type_ids else None
# For multi-modal models
multi_modal_kwargs = None
if len(input_data.multi_modal_inputs_list) != 0:
multi_modal_kwargs = MultiModalKwargs.batch(
input_data.multi_modal_inputs_list)
attn_metadata = self.att_metadata_builder.build(
input_data.seq_lens, input_data.query_lens, -1, -1)
is_prompt = (self.seq_group_metadata_list[0].is_prompt
if self.seq_group_metadata_list else None)
# LoRA data.
lora_requests = set()
lora_mapping = None
if self.enable_lora:
lora_requests = set(seq.lora_request
for seq in self.seq_group_metadata_list
if seq.lora_request is not None)
lora_mapping = self._prepare_lora_input(
self.seq_group_metadata_list, is_prompt)
return self.model_input_cls(input_tokens=input_tokens,
input_positions=input_positions,
token_type_ids=token_type_ids,
seq_lens=input_data.seq_lens,
query_lens=input_data.query_lens,
attn_metadata=attn_metadata,
multi_modal_kwargs=multi_modal_kwargs,
lora_mapping=lora_mapping,
lora_requests=lora_requests)
def _build_input_data(self):
for seq_group_metadata in self.seq_group_metadata_list:
for seq_id, seq_data in seq_group_metadata.seq_data.items():
if seq_group_metadata.is_prompt:
self._compute_prompt_input_tokens(self.input_data,
seq_group_metadata,
seq_data, seq_id)
if seq_group_metadata.multi_modal_data:
self._compute_multi_modal_input(
seq_group_metadata, seq_data)
else:
self._compute_decode_input_tokens(self.input_data,
seq_group_metadata,
seq_data, seq_id)
def _compute_decode_input_tokens(self, data: ModelInputData,
seq_group_metadata: SequenceGroupMetadata,
seq_data: SequenceData, seq_id: int):
"""
Compute decode input tokens, positions, block table and slot mapping.
"""
block_size = self.runner.block_size
block_table = seq_group_metadata.block_tables[seq_id]
seq_len = seq_data.get_len()
context_len = seq_data.get_num_computed_tokens()
tokens = seq_data.get_last_token_id()
token_positions = seq_len - 1
block_number = block_table[token_positions // block_size]
block_offset = token_positions % block_size
slot = block_number * block_size + block_offset
# For paged_attention kernel
if self.runner.sliding_window:
start_idx = max(0, seq_len - self.runner.sliding_window)
start_block = start_idx // block_size
start_idx = start_block * block_size
seq_len = seq_len - start_idx
block_table = block_table[start_block:]
# For MRotaryEmbedding
if seq_data.mrope_position_delta is not None:
next_pos = MRotaryEmbedding.get_next_input_positions(
seq_data.mrope_position_delta,
context_len,
seq_len,
)
for idx in range(3):
data.input_mrope_positions[idx].extend( # type: ignore
next_pos[idx])
else:
data.input_positions.append(token_positions) # type: ignore
# Update fields
data.input_tokens.append(tokens)
data.max_decode_seq_len = max(data.max_decode_seq_len, seq_len)
data.num_decode_tokens += 1
data.slot_mapping.append(slot)
data.decode_block_tables.append(block_table)
data.query_lens.append(1)
data.seq_lens.append(seq_len)
def _compute_prompt_input_tokens(self, data: ModelInputData,
seq_group_metadata: SequenceGroupMetadata,
seq_data: SequenceData, seq_id: int):
"""
Compute prompt input tokens, positions, block table and slot mapping.
"""
token_chunk_size = seq_group_metadata.token_chunk_size
block_size = self.runner.block_size
block_table = seq_group_metadata.block_tables[seq_id]
seq_len = seq_data.get_len()
context_len = seq_data.get_num_computed_tokens()
seq_len = min(seq_len, context_len + token_chunk_size)
# For prefix caching
prefix_cache_block_num = len(seq_group_metadata.computed_block_nums)
if prefix_cache_block_num > 0:
prefix_cache_len = (prefix_cache_block_num *
self.runner.block_size)
if prefix_cache_len <= context_len:
# We already passed the cache hit region,
# so do normal computation.
pass
elif context_len < prefix_cache_len < seq_len:
# Partial hit. Compute the missing part.
context_len = prefix_cache_len
token_chunk_size = seq_len - context_len
elif seq_len <= prefix_cache_len:
# Full hit. Only compute the last token to avoid
# erroneous behavior. FIXME: Ideally we should directly
# mark all tokens as computed in the scheduler and do not
# schedule this sequence, so this case should not happen.
context_len = seq_len - 1
token_chunk_size = 1
tokens = seq_data.get_token_ids()
tokens = tokens[context_len:seq_len]
token_positions = range(context_len, seq_len)
token_types = seq_group_metadata.token_type_ids
# For encoder-only models, the block_table is None,
# and there is no need to initialize the slot_mapping.
if block_table is not None:
slot_mapping = [_PAD_SLOT_ID] * len(token_positions)
for i, pos in enumerate(token_positions):
block_number = block_table[pos // block_size]
block_offset = pos % block_size
slot = block_number * block_size + block_offset
slot_mapping[i] = slot
data.slot_mapping.extend(slot_mapping)
# The MROPE positions are prepared in _compute_multi_modal_input
data.input_positions.extend(token_positions)
if data.token_type_ids is not None:
data.token_type_ids.extend(token_types if token_types else [])
# Update fields
data.input_tokens.extend(tokens)
data.num_prefills += 1
data.num_prefill_tokens += len(tokens)
data.query_lens.append(len(tokens))
data.prefill_block_tables.append(block_table)
data.seq_lens.append(seq_len)
def _compute_multi_modal_input(self,
seq_group_metadata: SequenceGroupMetadata,
seq_data: SequenceData):
computed_len = seq_data.get_num_computed_tokens()
seq_len = self.input_data.seq_lens[-1]
# NOTE: mm_kwargs only includes the subset of multi-modal items that
# intersect with the current prefill positions.
mm_kwargs, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
seq_group_metadata, range(computed_len, seq_len))
if not mm_kwargs:
return
# special processing for mrope position deltas.
if self.runner.model_config.uses_mrope:
assert not self.chunked_prefill, \
"MROPE on CPU does not support chunked-prefill."
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
audio_feature_lengths = mm_kwargs.get("audio_feature_lengths",
None)
assert (
image_grid_thw is not None or video_grid_thw is not None
or audio_feature_lengths is not None), (
"mrope embedding type requires multi-modal input mapper "
"returns 'image_grid_thw' or 'video_grid_thw' or "
"'audio_feature_lengths'.")
second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None)
use_audio_in_video = mm_kwargs.get("use_audio_in_video", False)
hf_config = self.runner.model_config.hf_config
token_ids = seq_data.get_token_ids()
mrope_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions(
token_ids,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
context_len=computed_len,
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
seq_data.mrope_position_delta = mrope_position_delta
for i in range(3):
self.input_data.input_mrope_positions[ # type: ignore
i].extend(mrope_positions[i])
self.input_data.multi_modal_inputs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items():
self.input_data.multi_modal_placeholder_maps[modality].extend(
placeholder_map)
def _prepare_lora_input(
self, seq_group_metadata_list: List[SequenceGroupMetadata],
is_prefill: bool) -> LoRAMapping:
index_mapping = []
prompt_mapping = []
for seq in seq_group_metadata_list:
lora_id = seq.lora_int_id
query_len = seq.token_chunk_size
index_mapping += [lora_id] * query_len
prompt_mapping += [lora_id] * (
query_len if seq.sampling_params
and seq.sampling_params.prompt_logprobs is not None else 1)
return LoRAMapping(index_mapping=tuple(index_mapping),
prompt_mapping=tuple(prompt_mapping),
is_prefill=is_prefill)
class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
"""
Helper class for shared methods between CPU model runners.
"""
_model_input_cls: Type[TModelInputForCPU]
_builder_cls: Type[ModelInputForCPUBuilder]
builder: ModelInputForCPUBuilder
def __init__(
self,
vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
return_hidden_states: bool = False,
*args,
**kwargs,
):
ModelRunnerBase.__init__(self, vllm_config)
model_config = self.model_config
cache_config = self.cache_config
self.is_driver_worker = is_driver_worker
self.return_hidden_states = return_hidden_states
self.device = self.device_config.device
self.pin_memory = False
self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
num_attn_heads = self.model_config.get_num_attention_heads(
self.parallel_config)
needs_attn_backend = (num_attn_heads != 0
or self.model_config.is_attention_free)
self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.model_config.use_mla,
) if needs_attn_backend else None
# Lazy initialization.
self.model: nn.Module # Set after init_Model
# Set after load_model.
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
self.sampler = get_sampler()
if hasattr(self, "_builder_cls"):
# multi-step model runner does not have `_builder_cls`
self.builder = self._builder_cls(weakref.proxy(self))
def load_model(self) -> None:
self.model = get_model(vllm_config=self.vllm_config)
if self.lora_config:
assert supports_lora(
self.model
), f"{self.model.__class__.__name__} does not support LoRA yet."
if supports_multimodal(self.model):
logger.warning("Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.")
# Use get_text_config() in case of multimodal models
text_config = self.model_config.hf_config.get_text_config()
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens,
self.vocab_size,
self.lora_config,
self.device,
self.model.embedding_modules,
self.model.embedding_padding_modules,
max_position_embeddings=text_config.max_position_embeddings,
)
self.model = self.lora_manager.create_lora_manager(self.model)
def get_model(self) -> nn.Module:
return self.model
def _prepare_model_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
finished_requests_ids: Optional[List[str]] = None
) -> TModelInputForCPU:
"""Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not
metadata for possible additional steps, e.g., sampling.
"""
self.builder.prepare(finished_requests_ids)
self.builder.set_seq_group_list(seq_group_metadata_list)
return self.builder.build() # type: ignore
@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()
def remove_all_loras(self):
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
self.lora_manager.remove_all_adapters()
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
def add_lora(self, lora_request: LoRARequest) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.add_adapter(lora_request)
def remove_lora(self, lora_id: int) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_adapter(lora_id)
def pin_lora(self, lora_id: int) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.pin_adapter(lora_id)
def list_loras(self) -> Set[int]:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.list_adapters()
class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
_model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = (
ModelInputForCPUWithSamplingMetadata)
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str, Any],
) -> ModelInputForCPUWithSamplingMetadata:
return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict( # noqa: E501
tensor_dict,
attn_backend=self.attn_backend,
)
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForCPUWithSamplingMetadata:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
"""
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
# Sampling metadata is only required for the final pp group
generators = self.get_generators(finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
model_input.seq_lens,
model_input.query_lens,
self.device,
pin_memory=False,
generators=generators)
is_prompt = (seq_group_metadata_list[0].is_prompt
if seq_group_metadata_list else None)
return dataclasses.replace(model_input,
sampling_metadata=sampling_metadata,
virtual_engine=virtual_engine,
is_prompt=is_prompt)
@torch.no_grad()
def execute_model(
self,
model_input: ModelInputForCPUWithSamplingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
previous_hidden_states: Optional[torch.Tensor] = None,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
raise ValueError(
"CPU worker does not support multi-step execution.")
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
model_executable = self.model
multimodal_kwargs = {}
if model_input.multi_modal_kwargs is not None:
multimodal_kwargs = MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs,
device=self.device,
)
execute_model_kwargs = {}
if previous_hidden_states is not None:
execute_model_kwargs.update(
{"previous_hidden_states": previous_hidden_states})
with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**execute_model_kwargs,
**multimodal_kwargs,
)
# Compute the logits.
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
return []
# Sample the next token.
output = self.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
if self.return_hidden_states:
# we only need to pass hidden states of most recent token
if model_input.is_prompt:
output.prefill_hidden_states = hidden_states
output.hidden_states = hidden_states
return [output]
def generate_proposals(self, *args, **kwargs):
return self.model.generate_proposals(*args, **kwargs)

View File

@ -0,0 +1,125 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch
from vllm.forward_context import set_forward_context
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MultiModalKwargs
from vllm.pooling_params import PoolingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.worker.cpu_model_runner import (CPUModelRunnerBase, ModelInputForCPU,
ModelInputForCPUBuilder)
@dataclasses.dataclass(frozen=True)
class ModelInputForCPUWithPoolingMetadata(ModelInputForCPU):
"""
Used by the CPUPoolingModelRunner.
"""
pooling_metadata: Optional["PoolingMetadata"] = None
class CPUPoolingModelRunner(
CPUModelRunnerBase[ModelInputForCPUWithPoolingMetadata]):
_model_input_cls: Type[ModelInputForCPUWithPoolingMetadata] = (
ModelInputForCPUWithPoolingMetadata)
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForCPUWithPoolingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]:
if num_steps > 1:
raise ValueError(
"CPU worker does not support multi-step execution.")
model_executable = self.model
cross_enc_kwargs = {}
if model_input.token_type_ids is not None:
cross_enc_kwargs["token_type_ids"] = model_input.token_type_ids
execute_model_kwargs = {
"input_ids":
model_input.input_tokens,
"positions":
model_input.input_positions,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
device=self.device,
),
**cross_enc_kwargs,
"intermediate_tensors":
intermediate_tensors,
}
with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_states = model_executable(**execute_model_kwargs)
# Only perform pooling in the driver worker.
if not self.is_driver_worker:
return []
return [
self.model.pooler(hidden_states=hidden_states,
pooling_metadata=model_input.pooling_metadata)
]
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str,
Any]) -> ModelInputForCPUWithPoolingMetadata:
return ModelInputForCPUWithPoolingMetadata.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
)
def prepare_model_input(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForCPUWithPoolingMetadata:
assert seq_group_metadata_list is not None
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
# Prepare PoolingMetadata.
assert model_input.seq_lens is not None
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
model_input.seq_lens)
return dataclasses.replace(model_input,
virtual_engine=virtual_engine,
pooling_metadata=pooling_metadata)
def _prepare_pooling(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
) -> PoolingMetadata:
"""Prepare PoolingMetadata for the sequence group metadata list."""
seq_groups: List[Tuple[List[int], PoolingParams]] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
pooling_params = seq_group_metadata.pooling_params
seq_groups.append((seq_ids, pooling_params))
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
pooling_metadata = PoolingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
)
return pooling_metadata

452
vllm/worker/cpu_worker.py Normal file
View File

@ -0,0 +1,452 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A CPU worker class."""
import os
from importlib import util
from typing import List, Optional, Set, Tuple, Type
import torch
import torch.distributed
import vllm.envs as envs
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, VllmConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest
from vllm.utils import bind_kv_cache
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)
logger = init_logger(__name__)
class CPUCacheEngine:
"""Manages the KV cache for CPU backend.
This class is responsible for initializing and managing CPU KV
caches. It also provides methods for performing KV cache operations, such
as copying.
"""
def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
parallel_config: ParallelConfig,
device_config: DeviceConfig) -> None:
assert device_config.device_type == "cpu"
self.cache_config = cache_config
self.model_config = model_config
self.parallel_config = parallel_config
self.head_size = model_config.get_head_size()
self.num_layers = model_config.get_num_layers(parallel_config)
self.num_heads = model_config.get_num_kv_heads(parallel_config)
self.block_size = cache_config.block_size
# Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
# for CPU backend, because we want to reuse KV cache management
# in the scheduler.
self.num_cpu_blocks = cache_config.num_gpu_blocks
self.dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config,
model_config)
# Get attention backend.
self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,
cache_config.cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.model_config.use_mla,
)
# Initialize the cache.
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks)
def _allocate_kv_cache(
self,
num_blocks: int,
) -> List[torch.Tensor]:
"""Allocates KV cache on CPU."""
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_heads, self.head_size)
kv_cache: List[torch.Tensor] = []
for _ in range(self.num_layers):
kv_cache.append(
torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu"))
return kv_cache
def swap_in(self, src_to_dst: torch.Tensor) -> None:
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
def swap_out(self, src_to_dst: torch.Tensor) -> None:
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
def copy(self, src_to_dsts: torch.Tensor) -> None:
self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
@staticmethod
def get_kv_cache_dtype(cache_config: CacheConfig,
model_config: ModelConfig):
if cache_config.cache_dtype == "auto":
return model_config.dtype
elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]:
return torch.float8_e5m2
else:
raise NotImplementedError(f"Unsupported KV cache type "
f"{cache_config.cache_dtype}.")
@staticmethod
def get_cache_block_size(
cache_config: CacheConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
head_size = model_config.get_head_size()
num_heads = model_config.get_num_kv_heads(parallel_config)
num_layers = model_config.get_num_layers(parallel_config)
key_cache_block = cache_config.block_size * num_heads * head_size
value_cache_block = key_cache_block if not model_config.use_mla else 0
total = num_layers * (key_cache_block + value_cache_block)
dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config, model_config)
dtype_size = torch.tensor([], dtype=dtype).element_size()
return dtype_size * total
class CPUWorker(LocalOrDistributedWorkerBase):
"""A worker class that executes (a partition of) the model on a CPU socket.
Each worker is associated with a single CPU socket. The worker is
responsible for maintaining the KV cache and executing the model on the
CPU. In case of distributed inference, each worker is assigned a partition
of the model.
"""
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
model_runner_cls: Optional[Type[CPUModelRunner]] = None,
) -> None:
WorkerBase.__init__(self, vllm_config=vllm_config)
self.local_rank = local_rank
self.rank = rank
vllm_config.parallel_config.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
# Setup OpenMP threads affinity.
omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
self.local_omp_cpuid = "all"
if omp_cpuids == "auto":
self.local_omp_cpuid = self.get_cpus_id_binding_based_on_numa_nodes(
)
else:
self.local_omp_cpuid = omp_cpuids.split("|")[rank]
# Return hidden states from target model if the draft model is an
# mlp_speculator
speculative_config = self.speculative_config
model_config = self.model_config
speculative_args = {} if speculative_config is None \
or (speculative_config.draft_model_config.model ==
model_config.model) \
or (speculative_config.draft_model_config.hf_config.model_type
not in ["medusa", "mlp_speculator", "eagle"]) \
else {"return_hidden_states": True}
ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner
if self.model_config.runner_type == "pooling":
ModelRunnerClass = CPUPoolingModelRunner
elif self.model_config.is_encoder_decoder:
ModelRunnerClass = CPUEncoderDecoderModelRunner
self.model_runner: CPUModelRunnerBase = ModelRunnerClass(
vllm_config=vllm_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
**speculative_args,
)
if model_runner_cls is not None:
self.model_runner = model_runner_cls(self.model_runner)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: List[CPUCacheEngine]
# Initialize cpu_cache as pooling models don't initialize kv_caches
self.cpu_cache: Optional[List[List[torch.Tensor]]] = None
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
],
with_stack=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True))
else:
self.profiler = None
def start_profile(self):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.start()
def stop_profile(self):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.stop()
def init_device(self) -> None:
if self.local_omp_cpuid != "all":
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
if ret:
logger.info(ret)
# Note: unique identifier for creating allreduce shared memory
os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(
":")[-1]
self.device = torch.device("cpu")
self.init_distributed_environment()
# Set random seed.
set_random_seed(self.model_config.seed)
def load_model(self):
self.model_runner.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of blocks available for the KV cache.
This determines how many KV blocks can fit into the configured CPU
KV cache space.
Note that since vLLM assumes a block resides on GPU if it can be
modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0.
This allows us to reuse the scheduler of vLLM without generalizing it
to different devices.
"""
# For CPU device, the block number will be calculated based on the
# cpu_kvcache_space.
cache_block_size = self.get_cache_block_size_bytes()
num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes //
cache_block_size)
num_cpu_blocks = max(num_cpu_blocks, 0)
# Note: To reuse the cache management procedure,
# use cpu cache as 'gpu cache'.
num_gpu_blocks = num_cpu_blocks
num_cpu_blocks = 0
return num_gpu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache. Currently, swappable CPU memory is not
supported.
Since this worker does not support GPUs, we use the num_gpu_blocks to
determine how many non-swappable CPU blocks to allocate.
"""
assert (num_cpu_blocks == 0
), f"{type(self)} does not support swappable cache"
# Note: To reuse the cache management procedure,
# use cpu cache as 'gpu cache'.
num_cpu_blocks = num_gpu_blocks
self._validate_num_cpu_blocks(num_cpu_blocks)
self.cache_config.num_gpu_blocks = num_cpu_blocks
self.cache_config.num_cpu_blocks = 0
# Initialize the cache.
self._init_cache_engine()
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.model_runner.list_loras()
def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None:
"""Raise errors if the num_cpu_blocks is invalid.
"""
if num_cpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `VLLM_CPU_KVCACHE_SPACE` when "
"initializing the engine.")
max_seq_len = self.cache_config.block_size * num_cpu_blocks
if self.model_config.max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({self.model_config.max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when "
"initializing the engine.")
def _init_cache_engine(self) -> None:
self.cache_engine = [
CPUCacheEngine(self.cache_config, self.model_config,
self.parallel_config, self.device_config)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.cpu_cache = [
self.cache_engine[ve].cpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)
]
bind_kv_cache(self.compilation_config.static_forward_context,
self.cpu_cache)
self.model_runner.block_size = self.cache_engine[0].block_size
assert all(
self.cpu_cache[ve] is not None
for ve in range(self.parallel_config.pipeline_parallel_size))
# Populate the cache to warmup the memory
for ve in range(self.parallel_config.pipeline_parallel_size):
for layer_cache in self.cpu_cache[ve]:
layer_cache.fill_(0)
@property
def do_metadata_broadcast(self) -> bool:
return self.parallel_config.tensor_parallel_size > 1
@property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return self.cpu_cache
@property
def vocab_size(self) -> int:
return self.model_runner.vocab_size
@property
def max_model_len(self) -> int:
return self.model_config.max_model_len
def execute_worker(
self,
worker_input: WorkerInput,
) -> None:
if (worker_input.blocks_to_copy is not None
and worker_input.blocks_to_copy.numel() > 0):
self.cache_engine[worker_input.virtual_engine].copy(
worker_input.blocks_to_copy)
@torch.inference_mode()
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
assert execute_model_req is not None
virtual_engine: int = execute_model_req.virtual_engine
num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device="cpu",
dtype=torch.int64).view(-1, 2)
assert len(execute_model_req.blocks_to_swap_in) == 0
assert len(execute_model_req.blocks_to_swap_out) == 0
return WorkerInput(
num_seq_groups=num_seq_groups,
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
)
def init_distributed_environment(self) -> None:
"""Initialize the distributed environment."""
parallel_config = self.parallel_config
rank = self.rank
distributed_init_method = self.distributed_init_method
init_distributed_environment(
world_size=parallel_config.world_size,
rank=rank,
distributed_init_method=distributed_init_method,
backend="gloo",
)
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cpu())
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
def get_cache_block_size_bytes(self) -> int:
"""Return the size in bytes of a single KV cache block.
"""
return CPUCacheEngine.get_cache_block_size(self.cache_config,
self.model_config,
self.parallel_config)
def get_cpus_id_binding_based_on_numa_nodes(self) -> str:
"""Return CPUs id binding based on NUMA nodes.
"""
rank_to_cpus = self.local_omp_cpuid
# Setup OpenMP thread affinity based on NUMA nodes automatically
world_size = self.vllm_config.parallel_config.world_size
libnuma_found = util.find_spec("numa") is not None
psutil_found = util.find_spec("psutil") is not None
if libnuma_found and psutil_found:
import psutil
from numa import info
cpu_count = psutil.cpu_count(logical=False)
cpus_allow_list = psutil.Process().cpu_affinity()
numa_size = info.get_num_configured_nodes()
cpu_count_per_numa = cpu_count // numa_size
num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU,
cpu_count_per_numa // 2)
# check allow node_to_cpus list
node_to_cpus = []
for i in range(numa_size):
node_intersect = set(
info.node_to_cpus(i)).intersection(cpus_allow_list)
if bool(node_intersect):
node_to_cpus.append(list(node_intersect))
if world_size > len(node_to_cpus):
logger.error(
"Auto thread-binding failed due to "
"world size: %d is larger than "
"allowed NUMA nodes number: %d."
"Please try to bind threads manually.", world_size,
len(node_to_cpus))
else:
end = cpu_count_per_numa - num_of_reserved_cpu
rank_to_cpus_list = node_to_cpus[self.rank][:end]
rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
logger.info("auto thread-binding list: %s", rank_to_cpus)
else:
logger.warning(
"Auto thread-binding is not supported due to "
"the lack of package numa and psutil,"
"fallback to no thread-binding. To get better performance,"
"please try to manually bind threads.")
return rank_to_cpus

View File

@ -0,0 +1,108 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from typing import Dict, Optional, Tuple
import torch
from vllm.distributed import broadcast_tensor_dict
from vllm.sequence import ExecuteModelRequest
from vllm.worker.tpu_model_runner import ModelInputForTPU
from vllm.worker.tpu_worker import TPUWorker
from vllm.worker.worker_base import WorkerInput
class MultiStepTPUWorker(TPUWorker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cached_model_input: Optional[ModelInputForTPU] = None
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
) -> Tuple[ModelInputForTPU, WorkerInput, Dict[str, torch.Tensor]]:
assert self.is_driver_worker
assert execute_model_req.virtual_engine == 0
is_first_multi_step = execute_model_req.is_first_multi_step
is_last_step = execute_model_req.is_last_step
if is_first_multi_step:
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
worker_input = dataclasses.replace(
worker_input,
num_steps=execute_model_req.num_lookahead_slots + 1)
model_input: ModelInputForTPU = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))
if execute_model_req.async_callback:
model_input = dataclasses.replace(
model_input,
async_callback=execute_model_req.async_callback)
else:
assert self.cached_model_input is not None
model_input = self.cached_model_input
worker_input = WorkerInput()
model_input = dataclasses.replace(
model_input,
is_first_multi_step=is_first_multi_step,
is_last_step=is_last_step)
if self.do_metadata_broadcast:
if is_first_multi_step:
broadcast_data = worker_input.as_broadcastable_tensor_dict()
broadcast_data.update(
model_input.as_broadcastable_tensor_dict())
broadcast_tensor_dict(broadcast_data, src=0)
else:
broadcast_data = {
"is_first_multi_step": is_first_multi_step,
"is_last_step": is_last_step,
}
broadcast_tensor_dict(broadcast_data, src=0)
# Retuning empty dict here to keep this compatible with
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
return model_input, worker_input, {}
def prepare_input(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[Tuple[ModelInputForTPU, WorkerInput, Dict[str,
torch.Tensor]]]:
if self.is_driver_worker:
if execute_model_req is None:
if self.do_metadata_broadcast:
broadcast_tensor_dict({}, src=0)
return None
model_input, worker_input, _ = self._get_driver_input_and_broadcast(
execute_model_req)
if model_input.is_first_multi_step:
self.cached_model_input = model_input
return model_input, worker_input, {}
else:
broadcast_data = broadcast_tensor_dict(src=0)
if not broadcast_data:
return None
if len(broadcast_data) == 2:
assert self.cached_model_input is not None
self.cached_model_input = dataclasses.replace(
self.cached_model_input,
is_first_multi_step=broadcast_data["is_first_multi_step"],
is_last_step=broadcast_data["is_last_step"])
empty_worker_input = WorkerInput()
return self.cached_model_input, empty_worker_input, {}
worker_input = WorkerInput.from_broadcasted_tensor_dict(
broadcast_data)
model_input = (
self.model_runner.
make_model_input_from_broadcasted_tensor_dict(broadcast_data))
self.cached_model_input = model_input
return model_input, worker_input, {}

View File

@ -0,0 +1,909 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import time
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Type, Union)
from unittest.mock import patch
import numpy as np
import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SequenceGroupMetadata, SequenceOutput)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__)
# Here we utilize the behavior that out-of-bound index is ignored.
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID = 1_000_000_000
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
_ENABLE_TOP_P = False
# FIXME(woosuk): A temporary hack to support `n > 1`.
# This can significantly affect the performance if too large.
_MAX_NUM_SAMPLES = 128
class ExecutionMode(enum.Enum):
PREFILL = enum.auto()
DECODE = enum.auto()
PREFIX_PREFILL = enum.auto()
def is_prefill(self) -> bool:
return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL)
@dataclass(frozen=True)
class ModelInputForTPU(ModelRunnerInputBase):
token_ids: torch.Tensor
position_ids: torch.Tensor
attn_metadata: AttentionMetadata
input_lens: torch.Tensor
t: torch.Tensor
p: torch.Tensor
num_samples: int
n: List[int]
seq_groups: List[List[int]]
is_first_multi_step: bool = True
is_last_step: bool = True
virtual_engine: int = 0
async_callback: Optional[Callable] = None
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
tensor_dict = {
"token_ids": self.token_ids,
"position_ids": self.position_ids,
"input_lens": self.input_lens,
"t": self.t,
"p": self.p,
"num_samples": self.num_samples,
"n": self.n,
"seq_groups": self.seq_groups,
"is_first_multi_step": self.is_first_multi_step,
"is_last_step": self.is_last_step,
"virtual_engine": self.virtual_engine,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls: Type["ModelInputForTPU"],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "ModelInputForTPU":
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
def __init__(
self,
vllm_config: VllmConfig,
is_driver_worker: bool = False,
):
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
self.is_driver_worker = is_driver_worker
self.block_size = self.cache_config.block_size
self.max_num_blocks_per_seq = (self.model_config.max_model_len //
self.block_size)
self.block_tables = np.zeros(
(self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq),
dtype=np.int32)
self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,
self.cache_config.cache_dtype,
self.block_size,
self.model_config.is_attention_free,
False,
)
self.cached_step_outputs: List[torch.Tensor] = []
smem_size = 512 * 1024
block_table_size = 4 * self.block_tables.size
if block_table_size >= smem_size:
logger.warning(
"The max_model_len (%d) is too large. This may degrade the "
"performance due to the insufficient smem size. Consider "
"setting --max-model-len to a smaller value, like %d.",
self.model_config.max_model_len,
self.model_config.max_model_len /
(block_table_size / smem_size))
def load_model(self) -> None:
self.device = self.device_config.device
# NOTE(woosuk): While the executor assigns the TP ranks to the worker
# process, the ranks can be different from the ranks internally assigned
# by the xm runtime. Therefore, there is a mismatch in the rank
# assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
# This is not a problem in linear layers because all-reduce is
# rank-agnostic. However, it matters for all-gather as the ranks
# determine the order of concatenating the output tensors.
# As a workaround, we use the xm's rank assignment only when loading
# the embedding weights.
xm_tp_rank = xr.global_ordinal()
with patch(
"vllm.model_executor.layers.vocab_parallel_embedding."
"get_tensor_model_parallel_rank",
return_value=xm_tp_rank):
model = get_model(vllm_config=self.vllm_config)
model = model.eval()
xm.wait_device_ops()
model = ModelWrapper(model)
self.model = torch.compile(model,
backend="openxla",
fullgraph=True,
dynamic=False)
def get_model(self) -> nn.Module:
return self.model.model
def _dummy_run(
self,
batch_size: int,
seq_len: int,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
exec_mode: ExecutionMode,
) -> None:
exec_mode = ExecutionMode(exec_mode)
if exec_mode.is_prefill():
seq_len = (seq_len + 15) // 16 * 16
token_ids = torch.zeros((batch_size, seq_len),
dtype=torch.int32,
device=self.device)
position_ids = torch.zeros((batch_size, seq_len),
dtype=torch.int32,
device=self.device)
slot_mapping = torch.zeros((batch_size, seq_len),
dtype=torch.int64,
device=self.device)
input_lens = torch.ones((batch_size, ),
dtype=torch.int32,
device=self.device)
if exec_mode == ExecutionMode.PREFILL:
attn_metadata = self.attn_backend.make_metadata(
num_prefills=batch_size,
num_prefill_tokens=batch_size * seq_len,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=None,
context_lens=None,
effective_query_lens=None,
)
else:
context_lens = torch.ones((batch_size, ),
dtype=torch.int32,
device=self.device)
block_tables = torch.tensor(self.block_tables[:batch_size],
dtype=torch.int32,
device=self.device)
effective_query_lens = torch.ones_like(context_lens)
attn_metadata = self.attn_backend.make_metadata(
num_prefills=batch_size,
num_prefill_tokens=batch_size * seq_len,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables,
context_lens=context_lens,
effective_query_lens=effective_query_lens,
)
else:
assert seq_len == 1
token_ids = torch.zeros((batch_size, seq_len),
dtype=torch.int32,
device=self.device)
position_ids = torch.zeros((batch_size, seq_len),
dtype=torch.int32,
device=self.device)
slot_mapping = torch.zeros((batch_size, seq_len),
dtype=torch.int64,
device=self.device)
block_tables = torch.zeros(
(batch_size, self.max_num_blocks_per_seq),
dtype=torch.int32,
device=self.device)
context_lens = torch.ones((batch_size, ),
dtype=torch.int32,
device=self.device)
input_lens = torch.ones((batch_size, ),
dtype=torch.int32,
device=self.device)
attn_metadata = self.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size * seq_len,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables,
context_lens=context_lens,
)
t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1
# NOTE(woosuk): There are two stages of compilation: torch.compile and
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
# overhead by reusing the FX graph for different shapes.
# However, the XLA graph will still require static shapes and needs to
# be re-compiled for every different shapes. This overhead is inevitable
# in the first run, but can be skipped afterwards as we cache the XLA
# graphs in the disk (VLLM_XLA_CACHE_PATH).
if exec_mode.is_prefill():
# Prefll
torch._dynamo.mark_dynamic(token_ids, 1)
torch._dynamo.mark_dynamic(position_ids, 1)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
else:
# Decode
torch._dynamo.mark_dynamic(token_ids, 0)
torch._dynamo.mark_dynamic(position_ids, 0)
torch._dynamo.mark_dynamic(input_lens, 0)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
torch._dynamo.mark_dynamic(t, 0)
torch._dynamo.mark_dynamic(p, 0)
# Dummy run.
with set_forward_context(attn_metadata, self.vllm_config, 0):
self.model(token_ids, position_ids, input_lens, t, p, num_samples,
kv_caches)
def warmup_model(
self,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> None:
# Prefill
logger.info("Compiling the model with different input shapes...")
start = time.time()
for batch_size in [1]:
seq_len = 16
while seq_len <= self.model_config.max_model_len:
self._dummy_run(batch_size,
seq_len,
kv_caches,
exec_mode=ExecutionMode.PREFILL)
xm.wait_device_ops()
logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
num_tokens = batch_size * seq_len
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
break
seq_len = seq_len * 2
end = time.time()
logger.info("Compilation for prefill done in %.2f s.", end - start)
# Prefix prefill
if self.cache_config.enable_prefix_caching:
logger.info("Compiling the model with different input shapes for "
"prefix prefill...")
start = time.time()
for batch_size in [1]:
seq_len = 16
while seq_len <= self.model_config.max_model_len:
self._dummy_run(batch_size,
seq_len,
kv_caches,
exec_mode=ExecutionMode.PREFIX_PREFILL)
xm.wait_device_ops()
logger.info("batch_size: %d, seq_len: %d", batch_size,
seq_len)
num_tokens = batch_size * seq_len
if (num_tokens
>= self.scheduler_config.max_num_batched_tokens):
break
seq_len = seq_len * 2
end = time.time()
logger.info("Compilation for prefix prefill done in %.2f s.",
end - start)
# Decode
start = time.time()
seq_len = 1
batch_size = 8 # Must be in sync with _get_padded_batch_size()
while True:
self._dummy_run(batch_size,
seq_len,
kv_caches,
exec_mode=ExecutionMode.DECODE)
xm.wait_device_ops()
logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
if batch_size >= self.scheduler_config.max_num_seqs:
break
batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
end = time.time()
logger.info("Compilation for decode done in %.2f s.", end - start)
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
prompt_lens: List[int] = []
context_lens: List[int] = []
slot_mapping: List[int] = []
for batch_idx, seq_group_metadata in enumerate(
seq_group_metadata_list):
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
assert len(seq_ids) == 1
seq_id = seq_ids[0]
seq_data = seq_group_metadata.seq_data[seq_id]
# Could include output tokens when a request is preempted.
prompt_tokens = seq_data.get_token_ids()
seq_len = len(prompt_tokens)
num_computed_blocks = len(seq_group_metadata.computed_block_nums)
num_computed_tokens = num_computed_blocks * self.block_size
if num_computed_tokens > 0:
prompt_tokens = prompt_tokens[num_computed_tokens:]
context_lens.append(seq_len)
else:
context_lens.append(0)
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
input_tokens.extend(prompt_tokens)
input_positions.extend(range(num_computed_tokens, seq_len))
assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
for i in range(num_computed_tokens, seq_len):
block_number = block_table[i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
if num_computed_tokens > 0:
self.block_tables[batch_idx, :len(block_table)] = block_table
# Add paddings to EACH prompt to the smallest power of 2 that is
# greater than or equal to the prompt length.
# We pad the seq_len to reduce the compilation overhead.
# We execute each prompt individually (i.e., with batch_size 1)
# because the FlashAttention kernel does not support ragged inputs.
# TODO(woosuk): Use SplashAttention to support ragged inputs.
padded_prompt_len = _get_padded_prefill_len(prompt_len)
num_paddings = padded_prompt_len - prompt_len
input_tokens += [0] * num_paddings
input_positions += [0] * num_paddings
slot_mapping += [_PAD_SLOT_ID] * num_paddings
assert len(prompt_lens) > 0
num_prefills = len(prompt_lens)
input_tokens = torch.tensor(input_tokens,
dtype=torch.int32,
device="cpu")
input_positions = torch.tensor(input_positions,
dtype=torch.int32,
device="cpu")
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.int64,
device="cpu")
prompt_lens = torch.tensor(prompt_lens,
dtype=torch.int32,
device="cpu")
context_lens = torch.tensor(context_lens,
dtype=torch.int32,
device="cpu")
block_tables = torch.tensor(self.block_tables[:num_prefills],
dtype=torch.int32,
device="cpu")
attn_metadata = self.attn_backend.make_metadata(
num_prefills=num_prefills,
num_prefill_tokens=0, # NOTE: This is not used.
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables,
context_lens=context_lens,
effective_query_lens=prompt_lens,
)
return input_tokens, input_positions, attn_metadata, prompt_lens
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
context_lens: List[int] = []
batch_idx = 0
for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token])
seq_len = seq_data.get_len()
position = seq_len - 1
input_positions.append([position])
context_lens.append(seq_len)
assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
self.block_tables[batch_idx, :len(block_table)] = block_table
batch_idx += 1
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append([slot])
batch_size = _get_padded_batch_size(batch_idx)
num_paddings = batch_size - batch_idx
input_tokens = input_tokens + [[0]] * num_paddings
input_positions = input_positions + [[0]] * num_paddings
slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
context_lens = context_lens + [0] * num_paddings
input_tokens = torch.tensor(input_tokens,
dtype=torch.int32,
device="cpu")
input_positions = torch.tensor(input_positions,
dtype=torch.int32,
device="cpu")
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.int64,
device="cpu")
context_lens = torch.tensor(context_lens,
dtype=torch.int32,
device="cpu")
block_tables = torch.tensor(self.block_tables[:batch_size],
dtype=torch.int32,
device="cpu")
input_lens = torch.tensor([1] * batch_size,
dtype=torch.int32,
device="cpu")
attn_metadata = self.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables,
context_lens=context_lens,
)
return input_tokens, input_positions, attn_metadata, input_lens
def _prepare_sample(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
padded_batch_size: int,
) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
assert len(seq_group_metadata_list) > 0
t = []
p = []
n = []
for seq_group_metadata in seq_group_metadata_list:
sampling_params = seq_group_metadata.sampling_params
t.append(sampling_params.temperature)
if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
raise NotImplementedError(
"Top-p sampling is currently disabled for the TPU backend "
"due to performance issues.")
p.append(sampling_params.top_p)
if sampling_params.top_k > 0:
raise NotImplementedError(
"Top-k sampling is currently disabled for the TPU backend "
"due to performance issues.")
if sampling_params.n > _MAX_NUM_SAMPLES:
raise NotImplementedError(
f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
"backend.")
n.append(sampling_params.n)
if sampling_params.logprobs is not None:
raise NotImplementedError(
"logprobs is not currently supported by the TPU backend.")
if sampling_params.prompt_logprobs is not None:
raise NotImplementedError(
"prompt_logprobs is not currently supported by the TPU "
"backend.")
# Repeat the sampling params if the seq group has multiple seqs.
num_seqs = len(seq_group_metadata.seq_data)
t += [t[-1]] * (num_seqs - 1)
p += [p[-1]] * (num_seqs - 1)
n += [n[-1]] * (num_seqs - 1)
num_paddings = padded_batch_size - len(t)
t += [1.0] * num_paddings
p += [1.0] * num_paddings
t = torch.tensor(t, dtype=torch.float32, device="cpu")
p = torch.tensor(p, dtype=torch.float32, device="cpu")
return t, p, n
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None,
) -> ModelInputForTPU:
del finished_requests_ids # Unused.
assert virtual_engine == 0
assert len(seq_group_metadata_list) > 0
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
if is_prompt:
inputs = self._prepare_prompt(seq_group_metadata_list)
else:
inputs = self._prepare_decode(seq_group_metadata_list)
input_tokens, input_positions, attn_metadata, input_lens = inputs
padded_batch_size = input_tokens.shape[0]
t, p, n = self._prepare_sample(seq_group_metadata_list,
padded_batch_size)
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
seq_groups = [
list(metadata.seq_data.keys())
for metadata in seq_group_metadata_list
]
return ModelInputForTPU(input_tokens, input_positions, attn_metadata,
input_lens, t, p, num_samples, n, seq_groups)
def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU:
model_input = ModelInputForTPU.from_broadcasted_tensor_dict(
tensor_dict, attn_backend=self.attn_backend)
return model_input
@torch.no_grad()
def execute_model(
self,
model_input: ModelInputForTPU,
kv_caches: Optional[List[Any]],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> List[SamplerOutput]:
assert intermediate_tensors is None
if not model_input.is_first_multi_step:
if not model_input.is_last_step:
return []
use_async_out_proc = model_input.async_callback is not None
sampler_outputs = []
num_outputs = len(self.cached_step_outputs)
for i in range(num_outputs):
next_token_ids = self.cached_step_outputs.pop(0)
next_token_ids = next_token_ids.cpu().tolist()
sampler_output = _make_decode_output(next_token_ids,
model_input.seq_groups)
sampler_outputs.append(sampler_output)
if i < num_outputs - 1 and use_async_out_proc:
assert model_input.async_callback is not None
ctx = model_input.async_callback.keywords[ # type: ignore
"ctx"]
ctx.append_output(
outputs=[sampler_output],
seq_group_metadata_list=ctx.seq_group_metadata_list,
scheduler_outputs=ctx.scheduler_outputs,
is_async=False,
is_last_step=False,
is_first_step_output=i == 0)
model_input.async_callback()
if use_async_out_proc:
return [sampler_outputs[-1]]
else:
return sampler_outputs
is_prompt = model_input.attn_metadata.num_prefills > 0
if is_prompt:
assert num_steps == 1
# NOTE(woosuk): Since the FlashAttention kernel does not support
# ragged inputs, we split the prompts into different batches and
# process them separately. This is a temporary hack that should be
# optimized by using SplashAttention.
orig_slot_mapping = model_input.attn_metadata.slot_mapping
orig_block_tables = model_input.attn_metadata.block_tables
orig_context_lens = model_input.attn_metadata.context_lens
orig_effective_query_lens = \
model_input.attn_metadata.effective_query_lens
batch_size = model_input.input_lens.shape[0]
start_idx = 0
next_token_ids = []
for i in range(batch_size):
# Get the actual prefill_len.
prefill_len = model_input.input_lens[i:i + 1].item()
prefill_len = _get_padded_prefill_len(prefill_len)
end_idx = start_idx + prefill_len
token_ids = model_input.token_ids[None, start_idx:end_idx].to(
self.device)
position_ids = model_input.position_ids[None,
start_idx:end_idx].to(
self.device)
attn_metadata = model_input.attn_metadata
attn_metadata.num_prefills = 1
attn_metadata.slot_mapping = orig_slot_mapping[
None, start_idx:end_idx].to(self.device)
if orig_context_lens[i].item() > 0:
attn_metadata.context_lens = orig_context_lens[i:i + 1].to(
self.device)
attn_metadata.block_tables = orig_block_tables[
i].unsqueeze(0).to(self.device)
attn_metadata.effective_query_lens = \
orig_effective_query_lens[i:i + 1].to(self.device)
else:
attn_metadata.context_lens = None
attn_metadata.block_tables = None
attn_metadata.effective_query_lens = None
input_lens = model_input.input_lens[i:i + 1].to(self.device)
t = model_input.t[i:i + 1].to(self.device)
p = model_input.p[i:i + 1].to(self.device)
with set_forward_context(model_input.attn_metadata,
self.vllm_config,
model_input.virtual_engine):
output_token_ids = self.model(token_ids, position_ids,
input_lens, t, p,
model_input.num_samples,
kv_caches)
next_token_ids.append(output_token_ids[0])
start_idx = end_idx
if model_input.async_callback is not None:
model_input.async_callback()
# Retrieve the outputs to CPU.
next_token_ids = [
output_token_ids.cpu().tolist()
for output_token_ids in next_token_ids
]
# NOTE(woosuk): Minimal code to construct the sampler outputs.
# The TPU backend does not reuse the sampler, since the TPU backend
# does not support advanced sampling parameters such as logprobs.
zero_logprob = Logprob(0.0)
sampler_outputs = []
for i, seq_group in enumerate(model_input.seq_groups):
seq_ids = seq_group
assert len(seq_ids) == 1
seq_id = seq_ids[0]
seq_outputs = []
for j in range(model_input.n[i]):
next_token_id = next_token_ids[i][j]
seq_outputs.append(
SequenceOutput(seq_id, next_token_id,
{next_token_id: zero_logprob}))
sampler_outputs.append(
CompletionSequenceGroupOutput(seq_outputs, None))
return [SamplerOutput(sampler_outputs)]
else:
token_ids = model_input.token_ids.to(self.device)
position_ids = model_input.position_ids.to(self.device)
attn_metadata = model_input.attn_metadata
attn_metadata.slot_mapping = attn_metadata.slot_mapping.to(
self.device)
attn_metadata.block_tables = attn_metadata.block_tables.to(
self.device)
attn_metadata.context_lens = attn_metadata.context_lens.to(
self.device)
t = model_input.t.to(self.device)
p = model_input.p.to(self.device)
input_lens = model_input.input_lens.to(self.device)
for i in range(num_steps):
slot_mapping = attn_metadata.slot_mapping
with set_forward_context(model_input.attn_metadata,
self.vllm_config,
model_input.virtual_engine):
output_token_ids = self.model(token_ids, position_ids,
input_lens, t, p,
model_input.num_samples,
kv_caches)
self.cached_step_outputs.append(output_token_ids)
if i < num_steps - 1:
# Prepare the inputs for the next step.
token_ids = output_token_ids.unsqueeze(dim=1).int()
position_ids = position_ids + 1
attn_metadata.context_lens = attn_metadata.context_lens + 1
block_tables = attn_metadata.block_tables
block_number = block_tables.gather(
1,
position_ids.long() // self.block_size)
block_offset = position_ids % self.block_size
is_padding = slot_mapping == _PAD_SLOT_ID
slot_mapping = block_number * self.block_size + block_offset
slot_mapping = slot_mapping.long()
slot_mapping = torch.where(is_padding, _PAD_SLOT_ID,
slot_mapping)
attn_metadata.slot_mapping = slot_mapping
if model_input.async_callback is not None:
model_input.async_callback()
if num_steps > 1:
return []
# Retrieve the outputs to CPU.
next_token_ids = self.cached_step_outputs.pop(0)
next_token_ids = next_token_ids.cpu().tolist()
sampler_output = _make_decode_output(next_token_ids,
model_input.seq_groups)
return [sampler_output]
class ModelWrapper(nn.Module):
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
def forward(
self,
token_ids: torch.Tensor,
position_ids: torch.Tensor,
input_lens: torch.Tensor,
t: torch.Tensor,
p: torch.Tensor,
num_samples: int,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token.
Args:
token_ids: The input token IDs of shape [batch_size, seq_len].
position_ids: The input position IDs of shape [batch_size, seq_len].
input_lens: The actual input lengths of shape [batch_size].
t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size].
num_samples: Number of samples to draw from each logits vector.
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
"""
batch_size, seq_len = token_ids.shape
# Calculate the positions to sample from.
start_indices = torch.arange(
batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
logits_indices = start_indices + input_lens - 1
attn_metadata = get_forward_context().attn_metadata
# FIXME(woosuk): This is a temporary hack to avoid using the existing
# sampler and sampling metadata.
sampling_metadata = SamplingMetadata(
seq_groups=[],
selected_token_indices=logits_indices,
categorized_sample_indices={},
num_prompts=attn_metadata.num_prefills,
)
# Skip this in memory profiling at initialization.
if kv_caches[0][0].numel() > 0:
# index_copy_(slot_mapping) only works when the inserted dimension
# is 0. However, the KV cache in the Pallas backend has the shape
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
# work, we need to flatten the first three dimensions and modify
# the slot_mapping accordingly.
num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
slot_mapping = attn_metadata.slot_mapping
slot_mapping = slot_mapping.flatten()
head_indices = torch.arange(0,
num_kv_heads,
device=slot_mapping.device,
dtype=slot_mapping.dtype)
head_indices *= block_size * num_blocks
slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
-1, num_kv_heads)
slot_mapping = slot_mapping + head_indices.view(1, -1)
slot_mapping = slot_mapping.flatten()
attn_metadata.slot_mapping = slot_mapping
hidden_states = self.model(token_ids, position_ids)
hidden_states = hidden_states.flatten(0, 1)
logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Argmax sampling.
argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
argmax_token_ids = argmax_token_ids.repeat(1, num_samples)
# Zero temperature means greedy decoding. Avoid division by zero.
nonzero_t = torch.where(t != 0, t, 1.0)
logits = logits / nonzero_t.unsqueeze(dim=1)
if _ENABLE_TOP_P:
logits = _apply_top_p(logits, p.unsqueeze(dim=1))
# Random sampling.
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
sampled_token_ids = torch.multinomial(probs,
num_samples,
replacement=True)
if num_samples == 1:
argmax_token_ids = argmax_token_ids.squeeze(dim=-1)
sampled_token_ids = sampled_token_ids.squeeze(dim=-1)
next_token_ids = torch.where(t != 0, sampled_token_ids,
argmax_token_ids)
return next_token_ids
def _get_padded_prefill_len(x: int) -> int:
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
# length to be a multiple of 16. We pad the prompt length to the nearest
# multiple of 16. This is also good for performance.
if x <= 16:
return 16
return 1 << (x - 1).bit_length()
def _get_padded_batch_size(batch_size: int) -> int:
# The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
# To meet this requirement in the simplest way, we set the minimal batch
# size to 8.
if batch_size <= 8:
return 8
else:
return ((batch_size + 15) // 16) * 16
def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
logits_sorted = torch.sort(logits, dim=-1, descending=True).values
sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1)
cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True)
cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index)
logits = logits.masked_fill_(logits < cutoff_logit, -float("inf"))
return logits
def _make_decode_output(
next_token_ids: List[int],
seq_groups: List[List[int]],
) -> SamplerOutput:
zero_logprob = Logprob(0.0)
sampler_outputs = []
batch_idx = 0
for seq_group in seq_groups:
seq_ids = seq_group
seq_outputs = []
for seq_id in seq_ids:
next_token_id = next_token_ids[batch_idx]
seq_outputs.append(
SequenceOutput(seq_id, next_token_id,
{next_token_id: zero_logprob}))
batch_idx += 1
sampler_outputs.append(CompletionSequenceGroupOutput(
seq_outputs, None))
return SamplerOutput(sampler_outputs)

337
vllm/worker/tpu_worker.py Normal file
View File

@ -0,0 +1,337 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import List, Optional, Tuple, Union
import torch
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.runtime as xr
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache, get_dtype_size
from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoRANotSupportedWorkerBase, WorkerBase,
WorkerInput)
logger = init_logger(__name__)
class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool,
) -> None:
WorkerBase.__init__(self, vllm_config=vllm_config)
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
assert self.device_config.device_type == "tpu"
if self.cache_config.cache_dtype == "auto":
self.cache_dtype = self.model_config.dtype
else:
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype]
self.model_runner: TPUModelRunner = TPUModelRunner(
vllm_config=vllm_config, is_driver_worker=is_driver_worker)
if self.model_config.seed is None:
self.model_config.seed = 0
if vllm_config.lora_config is not None:
raise NotImplementedError(
"The V0 TPU backend doesn't support LoRA serving")
def init_device(self) -> None:
os.environ["PJRT_DEVICE"] = "TPU"
torch.set_grad_enabled(False)
torch.set_default_dtype(self.model_config.dtype)
# NOTE(woosuk): This is just to initialize the TP group and broadcast
# the input objects on CPU. The all-reduce and all-gather ops on TPU
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
# own context.
init_distributed_environment(
world_size=self.parallel_config.world_size,
rank=self.rank,
local_rank=self.local_rank,
distributed_init_method=self.distributed_init_method,
backend="gloo",
)
ensure_model_parallel_initialized(
self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size)
# Device initialization should happen after initializing the distributed
# runtime.
self.device = xm.xla_device()
self.device_config.device = self.device
# Set random seed.
set_random_seed(self.model_config.seed)
xm.set_rng_state(self.model_config.seed, self.device)
# Increase the cache size limit, which is the maximum number of
# dynamo graphs that can be compiled.
# NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and
# 30-40 graphs for decode. 128 is an arbitrary safe number.
torch._dynamo.config.cache_size_limit = 128
# Use persistent cache to avoid XLA recompilation.
# NOTE(woosuk): Set per-rank cache path since different ranks
# can have slightly different XLA graphs.
world_size = self.parallel_config.world_size
rank = xr.global_ordinal()
# The PyTorch/XLA compilation cache uses the Torch IR to generate keys.
# Consequently, changes in optimization flags, which affect compilation
# results, don't change the cache key. This can result in the wrong
# compilation being used. To prevent this, disabling the XLA compilation
# cache during development is recommended.We can disable it by
# `export VLLM_XLA_CACHE_PATH=`
if envs.VLLM_XLA_CACHE_PATH:
per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH,
f"tp{world_size}_rank{rank}")
xr.initialize_cache(per_rank_path, readonly=False)
self.profiler = None
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
# For TPU, we can only have 1 active profiler session for 1 profiler
# server. So we only profile on rank0.
self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
self.profile_dir)
self.profiler = xp.start_server(9012)
def start_profile(self):
if self.rank < 1:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
xp.start_trace(self.profile_dir)
def stop_profile(self):
if self.rank < 1:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
xp.stop_trace()
def load_model(self):
self.model_runner.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
num_layers = self.model_config.get_num_layers(self.parallel_config)
head_size = self.model_config.get_head_size()
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches = [(torch.tensor([], dtype=torch.float32,
device=self.device),
torch.tensor([], dtype=torch.float32,
device=self.device))
for _ in range(num_layers)]
bind_kv_cache(self.compilation_config.static_forward_context,
[kv_caches])
self.model_runner._dummy_run(
batch_size=1,
seq_len=self.scheduler_config.max_num_batched_tokens,
kv_caches=kv_caches,
exec_mode=ExecutionMode.PREFILL,
)
# Synchronize before measuring the memory usage.
xm.wait_device_ops()
# Get the maximum amount of memory used by the model weights and
# intermediate activations.
m = xm.get_memory_info(self.device)
total_memory_size = m["bytes_limit"]
profiled = m["peak_bytes_used"] # Weights + intermediate activations.
# Calculate the TPU KV cache size based on profiling.
usable_memory_size = int(total_memory_size *
self.cache_config.gpu_memory_utilization)
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
dtype_bytes = get_dtype_size(self.cache_dtype)
block_size_bytes = (dtype_bytes * self.cache_config.block_size *
num_layers * 2 * head_size * num_kv_heads)
num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes
num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8.
# Calculate the CPU KV cache size based on the config.
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
block_size_bytes)
num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8.
return num_tpu_blocks, num_cpu_blocks
def initialize_cache(
self,
num_gpu_blocks: int,
num_cpu_blocks: int,
) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
self.block_size = self.cache_config.block_size
dtype = self.cache_dtype
num_layers = self.model_config.get_num_layers(self.parallel_config)
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
head_size = self.model_config.get_head_size()
self.cpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
self.tpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
num_gpu_blocks, self.block_size, num_kv_heads, head_size)
cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
num_cpu_blocks, self.block_size, num_kv_heads, head_size)
for _ in range(num_layers):
tpu_k_cache = torch.zeros(tpu_cache_shape,
dtype=dtype,
device=self.device)
tpu_v_cache = torch.zeros_like(tpu_k_cache)
self.tpu_cache.append((tpu_k_cache, tpu_v_cache))
cpu_k_cache = torch.zeros(cpu_cache_shape,
dtype=dtype,
device="cpu")
cpu_v_cache = torch.zeros_like(cpu_k_cache)
self.cpu_cache.append((cpu_k_cache, cpu_v_cache))
bind_kv_cache(self.compilation_config.static_forward_context,
[self.tpu_cache])
self._warmup_model()
def _warmup_model(self) -> None:
# FIXME(woosuk): Here we are abusing `enforce_eager` which is defined
# for CUDA graphs. We should refactor this part.
if not self.model_config.enforce_eager:
# Warm up the model with all possible input shapes so that
# compilation never happens during the actual execution.
# This may take ~30 mins for the first run and ~20 mins for the
# subsequent runs.
# If `enforce_eager` is True, the ahead-of-time compilation is
# skipped and the compilation happens during the actual execution,
# which is bad for performance but useful for development.
self.model_runner.warmup_model(self.tpu_cache)
def get_cache_block_size_bytes(self) -> int:
head_size = self.model_config.get_head_size()
num_heads = self.model_config.get_num_kv_heads(self.parallel_config)
num_layers = self.model_config.get_num_layers(self.parallel_config)
key_cache_block = self.cache_config.block_size * num_heads * head_size
value_cache_block = key_cache_block
total = num_layers * (key_cache_block + value_cache_block)
dtype_size = get_dtype_size(self.cache_dtype)
return dtype_size * total
@property
def do_metadata_broadcast(self) -> bool:
return self.parallel_config.tensor_parallel_size > 1
@property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
# NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline
# parallelism.
return [self.tpu_cache]
def prepare_worker_input(
self,
execute_model_req: ExecuteModelRequest,
) -> WorkerInput:
virtual_engine = execute_model_req.virtual_engine
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
blocks_to_swap_in = _make_src_to_dst(
execute_model_req.blocks_to_swap_in, "cpu", self.device)
blocks_to_swap_out = _make_src_to_dst(
execute_model_req.blocks_to_swap_out, self.device, "cpu")
blocks_to_copy = _make_src_to_dst(execute_model_req.blocks_to_copy,
self.device, self.device)
return WorkerInput(
num_seq_groups=num_seq_groups,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
)
def execute_worker(self, worker_input: WorkerInput) -> None:
virtual_engine = worker_input.virtual_engine
assert virtual_engine == 0
attn_backend = self.model_runner.attn_backend
num_layers = self.model_config.get_num_layers(self.parallel_config)
# Issue cache operations.
if worker_input.blocks_to_swap_in is not None:
src_indices, dst_indices = worker_input.blocks_to_swap_in
if src_indices.numel() > 0:
# Swap from CPU to TPU.
for i in range(num_layers):
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
k = cpu_k_cache[:, src_indices].to(self.device)
v = cpu_v_cache[:, src_indices].to(self.device)
_insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache)
if worker_input.blocks_to_swap_out is not None:
src_indices, dst_indices = worker_input.blocks_to_swap_out
if src_indices.numel() > 0:
# Swap from TPU to CPU.
for i in range(num_layers):
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices]
cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices]
if worker_input.blocks_to_copy is not None:
src_indices, dst_indices = worker_input.blocks_to_copy
if src_indices.numel() > 0:
attn_backend.copy_blocks(self.tpu_cache,
(src_indices, dst_indices))
def _make_src_to_dst(
mapping: List[Tuple[int, int]],
src_device: Union[torch.device, str],
dst_device: Union[torch.device, str],
) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
if not mapping:
return None
src_indices = [i for i, _ in mapping]
dst_indices = [i for _, i in mapping]
src_indices = torch.tensor(src_indices,
device=src_device,
dtype=torch.int64)
dst_indices = torch.tensor(dst_indices,
device=dst_device,
dtype=torch.int64)
return src_indices, dst_indices
@torch.compile(backend="openxla")
def _insert_kv(
k: torch.Tensor,
v: torch.Tensor,
indices: torch.Tensor,
tpu_k_cache: torch.Tensor,
tpu_v_cache: torch.Tensor,
) -> None:
torch.ops.xla.dynamo_set_buffer_donor_(tpu_k_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(tpu_v_cache, True)
tpu_k_cache[:, indices] = k
tpu_v_cache[:, indices] = v

View File

@ -0,0 +1,606 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import time
import weakref
from collections import defaultdict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Type, TypeVar)
import torch
import torch.nn as nn
from vllm.attention import get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadataCache
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs, MultiModalPlaceholderMap,
MultiModalRegistry)
from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import DeviceMemoryProfiler, GiB_bytes, make_tensor_with_pad
from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__)
_PAD_SLOT_ID = -1
TModelInputForXPU = TypeVar('TModelInputForXPU', bound="ModelInputForXPU")
@dataclass(frozen=True)
class ModelInputForXPU(ModelRunnerInputBase):
"""
Used by the NeuronModelRunner.
"""
input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
attn_metadata: Optional["AttentionMetadata"] = None
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
virtual_engine: Optional[int] = None
seq_lens: Optional[List[int]] = None
query_lens: Optional[List[int]] = None
async_callback: Optional[Callable] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls: Type[TModelInputForXPU],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> TModelInputForXPU:
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
@dataclass(frozen=True)
class ModelInputForXPUWithSamplingMetadata(ModelInputForXPU):
"""
Used by the ModelRunner.
"""
sampling_metadata: Optional["SamplingMetadata"] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "ModelInputForXPUWithSamplingMetadata":
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
def __init__(self,
runner: "XPUModelRunner",
finished_requests_ids: Optional[List[str]] = None) -> None:
super().__init__()
self.runner = runner
self.model_input_cls = self.runner._model_input_cls
self.attn_backend = self.runner.attn_backend
self.sliding_window = self.runner.sliding_window
self.block_size = self.runner.block_size
self.device = self.runner.device
def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
self.seq_group_metadata_list.append(seq_group_metadata)
def build(self) -> ModelInputForXPU:
is_prompt = self.seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs) = self._prepare_prompt(
self.seq_group_metadata_list)
else:
(input_tokens, input_positions,
attn_metadata) = self._prepare_decode(
self.seq_group_metadata_list)
seq_lens = None
multi_modal_kwargs = None
return self.model_input_cls(
input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
multi_modal_kwargs=multi_modal_kwargs,
seq_lens=seq_lens,
query_lens=seq_lens,
)
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
BatchedTensorInputs]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
seq_lens: List[int] = []
multi_modal_kwargs_list: List[MultiModalKwargs] = []
multi_modal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
assert len(seq_ids) == 1
seq_id = seq_ids[0]
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids()
computed_len = seq_data.get_num_computed_tokens()
seq_len = len(prompt_tokens)
seq_lens.append(seq_len) # Prompt token num
input_tokens.extend(prompt_tokens) # Token ids
# Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
positions_range = range(computed_len, seq_len)
input_positions.extend(list(positions_range))
if seq_group_metadata.multi_modal_data:
# NOTE: mm_kwargs only includes the subset of multi-modal items
# that intersect with the current prefill positions.
mm_kwargs, placeholder_maps = MultiModalPlaceholderMap \
.from_seq_group(seq_group_metadata, positions_range)
multi_modal_kwargs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items():
multi_modal_placeholder_maps[modality].extend(
placeholder_map)
if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
continue
# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, seq_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0
if self.sliding_window is not None:
start_idx = max(0, seq_len - self.sliding_window)
for i in range(computed_len, seq_len):
if i < start_idx:
slot_mapping.append(_PAD_SLOT_ID)
continue
block_number = block_table[i //
self.block_size] # type: ignore
block_offset = i % self.block_size # type: ignore
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
num_prompt_tokens = len(input_tokens)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device) # type: ignore
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device) # type: ignore
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device) # type: ignore
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
multi_modal_placeholder_maps.items()
}
max_seqlen = max(seq_lens)
tmp = [0]
tmp.extend(seq_lens)
seqlen = torch.tensor(tmp)
seqlen_q = torch.cumsum(seqlen, dim=0).to(device=self.device)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
seq_lens=seq_lens,
seqlen_q=seqlen_q,
max_seqlen=max_seqlen,
seq_lens_tensor=torch.tensor([]),
max_decode_seq_len=0,
num_prefills=len(seq_lens),
num_prefill_tokens=num_prompt_tokens,
num_decode_tokens=0,
block_tables=torch.tensor([], device=self.device, dtype=torch.int),
)
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs)
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
seq_lens: List[int] = []
block_tables: List[List[int]] = []
for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
assert seq_group_metadata.token_chunk_size == 1
seq_ids = list(seq_group_metadata.seq_data.keys())
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append(generation_token)
seq_len = seq_data.get_len()
position = seq_len - 1
input_positions.append(position)
seq_len = seq_len if self.sliding_window is None else min(
seq_len, self.sliding_window)
seq_lens.append(seq_len)
block_table = seq_group_metadata.block_tables[seq_id]
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
if self.sliding_window is not None:
sliding_window_blocks = (self.sliding_window //
self.block_size)
block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table)
max_decode_seq_len = max(seq_lens)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)
block_tables = make_tensor_with_pad(
block_tables,
pad=0,
dtype=torch.int,
device=self.device,
)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
seq_lens=seq_lens,
seqlen_q=torch.tensor([]),
max_seqlen=0,
seq_lens_tensor=seq_lens_tensor,
max_decode_seq_len=max_decode_seq_len,
num_prefill_tokens=0,
num_decode_tokens=len(input_tokens),
num_prefills=0,
block_tables=block_tables,
)
return (
input_tokens,
input_positions,
attn_metadata,
)
class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
_model_input_cls: Type[ModelInputForXPUWithSamplingMetadata] = (
ModelInputForXPUWithSamplingMetadata)
_builder_cls: Type[ModelInputForXPUBuilder] = ModelInputForXPUBuilder
def __init__(
self,
vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
return_hidden_states: bool = False,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
model_config = self.model_config
cache_config = self.cache_config
self.is_driver_worker = is_driver_worker
self.return_hidden_states = return_hidden_states
self.device = self.device_config.device
self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
)
# Multi-modal data support
self.input_registry = input_registry
self.mm_registry = mm_registry
# Lazy initialization.
self.model: nn.Module # Set after init_Model
self.sampler = get_sampler()
self.sampling_metadata_cache: SamplingMetadataCache = \
SamplingMetadataCache() \
if self.parallel_config.pipeline_parallel_size == 1 else None
self.builder = self._builder_cls(weakref.proxy(self))
def load_model(self) -> None:
with DeviceMemoryProfiler() as m:
self.model = get_model(vllm_config=self.vllm_config)
self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GiB",
self.model_memory_usage / GiB_bytes)
def get_model(self) -> nn.Module:
return self.model
@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()
@torch.inference_mode()
def profile_run(self) -> None:
# Enable top-k sampling to reflect the accurate memory usage.
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = []
# Additional GPU memory may be needed for multi-modal encoding, which
# needs to be accounted for when calculating the GPU blocks for
# vLLM blocker manager.
# To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number
# of images processed.
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
self.model_config)
if max_mm_tokens > 0:
max_num_seqs_orig = max_num_seqs
max_num_seqs = min(max_num_seqs,
max_num_batched_tokens // max_mm_tokens)
if max_num_seqs < 1:
expr = (f"min({max_num_seqs_orig}, "
f"{max_num_batched_tokens} // {max_mm_tokens})")
logger.warning(
"Computed max_num_seqs (%s) to be less than 1. "
"Setting it to the minimum value of 1.", expr)
max_num_seqs = 1
batch_size = 0
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len
dummy_data = self.input_registry \
.dummy_data_for_profiling(self.model_config,
seq_len,
self.mm_registry)
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: dummy_data.seq_data},
sampling_params=sampling_params,
block_tables=None,
lora_request=None,
multi_modal_data=dummy_data.multi_modal_data,
multi_modal_placeholders=dummy_data.multi_modal_placeholders)
seqs.append(seq)
finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids)
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=batch_size,
dtype=self.model_config.dtype,
device=self.device)
self.execute_model(model_input, None, intermediate_tensors)
torch.xpu.synchronize()
return
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str,
Any]) -> ModelInputForXPUWithSamplingMetadata:
return (
ModelInputForXPUWithSamplingMetadata.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
))
def _prepare_model_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForXPUWithSamplingMetadata:
"""Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not
metadata for possible additional steps, e.g., sampling.
"""
builder = self.builder
builder.prepare(finished_requests_ids)
for seq_group_metadata in seq_group_metadata_list:
builder.add_seq_group(seq_group_metadata)
return builder.build() # type: ignore
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForXPUWithSamplingMetadata:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
"""
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
# Sampling metadata is only required for the final pp group
generators = self.get_generators(finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
model_input.seq_lens,
model_input.query_lens,
self.device,
pin_memory=False,
generators=generators,
cache=self.sampling_metadata_cache)
return dataclasses.replace(model_input,
sampling_metadata=sampling_metadata,
virtual_engine=virtual_engine)
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForXPUWithSamplingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
raise ValueError(
"XPUModelRunner does not support multi-step execution.")
model_executable = self.model
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_start_time = time.time()
with set_forward_context(model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
device=self.device,
),
)
# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
return hidden_or_intermediate_states
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end_time = time.time()
# Compute the logits.
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
return []
if model_input.async_callback is not None:
model_input.async_callback()
# Sample the next token.
output: SamplerOutput = self.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time
and output is not None):
model_forward_time = (model_forward_end_time -
model_forward_start_time)
# If there are multiple workers, we are still tracking the latency
# from the start time of the driver worker to the end time of the
# driver worker. The model forward time will then end up covering
# the communication time as well.
output.model_forward_time = model_forward_time
return [output]

186
vllm/worker/xpu_worker.py Normal file
View File

@ -0,0 +1,186 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A XPU worker class."""
import gc
import os
from typing import List, Optional, Tuple
import intel_extension_for_pytorch # noqa: F401
import oneccl_bindings_for_pytorch # noqa: F401
import torch
import torch.distributed
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.worker import Worker
from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
from vllm.worker.xpu_model_runner import XPUModelRunner
logger = init_logger(__name__)
class XPUWorker(LoRANotSupportedWorkerBase, Worker):
"""A worker class that executes (a partition of) the model on a GPU.
Each worker is associated with a single XPU device. The worker is
responsible for maintaining the KV cache and executing the model on the
XPU. In case of distributed inference, each worker is assigned a partition
of the model.
"""
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
) -> None:
WorkerBase.__init__(self, vllm_config=vllm_config)
device_config = self.device_config
parallel_config = self.parallel_config
assert device_config.device_type == "xpu"
assert current_platform.is_xpu()
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
if parallel_config and is_driver_worker:
assert rank % parallel_config.tensor_parallel_size == 0, \
"Driver worker should be rank 0 of tensor parallel group."
self.model_runner = XPUModelRunner( # type: ignore
vllm_config=vllm_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: List[CacheEngine]
self.gpu_cache: Optional[List[List[torch.Tensor]]]
def init_device(self) -> None:
if self.device_config.device.type == "xpu" and current_platform.is_xpu(
):
self.device = torch.device(f"xpu:{self.local_rank}")
torch.xpu.set_device(self.device)
torch.xpu.empty_cache()
self.init_gpu_memory = torch.xpu.get_device_properties(
self.local_rank).total_memory
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
self.init_worker_distributed_environment()
# Initialize the model.
set_random_seed(self.model_config.seed)
# keep this method for `empty_cache` and `synchronize` api
@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
Tip:
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.xpu.empty_cache()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.xpu.synchronize()
used_memory = torch.xpu.memory_allocated()
total_gpu_memory = torch.xpu.get_device_properties(
self.local_rank).total_memory
free_gpu_memory = total_gpu_memory - used_memory
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
peak_memory = self.init_gpu_memory - free_gpu_memory
assert peak_memory > 0, (
"Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory}, current free memory"
f" {free_gpu_memory}. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
cache_block_size = self.get_cache_block_size_bytes()
num_gpu_blocks = int(
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory) // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
gc.collect()
torch.xpu.empty_cache()
return num_gpu_blocks, num_cpu_blocks
def _warm_up_model(self) -> None:
# IPEX don't support capture graph yet
pass
def init_worker_distributed_environment(self) -> None:
"""Initialize the distributed environment."""
parallel_config = self.parallel_config
rank = self.rank
distributed_init_method = self.distributed_init_method
if torch.distributed.is_initialized():
torch_world_size = torch.distributed.get_world_size()
if torch_world_size != parallel_config.world_size:
raise RuntimeError(
"torch.distributed is already initialized but the torch "
"world size does not match parallel_config.world_size "
f"({torch_world_size} vs. {parallel_config.world_size}).")
elif not distributed_init_method:
raise ValueError(
"distributed_init_method must be set if torch.distributed "
"is not already initialized")
else:
# use sockets as default Level zero IPC exchange backend. By
# default oneccl will use `drmfd` as mechanism which need extra
# dependency (libdrm and drm headers) on your system.
ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi")
ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE",
str(parallel_config.world_size))
os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT
os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
os.environ["LOCAL_RANK"] = str(self.local_rank)
init_distributed_environment(
world_size=parallel_config.world_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=self.local_rank,
backend="ccl")
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
# global all_reduce needed for overall oneccl warm up
torch.distributed.all_reduce(torch.zeros(1).xpu())
if parallel_config.pipeline_parallel_size > 1:
# Add pp group init to avoid
# p2p communication as the first call
get_pp_group().all_reduce(torch.zeros(1).xpu())